In Part 1 we derived the mathematical foundations of Mixture of Experts: top-$k$ gating, noisy exploration, and load balancing loss. This installment translates every equation into working PyTorch code. We build four modules from scratch -- Expert, TopKGating, MoELayer, and MoEClassifier -- with no external MoE libraries.
Architecture Overview
The full classifier architecture:
Expert Design
Each expert is deliberately simple -- a two-layer MLP with ReLU activation:
class Expert(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, x):
return self.net(x)
Each expert maps $\mathbb{R}^{256} \to \mathbb{R}^{128} \to \mathbb{R}^{256}$ with a bottleneck hidden dimension. With 8 experts, the total expert parameters are $8 \times 65{,}920 = 527{,}360$. The philosophy of MoE is that many small specialists outperform one large generalist when combined with intelligent routing.
Gate Computation
class TopKGating(nn.Module):
def __init__(self, input_dim, num_experts, top_k=2):
super().__init__()
self.gate = nn.Linear(input_dim, num_experts, bias=False)
self.noise_weights = nn.Parameter(torch.zeros(num_experts))
No bias. The gate uses bias=False. A bias term would create a default preference for certain experts independent of the input, which would fight against the load balancing objective.
Learnable noise. The noise_weights parameter controls how much exploration noise is added to each expert's logit. These are initialized to zero and learned during training.
Noisy Softmax Gating
def forward(self, x):
logits = self.gate(x) # (batch, num_experts)
if self.training:
noise = torch.randn_like(logits)
noise_scale = F.softplus(self.noise_weights)
logits = logits + noise * noise_scale.unsqueeze(0)
gate_probs = F.softmax(logits, dim=-1)
top_k_values, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
top_k_values = top_k_values / (top_k_values.sum(dim=-1, keepdim=True) + 1e-8)
return top_k_values, top_k_indices, gate_probs
The softplus function ($\log(1 + e^x)$) ensures the noise scale is always positive. Note that noise is only added during training (self.training check). At inference time, routing is deterministic.
Top-$k$ Selection and Sparse Routing
The torch.topk operation is the core sparsification step. The selected gate values are then renormalized:
Computing the Sparse Output
# Compute all expert outputs
expert_outputs = torch.stack(
[expert(x) for expert in self.experts], dim=1
) # (batch, num_experts, output_dim)
# Gather selected expert outputs
indices_expanded = top_k_indices.unsqueeze(-1).expand(
-1, -1, expert_outputs.size(-1)
)
selected_outputs = torch.gather(expert_outputs, 1, indices_expanded)
# Weighted sum
weights = top_k_values.unsqueeze(-1) # (batch, top_k, 1)
output = (selected_outputs * weights).sum(dim=1)
Implementation note: In this educational implementation, we compute all expert outputs and then select. In production systems (e.g., Megablocks, Tutel), only the selected experts are actually computed using custom CUDA kernels.
Auxiliary Load Balancing Loss
# f_i: fraction of tokens routed to expert i
expert_mask = torch.zeros(batch_size, self.num_experts, device=x.device)
expert_mask.scatter_add_(1, top_k_indices,
torch.ones_like(top_k_indices, dtype=torch.float))
f = expert_mask.sum(dim=0) / batch_size
# p_i: average gate probability for expert i
p = gate_probs.mean(dim=0)
# Load balancing loss
auxiliary_loss = self.num_experts * (f * p).sum()
The scatter_add_ operation efficiently builds a one-hot-like matrix indicating which experts were selected for each token in the batch.
The Complete MoE Classifier
class MoEClassifier(nn.Module):
def __init__(self, input_dim=784, hidden_dim=256,
num_classes=10, num_experts=8, top_k=2):
super().__init__()
self.input_proj = nn.Linear(input_dim, hidden_dim)
self.moe = MoELayer(hidden_dim, hidden_dim,
num_experts=num_experts, top_k=top_k,
expert_hidden_dim=128)
self.output_proj = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
x = F.relu(self.input_proj(x))
x, aux_loss, expert_counts = self.moe(x)
x = F.relu(x)
logits = self.output_proj(x)
return logits, aux_loss, expert_counts
Parameter Accounting
| Component | Parameters | Active per Input |
|---|---|---|
| Input projection ($784 \to 256$) | 200,960 | 200,960 |
| Gating network ($256 \to 8$) | 2,056 | 2,056 |
| 8 Experts ($256 \to 128 \to 256$ each) | 527,360 | 131,840 (2 of 8) |
| Output projection ($256 \to 10$) | 2,570 | 2,570 |
| Total | 732,946 | 337,426 |
Summary and Next Steps
- Expert: 2-layer MLP with bottleneck ($256 \to 128 \to 256$).
- TopKGating: Linear projection + noisy softmax + top-$k$ selection.
- MoELayer: Sparse routing, weighted expert combination, load balancing loss.
- MoEClassifier: Complete MNIST classifier with 732,946 total params and 337,426 active.
In Part 3, we train this model on MNIST and analyze the results: accuracy compared to a dense baseline, expert specialization patterns, and parameter efficiency.