Mixture of Experts (MoE) is the architectural principle behind the scaling of modern large language models including GPT-4, Mixtral, and Switch Transformer. Rather than activating every parameter for every input, MoE routes each token to a small subset of specialized sub-networks (experts), achieving massive model capacity at a fraction of the computational cost.
This first installment derives the mathematical foundations: gating functions, top-$k$ routing, and the load balancing loss that prevents expert collapse.
The Scaling Problem
Dense neural networks have a fundamental coupling between model capacity and computational cost. If a model has $P$ parameters, then processing a single input requires $\mathcal{O}(P)$ floating-point operations. Doubling the model's knowledge requires doubling the compute.
This coupling becomes untenable at the frontier. A dense transformer with 1.8 trillion parameters would require proportionally enormous compute for every single token -- whether that token is a trivial article like "the" or a complex reasoning step.
Conditional computation breaks this coupling. The idea is simple: maintain a large number of parameters for capacity, but only activate a small subset of them for each input.
Mixture of Experts implements conditional computation through two components:
- Expert networks: $N$ independent sub-networks, each capable of processing the input.
- Gating network: A lightweight router that decides which experts process each input.
Gating Functions
Let $\mathbf{x} \in \mathbb{R}^d$ be an input vector. The gating network $G$ produces a probability distribution over $N$ experts:
where $\mathbf{W}_g \in \mathbb{R}^{N \times d}$ is the gating weight matrix. Each component $G(\mathbf{x})_i$ represents the probability (or weight) assigned to expert $i$.
The output of the MoE layer is a weighted combination of expert outputs:
where $E_i(\mathbf{x})$ is the output of the $i$-th expert.
The Problem with Dense Gating
If we compute the full weighted sum above, we have gained nothing -- every expert runs on every input, and the compute cost is $N$ times that of a single expert. The entire point of MoE is to make this sum sparse.
Top-$k$ Routing
The key to computational efficiency is top-$k$ routing. Instead of consulting all $N$ experts, we select only the $k$ experts with the highest gate values:
where $\text{TopK}(G(\mathbf{x}), k)$ returns the indices of the $k$ largest components of $G(\mathbf{x})$, and $\tilde{G}$ denotes the renormalized gate values:
Computational Savings
With top-$k$ routing, the compute cost per input scales as $\mathcal{O}(k \cdot P_{\text{expert}})$ rather than $\mathcal{O}(N \cdot P_{\text{expert}})$. For a typical configuration of $N=8$ experts and $k=2$:
The model has $8\times$ the capacity of a single expert but only $2\times$ the compute cost.
Noisy Gating for Exploration
A critical training trick from Shazeer et al. (2017) is to add tunable noise to the gating logits during training:
where $\epsilon \sim \mathcal{N}(0, \mathbf{I})$ and $\mathbf{w}_{\text{noise}} \in \mathbb{R}^N$ are learnable noise scale parameters. Without this noise, the gating network quickly converges to always selecting the same experts, causing the remaining experts to receive no gradient signal and effectively die.
Load Balancing Loss
Even with noisy gating, there is a strong tendency for the model to develop expert collapse: a few experts become preferred, receive more training signal, become even better, and attract even more tokens -- a positive feedback loop.
The load balancing auxiliary loss explicitly counteracts this:
where:
- $f_i = \frac{1}{B} \sum_{j=1}^{B} \mathbb{1}[i \in \text{TopK}(G(\mathbf{x}_j), k)]$ is the fraction of tokens in the batch routed to expert $i$.
- $p_i = \frac{1}{B} \sum_{j=1}^{B} G(\mathbf{x}_j)_i$ is the average gate probability for expert $i$ across the batch.
Why This Loss Works
Consider the ideal case where routing is perfectly uniform. Each expert receives a fraction $f_i = k/N$ of tokens, and the average gate probability is $p_i = 1/N$. Then:
Any deviation from uniform routing increases the loss (by the Cauchy-Schwarz inequality). Multiplying by $N$ ensures the loss scale does not diminish as the number of experts grows.
Total Training Objective
The total loss combines the task-specific loss with the auxiliary load balancing loss:
where $\alpha$ is a hyperparameter (typically $0.01$--$0.1$) controlling the strength of the balancing penalty. Too small and experts collapse; too large and the model sacrifices task performance for routing uniformity.
Summary and Next Steps
The mathematical framework for MoE rests on three pillars:
- Sparse gating via top-$k$ selection reduces compute from $\mathcal{O}(N)$ to $\mathcal{O}(k)$ expert evaluations.
- Noisy exploration via learnable Gaussian noise prevents premature convergence to a fixed routing pattern.
- Load balancing loss via $\mathcal{L}_{\text{balance}} = N \sum f_i p_i$ prevents expert collapse by penalizing non-uniform routing.
In Part 2, we implement these ideas in pure PyTorch: the Expert MLP, TopKGating with noisy exploration, MoELayer with load balancing, and a full MoE classifier ready for MNIST.