Back to MoE Hub

Deconstructing MoE from Scratch

Part 2: PyTorch Implementation

Introduction

In Part 1 we derived the mathematical foundations of Mixture of Experts: top-$k$ gating, noisy exploration, and load balancing loss. This installment translates every equation into working PyTorch code. We build four modules from scratch -- Expert, TopKGating, MoELayer, and MoEClassifier -- with no external MoE libraries.

Architecture Overview

The full classifier architecture:

$$ \text{Linear}(784 \to 256) \to \text{ReLU} \to \text{MoE}(256 \to 256) \to \text{ReLU} \to \text{Linear}(256 \to 10) $$

Expert Design

Each expert is deliberately simple -- a two-layer MLP with ReLU activation:

class Expert(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
        )

    def forward(self, x):
        return self.net(x)

Each expert maps $\mathbb{R}^{256} \to \mathbb{R}^{128} \to \mathbb{R}^{256}$ with a bottleneck hidden dimension. With 8 experts, the total expert parameters are $8 \times 65{,}920 = 527{,}360$. The philosophy of MoE is that many small specialists outperform one large generalist when combined with intelligent routing.

Gate Computation

class TopKGating(nn.Module):
    def __init__(self, input_dim, num_experts, top_k=2):
        super().__init__()
        self.gate = nn.Linear(input_dim, num_experts, bias=False)
        self.noise_weights = nn.Parameter(torch.zeros(num_experts))

No bias. The gate uses bias=False. A bias term would create a default preference for certain experts independent of the input, which would fight against the load balancing objective.

Learnable noise. The noise_weights parameter controls how much exploration noise is added to each expert's logit. These are initialized to zero and learned during training.

Noisy Softmax Gating

def forward(self, x):
    logits = self.gate(x)  # (batch, num_experts)

    if self.training:
        noise = torch.randn_like(logits)
        noise_scale = F.softplus(self.noise_weights)
        logits = logits + noise * noise_scale.unsqueeze(0)

    gate_probs = F.softmax(logits, dim=-1)
    top_k_values, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
    top_k_values = top_k_values / (top_k_values.sum(dim=-1, keepdim=True) + 1e-8)

    return top_k_values, top_k_indices, gate_probs

The softplus function ($\log(1 + e^x)$) ensures the noise scale is always positive. Note that noise is only added during training (self.training check). At inference time, routing is deterministic.

Top-$k$ Selection and Sparse Routing

The torch.topk operation is the core sparsification step. The selected gate values are then renormalized:

$$ \tilde{g}_i = \frac{g_i}{\sum_{j \in \text{TopK}} g_j} $$

Computing the Sparse Output

# Compute all expert outputs
expert_outputs = torch.stack(
    [expert(x) for expert in self.experts], dim=1
)  # (batch, num_experts, output_dim)

# Gather selected expert outputs
indices_expanded = top_k_indices.unsqueeze(-1).expand(
    -1, -1, expert_outputs.size(-1)
)
selected_outputs = torch.gather(expert_outputs, 1, indices_expanded)

# Weighted sum
weights = top_k_values.unsqueeze(-1)  # (batch, top_k, 1)
output = (selected_outputs * weights).sum(dim=1)

Implementation note: In this educational implementation, we compute all expert outputs and then select. In production systems (e.g., Megablocks, Tutel), only the selected experts are actually computed using custom CUDA kernels.

Auxiliary Load Balancing Loss

# f_i: fraction of tokens routed to expert i
expert_mask = torch.zeros(batch_size, self.num_experts, device=x.device)
expert_mask.scatter_add_(1, top_k_indices,
    torch.ones_like(top_k_indices, dtype=torch.float))
f = expert_mask.sum(dim=0) / batch_size

# p_i: average gate probability for expert i
p = gate_probs.mean(dim=0)

# Load balancing loss
auxiliary_loss = self.num_experts * (f * p).sum()

The scatter_add_ operation efficiently builds a one-hot-like matrix indicating which experts were selected for each token in the batch.

The Complete MoE Classifier

class MoEClassifier(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=256,
                 num_classes=10, num_experts=8, top_k=2):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.moe = MoELayer(hidden_dim, hidden_dim,
                            num_experts=num_experts, top_k=top_k,
                            expert_hidden_dim=128)
        self.output_proj = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        x = F.relu(self.input_proj(x))
        x, aux_loss, expert_counts = self.moe(x)
        x = F.relu(x)
        logits = self.output_proj(x)
        return logits, aux_loss, expert_counts

Parameter Accounting

Component Parameters Active per Input
Input projection ($784 \to 256$)200,960200,960
Gating network ($256 \to 8$)2,0562,056
8 Experts ($256 \to 128 \to 256$ each)527,360131,840 (2 of 8)
Output projection ($256 \to 10$)2,5702,570
Total732,946337,426

Summary and Next Steps

In Part 3, we train this model on MNIST and analyze the results: accuracy compared to a dense baseline, expert specialization patterns, and parameter efficiency.