Mixture of Experts (MoE) is the architectural principle behind the scaling of GPT-4, Mixtral, and Switch Transformer. Rather than activating every parameter for every input, MoE routes each token to a small subset of specialized sub-networks (experts), increasing model capacity without proportionally increasing compute.
This first installment derives the mathematical foundations: gating functions, top-$k$ routing, and the load balancing loss that prevents expert collapse.
The Scaling Problem
Dense neural networks have a fundamental coupling between model capacity and computational cost. If a model has $P$ parameters, then processing a single input requires $\mathcal{O}(P)$ floating-point operations. Doubling the model's knowledge requires doubling the compute.
This coupling becomes impractical at scale. A dense transformer with 1.8 trillion parameters would require proportionally enormous compute for every token -- whether that token is the word "the" or a complex reasoning step.
Conditional computation breaks this coupling: maintain a large number of parameters for capacity, but only activate a small subset of them for each input.
Mixture of Experts implements conditional computation through two components:
- Expert networks: $N$ independent sub-networks, each capable of processing the input.
- Gating network: A lightweight router that decides which experts process each input.
Gating Functions
Let $\mathbf{x} \in \mathbb{R}^d$ be an input vector. The gating network $G$ produces a probability distribution over $N$ experts:
where $\mathbf{W}_g \in \mathbb{R}^{N \times d}$ is the gating weight matrix. Each component $G(\mathbf{x})_i$ represents the probability (or weight) assigned to expert $i$.
The output of the MoE layer is a weighted combination of expert outputs:
where $E_i(\mathbf{x})$ is the output of the $i$-th expert.
The Problem with Dense Gating
If we compute the full weighted sum above, we have gained nothing -- every expert runs on every input, and the compute cost is $N$ times that of a single expert. The entire point of MoE is to make this sum sparse.
Top-$k$ Routing
The fix is top-$k$ routing. Instead of consulting all $N$ experts, we select only the $k$ experts with the highest gate values:
where $\text{TopK}(G(\mathbf{x}), k)$ returns the indices of the $k$ largest components of $G(\mathbf{x})$, and $\tilde{G}$ denotes the renormalized gate values:
Computational Savings
With top-$k$ routing, the compute cost per input scales as $\mathcal{O}(k \cdot P_{\text{expert}})$ rather than $\mathcal{O}(N \cdot P_{\text{expert}})$. For a typical configuration of $N=8$ experts and $k=2$:
The model has $8\times$ the capacity of a single expert but only $2\times$ the compute cost.
Noisy Gating for Exploration
Shazeer et al. (2017) add tunable noise to the gating logits during training:
where $\epsilon \sim \mathcal{N}(0, \mathbf{I})$ and $\mathbf{w}_{\text{noise}} \in \mathbb{R}^N$ are learnable noise scale parameters. Without this noise, the gating network quickly converges to always selecting the same experts, causing the remaining experts to receive no gradient signal and effectively die.
Load Balancing Loss
Even with noisy gating, there is a strong tendency for the model to develop expert collapse: a few experts become preferred, receive more training signal, become even better, and attract even more tokens -- a positive feedback loop.
The load balancing auxiliary loss explicitly counteracts this:
where:
- $f_i = \frac{1}{B} \sum_{j=1}^{B} \mathbb{1}[i \in \text{TopK}(G(\mathbf{x}_j), k)]$ is the fraction of tokens in the batch routed to expert $i$.
- $p_i = \frac{1}{B} \sum_{j=1}^{B} G(\mathbf{x}_j)_i$ is the average gate probability for expert $i$ across the batch.
Why This Loss Works
Consider the ideal case where routing is perfectly uniform. Each expert receives a fraction $f_i = k/N$ of tokens, and the average gate probability is $p_i = 1/N$. Then:
Any deviation from uniform routing increases the loss (by the Cauchy-Schwarz inequality). Multiplying by $N$ ensures the loss scale does not diminish as the number of experts grows.
Total Training Objective
The total loss combines the task-specific loss with the auxiliary load balancing loss:
where $\alpha$ is a hyperparameter (typically $0.01$--$0.1$) controlling the strength of the balancing penalty. Too small and experts collapse; too large and the model sacrifices task performance for routing uniformity.
Summary
Three components define the MoE framework:
- Sparse gating via top-$k$ selection reduces compute from $\mathcal{O}(N)$ to $\mathcal{O}(k)$ expert evaluations.
- Noisy exploration via learnable Gaussian noise prevents premature convergence to a fixed routing pattern.
- Load balancing loss via $\mathcal{L}_{\text{balance}} = N \sum f_i p_i$ prevents expert collapse by penalizing non-uniform routing.
Part 2 implements all of this in pure PyTorch: Expert MLP, TopKGating, MoELayer, and a full MoE classifier for MNIST.