Part 1 made the structural argument: every normalization layer is a particular choice of which axis to reduce over, whether to subtract the mean, and how to apply a learnable affine. Part 2 implements all four — BatchNorm1d, LayerNorm, RMSNorm, GroupNorm — in pure PyTorch. Total code: about $60$ lines.
The differences in code are tiny — RMSNorm differs from LayerNorm by exactly three lines of removal. The differences in behaviour, as Part 3 will demonstrate, are not tiny. Reading the four implementations side by side makes precise what those design decisions actually are.
The Common Shape
Every normalization layer has the same skeleton:
- Compute summary statistics along some axis of the input tensor.
- Subtract the mean (optionally).
- Divide by the square root of the variance plus a small epsilon.
- Apply a learnable affine transformation $\gamma \odot x + \beta$.
The four normalizations differ in two places: which axis the statistics are computed over (batch dimension? feature dimension? groups of features?), and whether the mean subtraction step happens at all (RMSNorm skips it). Everything else is shared.
For all four implementations below, we assume the input has shape $(B, C)$ where $B$ is batch size and $C$ is the number of features. In practice, normalization layers are usually applied to higher-dimensional tensors (e.g., $(B, T, C)$ for sequences or $(B, C, H, W)$ for images), but the core math is the same — only the axes change.
BatchNorm1d — Normalize Across the Batch
BatchNorm computes statistics per feature, averaged across the batch dimension:
class BatchNorm1d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1):
super().__init__()
self.eps = eps
self.momentum = momentum
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
def forward(self, x):
if self.training:
mean = x.mean(dim=0)
var = x.var(dim=0, unbiased=False)
with torch.no_grad():
self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean)
self.running_var.mul_(1 - self.momentum).add_(self.momentum * var)
else:
mean, var = self.running_mean, self.running_var
return self.gamma * (x - mean) / torch.sqrt(var + self.eps) + self.beta
The interesting design decisions are not in the math, which is rote, but in the bookkeeping.
The train/eval split. During training, we use statistics computed from the current mini-batch. During inference (where you may not have a meaningful batch), we use running statistics accumulated during training. The model switches between these via model.train() and model.eval(). Forgetting to call model.eval() at inference time is one of the most common PyTorch bugs — it can produce dramatically different outputs because BatchNorm starts using single-sample statistics on isolated inputs.
The EMA update. self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean) is an exponential moving average update. After many training batches, running_mean converges to roughly the long-run average of mini-batch means, weighted toward recent batches. The default momentum=0.1 means the EMA has an effective horizon of about $1/0.1 = 10$ batches.
Why register_buffer instead of nn.Parameter. The running statistics are not trainable — they are computed from the data, not learned by gradient descent. register_buffer tells PyTorch this is a stateful tensor that should move with .to(device) and be saved in checkpoints, but should not receive gradients and should not appear in parameters(). The optimizer ignores it.
The unbiased=False on variance. By default PyTorch's var() uses Bessel's correction ($N-1$ in the denominator). For batch normalization the original paper uses the biased estimator ($N$ in the denominator). The difference matters for very small batch sizes; for $B \geq 64$ it is negligible. We match the paper.
BatchNorm dominates CNN training on ImageNet-scale data. It struggles on small batches and on architectures with variable-length inputs (sequences, attention), which is why Transformers use LayerNorm instead.
LayerNorm — Normalize Across Features (Per Sample)
LayerNorm flips the reduction axis: statistics are computed per sample, across the feature dimension.
class LayerNorm(nn.Module):
def __init__(self, num_features, eps=1e-5):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
return self.gamma * (x - mean) / torch.sqrt(var + self.eps) + self.beta
Ten lines total. No running statistics, no train/eval split, no batch dependence. The only state is the learnable affine $\gamma$ and $\beta$.
Why no running statistics? Because LayerNorm's statistics are computed from a single sample, the "batch of 1 at inference" problem disappears entirely. The single sample at inference time has exactly the same kind of statistics as a sample during training — just compute them on the fly.
The dim=-1, keepdim=True idiom. Reducing along the last axis with keepdim=True preserves the broadcast-compatible shape. If x is $(B, T, C)$, then mean is $(B, T, 1)$ — broadcasting correctly when we subtract (x - mean) to get back a $(B, T, C)$ tensor. Without keepdim, mean would be $(B, T)$ and the subtraction would broadcast wrong.
Why Transformers love LayerNorm. Transformer training is full of mini-batches with wildly varying sequence lengths and effective batch sizes. With sequence-packed inputs, the "batch" is sometimes a few thousand tokens; with single-example inference, it is one. BatchNorm's batch-dependent statistics would fluctuate enormously. LayerNorm's per-sample statistics are stable regardless.
RMSNorm — LayerNorm Without the Mean
RMSNorm is LayerNorm with the mean subtraction removed and the additive bias dropped:
class RMSNorm(nn.Module):
def __init__(self, num_features, eps=1e-6):
super().__init__()
self.eps = eps
self.gamma = nn.Parameter(torch.ones(num_features))
def forward(self, x):
rms = torch.sqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
return self.gamma * x / rms
Eight lines total. Three structural simplifications relative to LayerNorm:
(1) No mean reduction. LayerNorm computes both the mean and the variance, requiring two reductions over the feature dimension. RMSNorm only needs the mean of $x^2$ (which gives the squared RMS), so one reduction instead of two.
(2) No subtraction step. LayerNorm subtracts the mean before dividing. RMSNorm divides directly, leaving the data's mean intact.
(3) No additive bias. LayerNorm has both $\gamma$ (multiplicative) and $\beta$ (additive) learnable parameters. RMSNorm drops $\beta$ entirely. The bias's role would have been to allow the normalized output to be shifted; RMSNorm pushes that responsibility to subsequent layers (which can usually shift via their own biases).
The cumulative savings: roughly $30\%$ fewer FLOPs per call, $\frac{1}{2}$ the parameter count of the normalization layer itself, and reduced memory bandwidth pressure during the forward pass. These savings are uninteresting on a $4$-layer MLP. They are very interesting on a $80$-layer Transformer applied to trillions of tokens, which is why every modern LLM since Llama has switched to RMSNorm.
The empirical claim from Zhang & Sennrich (2019), confirmed at scale by Llama, Mistral, and Gemma: the mean subtraction makes no measurable difference to downstream model quality. Part 3 verifies this on a controlled benchmark.
GroupNorm — Normalize Within Groups
GroupNorm splits the feature dimension into $G$ groups and normalizes within each group, per sample:
class GroupNorm(nn.Module):
def __init__(self, num_features, num_groups=8, eps=1e-5):
super().__init__()
assert num_features % num_groups == 0
self.num_groups = num_groups
self.eps = eps
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
def forward(self, x):
B, C = x.shape
x = x.view(B, self.num_groups, C // self.num_groups)
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
x_hat = (x - mean) / torch.sqrt(var + self.eps)
return self.gamma * x_hat.view(B, C) + self.beta
The structural trick is the reshape: we view the $C$-dimensional feature vector as $G$ groups of $C/G$ features each, normalize within each group, then reshape back. The math is identical to LayerNorm applied within each of the $G$ groups.
What GroupNorm buys you over LayerNorm. LayerNorm normalizes using statistics from all $C$ features at once. GroupNorm only uses $C/G$ features. This makes the normalization more local: features in different groups don't influence each other's normalization. For images, where nearby channels often represent related features (e.g., edge detectors at different orientations), grouping makes structural sense.
The num_groups choice. Typical values are $G \in \{8, 16, 32\}$. With $C = 64$ and $G = 8$, each group has $8$ features. With $G = 1$, GroupNorm reduces to LayerNorm. With $G = C$ (one feature per group), it becomes InstanceNorm — which has its own use cases in style transfer.
GroupNorm is the standard in modern image diffusion (U-Net architectures of DDPM, Latent Diffusion, Stable Diffusion). Two reasons: image diffusion typically uses very small batch sizes (often $1$–$4$ per GPU), where BatchNorm's batch statistics are too noisy; and within-group normalization respects the channel structure of CNN feature maps better than full LayerNorm.
FLOPs per Call
Approximate FLOP counts for $B = 128$, $C = 64$:
- BatchNorm: ~$65{,}536$ FLOPs (mean + var + scale + shift).
- LayerNorm: ~$65{,}536$ FLOPs (same operations, different axes).
- RMSNorm: ~$32{,}896$ FLOPs (no mean, no subtraction, no additive bias).
- GroupNorm: ~$65{,}536$ FLOPs (essentially LayerNorm applied in pieces).
The FLOP savings of RMSNorm look modest on this $128 \times 64$ example. They become significant at LLM scale. A $70$B-parameter Llama applies RMSNorm twice per Transformer block, across $80$ blocks, for every token, across trillions of training tokens. Saving $30\%$ on each call accumulates into a meaningful fraction of the total compute budget — and an even more meaningful fraction of the memory bandwidth, which is often the actual bottleneck at scale.
The Common Skeleton
Side-by-side, the four implementations look almost identical: same imports, same nn.Module inheritance, same forward-pass structure (compute statistic, subtract, divide, affine). The differences live in three places.
First: which axis the statistics are computed over. dim=0 for BatchNorm (batch axis), dim=-1 for LayerNorm and RMSNorm (feature axis), dim=-1 after reshape for GroupNorm (within groups).
Second: whether the mean subtraction happens. Yes for BatchNorm, LayerNorm, GroupNorm. No for RMSNorm.
Third: whether running statistics are tracked. Only for BatchNorm.
These three switches define all four normalization layers — and indeed all the variants you may encounter (InstanceNorm, WeightNorm, ScaleNorm). The space of useful normalization layers is small. Knowing the skeleton lets you read any new normalization paper in about thirty seconds.
What Part 3 Tests
With all four normalizations implemented, Part 3 runs them head-to-head on a $20$-layer non-residual MLP from identical initialisation. The result is genuinely surprising — no-normalization wins on this toy benchmark, BatchNorm actively hurts, and RMSNorm beats LayerNorm on every measured axis. The "norms are essential for deep networks" wisdom turns out to be regime-dependent.
Full code on GitHub: github.com/soveshmohapatra/Normalization-Layers