In Part 1 we derived the mathematical foundations of Mixture of Experts: top-$k$ gating, noisy exploration, and load balancing loss. This installment translates every equation into working PyTorch code. We build four modules from scratch -- Expert, TopKGating, MoELayer, and MoEClassifier -- with no external MoE libraries.
Architecture Overview
The full classifier architecture:
Expert Design
Each expert is deliberately simple -- a two-layer MLP with ReLU activation:
class Expert(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, output_dim),
)
def forward(self, x):
return self.net(x)
Each expert maps $\mathbb{R}^{256} \to \mathbb{R}^{128} \to \mathbb{R}^{256}$ with a bottleneck hidden dimension. With 8 experts, the total expert parameters are $8 \times 65{,}920 = 527{,}360$.
Gate Computation
class TopKGating(nn.Module):
def __init__(self, input_dim, num_experts, top_k=2):
super().__init__()
self.gate = nn.Linear(input_dim, num_experts, bias=False)
self.noise_weights = nn.Parameter(torch.zeros(num_experts))
No bias. The gate uses bias=False. A bias term would create a default preference for certain experts independent of the input, which would fight against the load balancing objective.
Learnable noise. The noise_weights parameter controls how much exploration noise is added to each expert's logit. These are initialized to zero and learned during training.
Noisy Softmax Gating
def forward(self, x):
logits = self.gate(x) # (batch, num_experts)
if self.training:
noise = torch.randn_like(logits)
noise_scale = F.softplus(self.noise_weights)
logits = logits + noise * noise_scale.unsqueeze(0)
gate_probs = F.softmax(logits, dim=-1)
top_k_values, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
top_k_values = top_k_values / (top_k_values.sum(dim=-1, keepdim=True) + 1e-8)
return top_k_values, top_k_indices, gate_probs
softplus ($\log(1 + e^x)$) ensures the noise scale is always positive. Noise is only added during training; at inference time, routing is deterministic.
Top-$k$ Selection and Sparse Routing
The torch.topk operation is the core sparsification step. The selected gate values are then renormalized:
Computing the Sparse Output
# Compute all expert outputs
expert_outputs = torch.stack(
[expert(x) for expert in self.experts], dim=1
) # (batch, num_experts, output_dim)
# Gather selected expert outputs
indices_expanded = top_k_indices.unsqueeze(-1).expand(
-1, -1, expert_outputs.size(-1)
)
selected_outputs = torch.gather(expert_outputs, 1, indices_expanded)
# Weighted sum
weights = top_k_values.unsqueeze(-1) # (batch, top_k, 1)
output = (selected_outputs * weights).sum(dim=1)
Implementation note: Here we compute all expert outputs and then select. Production systems (Megablocks, Tutel) only compute selected experts via custom CUDA kernels.
Auxiliary Load Balancing Loss
# f_i: fraction of tokens routed to expert i
expert_mask = torch.zeros(batch_size, self.num_experts, device=x.device)
expert_mask.scatter_add_(1, top_k_indices,
torch.ones_like(top_k_indices, dtype=torch.float))
f = expert_mask.sum(dim=0) / batch_size
# p_i: average gate probability for expert i
p = gate_probs.mean(dim=0)
# Load balancing loss
auxiliary_loss = self.num_experts * (f * p).sum()
scatter_add_ builds a one-hot-like matrix indicating which experts were selected for each token in the batch.
The Complete MoE Classifier
class MoEClassifier(nn.Module):
def __init__(self, input_dim=784, hidden_dim=256,
num_classes=10, num_experts=8, top_k=2):
super().__init__()
self.input_proj = nn.Linear(input_dim, hidden_dim)
self.moe = MoELayer(hidden_dim, hidden_dim,
num_experts=num_experts, top_k=top_k,
expert_hidden_dim=128)
self.output_proj = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
x = F.relu(self.input_proj(x))
x, aux_loss, expert_counts = self.moe(x)
x = F.relu(x)
logits = self.output_proj(x)
return logits, aux_loss, expert_counts
Parameter Accounting
| Component | Parameters | Active per Input |
|---|---|---|
| Input projection ($784 \to 256$) | 200,960 | 200,960 |
| Gating network ($256 \to 8$) | 2,056 | 2,056 |
| 8 Experts ($256 \to 128 \to 256$ each) | 527,360 | 131,840 (2 of 8) |
| Output projection ($256 \to 10$) | 2,570 | 2,570 |
| Total | 732,946 | 337,426 |
Summary
- Expert: 2-layer MLP with bottleneck ($256 \to 128 \to 256$).
- TopKGating: Linear projection + noisy softmax + top-$k$ selection.
- MoELayer: Sparse routing, weighted expert combination, load balancing loss.
- MoEClassifier: 732,946 total params, 337,426 active per input.
Part 3 trains this model on MNIST and compares it against a dense baseline -- accuracy, expert specialization, and parameter efficiency.