Chapter 19
28 min read
Section 60 of 65

Batch and Layer Normalization

Modern Training Techniques

Why Neural Networks Need Normalization

Training a deep network is a negotiation between layers that keep changing their own inputs. Layer 10 is trying to learn a function of the activations produced by layer 9 — but layer 9 is simultaneously updating its weights, which shifts the distribution layer 10 sees at every training step. In a 24-layer Transformer or a 100-layer ResNet the effect compounds: by the time gradients flow back from the loss, each layer is chasing a moving target.

Without intervention this leads to three very concrete problems that any practitioner has encountered:

  1. Saturating activations. If pre-activations drift to large positive or negative values, sigmoid/tanh flatten out and gradients become numerically zero (σz=σ(z)(1σ(z))0\frac{\partial \sigma}{\partial z} = \sigma(z)(1-\sigma(z)) \approx 0). The network stops learning.
  2. Exploding or vanishing gradients. The Jacobian of a deep network is a product of per-layer Jacobians. If their magnitudes drift above 1, the product explodes; below 1, it vanishes exponentially with depth.
  3. Step-size fragility. The optimal learning rate depends on the scale of activations. If that scale drifts during training, a learning rate that worked at epoch 1 is catastrophically wrong at epoch 10.

Normalization layers fix all three by imposing a contract on the activations that flow between layers: no matter what weights come before, the activation entering the next layer will have a predictable mean and variance. That contract is the single idea behind BatchNorm, LayerNorm, InstanceNorm, GroupNorm, and RMSNorm — they differ only in which axis they normalize over.

Mental model. Think of normalization as inserting a variance-control valve between layers. It standardizes the signal, then offers the next layer a pair of knobs (γ,β\gamma, \beta) to pick whatever distribution it actually wants. The optimizer becomes a thermostat — gradients correct the knobs, not the (already well-conditioned) pre-activations.

Internal Covariate Shift: The Original Motivation

Ioffe and Szegedy coined the term internal covariate shift (ICS) in 2015 to name the phenomenon: as weights update, the distribution of inputs to each internal layer changes. Their paper showed that simply forcing each layer's inputs to be zero-mean, unit-variance dramatically stabilized training and let them use 10× larger learning rates for ImageNet-scale models.

The mathematical reading is subtler. Let h()h^{(\ell)} be the activations at layer \ell and W()W^{(\ell)} the weights above it. The gradient update to W()W^{(\ell)} depends on both h(1)h^{(\ell-1)} (its input) and the local Jacobian. If h(1)h^{(\ell-1)} has a time-varying distribution, SGD is solving a non-stationary optimization problem — each step is optimizing a slightly different loss surface than the previous step.

Santurkar et al. (2018) later showed that BN's benefits extend beyond reducing ICS: it smooths the loss landscape (smaller Lipschitz constants of both loss and gradients), which is the real reason larger step sizes stay stable. Either framing leads to the same practical recipe — standardize before the next linear layer.

The demo below simulates the statistics of four stacked layers over 50 training epochs, with and without BatchNorm. Toggle BN on and watch the layer means and variances stay tightly around (0,1)(0, 1) ; toggle it off and watch the distributions drift apart as depth increases.

Internal Covariate Shift: The Problem Normalization Solves

Epoch: 0 / 50

Without Batch Normalization (ICS Problem)

As training progresses, the distribution of inputs to each layer changes (Internal Covariate Shift). Deeper layers see more unstable distributions, making optimization harder.

Layer Mean Over Training

-4-2024Epoch
Layer 1
Layer 2
Layer 3
Layer 4

Layer Std Dev Over Training

0123Epoch
Green dashed line: target \u03C3 = 1

Training Loss

00.511.52Epoch

Key Insight: What is Internal Covariate Shift?

As network parameters update during training, the distribution of inputs to each layer changes. This forces later layers to continuously adapt to new input distributions, slowing down training. Batch Normalization fixes this by normalizing layer inputs, ensuring each layer receives data with consistent statistics (\u03BC=0, \u03C3=1) regardless of earlier layer changes.

Without BatchNorm:
  • Layer statistics drift during training
  • Deeper layers see more instability
  • Slower, noisier convergence
  • Requires careful initialization
With BatchNorm:
  • Layer statistics stay stable
  • All layers see consistent inputs
  • Faster, smoother convergence
  • Enables higher learning rates

The Core Idea: Standardize, Then Re-Learn the Distribution

Every normalization layer, without exception, follows the same four-step recipe. The differences between BN, LN, IN, GN and RMSNorm are purely about the set of elements over which the statistics are computed.

  1. Choose a reduction axis set A\mathcal{A} — the elements that share a (μ, σ) pair.
  2. Compute the mean: μ=1AiAxi\mu = \frac{1}{|\mathcal{A}|} \sum_{i \in \mathcal{A}} x_i
  3. Compute the variance: σ2=1AiA(xiμ)2\sigma^{2} = \frac{1}{|\mathcal{A}|} \sum_{i \in \mathcal{A}} (x_i - \mu)^{2}
  4. Standardize and re-scale: x^i=xiμσ2+ε\hat{x}_i = \frac{x_i - \mu}{\sqrt{\sigma^{2} + \varepsilon}}, then yi=γix^i+βiy_i = \gamma_i \hat{x}_i + \beta_i.

The ε\varepsilon guards against division by zero for dead features. The learnable affine pair (γ,β)(\gamma, \beta) is the critical escape hatch: it guarantees the network can recover the original distribution if needed (γ=σ,β=μ\gamma = \sigma, \beta = \mu makes the layer the identity), so normalization never reduces expressivity — it only conditions the optimization problem.

The only remaining question is which elements share statistics. Answering it differently gives each normalization variant.

Compactly, every variant in the family computes Norm(x)=γxμAσA2+ε+β\text{Norm}(x) = \gamma \cdot \frac{x - \mu_{\mathcal{A}}}{\sqrt{\sigma^{2}_{\mathcal{A}} + \varepsilon}} + \beta — only the choice of A\mathcal{A} changes.

VariantReduction set 𝒜 for a (N,C,H,W) tensorStats perWhere you meet it
BatchNorm(N, H, W) for each channel CChannelCNNs (ResNet, VGG, EfficientNet)
LayerNorm(C, H, W) for each sample NSampleTransformers, RNNs, LayerNorm-BERT
InstanceNorm(H, W) for each (N, C)Sample × ChannelStyle transfer, GAN generators
GroupNorm(C/G, H, W) for each (N, G)Sample × GroupSmall-batch vision (detection, 3D)
RMSNormLast axis, no mean centeringSample (or token)LLaMA, Gemma, Mistral, Qwen, T5

Batch Normalization: Statistics Along the Batch Axis

BatchNorm was the breakthrough. For a convolutional activation tensor XRN×C×H×WX \in \mathbb{R}^{N \times C \times H \times W}, it computes one mean and one variance per channel, reducing across the batch, height and width axes: μc=1NHWn,h,wxn,c,h,w\mu_c = \frac{1}{N H W}\sum_{n,h,w} x_{n,c,h,w} and σc2=1NHWn,h,w(xn,c,h,wμc)2\sigma_c^{2} = \frac{1}{N H W}\sum_{n,h,w} (x_{n,c,h,w} - \mu_c)^{2}.

Intuitively, every channel is a feature detector; BN normalizes each detector's response across all the locations and samples it sees in the batch, so a single pixel in a single image never dominates the channel's scale.

Why BN needs running statistics

During training, μc,σc\mu_c, \sigma_c are computed from the current batch. At inference time you often process a single example — there's no batch to average over. BN solves this by maintaining an exponential moving average (EMA) of the mini-batch statistics during training, then using the frozen EMA means/variances at inference:

The running mean is updated as μcrun(1m)μcrun+mμcbatch\mu^{\text{run}}_c \leftarrow (1 - m)\,\mu^{\text{run}}_c + m\,\mu^{\text{batch}}_c with momentum m0.1m \approx 0.1. This train/eval discrepancy is the classic source of BatchNorm footguns: fine-tuning with batch size 1 breaks the EMA updates, and shuffling data differently between train and eval produces subtly different stats.

Batch Normalization Step-by-Step

xi (raw activations)

Activation Values (Batch 1)

-2-10123456784.459.420.686.552.223.997.234.48Sample Index

Distribution

-202468

Select Batch to Visualize

Notice how different batches have different means (internal covariate shift). BatchNorm normalizes each batch to have mean=0 and variance=1.

Batch Mean (\u03BC)
4.880
Batch Variance (\u03C3\u00B2)
6.843
Batch Std (\u03C3)
2.616
After Norm Mean
4.880

BatchNorm From Scratch: Code Trace

The math above compresses into one page of NumPy. Below is a class-based implementation that mirrorsnn.BatchNorm2d\texttt{nn.BatchNorm2d} exactly: it holds per-channel learnable parametersγ,β\gamma, \beta, maintains EMA buffers for inference, and exposes a training-mode forward pass. Every line is annotated with actual numerical values for a small (4, 3, 2, 2) input.

BatchNorm2d from scratch — NumPy
🐍batch_norm_from_scratch.py
1import numpy as np

NumPy is the numerical backbone. Every reduction below — axis-tuple means, variances, broadcasting — is implemented in C inside NumPy, so the whole from-scratch BatchNorm runs at native speed without any Python loops.

EXECUTION STATE
📚 numpy = Provides the ndarray, broadcasting, and reductions we need: ndarray.mean(axis=...), ndarray.var(axis=...), np.sqrt, np.ones, np.zeros.
3class BatchNorm2dScratch:

We wrap BatchNorm in a class because it holds STATE across calls — the running mean and running variance are updated every forward pass. nn.BatchNorm2d does the same, we just make every piece visible.

EXECUTION STATE
why a class = Stateful modules can't be pure functions: running_mean / running_var must persist. PyTorch's nn.Module uses the same trick (registered buffers via self.register_buffer).
6def __init__(self, C, eps=1e-5, momentum=0.1)

Constructor. C = number of channels — BatchNorm has one (μ, σ², γ, β) tuple PER CHANNEL, so every parameter vector has shape (C,). eps and momentum mirror PyTorch's nn.BatchNorm2d defaults exactly.

EXECUTION STATE
⬇ arg: C (int) = Number of channels in the (N, C, H, W) input. Example: a ResNet block operating on a 256-channel feature map would pass C=256.
⬇ arg: eps (float, default 1e-5) = Numerical safety in the denominator of 1 / sqrt(var + eps). Matches nn.BatchNorm2d default. Values up to 1e-3 are sometimes used for FP16 stability.
⬇ arg: momentum (float, default 0.1) = PyTorch-style momentum: fraction of the CURRENT batch stat that replaces the running stat. Note this is the OPPOSITE of the classical EMA convention where momentum is on the running side.
7self.gamma = np.ones((C,), dtype=np.float32)

Per-channel learnable scale, initialized to ones so the layer starts as the identity transform after normalization (γ · x_hat = 1 · x_hat = x_hat).

EXECUTION STATE
gamma (initial) = [1.0, 1.0, 1.0] (C=3 in our demo)
why all ones? = At init we want the BN layer to be a no-op so the pre-trained network behavior is preserved; optimization then learns any scale it needs.
8self.beta = np.zeros((C,), dtype=np.float32)

Per-channel learnable shift, initialized to zero so the layer has no bias at init.

EXECUTION STATE
beta (initial) = [0.0, 0.0, 0.0]
9self.running_mean = np.zeros((C,), dtype=np.float32)

EMA-tracked per-channel mean used AT INFERENCE. Initialized to zero. In nn.BatchNorm2d this is a registered buffer — saved with the model, not trained by gradient descent.

EXECUTION STATE
running_mean (initial) = [0.0, 0.0, 0.0]
10self.running_var = np.ones((C,), dtype=np.float32)

EMA-tracked per-channel variance. Initialized to ONES (not zeros!) so eval-mode division doesn't blow up before any batches have been seen: 1 / sqrt(1 + eps) ≈ 1.

EXECUTION STATE
running_var (initial) = [1.0, 1.0, 1.0] — note: 1s, not 0s
why 1s at init? = If running_var started at 0, eval-mode inference on a freshly initialized model would divide by ≈sqrt(eps) ≈ 0.003 and produce huge activations. Initializing to 1 makes the untrained BN an identity.
11self.eps, self.m = eps, momentum

Store hyperparameters on self so forward_train() can read them. Tuple unpack is just a compact way to write two assignments.

EXECUTION STATE
self.eps = 1e-5
self.m (momentum) = 0.1
13def forward_train(self, x):

Training-mode forward pass. Computes stats from the CURRENT batch, updates the EMA, normalizes with BATCH stats (not EMA), and applies the affine. Eval mode would skip the stats computation and divide by the frozen running stats instead — that's the fundamental train/eval difference in BN.

EXECUTION STATE
⬇ arg: x shape (N, C, H, W) = N=4 images, C=3 channels, H=W=2 pixels. 48 floats total.
14# x shape: (N, C, H, W) — image-style conv activations

Comment fixing the tensor layout. BatchNorm2d reduces over axes (0, 2, 3) — batch, height, width — leaving the channel axis untouched. This is what gives BN its per-channel statistics.

15mu = x.mean(axis=(0, 2, 3))

Per-channel mean. For each of the C channels, average the N·H·W = 4·2·2 = 16 values across the entire batch and spatial extent. The result is a length-C vector.

EXECUTION STATE
📚 ndarray.mean(axis=tuple) = Passing a tuple of axes reduces over all of them simultaneously. Equivalent to x.mean(0).mean(-1).mean(-1) but faster and without reshape.
⬇ arg: axis = (0, 2, 3) = Reduce over batch (0), height (2), width (3). Leaves only axis 1 — the channel — so the result has shape (C,)=(3,).
→ example with seed 0 = With np.random.seed(0), x.mean(axis=(0,2,3)) ≈ [0.0465, 0.1215, -0.3021]. These are the 3 channel means over 16 entries each.
⬆ mu (C,) = (3,) = [0.0465, 0.1215, -0.3021]
16var = x.var(axis=(0, 2, 3))

Per-channel variance using NumPy's BIASED estimator (divides by N·H·W, not N·H·W − 1). This matches PyTorch's nn.BatchNorm2d, which also uses the biased estimator during training for the normalization step.

EXECUTION STATE
→ example with seed 0 = x.var(axis=(0,2,3)) ≈ [0.8911, 1.1732, 0.9504]. Each channel has ~16 draws from N(0,1), so variances cluster around 1.
⬆ var (C,) = (3,) = [0.8911, 1.1732, 0.9504]
subtle PyTorch detail = PyTorch uses the UNBIASED variance (N-1 divisor) when computing running_var, but the BIASED variance in the normalization step. Our code uses biased for both, which is the classic academic definition.
17# EMA update — PyTorch uses m on the batch side, (1-m) on running

Comment flagging the easy-to-miss convention flip. Classical EMA is running ← α·running + (1−α)·batch with α near 1. PyTorch writes running ← (1 − m)·running + m·batch, so its momentum=0.1 corresponds to classical α=0.9. Easy source of bugs when porting between frameworks.

18self.running_mean = (1 - self.m) * self.running_mean + self.m * mu

Blend 90% of the old running mean with 10% of the current batch mean. Over thousands of batches this converges to a smoothed estimate of the dataset-wide per-channel mean, which is then used AT INFERENCE to avoid the small-batch problem.

EXECUTION STATE
before = running_mean = [0.0, 0.0, 0.0]
0.9 * running_mean = [0.0, 0.0, 0.0] (still zero on first step)
0.1 * mu = [0.00465, 0.01215, -0.03021]
⬆ running_mean (after) = [0.00465, 0.01215, -0.03021]
19self.running_var = (1 - self.m) * self.running_var + self.m * var

Same EMA for variance. Starts at [1,1,1] so after one step it sits at 0.9·[1,1,1] + 0.1·var_batch, which is a conservative mix — gradually the running_var drifts toward the true dataset variance.

EXECUTION STATE
before = running_var = [1.0, 1.0, 1.0]
0.9 * running_var + 0.1 * var = [0.9 + 0.08911, 0.9 + 0.11732, 0.9 + 0.09504]
⬆ running_var (after) = [0.9891, 1.0173, 0.9950]
20# Normalize with the CURRENT batch stats (not running)

Comment spelling out the defining property of train mode: the normalization uses THIS BATCH'S μ and σ², not the EMA. At eval time the opposite holds — EMA is used, batch stats are ignored.

21x_hat = (x - mu[None, :, None, None]) / np.sqrt(var[None, :, None, None] + self.eps)

The standardization step. The [None, :, None, None] slice reshapes mu from shape (C,) to (1, C, 1, 1) so it broadcasts correctly against x of shape (N, C, H, W). Same for var. After this line, each channel has mean≈0 and variance≈1 across the (batch × spatial) axes.

EXECUTION STATE
📚 x[None, :, None, None] trick = None is an alias for np.newaxis. Adding axes of size 1 lets NumPy stretch the tensor during broadcasting without copying data.
mu[None, :, None, None] shape = (1, 3, 1, 1) — broadcastable against (4, 3, 2, 2). Every (n, c, h, w) entry subtracts mu[c].
sqrt(var + eps) shape = (1, 3, 1, 1). For our numbers: [sqrt(0.8911+1e-5), sqrt(1.1732+1e-5), sqrt(0.9504+1e-5)] ≈ [0.9440, 1.0831, 0.9749].
⬆ x_hat shape = (4, 3, 2, 2) — same as x. Per-channel: mean(x_hat[:, c, :, :]) ≈ 0.0, var(x_hat[:, c, :, :]) ≈ 1.0 by construction.
22return self.gamma[None, :, None, None] * x_hat + self.beta[None, :, None, None]

Apply the learnable affine per channel. With gamma=1, beta=0 at init this is a pass-through; later the optimizer pushes gamma and beta to whatever values make downstream layers happy.

EXECUTION STATE
gamma reshaped = (1, 3, 1, 1) — per-channel scale, broadcast across every (n, h, w) position.
beta reshaped = (1, 3, 1, 1) — per-channel shift.
⬆ return = Same shape (4, 3, 2, 2) as x, standardized per channel and then affine-transformed. Identity at init (gamma=1, beta=0).
24bn = BatchNorm2dScratch(C=3)

Instantiate the BN layer for a 3-channel feature map. Allocates gamma, beta, running_mean, running_var as length-3 vectors.

EXECUTION STATE
bn.gamma = [1.0, 1.0, 1.0]
bn.beta = [0.0, 0.0, 0.0]
bn.running_mean = [0.0, 0.0, 0.0]
bn.running_var = [1.0, 1.0, 1.0]
25x = np.random.randn(4, 3, 2, 2).astype(np.float32)

Sample 48 i.i.d. draws from N(0, 1), shaped as a (batch=4, channels=3, H=2, W=2) tensor. With np.random.seed(0) fixed earlier (conceptually — the notebook seeds the RNG for reproducibility), the first few values are deterministic.

EXECUTION STATE
📚 np.random.randn(*shape) = Standard-normal sampler. Each element is independent N(0,1). Ensures roughly zero mean and unit variance per channel in the large-batch limit.
⬇ arg: shape tuple (4, 3, 2, 2) = N=4 images, C=3 channels, H=W=2 pixels. 48 total floats.
⬇ arg: astype(np.float32) = Cast from float64 (NumPy default) to float32 to match GPU/PyTorch precision. Halves memory.
→ first 4 values (seed 0, channel 0) = x[:, 0, 0, 0] ≈ [1.7641, 1.8676, -0.1034, 0.0500]
26y = bn.forward_train(x)

Run one training-mode forward pass. Inside: compute batch μ and σ² (lines 15–16), EMA-update running stats (lines 18–19), standardize (line 21), apply affine (line 22). The output y has the same shape as x.

EXECUTION STATE
⬆ y.shape = (4, 3, 2, 2) — identical to x.
⬆ side effect = bn.running_mean was updated from [0,0,0] to ≈[0.00465, 0.01215, -0.03021]. bn.running_var from [1,1,1] to ≈[0.9891, 1.0173, 0.9950].
27print("y.shape =", y.shape)

Sanity-check the output shape.

EXECUTION STATE
expected stdout = y.shape = (4, 3, 2, 2)
28print("batch μ[c=0] =", round(float(x[:, 0].mean()), 4))

Print the mean of the input channel 0 across the 16 positions in the batch. This is exactly what mu[0] was computed as on line 15 — we print it to confirm the batch-stats computation is doing what the code claims.

EXECUTION STATE
x[:, 0] = All channel-0 entries: shape (4, 2, 2) = 16 floats.
x[:, 0].mean() = ≈ 0.0465 (matches mu[0])
expected stdout = batch μ[c=0] = 0.0465
29print("running μ =", bn.running_mean)

Print the running mean buffer AFTER the first forward pass. Notice that running_mean = 0.1 * batch_mu ≈ batch_mu / 10 — with momentum=0.1, the running stat barely moves on the first step. It takes ~20 batches before running_mean is within 10% of the true dataset mean.

EXECUTION STATE
expected stdout = running μ = [ 0.00465 0.01215 -0.03021]
→ insight = This slow EMA drift is why BN is so sensitive to batch order at the start of training — the first few batches have disproportionate influence on the running stats.
5 lines without explanation
1import numpy as np
2
3class BatchNorm2dScratch:
4    """Training-mode forward pass that matches nn.BatchNorm2d exactly."""
5
6    def __init__(self, C, eps=1e-5, momentum=0.1):
7        self.gamma = np.ones((C,), dtype=np.float32)
8        self.beta  = np.zeros((C,), dtype=np.float32)
9        self.running_mean = np.zeros((C,), dtype=np.float32)
10        self.running_var  = np.ones((C,),  dtype=np.float32)
11        self.eps, self.m = eps, momentum
12
13    def forward_train(self, x):
14        # x shape: (N, C, H, W) — image-style conv activations
15        mu  = x.mean(axis=(0, 2, 3))                    # (C,)
16        var = x.var(axis=(0, 2, 3))                     # (C,)
17        # EMA update — PyTorch uses m on the batch side, (1-m) on running
18        self.running_mean = (1 - self.m) * self.running_mean + self.m * mu
19        self.running_var  = (1 - self.m) * self.running_var  + self.m * var
20        # Normalize with the CURRENT batch stats (not running)
21        x_hat = (x - mu[None, :, None, None]) / np.sqrt(var[None, :, None, None] + self.eps)
22        return self.gamma[None, :, None, None] * x_hat + self.beta[None, :, None, None]
23
24bn = BatchNorm2dScratch(C=3)
25x  = np.random.randn(4, 3, 2, 2).astype(np.float32)
26y  = bn.forward_train(x)
27print("y.shape      =", y.shape)
28print("batch μ[c=0] =", round(float(x[:, 0].mean()), 4))
29print("running μ    =", bn.running_mean)

The same computation in PyTorch is a single line, but the built-in module hides the train/eval switch that powers BN. The next trace makes that switch explicit: we run the SAME input through the SAME layer twice, once in train\texttt{train} mode and once in eval\texttt{eval} mode, and show the outputs disagree.

BatchNorm2d — PyTorch train vs eval
🐍batch_norm_pytorch.py
1import torch

PyTorch's core tensor library. Provides torch.Tensor and autograd. We'll use torch.randn and torch.allclose below.

EXECUTION STATE
📚 torch = torch.randn(*shape) is the PyTorch analog of np.random.randn. torch.allclose(a, b) checks element-wise closeness with a tolerance.
2import torch.nn as nn

nn contains the stateful layer modules. nn.BatchNorm2d is the GPU-friendly, autograd-aware counterpart of our BatchNorm2dScratch class from the previous trace.

EXECUTION STATE
📚 nn.BatchNorm2d = Wraps the exact math of our scratch class but registers gamma (weight), beta (bias), running_mean, running_var as Parameters / buffers so they move with the model to GPU.
4bn = nn.BatchNorm2d(num_features=3, eps=1e-5, momentum=0.1)

Build the BN layer. Creates four length-C tensors internally: two Parameters (weight=gamma, bias=beta) that autograd will update, and two buffers (running_mean, running_var) that get EMA-updated in-place during training.

EXECUTION STATE
📚 nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) = Full signature. affine=False disables gamma/beta; track_running_stats=False skips EMA. Defaults are affine=True, track_running_stats=True.
⬇ arg: num_features = 3 = Number of channels C. Allocates gamma, beta, running_mean, running_var each of shape (3,).
⬇ arg: eps = 1e-5 = Numerical safety constant in the denominator. Identical to our scratch class default.
⬇ arg: momentum = 0.1 = PyTorch-style momentum: fraction of the batch stat blended into the running stat each step. momentum=None would switch to a simple-average estimator instead.
⬆ result: bn = BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) bn.weight = [1., 1., 1.] (Parameter, requires_grad=True) bn.bias = [0., 0., 0.] (Parameter, requires_grad=True) bn.running_mean = [0., 0., 0.] (buffer) bn.running_var = [1., 1., 1.] (buffer) bn.num_batches_tracked = 0
5x = torch.randn(4, 3, 2, 2)

Sample a random input. torch.randn defaults to float32 on CPU. Shape matches the NumPy demo so the computation is directly comparable.

EXECUTION STATE
📚 torch.randn(*shape) = Draws from N(0, 1). Returns a float32 CPU tensor unless torch.set_default_dtype / device overrides apply.
⬆ x.shape = torch.Size([4, 3, 2, 2])
→ x.mean(dim=(0,2,3)) = Approximately [0, 0, 0] — each channel has 16 draws from N(0,1).
7bn.train()

Switch the module to TRAINING mode. Sets the internal flag `bn.training = True`. Its forward pass will (a) compute μ and σ² from the current batch, (b) EMA-update running_mean / running_var in place, (c) increment num_batches_tracked, and (d) normalize using batch stats.

EXECUTION STATE
📚 Module.train(mode=True) = Sets self.training = True and recursively sets it on all children. Counterpart to .eval(). BatchNorm, Dropout, and LayerNorm-based models all inspect this flag to switch behavior.
side effect = bn.training changes from False (the module default is actually True, but toggling makes it explicit) to True.
8y_train = bn(x)

Calling the module runs its __call__ → forward. Because bn.training=True, forward computes batch μ/σ², updates running stats in-place, and returns the normalized + affine output.

EXECUTION STATE
internal: batch μ (C,) = ≈ [random channel means] — one number per channel, computed over N·H·W = 16 positions.
internal: batch σ² (C,) = ≈ [random channel variances]
side effect: running_mean = 0.9·[0,0,0] + 0.1·batch_μ ≈ batch_μ/10
side effect: running_var = 0.9·[1,1,1] + 0.1·batch_σ² ≈ [0.9 + 0.1·var_c]
⬆ y_train.shape = torch.Size([4, 3, 2, 2]) — per channel, y_train has mean≈0, var≈1.
10bn.eval()

Switch to EVALUATION mode. Sets bn.training = False. The forward pass now IGNORES the current batch's statistics and uses the frozen running_mean / running_var instead. This is essential at inference, where a batch of size 1 has no variance to speak of.

EXECUTION STATE
📚 Module.eval() = Shorthand for .train(False). Recursively sets self.training=False on all submodules. Always call before inference for BN / Dropout layers.
side effect = bn.training = False. Subsequent forward passes skip the EMA update.
11y_eval = bn(x)

Second forward pass, SAME input tensor, but now in eval mode. Uses the frozen running_mean and running_var from the training pass. Because running_var is still dominated by the init value of 1.0 (only one training batch has been processed), eval output is very different from the train output.

EXECUTION STATE
internal: μ used = bn.running_mean ≈ [batch_μ/10] (NOT the current batch's μ)
internal: σ² used = bn.running_var ≈ [0.9 + 0.1·batch_σ²] (NOT the current batch's σ²)
→ per-channel y_eval is NOT unit variance = Because running_var ≠ batch_var, the channel-wise variance of y_eval drifts from 1.0 until enough training batches have stabilized the EMA.
⬆ y_eval.shape = torch.Size([4, 3, 2, 2]) — same shape as y_train, different values.
13print("y_train shape :", y_train.shape)

Shape check — both train and eval produce (4, 3, 2, 2).

EXECUTION STATE
expected stdout = y_train shape : torch.Size([4, 3, 2, 2])
14print("running_mean :", bn.running_mean)

Shows the running_mean buffer AFTER one training step. Each entry is exactly 0.1 × the corresponding batch mean — the EMA is still dominated by its initial zero.

EXECUTION STATE
expected stdout = running_mean : tensor([~0.01, ~0.01, ~-0.03]) (depends on random seed)
15print("running_var :", bn.running_var)

Running variance after one step: 0.9·1 + 0.1·batch_var ≈ 0.99 per channel (small deviations due to the random batch).

EXECUTION STATE
expected stdout = running_var : tensor([~0.99, ~0.99, ~0.99])
16print("train ≠ eval? :", not torch.allclose(y_train, y_eval))

Confirm the train/eval divergence. torch.allclose returns True only if every pair of entries is within a tolerance — `not torch.allclose(...)` evaluates to True whenever any element differs materially. Here it WILL be True because running_var still holds mostly the init value 1, while batch_var ≈ 1 but not identical.

EXECUTION STATE
📚 torch.allclose(a, b, rtol=1e-5, atol=1e-8) = Element-wise closeness check. Returns a Python bool. Useful for numerical equivalence tests.
expected stdout = train ≠ eval? : True
→ why this matters = Forgetting to call model.eval() before inference is one of the most common BatchNorm bugs — the model silently produces different predictions for the same input depending on mode. This 3-line demo is the shortest reproducible example.
4 lines without explanation
1import torch
2import torch.nn as nn
3
4bn = nn.BatchNorm2d(num_features=3, eps=1e-5, momentum=0.1)
5x  = torch.randn(4, 3, 2, 2)
6
7bn.train()                       # enable running-stats update
8y_train = bn(x)
9
10bn.eval()                        # freeze running stats, use them for forward
11y_eval  = bn(x)
12
13print("y_train shape   :", y_train.shape)
14print("running_mean    :", bn.running_mean)
15print("running_var     :", bn.running_var)
16print("train ≠ eval?   :", not torch.allclose(y_train, y_eval))

The divergence is the whole point. In training mode BN normalizes with the CURRENT batch's statistics and updates its running buffers; in eval mode it freezes those buffers and uses them instead, so the forward pass becomes a pure function of the input. Early in training the two modes disagree because the running stats lag behind the true dataset distribution — forgetting model.eval()\texttt{model.eval()}before inference is the classic BatchNorm footgun, and this visualizer makes it unmissable.

BatchNorm — Train vs Eval Statistics

In train mode, BatchNorm normalizes with the noisy per-batch mean and variance. In eval mode, it uses the running EMA that was accumulated during training — scrub the slider to watch the EMA lag behind the batch stats.

batch μrunning EMA μtrue μ(t)
Per-batch μ vs running EMA μ across training steps
-0.50.00.51.0050100150200μ
μ_batch = 0.370
μ_running = 0.223
σ_batch = 0.899
σ_running = 0.885
Activation distribution at step t = 100using batch stats
-4-2024count
input xnormalized x̂using μ = 0.370, σ = 0.899
What you're seeing. The blue EMA curve is a low-pass filter over the red per-batch means, updated as μ_running ← (1 − m)·μ_running + m·μ_batch. Shrink the batch size and the red curve becomes wildly noisy — but the blue EMA stays smooth. In eval mode inference uses that stabilized blue estimate, so a single input always produces the same output. In train mode the same input is normalized against whatever batch it happened to land in, so outputs on the same data can differ between the two modes — especially early in training, or when m is too small for the EMA to have caught up.

Layer Normalization: Statistics Along the Feature Axis

Layer Normalization, introduced by Ba, Kiros and Hinton (2016), reduces in the opposite direction from BatchNorm. For a Transformer activation of shape (N,T,D)(N, T, D) — batch of NN sequences, each of length TT, each token a DD-dimensional vector — LayerNorm computes one mean and variance per token, reducing across the feature dimension: μn,t=1Dd=1Dxn,t,d\mu_{n,t} = \frac{1}{D}\sum_{d=1}^{D} x_{n,t,d} and σn,t2=1Dd=1D(xn,t,dμn,t)2\sigma_{n,t}^{2} = \frac{1}{D}\sum_{d=1}^{D}(x_{n,t,d} - \mu_{n,t})^{2}.

Critically, the statistics depend only on the token itself — not on the batch, not on other tokens. This has three consequences that make LN the correct choice for Transformers:

  1. Batch-size independent. Train with batch 256 or infer with batch 1 — LayerNorm does the same math either way. There is no train/eval mismatch, no running statistics to maintain.
  2. Length-independent. Each token is normalized by itself, so sequences of length 4 and 4096 are both handled correctly. BatchNorm, which would average across the sequence dimension, would mix information across positions — a leak attention is designed to avoid.
  3. Causal-safe. In an autoregressive decoder, token tt must not see information from tokens {>}t. LN's stats come only from position tt's own features, so no future leakage.

Batch Norm vs Layer Norm: Visualization

Layer Normalization

Normalizes across the feature dimension for each sample. Statistics (\u03BC, \u03C3) are computed per sample row. Works with any batch size (even 1) and behaves the same at train and test time.

Raw Activations [Batch=4, Features=6]

F1F2F3F4F5F6S1S2S3S42.34.30.83.11.42.15.43.94.81.52.02.05.84.83.38.24.94.35.9NaN6.13.710.35.2Normalizes across rows

After Layer Normalization

F1F2F3F4F5F6S1S2S3S4-0.041.71-1.360.70-0.82-0.201.420.431.00-1.17-0.84-0.840.39-0.26-1.251.95-0.25-0.59NaNNaNNaNNaNNaNNaNValues centered around 0

Select Sample Row to Highlight

Sample 1
Selected
Mean (\u03BC)
2.322
Variance (\u03C3\u00B2)
1.291

Learnable Parameters (\u03B3 and \u03B2)

PropertyBatch NormLayer Norm
Normalizes overBatch dimensionFeature dimension
Batch size dependencyNeeds large batchesWorks with batch=1
Train vs TestDifferent (running stats)Same behavior
Best forCNNs, large batch trainingTransformers, RNNs
Geometric reading. LayerNorm projects each sample vector onto the unit sphere (after centering) in RD\mathbb{R}^{D}, then scales it to radius D\sqrt{D}. Direction is preserved; only magnitude is rewritten. For attention this matters: dot-product scores QKQ K^{\top} are angle-like quantities, and LayerNorm makes sure each token's query/key contributes the same magnitude regardless of what happened upstream.

BatchNorm vs LayerNorm: The Dimensional Story

The fastest way to internalize the difference is to see which cells of a (N, C, H, W) tensor share their mean and variance. BN colors a whole (N,H,W)(N, H, W) slab of one channel; LN colors a whole(C,H,W)(C, H, W) slab of one sample; InstanceNorm colors one (H,W)(H, W) feature map; GroupNorm interpolates between LN and IN.

Normalization Methods Comparison

Dim: (N, H, W) per C

Normalizes over batch, height, width for each channel

Tensor Shape: [N=2, C=4, H=2, W=2]

N (Batch)C (Channel)Sample 1C13.57.30.75.1C23.44.77.15.1C37.73.33.93.9C48.47.35.710.9Sample 2C12.82.32.5NaNC24.22.07.83.4C35.26.85.64.0C48.46.37.58.9\u2193 Normalize \u2193NaNNaNNaNNaN-0.7-0.01.30.21.8-1.2-0.8-0.80.3-0.4-1.52.0NaNNaNNaNNaN-0.3-1.51.7-0.70.11.20.4-0.70.3-1.1-0.30.6
Normalization regions:
C0 (\u03BC=NaN)
C1 (\u03BC=4.7)
C2 (\u03BC=5.1)
C3 (\u03BC=7.9)
MethodNorm DimStats PerUse Case
Batch(N, H, W)ChannelCNNs, large batches
Layer(C, H, W)SampleTransformers, RNNs
Instance(H, W)Sample + ChannelStyle transfer, GANs
Group(C/G, H, W)Sample + GroupSmall batch CNNs
PropertyBatchNormLayerNorm
Stats computed perChannel (across batch + spatial)Sample / token (across features)
Depends on batch size?Yes — unstable at small batchNo
Train/eval behaviourDifferent (running stats)Identical
Causal / autoregressive safe?No (averages across sequence)Yes
Primary useCNNsTransformers, RNNs
Parametersγ, β each of size Cγ, β each of size D (the last dim)
Extra buffersrunning mean / var (one per channel)None

InstanceNorm and GroupNorm: The Interpolation

Between BatchNorm (statistics pooled over the entire batch per channel) and LayerNorm (statistics over the full feature vector per sample) lie two useful interpolations. InstanceNorm, introduced by Ulyanov, Vedaldi and Lempitsky (2016) in Instance Normalization: The Missing Ingredient for Fast Stylization, computes one (μ,σ2)(\mu, \sigma^{2}) per (n,c)(n, c) — each feature map is normalized in isolation, independently of the batch and of other channels. It became the default for style-transfer networks and GAN generators because it strips content-specific statistics out of each feature map, letting learned γ,β\gamma, \beta reinject the desired style.

GroupNorm, proposed by Wu and He (ECCV 2018, Group Normalization), sits in between. It partitions the CC channels into GG groups and computes statistics over (C/G,H,W)(C/G, H, W) per sample. The per-group mean isμn,g=1(C/G)HWcg,h,wxn,c,h,w\mu_{n,g} = \frac{1}{(C/G) H W}\sum_{c \in g, h, w} x_{n,c,h,w}, with an analogous variance. Two limits recover the neighbors: G=1G = 1 reduces to LayerNorm (all channels share stats) and G=CG = C reduces to InstanceNorm (each channel alone). For vision tasks with small batch sizes — detection, segmentation, 3D medical imaging — G=32G = 32 is the empirical sweet spot and GroupNorm routinely outperforms BatchNorm.

GroupNorm from scratch — NumPy (recovers LN and IN as limits)
🐍group_norm_from_scratch.py
1import numpy as np

Same NumPy backbone as before. We'll use reshape (the defining move of GroupNorm), reductions, and broadcasting.

EXECUTION STATE
📚 numpy = ndarray.reshape(...) returns a view when possible — no data copy. That makes the group-split step essentially free.
3def group_norm(x, gamma, beta, G, eps=1e-5)

Single-function GroupNorm. x is the conv activation, gamma and beta are per-channel affine params (same shape as BN's), G is the group count, eps is numerical safety. When G=1 the C channels are merged into one group → that's LayerNorm; when G=C each channel is its own group → that's InstanceNorm.

EXECUTION STATE
⬇ arg: x = Activation tensor of shape (N, C, H, W).
⬇ arg: gamma = Per-channel scale, shape (C,). Broadcast across (N, H, W).
⬇ arg: beta = Per-channel shift, shape (C,).
⬇ arg: G (int) = Number of groups. Must divide C evenly. Classic choice is G=32 for ResNet — empirically best for detection / segmentation at small batch sizes.
⬇ arg: eps = 1e-5 = Denominator safety. Same default as BN / LN.
5N, C, H, W = x.shape

Unpack the four dimensions so the subsequent lines read cleanly.

EXECUTION STATE
with x.shape = (2, 6, 2, 2) = N=2 samples, C=6 channels, H=W=2 pixels. 48 floats total.
6assert C % G == 0

Groups must evenly partition the channels. If C=6 and G=4 we'd need to drop or pad channels, so we fail fast with an AssertionError.

EXECUTION STATE
example: C=6, G=2 = 6 % 2 == 0 → passes. Channels per group = 3.
example: C=6, G=4 = 6 % 4 == 2 ≠ 0 → AssertionError.
7x_g = x.reshape(N, G, C // G, H, W)

The signature GroupNorm move. We re-view the channel axis as (G, C/G). With N=2, C=6, H=W=2 and G=2, this turns shape (2, 6, 2, 2) into shape (2, 2, 3, 2, 2) — two samples, two groups, three channels per group, two × two spatial.

EXECUTION STATE
📚 ndarray.reshape(*new_shape) = Reinterprets the underlying buffer under a new shape. The product of dimensions must match. Returns a view when the array is contiguous (no data copied).
⬆ x_g.shape = (2, 2, 3, 2, 2) — (N, G, C/G, H, W)
what this lets us do = Now axis 2 (C/G), axis 3 (H), axis 4 (W) together span exactly the elements that share a (μ, σ²) pair. One reduction is enough.
8mu = x_g.mean(axis=(2, 3, 4), keepdims=True)

Compute the mean over (C/G, H, W) for each (sample, group). For G=2: each group has 3 channels × 2 × 2 = 12 elements per sample. Two samples × two groups = four means total.

EXECUTION STATE
⬇ arg: axis = (2, 3, 4) = Reduce over the channel-per-group, height, and width axes. Leaves (N, G) — one μ per (sample, group).
⬇ arg: keepdims = True = Preserves the reduced axes as size 1, so mu has shape (N, G, 1, 1, 1) and broadcasts back against x_g in the next step.
⬆ mu.shape = (2, 2, 1, 1, 1) — four means, each collapsing 12 input elements.
9var = x_g.var(axis=(2, 3, 4), keepdims=True)

Biased variance over the same (C/G, H, W) axes. Same shape as mu: (N, G, 1, 1, 1).

EXECUTION STATE
⬆ var.shape = (2, 2, 1, 1, 1)
formula reminder = μ_{n,g} = (1 / ((C/G)·H·W)) Σ_{c∈g, h, w} x_{n,c,h,w}. Same for var with the squared deviation.
10x_hat = (x_g - mu) / np.sqrt(var + eps)

Standardize within each (sample, group). After this line, for every (n, g) pair the 12 elements in that group have mean≈0 and variance≈1.

EXECUTION STATE
broadcasting = (2, 2, 3, 2, 2) − (2, 2, 1, 1, 1) → (2, 2, 3, 2, 2). μ_{n,g} is subtracted from every element of its group.
⬆ x_hat.shape = (2, 2, 3, 2, 2)
11x_hat = x_hat.reshape(N, C, H, W)

Un-group: collapse (G, C/G) back into a single channel axis so the affine transform below sees the usual (N, C, H, W) layout.

EXECUTION STATE
⬆ x_hat.shape = (2, 6, 2, 2)
12return gamma[None, :, None, None] * x_hat + beta[None, :, None, None]

Per-CHANNEL affine (not per-group!). The gamma and beta live on the channel axis just like BatchNorm — GN only changes HOW the statistics are computed, not how they're re-scaled afterward.

EXECUTION STATE
gamma shape after reshape = (1, 6, 1, 1) — broadcasts across (N, H, W).
⬆ return shape = (2, 6, 2, 2) — same as input x.
14np.random.seed(0)

Pin the RNG so the numbers below are reproducible across runs and across machines.

15x = np.random.randn(2, 6, 2, 2).astype(np.float32)

48 float32 draws from N(0, 1), shaped (N=2, C=6, H=2, W=2).

EXECUTION STATE
⬆ x shape = (2, 6, 2, 2) — 48 values
x[0, 0] (seed 0) =
[[1.7641, 0.4002], [0.9787, 2.2409]]  — channel 0 of sample 0
16gamma = np.ones((6,), dtype=np.float32)

Identity scale — makes the affine step a pass-through, so we see the pure normalization effect.

EXECUTION STATE
gamma = [1, 1, 1, 1, 1, 1]
17beta = np.zeros((6,), dtype=np.float32)

Zero shift — again, no affine modification.

EXECUTION STATE
beta = [0, 0, 0, 0, 0, 0]
18y_ln = group_norm(x, gamma, beta, G=1)

G=1 means ALL 6 channels share a (μ, σ²) per sample. That's exactly LayerNorm — proof that GN generalizes LN. Each sample's 24 values (6·2·2) get normalized as one block.

EXECUTION STATE
⬆ y_ln per-sample stats = y_ln[n].mean() ≈ 0, y_ln[n].var() ≈ 1 for every n
identity = GN(G=1) == LN over the last three axes (C, H, W).
19y_gn = group_norm(x, gamma, beta, G=2)

G=2 splits the 6 channels into two groups of 3. Each group (per sample) has 12 elements and is standardized in isolation. Four means and four variances total (2 samples × 2 groups).

EXECUTION STATE
⬆ y_gn = (2, 6, 2, 2) — per-group, 12-element stats applied.
→ comparison = y_gn ≠ y_ln even though both have the same shape: LN sees 24-element groups, GN(G=2) sees 12-element groups, so the standardization scales differ.
20y_in = group_norm(x, gamma, beta, G=6)

G=6 means each channel is its own group. Stats per (sample, channel) over only H·W = 4 elements. That's InstanceNorm — GN(G=C) == IN.

EXECUTION STATE
⬆ y_in = (2, 6, 2, 2) — tiny 4-element groups.
→ use case = IN is the classic choice for style transfer: normalizing each feature map in isolation strips content statistics, leaving style to be reinjected by learned gamma/beta.
21print("Per-row mean ≈ 0?", y_gn.mean(axis=(1,2,3)).round(4))

Sanity-check that the sample-wise mean of y_gn is near zero. Because each sample has two groups of 12 elements and BOTH groups were centered, their pooled mean is exactly 0 (up to float rounding).

EXECUTION STATE
expected stdout = Per-row mean ≈ 0? [-0. -0. ]
→ interpretation = Confirms GN's per-group centering. If we'd computed per-group means we'd see every one is exactly 0.
3 lines without explanation
1import numpy as np
2
3def group_norm(x, gamma, beta, G, eps=1e-5):
4    """x shape (N, C, H, W). Normalize within each (sample, group)."""
5    N, C, H, W = x.shape
6    assert C % G == 0
7    x_g = x.reshape(N, G, C // G, H, W)
8    mu  = x_g.mean(axis=(2, 3, 4), keepdims=True)
9    var = x_g.var(axis=(2, 3, 4),  keepdims=True)
10    x_hat = (x_g - mu) / np.sqrt(var + eps)
11    x_hat = x_hat.reshape(N, C, H, W)
12    return gamma[None, :, None, None] * x_hat + beta[None, :, None, None]
13
14np.random.seed(0)
15x = np.random.randn(2, 6, 2, 2).astype(np.float32)
16gamma = np.ones((6,), dtype=np.float32)
17beta  = np.zeros((6,), dtype=np.float32)
18y_ln = group_norm(x, gamma, beta, G=1)   # equivalent to LayerNorm
19y_gn = group_norm(x, gamma, beta, G=2)   # 2 groups of 3 channels
20y_in = group_norm(x, gamma, beta, G=6)   # equivalent to InstanceNorm
21print("Per-row mean ≈ 0?", y_gn.mean(axis=(1,2,3)).round(4))

LayerNorm From Scratch in Python

The entire LayerNorm forward pass fits in five lines of NumPy. We'll step through it with concrete numbers so there's no black box. Our input is a 3×43 \times 4 matrix (three samples, four features). Click any line on the right to see what that line computes and why.

LayerNorm from scratch — NumPy
🐍layer_norm_from_scratch.py
1import numpy as np

NumPy is Python's numerical library. It provides the ndarray — a contiguous, typed, N-dimensional array. All the arithmetic on this page (broadcasting, reductions like .mean() and .var(), element-wise operators) runs as optimized C under the hood, not as Python loops. We alias it as np by universal convention.

EXECUTION STATE
📚 numpy = Library for fast multi-dimensional array math. Provides ndarray, linear algebra (np.linalg), random sampling (np.random), and universal functions (np.sqrt, np.exp, ...).
as np = Creates the alias np so we can write np.array() instead of numpy.array(). Universal Python convention.
3def layer_norm(x, gamma, beta, eps=1e-5)

Defines LayerNorm from first principles. For EACH sample (row) in x, it subtracts that sample's mean, divides by that sample's standard deviation, then scales by gamma and shifts by beta. The output has the SAME shape as x but each row is zero-mean, unit-variance before the affine transform.

EXECUTION STATE
⬇ input: x (3×4) — activations =
        f0    f1    f2    f3
samp0  2.0   4.0   6.0   8.0
samp1  1.0   3.0   2.0   6.0
samp2  5.0   7.0   3.0   9.0
→ x purpose = The layer's activations. Rows are samples (or tokens in a Transformer); columns are features/channels. LN treats each row as an independent vector to be standardized.
⬇ input: gamma (4,) — learnable scale = [2.0, 1.0, 0.5, 1.0]
→ gamma purpose = Per-feature scaling factor applied AFTER normalization. Without gamma, the network would be forced to keep variance=1 everywhere, which reduces expressiveness. gamma lets the network recover any variance it needs per feature.
⬇ input: beta (4,) — learnable shift = [0.0, 0.0, 0.0, 0.5]
→ beta purpose = Per-feature bias applied AFTER normalization. Combined with gamma, the affine transform y = gamma * x_hat + beta can undo the standardization if the network needs to — so normalization never removes representational power, it just gives the optimizer a well-conditioned starting point.
⬇ input: eps = 1e-5 = A tiny positive constant added under the square root. Prevents division by zero when var ≈ 0 (a dead feature). Example: if var = 0, sqrt(var + 1e-5) ≈ 3.16e-3 — small but finite, so x/sqrt stays well-defined.
⬆ returns = np.ndarray of shape (3,4) — each row has been standardized (mean≈0, var≈1) and then re-scaled by gamma and shifted by beta.
4Docstring: per-sample normalization across the feature axis

This docstring pins down the DIRECTION of normalization. In NumPy, axis=-1 means 'the last axis' — for a (3,4) matrix that's the feature axis (4 features). So LayerNorm reduces across features, computing one (mean, variance) per sample. Contrast with BatchNorm, which would reduce across samples (axis=0) to get one (mean, variance) per feature.

5mu = x.mean(axis=-1, keepdims=True)

Compute the per-sample mean along the feature axis. Every row of x is reduced to a single number — its average. This is the first statistic LayerNorm needs: the center of each sample's distribution.

EXECUTION STATE
📚 ndarray.mean(axis, keepdims) = NumPy reduction: sums the elements along `axis` and divides by their count. Example: np.array([1,2,3]).mean() = 2.0. With axis=-1 on a 2-D matrix, it averages across columns and returns a 1-D (or column) result.
⬇ arg: axis = -1 = Reduce along the LAST dimension (here, the feature axis with 4 elements). For a 3×4 matrix, this gives 3 means — one per sample. axis=0 would instead give 4 means — one per feature (that's BatchNorm's axis).
→ axis example = x = [[1,2,3,4],[5,6,7,8]] (shape 2×4) x.mean(axis=-1) → [2.5, 6.5] (row averages) x.mean(axis=0) → [3.0, 4.0, 5.0, 6.0] (column averages)
⬇ arg: keepdims = True = Keep the reduced axis with size 1 instead of dropping it. Without keepdims, mean has shape (3,); with keepdims, it has shape (3,1). The (3,1) shape broadcasts correctly against x (3,4) in the next step — otherwise we'd need to reshape manually.
→ keepdims example = x.mean(axis=-1, keepdims=False) → shape (3,) → [[. . . .], [. . . .], [. . . .]] — broadcasts only along axis 0, WRONG direction x.mean(axis=-1, keepdims=True) → shape (3,1) → [[μ0], [μ1], [μ2]] — broadcasts along the 4 features of each row, CORRECT
── computed mu per sample ── =
samp0 μ = (2+4+6+8)/4 = 20/4 = 5.0000
samp1 μ = (1+3+2+6)/4 = 12/4 = 3.0000
samp2 μ = (5+7+3+9)/4 = 24/4 = 6.0000
⬆ result: mu (3×1) =
[[5.0000],
 [3.0000],
 [6.0000]]
6var = x.var(axis=-1, keepdims=True)

Compute the per-sample variance along the feature axis. Variance measures how spread out each sample's features are — large variance = some features dominate, small variance = features are tightly grouped. LayerNorm will use it to rescale every sample to variance 1.

EXECUTION STATE
📚 ndarray.var(axis, keepdims) = Population variance: mean of squared deviations from the mean. var(x) = mean((x - mean(x))²). NumPy uses the BIASED estimator by default (divides by N, not N-1) — this matches PyTorch's LayerNorm exactly.
⬇ arg: axis = -1 = Same as .mean(axis=-1): reduce along the feature axis so each sample gets its own variance.
⬇ arg: keepdims = True = Same reasoning as for mu: produces shape (3,1) so broadcasting against x (3,4) works in the next line.
── computed variance per sample ── =
samp0 var = mean((2-5)², (4-5)², (6-5)², (8-5)²) = mean(9, 1, 1, 9) = 20/4 = 5.0000
samp1 var = mean((1-3)², (3-3)², (2-3)², (6-3)²) = mean(4, 0, 1, 9) = 14/4 = 3.5000
samp2 var = mean((5-6)², (7-6)², (3-6)², (9-6)²) = mean(1, 1, 9, 9) = 20/4 = 5.0000
⬆ result: var (3×1) =
[[5.0000],
 [3.5000],
 [5.0000]]
7x_hat = (x - mu) / np.sqrt(var + eps)

The standardization step. For every element, subtract the row mean (centering) and divide by the row standard deviation (scaling). After this line, EACH ROW of x_hat has mean≈0 and variance≈1 — regardless of what scale the row had before.

EXECUTION STATE
📚 np.sqrt() = Element-wise square root. For a (3,1) input, returns a (3,1) output. Example: np.sqrt(np.array([4., 9.])) = [2., 3.].
⬇ x - mu (broadcasting 3×4 − 3×1 → 3×4) =
        f0     f1     f2     f3
samp0 -3.0   -1.0   +1.0   +3.0
samp1 -2.0    0.0   -1.0   +3.0
samp2 -1.0   +1.0   -3.0   +3.0
→ broadcasting rule = NumPy expands (3,1) to (3,4) by duplicating the single column across 4 positions. So mu[0][0]=5 is subtracted from ALL four features of sample 0.
⬇ var + eps =
[[5.00001],
 [3.50001],
 [5.00001]]   (eps=1e-5 is added for numerical safety)
⬇ sqrt(var + eps) (3×1) =
[[2.2361],
 [1.8708],
 [2.2361]]
── x_hat row by row ── =
samp0 x_hat = [-3/2.2361, -1/2.2361, 1/2.2361, 3/2.2361] = [-1.3416, -0.4472, 0.4472, 1.3416]
samp1 x_hat = [-2/1.8708, 0/1.8708, -1/1.8708, 3/1.8708] = [-1.0690, 0.0000, -0.5345, 1.6036]
samp2 x_hat = [-1/2.2361, 1/2.2361, -3/2.2361, 3/2.2361] = [-0.4472, 0.4472, -1.3416, 1.3416]
⬆ result: x_hat (3×4) =
         f0       f1       f2       f3
samp0  -1.3416  -0.4472   0.4472   1.3416
samp1  -1.0690   0.0000  -0.5345   1.6036
samp2  -0.4472   0.4472  -1.3416   1.3416
→ verify = Every row of x_hat sums to ≈0 (mean=0) and has squared mean ≈1 (variance=1). That's the entire point of the standardization step.
8return gamma * x_hat + beta

The learnable affine transform. After standardization every feature is unit-variance, but the NETWORK may want some features to be large (gamma>1) or biased toward positive (beta>0). The broadcast is (3,4) * (4,) + (4,), which NumPy expands row-wise: gamma is applied identically to every sample.

EXECUTION STATE
* (element-wise multiply) = NumPy's broadcasting multiply. A shape (3,4) times shape (4,) stretches gamma across the 3 rows: each row is element-wise multiplied by gamma. So gamma[0] scales column 0 of every row.
+ (element-wise add) = Same broadcasting as above: beta is added to every row element-wise.
⬇ gamma * x_hat (3×4) =
         f0       f1       f2       f3
samp0  -2.6833  -0.4472   0.2236   1.3416
samp1  -2.1381   0.0000  -0.2673   1.6036
samp2  -0.8944   0.4472  -0.6708   1.3416
(column 0 doubled, column 2 halved, rest unchanged)
⬆ return: gamma * x_hat + beta (3×4) =
         f0       f1       f2       f3
samp0  -2.6833  -0.4472   0.2236   1.8416
samp1  -2.1381   0.0000  -0.2673   2.1036
samp2  -0.8944   0.4472  -0.6708   1.8416
(beta=0.5 added only to the last column)
→ why affine is essential = If gamma were fixed at 1 and beta at 0, LayerNorm would force every layer's output to be unit-variance. That's a restriction — the next layer may need larger activations. With learnable gamma and beta, the network can restore any distribution it wants. In fact, setting gamma=sigma, beta=mu makes LN the identity map.
10x = np.array([[2., 4., 6., 8.], ...

Construct the input activation matrix as a NumPy ndarray. We use three samples of four features so every step can be verified by hand. dtype=np.float32 matches PyTorch's default (and GPU-friendly) precision.

EXECUTION STATE
📚 np.array(object, dtype) = Creates an ndarray from a nested Python list (or any sequence-of-sequences). Infers shape from nesting depth — a list of 3 lists of 4 numbers becomes shape (3,4).
⬇ arg: nested list = 3 rows × 4 columns — each inner list is one sample with 4 features.
⬇ arg: dtype = np.float32 = 32-bit floating point (4 bytes per element). Half the memory of float64 and matches what GPUs natively compute. PyTorch's default tensor dtype is also float32, so we'll get identical results in the PyTorch block below.
⬆ result: x (3,4) =
        f0    f1    f2    f3
samp0  2.0   4.0   6.0   8.0
samp1  1.0   3.0   2.0   6.0
samp2  5.0   7.0   3.0   9.0
11Row 1 of the matrix literal: [1., 3., 2., 6.]

Second sample — four feature values. Note that this row has a smaller range (1 to 6) than row 0 (2 to 8). LayerNorm will still map it to mean 0, variance 1, regardless of the original scale.

EXECUTION STATE
samp1 = [1.0, 3.0, 2.0, 6.0] — small magnitude, will be scaled up by LN since its variance is lower.
12Row 2 of the matrix literal: [5., 7., 3., 9.]

Third sample. Same scale as row 0 (variance = 5.0) but centered at a different mean (6 instead of 5). LayerNorm erases the mean difference — both rows become zero-mean after standardization.

EXECUTION STATE
samp2 = [5.0, 7.0, 3.0, 9.0] — mean=6, var=5.
13gamma = np.array([2.0, 1.0, 0.5, 1.0], dtype=np.float32)

Learnable scale vector. In a real network these would be nn.Parameters initialized to 1.0 and updated by gradient descent; here we hard-code them to show the affine effect clearly.

EXECUTION STATE
gamma = [2.0, 1.0, 0.5, 1.0] = Feature 0 will be scaled ×2; feature 1 unchanged; feature 2 halved; feature 3 unchanged. If the network decided feature 0 needs to be amplified, it would learn gamma[0]>1.
14beta = np.array([0.0, 0.0, 0.0, 0.5], dtype=np.float32)

Learnable shift vector. Only feature 3 gets a nonzero bias, so all outputs in the last column will be shifted up by 0.5 AFTER the gamma scaling.

EXECUTION STATE
beta = [0.0, 0.0, 0.0, 0.5] = Only the last feature shifts. This will make the standardization asymmetric on the fly — useful when the next layer has a natural bias for positive values there.
16y = layer_norm(x, gamma, beta)

Calls the LayerNorm function. Internally it runs line 5 (mean), line 6 (variance), line 7 (standardization) and line 8 (affine). The returned y has the same shape as x (3×4) but a completely different distribution — every row is now zero-mean and the network's affine parameters have re-shaped it.

EXECUTION STATE
⬆ y (3×4) =
         f0       f1       f2       f3
samp0  -2.6833  -0.4472   0.2236   1.8416
samp1  -2.1381   0.0000  -0.2673   2.1036
samp2  -0.8944   0.4472  -0.6708   1.8416
→ invariant = If we called layer_norm(x*1000, gamma, beta), y would be IDENTICAL. LN is scale-invariant per row — a defining property that makes training stable across different input magnitudes.
17print(y)

Prints the final normalized activations. These numbers match PyTorch's nn.LayerNorm(4) output exactly — we verify this in the PyTorch block below.

EXECUTION STATE
expected stdout =
[[-2.6833 -0.4472  0.2236  1.8416]
 [-2.1381  0.     -0.2673  2.1036]
 [-0.8944  0.4472 -0.6708  1.8416]]
3 lines without explanation
1import numpy as np
2
3def layer_norm(x, gamma, beta, eps=1e-5):
4    """Per-sample normalization across the feature axis."""
5    mu    = x.mean(axis=-1, keepdims=True)
6    var   = x.var(axis=-1, keepdims=True)
7    x_hat = (x - mu) / np.sqrt(var + eps)
8    return gamma * x_hat + beta
9
10x     = np.array([[2., 4., 6., 8.],
11                  [1., 3., 2., 6.],
12                  [5., 7., 3., 9.]], dtype=np.float32)
13gamma = np.array([2.0, 1.0, 0.5, 1.0], dtype=np.float32)
14beta  = np.array([0.0, 0.0, 0.0, 0.5], dtype=np.float32)
15
16y = layer_norm(x, gamma, beta)
17print(y)

LayerNorm in PyTorch

PyTorch ships nn.LayerNorm\texttt{nn.LayerNorm} with a fused CUDA kernel that computes mean, variance, standardization and affine in a single pass — roughly 3×3\times faster than the naive sequence of operations. Importantly, the math is byte-for-byte identical to our NumPy implementation, which we verify below.

LayerNorm — PyTorch nn.LayerNorm
🐍layer_norm_pytorch.py
1import torch

PyTorch's core package. Provides torch.Tensor — a GPU-capable ndarray — plus autograd (automatic differentiation), optimizers, and neural-network building blocks. Unlike NumPy arrays, tensors can be moved to a CUDA device and track gradients for backprop.

EXECUTION STATE
📚 torch = Root module. torch.tensor(...) constructs tensors; torch.no_grad() disables autograd tracking; torch.float32 is the default dtype.
2import torch.nn as nn

torch.nn is the neural-network module library. Every learnable layer (Linear, Conv2d, LayerNorm, BatchNorm, TransformerEncoderLayer, …) is an nn.Module subclass. nn.LayerNorm is what we'll use on line 8 — it wraps the exact math from our from-scratch function plus learnable nn.Parameters for gamma and beta.

EXECUTION STATE
📚 nn = Namespace holding all layers. nn.Module is the base class — every custom model subclasses it.
4x = torch.tensor([[...], [...], [...]])

Constructs a 2-D float32 tensor of shape (3,4) from a nested Python list. Identical numbers to the NumPy example — this lets us verify that our from-scratch LayerNorm and PyTorch's built-in produce the same output.

EXECUTION STATE
📚 torch.tensor(data) = Creates a new Tensor and copies `data` into it. Infers shape from the nesting and dtype from the first element (float because we used 2., not 2).
⬇ data: nested list 3×4 =
        f0    f1    f2    f3
samp0  2.0   4.0   6.0   8.0
samp1  1.0   3.0   2.0   6.0
samp2  5.0   7.0   3.0   9.0
→ dtype inferred = torch.float32 (because we wrote 2., not 2). If we'd written [[2,4,6,8],...] the dtype would be int64 and LayerNorm would error.
⬆ result: x.shape = torch.Size([3, 4]) — requires_grad=False (it's an input, not a parameter).
5Row 1 of the tensor literal

The second sample, same numbers as the NumPy example. PyTorch reads the Python list top-to-bottom so this line becomes x[1].

EXECUTION STATE
x[1] = [1.0, 3.0, 2.0, 6.0] — smallest-variance sample in the batch.
6Row 2 of the tensor literal

The third sample. The closing brackets `]])` end the nested list and the torch.tensor() call.

EXECUTION STATE
x[2] = [5.0, 7.0, 3.0, 9.0]
8ln = nn.LayerNorm(normalized_shape=4, eps=1e-5)

Constructs a LayerNorm module that will normalize over the LAST dimension of size 4. Internally this creates two nn.Parameters — weight (gamma) initialized to ones and bias (beta) initialized to zeros — and registers them so the optimizer updates them during training.

EXECUTION STATE
📚 nn.LayerNorm(normalized_shape, eps, elementwise_affine) = Module that computes y = (x − μ) / √(σ² + ε) · γ + β. Statistics are computed over the last len(normalized_shape) dimensions of the input, independently for every position in the earlier dimensions.
⬇ arg: normalized_shape = 4 = Normalize over the last dimension, which has size 4. Equivalent to normalized_shape=(4,). For transformer activations of shape (batch, seq, d_model) you'd pass d_model — then stats are computed per (batch, seq) position.
→ normalized_shape example = (batch=2, seq=3, d=4) with normalized_shape=4 → 6 means & 6 variances, one per (batch, seq). A single vector of length 4 is standardized in isolation — exactly the per-token LN used in Transformers.
⬇ arg: eps = 1e-5 = Same role as in the NumPy version — numerical safety in the denominator. 1e-5 is PyTorch's default; LLaMA-1/2/3 and Mistral also use 1e-5 in their RMSNorm, while Gemma and T5 drop it to 1e-6.
⬇ arg: elementwise_affine (default True) = Controls whether gamma/beta exist. True → module has learnable weight and bias; False → the affine step is skipped and y = x_hat. We leave the default so we can set weight/bias manually.
⬆ result: ln = nn.LayerNorm module ln.weight = Parameter([1., 1., 1., 1.]) requires_grad=True ln.bias = Parameter([0., 0., 0., 0.]) requires_grad=True ln.eps = 1e-5
9with torch.no_grad():

Opens a context in which NO operations are tracked for autograd. Inside the block, assignments to parameters don't build a gradient graph and don't trigger errors about 'leaf tensor requires_grad'. We use it to hand-set ln.weight and ln.bias so our outputs match the NumPy version exactly.

EXECUTION STATE
📚 torch.no_grad() = Context manager that temporarily disables autograd. Common uses: manual parameter surgery, inference (no backward pass needed), memory savings in eval mode. Inside the block, tensor.grad_fn is None even for ops on leaf parameters.
→ without no_grad = ln.weight.copy_(...) would fail with 'a leaf Variable that requires grad is being used in an in-place operation'. Autograd guards leaves to prevent breaking the graph.
10ln.weight.copy_(torch.tensor([2.0, 1.0, 0.5, 1.0]))

In-place overwrite of the LayerNorm's gamma parameter. The trailing underscore in .copy_() means 'modify in place' — the tensor keeps its identity (same memory address, still a Parameter, still registered) but its values change.

EXECUTION STATE
📚 tensor.copy_(src) = In-place copy: writes src's values into tensor's existing storage. Shapes must be broadcast-compatible. The underscore suffix is PyTorch's universal marker for in-place ops (add_, mul_, zero_, ...).
⬇ arg: torch.tensor([2.0, 1.0, 0.5, 1.0]) = A temporary 1-D tensor of length 4 on the CPU. This gets copied element-by-element into ln.weight.
⬆ after: ln.weight = Parameter containing: tensor([2.0000, 1.0000, 0.5000, 1.0000])
11ln.bias.copy_(torch.tensor([0.0, 0.0, 0.0, 0.5]))

Same idea for beta: set ln.bias in-place to match our NumPy example. In a real training run, the optimizer would learn these values; here we're pinning them to prove the numerical equivalence.

EXECUTION STATE
⬆ after: ln.bias = Parameter containing: tensor([0.0000, 0.0000, 0.0000, 0.5000])
13y = ln(x)

Runs the forward pass. Calling a Module as a function triggers its __call__, which dispatches to .forward(). Internally PyTorch computes μ, σ² along the last dim, standardizes, then multiplies by ln.weight and adds ln.bias — the exact four steps we wrote by hand.

EXECUTION STATE
📚 Module.__call__(input) = Runs input through the layer. Wraps .forward() with hooks (pre-forward, forward, full-backward). Autograd tracks the ops unless we're inside torch.no_grad().
⬇ input: x (3,4) = Our original activations — unchanged from line 4.
internal: μ (3,1) =
[[5.0],
 [3.0],
 [6.0]]  — same as NumPy.
internal: σ² (3,1) =
[[5.00],
 [3.50],
 [5.00]]
internal: x_hat (3,4) =
         f0       f1       f2       f3
samp0  -1.3416  -0.4472   0.4472   1.3416
samp1  -1.0690   0.0000  -0.5345   1.6036
samp2  -0.4472   0.4472  -1.3416   1.3416
⬆ result: y = γ · x_hat + β (3,4) =
         f0       f1       f2       f3
samp0  -2.6833  -0.4472   0.2236   1.8416
samp1  -2.1381   0.0000  -0.2673   2.1036
samp2  -0.8944   0.4472  -0.6708   1.8416
→ numerical match = Byte-for-byte identical to the NumPy from-scratch output. PyTorch's LayerNorm uses the SAME formula — with a heavily optimized CUDA kernel under the hood when x is on GPU.
14print(y)

Prints the tensor. PyTorch's __repr__ adds a `tensor(...)` wrapper and truncates at 4 decimals by default.

EXECUTION STATE
expected stdout =
tensor([[-2.6833, -0.4472,  0.2236,  1.8416],
        [-2.1381,  0.0000, -0.2673,  2.1036],
        [-0.8944,  0.4472, -0.6708,  1.8416]],
       grad_fn=<NativeLayerNormBackward0>)
→ grad_fn = The grad_fn suffix tells us PyTorch recorded the LayerNorm op for backprop. If we called y.backward(), gradients would flow all the way back to ln.weight and ln.bias.
3 lines without explanation
1import torch
2import torch.nn as nn
3
4x = torch.tensor([[2., 4., 6., 8.],
5                  [1., 3., 2., 6.],
6                  [5., 7., 3., 9.]])
7
8ln = nn.LayerNorm(normalized_shape=4, eps=1e-5)
9with torch.no_grad():
10    ln.weight.copy_(torch.tensor([2.0, 1.0, 0.5, 1.0]))
11    ln.bias.copy_(torch.tensor([0.0, 0.0, 0.0, 0.5]))
12
13y = ln(x)
14print(y)

Pre-LN vs Post-LN in Transformers

The original Attention Is All You Need (Vaswani et al. 2017) placed LayerNorm afterthe residual addition, so one sub-block was xLN(x+Attention(x))x \mapsto \text{LN}\bigl(x + \text{Attention}(x)\bigr).

This worked for the 6-layer original Transformer but turned out to be catastrophically unstable beyond ~12 layers. Xiong et al. (2020, Theorem 1, ICML On Layer Normalization in the Transformer Architecture) show that in Post-LN the expected gradient norm at the input of the last layer scales as O(L)\mathcal{O}(\sqrt{L}) at initialization, while the earliest layers receive disproportionately large gradients. The practical consequence is that Post-LN requires a small learning rate plus a long warmup schedule, whereas Pre-LN places LayerNorm inside the residual branch and achieves roughly uniform gradient magnitudes across depth. The visualizer below makes this depth-wise asymmetry concrete — sweep the layer count and watch the per-layer gradient norms diverge under Post-LN and stay flat under Pre-LN.

Pre-LN vs Post-LN — Gradient Magnitude by Depth

Xiong et al. (ICML 2020) proved that at initialization, Post-LN Transformers have gradient norms that blow up at shallow layers (scaling like O(√L) at layer 1). Pre-LN keeps the gradient norm roughly O(1) at every depth — which is why every modern LLM uses Pre-LN and skips the warmup heuristic.

Post-LN
Pre-LN
Y-axis: log scale
Post-LN
xF(x)+LayerNormy
Pre-LN
xLayerNormF(·)+y
0.771.141.702.533.785.6314812162024layer index ℓgradient norm (log)
Post-LN disparity
g[1] / g[L] = 4.90
Pre-LN disparity
g[1] / g[L] = 0.94

As L grows, Post-LN's gradient disparity widens — the earliest layers receive loud gradients while deep layers are starved, so a long learning-rate warmup is required to avoid divergence. Pre-LN keeps the per-layer gradient norm roughly flat, which is why deep Pre-LN Transformers train stably from step 1. This is the core reason every modern LLM (GPT-3 onwards, LLaMA, PaLM, Mistral) uses Pre-LN.

Pre-LN moves the normalization inside the residual branch — the rule becomes xx+Attention(LN(x))x \mapsto x + \text{Attention}\bigl(\text{LN}(x)\bigr).

The residual stream is now an unnormalized highway; every sub-block reads a normalized copy but writes back to the raw stream. The consequence is a uniform gradient magnitude across depth — GPT-2, GPT-3, LLaMA, PaLM, and essentially every modern LLM use Pre-LN. The tradeoff is a slight quality drop at the same parameter count, which is overwhelmingly worth the ability to train stably at 70B+ parameters without warmup gymnastics.

SchemeFormulaStability at depthUsed by
Post-LNLN(x + F(x))Unstable beyond ~12 layersOriginal Transformer, BERT
Pre-LNx + F(LN(x))Stable to hundreds of layersGPT-2/3/4, LLaMA, PaLM, Gemini
Sandwich-LNx + LN(F(LN(x)))Very stable, costs one extra LNCogView, some multimodal models
DeepNormLN(αx + F(x)) with scaled initEnables 1000-layer trainingDeepNet (Wang et al. 2022)

RMSNorm: The Modern LLM Default

Zhang and Sennrich (2019) observed that LayerNorm spends half its operations on the mean-subtraction step — and that, empirically, removing that step doesn't hurt model quality. The result is RMSNorm, defined as RMSNorm(x)=γx1Ddxd2+ε\text{RMSNorm}(x) = \gamma \cdot \frac{x}{\sqrt{\tfrac{1}{D}\sum_{d} x_d^{2} + \varepsilon}}.

Three changes from LayerNorm: (i) no mean subtraction, (ii) variance is replaced by raw second moment (mean of x2x^{2}), (iii) no β\beta. That removes one reduction, one subtraction, and one learnable parameter — about 7–15% fewer FLOPs and one fewer all-reduce when training across many GPUs.

Every flagship open-weights LLM from 2023 onward uses RMSNorm: LLaMA, LLaMA 2, LLaMA 3, Mistral, Mixtral, Qwen, Gemma, Phi-3. T5 used it earlier. The original LayerNorm survives mostly in older stacks (BERT family, GPT-2). Epsilons differ slightly: LLaMA-1/2/3 and Mistral set ε=105\varepsilon = 10^{-5} in their published HuggingFace configs, while Gemma and T5 use ε=106\varepsilon = 10^{-6}. Either way, ε\varepsilon only exists to keep the denominator finite for near-dead features — it has no measurable effect on loss for well-trained models.

RMSNorm — what LLaMA, Mistral, Gemma use
🐍rms_norm.py
1def rms_norm(x, gamma, eps=1e-5)

RMSNorm replaces LayerNorm in most modern LLMs (LLaMA, Mistral, Gemma, Qwen). It does ONE less operation: no mean subtraction, no beta. Only divide by the root-mean-square, then scale by gamma. Empirically just as good as LayerNorm while being ~7–15% faster per step.

EXECUTION STATE
⬇ input: x — 1-D activations = [2.0, 4.0, 6.0, 8.0]
⬇ input: gamma — learnable scale = [1.0, 1.0, 1.0, 1.0] (identity for this demo)
⬇ input: eps = 1e-5 = RMSNorm epsilons vary by model — LLaMA-1/2/3 and Mistral use 1e-5 per their HuggingFace configs; Gemma uses 1e-6; T5 uses 1e-6 in its original paper. In both cases the value only guards division-by-zero for near-dead features.
3rms = sqrt(mean(x²) + eps)

The root-mean-square statistic. Instead of subtracting the mean and dividing by σ, RMSNorm simply divides by the RMS. This keeps the operation scale-invariant but not shift-invariant — for Transformer activations (which are roughly zero-centered by residuals anyway) that's fine.

EXECUTION STATE
= [4.0, 16.0, 36.0, 64.0]
mean(x²) = (4+16+36+64)/4 = 120/4 = 30.0
sqrt(30 + 1e-5) = 5.4772
⬆ rms = 5.4772
4return gamma * (x / rms)

Divide x by its RMS, then scale by gamma. With gamma=1 the output magnitude is bounded: the RMS of the result is always ≈ 1, regardless of input scale.

EXECUTION STATE
x / rms = [2/5.4772, 4/5.4772, 6/5.4772, 8/5.4772] = [0.3651, 0.7303, 1.0954, 1.4606]
⬆ gamma * (x / rms) = [0.3651, 0.7303, 1.0954, 1.4606] (RMS of output ≈ 1.0)
→ compare to LN = LN of same x: [-1.3416, -0.4472, 0.4472, 1.3416] — centered at 0. RMSNorm: [0.3651, 0.7303, 1.0954, 1.4606] — NOT centered, but bounded in magnitude. The difference is exactly the mean-subtraction step.
6x = np.array([2., 4., 6., 8.], dtype=np.float32)

A single-row activation. Same numbers as sample 0 in the LayerNorm example so you can compare outputs directly.

EXECUTION STATE
x = [2.0, 4.0, 6.0, 8.0]
7gamma = np.array([1., 1., 1., 1.], dtype=np.float32)

Identity scale. In LLaMA this would be a learnable parameter initialized to ones.

EXECUTION STATE
gamma = [1.0, 1.0, 1.0, 1.0]
8print(rms_norm(x, gamma))

Runs the function and prints the result.

EXECUTION STATE
expected stdout = [0.3651 0.7303 1.0954 1.4606]
2 lines without explanation
1def rms_norm(x, gamma, eps=1e-5):
2    """No mean subtraction — divide by the root-mean-square only."""
3    rms = np.sqrt((x ** 2).mean(axis=-1, keepdims=True) + eps)
4    return gamma * (x / rms)
5
6x     = np.array([2., 4., 6., 8.], dtype=np.float32)
7gamma = np.array([1., 1., 1., 1.], dtype=np.float32)
8print(rms_norm(x, gamma))

Connections to Flash Attention, MHA, PE, KV-Cache, and Scaling

LayerNorm is the quiet infrastructure that every other Transformer component takes for granted. Pull it out and attention collapses, positional encodings drift, KV-caches become invalid, and scaling laws fail to hold. Below is how each modern system leans on it.

Multi-Head Attention: uniform input scale for every head

In MHA, the input token xRDx \in \mathbb{R}^{D} is projected into HH heads via Qh=xWhQQ_h = x W_h^{Q}, and scores are scaled by 1/dk1 / \sqrt{d_k}. That scaling only gives unit-variance scores if xx itself has unit variance. LN is what enforces that precondition, so every head sees queries and keys at the same statistical scale and no single head dominates the softmax prematurely.

Positional Encodings: LN washes out arbitrary magnitudes

Sinusoidal PE has entries in [1,1][-1, 1], but learned embeddings can have much larger magnitudes. When we compute h=Embed(t)+PE(t)h = \text{Embed}(t) + \text{PE}(t) and pass it through LN, the layer strips out the overall scale and leaves only the direction. That's why empirical work shows model quality is remarkably robust to PE magnitude: LN re-standardizes whatever you hand it.

Rotary Position Embedding (RoPE) is applied after LN, inside the Q and K projections — and because it's a pure rotation in 2-D subspaces, it preserves the unit norm LN just imposed. That's not a coincidence: RoPE was designed to compose cleanly with RMSNorm.

Flash Attention: LN is the kernel boundary

FlashAttention (Dao et al. 2022) fuses softmax(QK/dk)V\text{softmax}(Q K^{\top} / \sqrt{d_k}) V into a single SRAM-resident CUDA kernel, eliminating the O(T2)O(T^{2}) HBM traffic of materializing the attention matrix. But FlashAttention does not include the LayerNorm that precedes the Q/K/V projections — LN sits at the kernel boundary, and its output is the input stream that Flash consumes.

In practice modern stacks fuse LN+Linear (or RMSNorm+Linear) into their own kernel — the whole Pre-LN sub-block becomes:

  • Kernel 1: RMSNorm + QKV projection (fused)
  • Kernel 2: FlashAttention (fused softmax + matmul)
  • Kernel 3: output projection + residual add (fused)

Three kernels instead of the naïve dozen, and LN is the hand-off point for each.

KV-Cache: LN is per-token, so caches stay valid

Autoregressive inference caches past keys and values to avoid recomputing them at every step. For that cache to stay valid, the operation that produces K and V must depend only on the current token's past. LayerNorm satisfies this perfectly: its statistics come from the token being processed, nothing else. BatchNorm would be catastrophic here — its statistics depend on the entire batch, so a cachedKtK_t would be inconsistent with a new batch of different size (see Section 3 for another reason BatchNorm breaks under gradient accumulation).

This is a hard architectural reason Transformers cannot use BatchNorm. The moment you want autoregressive generation with caching, the per-sample (per-token) nature of LN becomes non-negotiable.

Grouped-Query and Multi-Query Attention

GQA and MQA shrink the KV-cache by sharing keys/values across multiple query heads. Both still apply LN to the full DD-dimensional input before projections, so the reduced-size K/V still has unit-variance inputs. The savings are in memory (the cache) and bandwidth, not in the normalization step — which is why LN's cost becomes relatively more visible in MQA models and motivates the switch to RMSNorm.

Transformer scaling: LN pays for itself

At LLaMA-3-70B scale a forward pass performs roughly L×2L \times 2 LayerNorms per token (one before attention, one before FFN) across L=80L = 80 layers. That's 160 normalizations per token, each touching D=8192D = 8192 floats. Moving from LayerNorm to RMSNorm removes one reduction per LN and one learnable parameter — at the 15-trillion-token training budget used for LLaMA-3, that's a measurable wall-clock win.

Scaling laws (Kaplan et al. 2020, Hoffmann et al. 2022) implicitly assume the network is well-conditioned at every scale. Without normalization, the compute-optimal curves break: at 70B+ parameters, a de-normalized network doesn't converge at any learning rate. Normalization is not a performance optimization — it is what makes scaling work at all.

The one-line summary. LayerNorm (and its RMSNorm descendant) is the precondition that lets every other piece of the modern Transformer — scaled dot-product attention, multi-head projection, positional encoding, KV-cache reuse, deep residual stacking, and trillion-token training — work as designed. Remove it and you lose depth, stability, and scale simultaneously.

References

  • Ioffe, S., Szegedy, C. (2015). Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift. ICML.
  • Ba, J. L., Kiros, J. R., Hinton, G. E. (2016). Layer Normalization. arXiv:1607.06450.
  • Santurkar, S. et al. (2018). How Does Batch Normalization Help Optimization? NeurIPS.
  • Vaswani, A. et al. (2017). Attention Is All You Need. NeurIPS.
  • Xiong, R. et al. (2020). On Layer Normalization in the Transformer Architecture. ICML.
  • Zhang, B., Sennrich, R. (2019). Root Mean Square Layer Normalization. NeurIPS.
  • Dao, T. et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS.
  • Wang, H. et al. (2022). DeepNet: Scaling Transformers to 1,000 Layers. arXiv:2203.00555.

Check Your Understanding

A short self-check to see that the key distinctions have stuck.

Test Your Understanding

Question 1 of 10

What problem does Batch Normalization primarily address?

Loading comments...