Back to MoE Hub

Deconstructing MoE from Scratch

Part 2: PyTorch Implementation

Introduction

In Part 1 we derived the mathematical foundations of Mixture of Experts: top-$k$ gating, noisy exploration, and load balancing loss. This installment translates every equation into working PyTorch code. We build four modules from scratch -- Expert, TopKGating, MoELayer, and MoEClassifier -- with no external MoE libraries.

Architecture Overview

The full classifier architecture:

$$ \text{Linear}(784 \to 256) \to \text{ReLU} \to \text{MoE}(256 \to 256) \to \text{ReLU} \to \text{Linear}(256 \to 10) $$

Expert Design

Each expert is deliberately simple -- a two-layer MLP with ReLU activation:

class Expert(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.net(x)

Each expert maps $\mathbb{R}^{256} \to \mathbb{R}^{128} \to \mathbb{R}^{256}$ with a bottleneck hidden dimension. With 8 experts, the total expert parameters are $8 \times 65{,}920 = 527{,}360$.

Gate Computation

class TopKGating(nn.Module):
    def __init__(self, input_dim, num_experts, top_k=2):
        super().__init__()
        self.gate = nn.Linear(input_dim, num_experts, bias=False)
        self.noise_weights = nn.Parameter(torch.zeros(num_experts))

No bias. The gate uses bias=False. A bias term would create a default preference for certain experts independent of the input, which would fight against the load balancing objective.

Learnable noise. The noise_weights parameter controls how much exploration noise is added to each expert's logit. These are initialized to zero and learned during training.

Noisy Softmax Gating

def forward(self, x):
    logits = self.gate(x)  # (batch, num_experts)

    if self.training:
        noise = torch.randn_like(logits)
        noise_scale = F.softplus(self.noise_weights)
        logits = logits + noise * noise_scale.unsqueeze(0)

    gate_probs = F.softmax(logits, dim=-1)
    top_k_values, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
    top_k_values = top_k_values / (top_k_values.sum(dim=-1, keepdim=True) + 1e-8)

    return top_k_values, top_k_indices, gate_probs

softplus ($\log(1 + e^x)$) ensures the noise scale is always positive. Noise is only added during training; at inference time, routing is deterministic.

Top-$k$ Selection and Sparse Routing

The torch.topk operation is the core sparsification step. The selected gate values are then renormalized:

$$ \tilde{g}_i = \frac{g_i}{\sum_{j \in \text{TopK}} g_j} $$

Computing the Sparse Output

# Compute all expert outputs
expert_outputs = torch.stack(
    [expert(x) for expert in self.experts], dim=1
)  # (batch, num_experts, output_dim)

# Gather selected expert outputs
indices_expanded = top_k_indices.unsqueeze(-1).expand(
    -1, -1, expert_outputs.size(-1)
)
selected_outputs = torch.gather(expert_outputs, 1, indices_expanded)

# Weighted sum
weights = top_k_values.unsqueeze(-1)  # (batch, top_k, 1)
output = (selected_outputs * weights).sum(dim=1)

Implementation note: Here we compute all expert outputs and then select. Production systems (Megablocks, Tutel) only compute selected experts via custom CUDA kernels.

Auxiliary Load Balancing Loss

# f_i: fraction of tokens routed to expert i
expert_mask = torch.zeros(batch_size, self.num_experts, device=x.device)
expert_mask.scatter_add_(1, top_k_indices,
    torch.ones_like(top_k_indices, dtype=torch.float))
f = expert_mask.sum(dim=0) / batch_size

# p_i: average gate probability for expert i
p = gate_probs.mean(dim=0)

# Load balancing loss
auxiliary_loss = self.num_experts * (f * p).sum()

scatter_add_ builds a one-hot-like matrix indicating which experts were selected for each token in the batch.

The Complete MoE Classifier

class MoEClassifier(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=256,
                 num_classes=10, num_experts=8, top_k=2):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.moe = MoELayer(hidden_dim, hidden_dim,
                            num_experts=num_experts, top_k=top_k,
                            expert_hidden_dim=128)
        self.output_proj = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = F.relu(self.input_proj(x))
        x, aux_loss, expert_counts = self.moe(x)
        x = F.relu(x)
        logits = self.output_proj(x)
        return logits, aux_loss, expert_counts

Parameter Accounting

Component Parameters Active per Input
Input projection ($784 \to 256$)200,960200,960
Gating network ($256 \to 8$)2,0562,056
8 Experts ($256 \to 128 \to 256$ each)527,360131,840 (2 of 8)
Output projection ($256 \to 10$)2,5702,570
Total732,946337,426

Summary

Part 3 trains this model on MNIST and compares it against a dense baseline -- accuracy, expert specialization, and parameter efficiency.