Chapter 5
25 min read
Section 35 of 178

Normalization Layers

Neural Network Building Blocks

Learning Objectives

By the end of this section, you will:

  • Understand why normalization is essential for training deep neural networks
  • Grasp the concept of Internal Covariate Shift and how normalization addresses it
  • Master Batch Normalization—its mathematics, implementation, and training vs inference behavior
  • Understand Layer Normalization and why it's the standard for Transformers
  • Know when to use Instance Normalization, Group Normalization, and RMS Normalization
  • Implement all normalization variants in PyTorch
  • Choose the right normalization for your specific architecture and task

The Big Picture

In the previous sections, we covered linear layers, activation functions, and loss functions. These form the computational backbone of neural networks. But there's a critical challenge we haven't addressed: as networks get deeper, they become increasingly difficult to train.

The Training Problem: As gradients flow backward through many layers, two things can go wrong: gradients can vanish (becoming too small to update early layers) or explode (becoming unstably large). But there's another, subtler problem: the distribution of inputs to each layer changes during training, forcing layers to constantly adapt.

Normalization layers address this by ensuring that layer inputs have consistent statistical properties (typically zero mean and unit variance) throughout training. This seemingly simple intervention has dramatic effects:

  1. Faster convergence — Training reaches good performance in fewer iterations
  2. Higher learning rates — You can train more aggressively without instability
  3. Less sensitivity to initialization — The network is more robust to weight initialization choices
  4. Regularization effect — Normalization adds noise that can reduce overfitting

Historical Context

Batch Normalization was introduced by Ioffe and Szegedy in 2015 and immediately became one of the most important techniques in deep learning. It enabled training of much deeper networks and became a standard component in almost every CNN architecture. Later, Layer Normalization (2016) emerged as the preferred choice for Transformers and sequence models, followed by variants like Group Normalization and RMS Normalization for specific use cases.


Internal Covariate Shift

Before diving into normalization techniques, we need to understand the problem they solve: Internal Covariate Shift (ICS).

What is Internal Covariate Shift?

Consider a deep neural network during training. When we update the weights in layer LL, the outputs of layer LL change. But these outputs are the inputs to layer L+1L+1. So layer L+1L+1 now receives inputs from a different distribution than before!

This cascades through the network: every layer is trying to learn a mapping while its input distribution is constantly shifting. It's like trying to hit a moving target.

Covariate Shift in machine learning refers to the situation where the distribution of inputs changes between training and testing. Internal Covariate Shift is when this happens inside the network, at every layer, during training itself.

Effects of Internal Covariate Shift

ProblemEffect on Training
Slower convergenceLayers must constantly re-adapt to new input distributions
Requires lower learning ratesLarge updates cause distribution to shift too much
Careful initialization neededBad initial distributions compound through layers
Saturated activationsInputs drift into saturation regions of sigmoid/tanh

The Normalization Solution

The key insight is simple: if we normalize the inputs to each layer to have zero mean and unit variance, the distribution stays constant regardless of how earlier layers change. Each layer sees a stable, well-behaved input distribution.

x^=xμσ\hat{x} = \frac{x - \mu}{\sigma}

Where μ\mu is the mean and σ\sigma is the standard deviation, computed over some set of activations (which set depends on the normalization type).


Interactive: ICS Visualization

Watch Internal Covariate Shift in action. See how layer statistics change during training without normalization, and how Batch Normalization keeps them stable.

Loading interactive demo...

Quick Check

What is Internal Covariate Shift?


Batch Normalization

Batch Normalization (BatchNorm) was the first widely successful normalization technique. The key idea: normalize activations using statistics computed over the batch dimension.

The BatchNorm Algorithm

For a mini-batch B={x1,...,xm}\mathcal{B} = \{x_1, ..., x_m\}, BatchNorm performs:

μB=1mi=1mxi(batch mean)σB2=1mi=1m(xiμB)2(batch variance)x^i=xiμBσB2+ϵ(normalize)yi=γx^i+β(scale and shift)\begin{aligned} \mu_{\mathcal{B}} &= \frac{1}{m} \sum_{i=1}^{m} x_i & \text{(batch mean)} \\ \sigma^2_{\mathcal{B}} &= \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_{\mathcal{B}})^2 & \text{(batch variance)} \\ \hat{x}_i &= \frac{x_i - \mu_{\mathcal{B}}}{\sqrt{\sigma^2_{\mathcal{B}} + \epsilon}} & \text{(normalize)} \\ y_i &= \gamma \hat{x}_i + \beta & \text{(scale and shift)} \end{aligned}

Where:

  • ϵ\epsilon is a small constant (e.g., 10510^{-5}) for numerical stability
  • γ\gamma and β\beta are learnable parameters

Why the Learnable Parameters?

After normalization, the output has zero mean and unit variance. But what if the optimal representation for the next layer isn't zero-mean and unit-variance? The learnable parameters γ\gamma (scale) and β\beta (shift) allow the network to undo the normalization if that's beneficial!

Representational Power

Setting γ=σ\gamma = \sigma and β=μ\beta = \mu recovers the original, unnormalized activations. So BatchNorm can learn to do nothing if that's optimal, but it defaults to normalized activations which helps training.

Training vs Inference

BatchNorm behaves differently during training and inference:

PhaseStatistics UsedBehavior
TrainingBatch statistics (μ_B, σ²_B)Compute from current batch, update running stats
InferenceRunning statistics (μ_running, σ²_running)Use accumulated statistics, deterministic output

During training, BatchNorm maintains running averages of the batch statistics using exponential moving average:

μrunning(1momentum)μrunning+momentumμBσrunning2(1momentum)σrunning2+momentumσB2\begin{aligned} \mu_{\text{running}} &\leftarrow (1 - \text{momentum}) \cdot \mu_{\text{running}} + \text{momentum} \cdot \mu_{\mathcal{B}} \\ \sigma^2_{\text{running}} &\leftarrow (1 - \text{momentum}) \cdot \sigma^2_{\text{running}} + \text{momentum} \cdot \sigma^2_{\mathcal{B}} \end{aligned}

Common Bug: Forgetting to Switch Modes

Always call model.train() before training and model.eval() before inference! Using training mode at inference gives noisy outputs (batch statistics vary). Using eval mode during training prevents running stats from updating.

Interactive: Batch Norm Step-by-Step

Explore how Batch Normalization works step by step. Watch the transformation from raw activations to normalized and scaled outputs.

Loading interactive demo...


Layer Normalization

Layer Normalization (LayerNorm) normalizes across the feature dimension instead of the batch dimension. This seemingly small change has profound implications.

The LayerNorm Algorithm

For a single input xx with HH features:

μ=1Hi=1Hxi(mean over features)σ2=1Hi=1H(xiμ)2(variance over features)x^i=xiμσ2+ϵ(normalize)yi=γix^i+βi(per-feature scale and shift)\begin{aligned} \mu &= \frac{1}{H} \sum_{i=1}^{H} x_i & \text{(mean over features)} \\ \sigma^2 &= \frac{1}{H} \sum_{i=1}^{H} (x_i - \mu)^2 & \text{(variance over features)} \\ \hat{x}_i &= \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} & \text{(normalize)} \\ y_i &= \gamma_i \hat{x}_i + \beta_i & \text{(per-feature scale and shift)} \end{aligned}

Key Differences from BatchNorm

PropertyBatch NormalizationLayer Normalization
Normalizes overBatch dimension (N)Feature dimension (H)
Statistics perOne per channel/featureOne per sample
Batch size dependencyNeeds batch > 1Works with batch = 1
Train/Eval differenceUses running stats at evalSame at train and eval
Best forCNNs with large batchesTransformers, RNNs

Why LayerNorm for Transformers?

Transformers have several properties that make LayerNorm the preferred choice:

  1. Variable sequence lengths: Different samples may have different sequence lengths. LayerNorm works independently per token position.
  2. Batch size = 1 inference: At inference, you often process one sample at a time. LayerNorm works fine; BatchNorm would need stored statistics.
  3. Autoregressive generation: When generating text token by token, you need consistent behavior. LayerNorm provides this naturally.

Pre-LN vs Post-LN Transformers

The original Transformer applied LayerNorm after the residual connection (Post-LN). Modern Transformers often apply it before (Pre-LN), which improves training stability and enables larger learning rates.

Interactive: Batch Norm vs Layer Norm

Compare how BatchNorm and LayerNorm compute statistics differently. Toggle between them to see which dimensions each method normalizes over.

Loading interactive demo...

Quick Check

You're building a Transformer for language modeling. Which normalization should you use?


Instance Normalization

Instance Normalization (InstanceNorm) normalizes each sample and channel independently, computing statistics only over the spatial dimensions (H, W).

x^nchw=xnchwμncσnc2+ϵ\hat{x}_{nchw} = \frac{x_{nchw} - \mu_{nc}}{\sqrt{\sigma^2_{nc} + \epsilon}}

Where μnc\mu_{nc} and σnc2\sigma^2_{nc} are computed over the spatial dimensions for each (sample, channel) pair.

Instance Norm and Style Transfer

InstanceNorm became famous for its use in style transfer. The key insight: the mean and variance of feature maps capture style information (textures, colors), while the normalized features capture content (shapes, objects). By normalizing each instance independently, InstanceNorm removes style information, making it easier to transfer new styles.

Use CaseWhy Instance Norm Helps
Style transferRemoves source style, enables style injection
Image-to-image translationNormalizes contrast/brightness variations
GANsStabilizes training by normalizing generated samples

Affine Parameters

By default, InstanceNorm in PyTorch has affine=False (no learned γ, β). For style transfer, this is often desired. For other applications, you may want affine=True.

Group Normalization

Group Normalization (GroupNorm) divides channels into groups and normalizes within each group. It's a flexible interpolation between LayerNorm and InstanceNorm.

x^=xμgσg2+ϵ\hat{x} = \frac{x - \mu_g}{\sqrt{\sigma^2_g + \epsilon}}

Where gg is the group index, and statistics are computed over (channels in group, H, W) for each sample.

The Group Norm Spectrum

num_groupsEquivalent ToBehavior
1Layer NormalizationAll channels in one group, normalize all together
num_channelsInstance NormalizationEach channel is its own group
2-64 (typical)Group NormalizationBalance between LN and IN

Why Group Normalization? When batch sizes are small (common in object detection, segmentation, and 3D convolutions), BatchNorm's batch statistics become noisy. GroupNorm provides stable normalization independent of batch size, while still capturing some channel correlations (unlike InstanceNorm).

Choosing Number of Groups

Common choices are 8, 16, or 32 groups. The channels should be evenly divisible by num_groups. Too few groups (close to LayerNorm) may lose spatial specificity; too many (close to InstanceNorm) may lose channel correlations.

RMS Normalization

Root Mean Square Normalization (RMSNorm) is a simplified version of LayerNorm that has become popular in large language models like LLaMA, Mistral, and Gemma.

The RMSNorm Formula

RMSNorm(x)=xRMS(x)γwhereRMS(x)=1ni=1nxi2\text{RMSNorm}(x) = \frac{x}{\text{RMS}(x)} \cdot \gamma \quad \text{where} \quad \text{RMS}(x) = \sqrt{\frac{1}{n}\sum_{i=1}^{n} x_i^2}

Note what's different from LayerNorm:

  • No mean subtraction: We don't center the data, just scale it
  • No bias parameter: Only γ (scale), no β (shift)
  • Simpler gradient: Fewer terms in the backward pass

Why RMSNorm for LLMs?

Research showed that the mean subtraction step in LayerNorm provides minimal benefit for large language models, while adding computational cost. RMSNorm achieves the same training dynamics with ~15% fewer operations.

PropertyLayerNormRMSNorm
CenteringYes (subtract mean)No
ScalingBy standard deviationBy RMS
Parametersγ (scale) + β (shift)γ (scale) only
Compute costHigher~15% lower
PerformanceExcellentComparable

Spectral Normalization

Spectral Normalization is a weight normalization technique (not activation normalization) that constrains the spectral norm (largest singular value) of weight matrices. It's critical for training stable GANs.

The Problem: Lipschitz Constraint

In GANs, the discriminator tends to produce extreme outputs, leading to training instability. Spectral normalization ensures the discriminator is 1-Lipschitz—its outputs can't change too drastically for small input changes.

WSN=Wσ(W)W_{\text{SN}} = \frac{W}{\sigma(W)}

Where σ(W)\sigma(W) is the spectral norm (largest singular value) of the weight matrix WW. After normalization, the spectral norm equals 1.

Power Iteration Method

Computing the exact spectral norm via SVD is expensive. Instead, we use power iteration to approximate it efficiently:

🐍spectral_norm.py
1import torch
2import torch.nn as nn
3
4# Apply spectral normalization to a layer
5conv = nn.Conv2d(64, 128, kernel_size=3, padding=1)
6conv_sn = nn.utils.spectral_norm(conv)
7
8# The weight is now normalized by its spectral norm
9# Internally, PyTorch maintains u and v vectors for power iteration
10
11# Use in discriminator
12class Discriminator(nn.Module):
13    def __init__(self):
14        super().__init__()
15        self.conv1 = nn.utils.spectral_norm(nn.Conv2d(3, 64, 4, 2, 1))
16        self.conv2 = nn.utils.spectral_norm(nn.Conv2d(64, 128, 4, 2, 1))
17        self.conv3 = nn.utils.spectral_norm(nn.Conv2d(128, 256, 4, 2, 1))
18        self.fc = nn.utils.spectral_norm(nn.Linear(256 * 4 * 4, 1))
19
20    def forward(self, x):
21        x = F.leaky_relu(self.conv1(x), 0.2)
22        x = F.leaky_relu(self.conv2(x), 0.2)
23        x = F.leaky_relu(self.conv3(x), 0.2)
24        x = x.view(x.size(0), -1)
25        return self.fc(x)
26
27# Remove spectral normalization if needed
28conv_original = nn.utils.remove_spectral_norm(conv_sn)

When to Use Spectral Normalization

Use CaseApply ToBenefit
GAN discriminatorsAll layersStabilizes training, prevents mode collapse
GAN generatorsOptionalCan help but less critical than discriminator
Contrastive learningProjection headsPrevents feature collapse
Regular classifiersUsually not neededStandard BatchNorm is sufficient

Key Insight

Spectral normalization normalizes weights, not activations. It can be combined with BatchNorm or other activation normalizations. In fact, the combination of spectral norm + batch norm is common in modern GAN architectures.

Synchronized BatchNorm

When training on multiple GPUs, standard BatchNorm computes statistics independently on each GPU's local batch. This can cause problems when the per-GPU batch size is small.

The Multi-GPU Problem

With 4 GPUs and a total batch size of 32, each GPU sees only 8 samples. BatchNorm statistics computed from 8 samples can be noisy, hurting training. Synchronized BatchNorm (SyncBatchNorm) aggregates statistics across all GPUs before normalizing.

🐍sync_batchnorm.py
1import torch
2import torch.nn as nn
3
4# Convert BatchNorm layers to SyncBatchNorm
5model = YourModel()
6model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
7
8# Move to distributed setup
9model = nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
10
11# Alternative: Create SyncBatchNorm directly
12sync_bn = nn.SyncBatchNorm(num_features=64)
13
14# Statistics are now computed across ALL GPUs:
15# - GPU 0: batch of 8 → shares stats with other GPUs
16# - GPU 1: batch of 8 → shares stats with other GPUs
17# - GPU 2: batch of 8 → shares stats with other GPUs
18# - GPU 3: batch of 8 → shares stats with other GPUs
19# Effective batch size for BN statistics: 32 (not 8!)

When to Use SyncBatchNorm

ScenarioUse SyncBatchNorm?Reason
Multi-GPU with small per-GPU batch✅ YesAggregates stats for stability
Single GPU training❌ NoNo benefit, slight overhead
Large per-GPU batch (32+)🔄 OptionalLocal stats already reliable
Object detection/segmentation✅ Usually yesOften trained with small batches

Performance Consideration

SyncBatchNorm requires communication between GPUs at every forward pass, adding overhead. Only use it when per-GPU batch sizes are small enough that local BatchNorm would be unreliable (typically < 16 samples per GPU).

Alternative: GroupNorm

If SyncBatchNorm's overhead is too high, consider GroupNorm instead. It provides batch-size-independent normalization without requiring GPU communication. Many modern architectures (like those using ViT backbones) prefer GroupNorm or LayerNorm for this reason.

Quick Check

You're training a GAN discriminator and experiencing mode collapse. Which normalization might help?


Interactive: All Normalization Methods

Compare all four normalization methods side by side. See which dimensions each method normalizes over and how they partition the tensor.

Loading interactive demo...


PyTorch Implementation

Let's see how to use all these normalization methods in PyTorch.

Batch Normalization

BatchNorm1d and BatchNorm2d
🐍batch_norm.py
5BatchNorm2d for CNNs

Use BatchNorm2d for 4D inputs (N, C, H, W). num_features = number of channels. Normalizes over N, H, W for each channel.

8Input Shape

Batch of 32 images, 64 channels, 28x28 spatial. Statistics are computed per channel across the entire batch and spatial dimensions.

12Gamma (Weight)

Learnable scale parameter γ. Shape [64] = one per channel. Initialized to 1.

13Beta (Bias)

Learnable shift parameter β. Shape [64] = one per channel. Initialized to 0.

16Running Mean

Exponential moving average of batch means. Used during inference (model.eval()). Updated during training with momentum.

17Running Variance

Exponential moving average of batch variances. Ensures consistent behavior at inference regardless of batch size.

20BatchNorm1d for MLPs

Use BatchNorm1d for 2D inputs (N, features). Common in fully-connected layers. Normalizes over the batch dimension.

16 lines without explanation
1import torch
2import torch.nn as nn
3
4# --- Batch Normalization for CNNs ---
5# Input shape: (N, C, H, W)
6batch_norm_2d = nn.BatchNorm2d(num_features=64)
7
8# Example: batch of 32 images, 64 channels, 28x28
9x = torch.randn(32, 64, 28, 28)
10out = batch_norm_2d(x)  # Same shape: (32, 64, 28, 28)
11
12# Access learned parameters
13print(f"Gamma (weight): {batch_norm_2d.weight.shape}")  # [64]
14print(f"Beta (bias): {batch_norm_2d.bias.shape}")       # [64]
15
16# Access running statistics (used at inference)
17print(f"Running mean: {batch_norm_2d.running_mean.shape}")  # [64]
18print(f"Running var: {batch_norm_2d.running_var.shape}")    # [64]
19
20# --- Batch Normalization for 1D data (MLPs) ---
21batch_norm_1d = nn.BatchNorm1d(num_features=256)
22x_1d = torch.randn(32, 256)  # (batch_size, features)
23out_1d = batch_norm_1d(x_1d)

Layer Normalization

LayerNorm for Transformers
🐍layer_norm.py
5LayerNorm

Normalizes over the last D dimensions specified by normalized_shape. Here, normalizes over the 768 feature dimension.

8Transformer Input

Typical transformer input: (batch, sequence_length, hidden_dim). Each token is normalized independently over its 768 features.

12Per-Sample Normalization

Unlike BatchNorm, LayerNorm computes statistics per sample. No batch-level interaction, works with any batch size.

15Multi-Dim LayerNorm

Can normalize over multiple dimensions. Here, normalizes over (C, H, W) for each sample. Less common in practice.

20Learnable Parameters

Same γ and β as BatchNorm, but applied after normalizing over features, not batch.

24No Running Stats

LayerNorm has no running_mean or running_var. Behavior is identical in train() and eval() modes.

20 lines without explanation
1import torch
2import torch.nn as nn
3
4# --- Layer Normalization (for Transformers) ---
5# Normalize over the last dimension(s)
6layer_norm = nn.LayerNorm(normalized_shape=768)
7
8# Example: batch of 32 sequences, 128 tokens, 768 features
9x = torch.randn(32, 128, 768)
10out = layer_norm(x)  # Same shape: (32, 128, 768)
11
12# For each token position, normalize over 768 features
13# Statistics are computed independently per sample & position
14
15# Multiple dimensions normalization
16layer_norm_2d = nn.LayerNorm([64, 28, 28])  # For CNN-like data
17x_2d = torch.randn(32, 64, 28, 28)
18out_2d = layer_norm_2d(x_2d)
19
20# Access learnable parameters
21print(f"Gamma: {layer_norm.weight.shape}")  # [768]
22print(f"Beta: {layer_norm.bias.shape}")     # [768]
23
24# Key difference from BatchNorm:
25# - No running statistics (same behavior train/eval)
26# - Works with batch_size=1

All Normalization Types

Comparing All Normalization Methods
🐍all_norms.py
8BatchNorm2d

Standard for CNNs. Normalizes across batch and spatial dims. Needs batch > 1 for meaningful statistics.

13LayerNorm

Standard for Transformers. Normalizes each sample over all features. Batch-size independent.

18InstanceNorm2d

Popular in style transfer and GANs. Each sample+channel normalized independently. Removes style information.

23GroupNorm

Hybrid: splits channels into groups. G=1 → LayerNorm, G=C → InstanceNorm. Good for small batch training.

27Statistics Summary

Each method computes different numbers of means/variances. BatchNorm: C, LayerNorm: N, InstanceNorm: N×C, GroupNorm: N×G.

28 lines without explanation
1import torch
2import torch.nn as nn
3
4# --- All Normalization Types Comparison ---
5batch_size, channels, height, width = 8, 64, 28, 28
6x = torch.randn(batch_size, channels, height, width)
7
8# 1. Batch Normalization
9# Normalizes over (N, H, W) for each C
10batch_norm = nn.BatchNorm2d(channels)
11out_bn = batch_norm(x)  # Stats: 64 means, 64 vars
12
13# 2. Layer Normalization
14# Normalizes over (C, H, W) for each N
15layer_norm = nn.LayerNorm([channels, height, width])
16out_ln = layer_norm(x)  # Stats: 8 means, 8 vars
17
18# 3. Instance Normalization
19# Normalizes over (H, W) for each (N, C)
20instance_norm = nn.InstanceNorm2d(channels)
21out_in = instance_norm(x)  # Stats: 512 means, 512 vars
22
23# 4. Group Normalization
24# Normalizes over (C/G, H, W) for each (N, G)
25num_groups = 8  # 64 channels / 8 groups = 8 channels/group
26group_norm = nn.GroupNorm(num_groups, channels)
27out_gn = group_norm(x)  # Stats: 64 means, 64 vars
28
29# Summary of dimensions normalized:
30# BatchNorm:    (N, H, W) → 1 stat per channel
31# LayerNorm:    (C, H, W) → 1 stat per sample
32# InstanceNorm: (H, W)    → 1 stat per sample+channel
33# GroupNorm:    (C/G, H, W) → 1 stat per sample+group

RMS Normalization (Custom)

RMSNorm for Modern LLMs
🐍rms_norm.py
6RMSNorm Class

Root Mean Square Normalization. Used in modern LLMs like LLaMA, Mistral, and Gemma for efficiency.

10Only Gamma

RMSNorm only has a scale parameter γ (weight). No bias β parameter, simplifying the computation.

14RMS Calculation

RMS = √(mean(x²)). Measures the magnitude of the vector. Note: no mean subtraction step.

16Normalize and Scale

Divide input by RMS to normalize, then multiply by learned γ. Simpler than LayerNorm.

22Efficiency

~15% faster than LayerNorm in practice. The mean subtraction in LayerNorm is computationally redundant for many architectures.

28Comparison

LayerNorm: 4 ops (mean, subtract, std, scale+shift). RMSNorm: 2 ops (rms, scale). Same expressiveness for transformers.

24 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5# --- RMS Normalization (used in LLaMA, Mistral, etc.) ---
6class RMSNorm(nn.Module):
7    def __init__(self, dim: int, eps: float = 1e-6):
8        super().__init__()
9        self.eps = eps
10        self.weight = nn.Parameter(torch.ones(dim))  # γ only, no β
11
12    def forward(self, x: torch.Tensor) -> torch.Tensor:
13        # RMS = sqrt(mean(x^2))
14        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
15        # Normalize and scale
16        return x / rms * self.weight
17
18# Usage
19rms_norm = RMSNorm(dim=768)
20x = torch.randn(32, 128, 768)  # (batch, seq_len, hidden)
21out = rms_norm(x)
22
23# Why RMSNorm?
24# 1. Simpler: no mean subtraction, just divide by RMS
25# 2. Faster: fewer operations than LayerNorm
26# 3. Empirically works as well as LayerNorm for LLMs
27
28# Comparison with LayerNorm:
29# LayerNorm: (x - mean) / std * γ + β
30# RMSNorm:   x / RMS * γ  (no mean, no β)

Training vs Evaluation Mode

Switching Between Train and Eval Modes
🐍train_eval_mode.py
11model.train()

Sets all modules to training mode. BatchNorm uses batch statistics and updates running stats.

16model.eval()

Sets all modules to evaluation mode. BatchNorm uses running (accumulated) statistics. Also affects Dropout.

22Training Loop

Always call model.train() at the start of training. Ensures correct BatchNorm behavior.

28Inference

Always call model.eval() and use torch.no_grad() for inference. Uses stable running statistics.

33Common Bug

Forgetting to switch modes causes inconsistent results. Training with eval() mode hurts convergence. Inference with train() mode gives noisy outputs.

32 lines without explanation
1import torch
2import torch.nn as nn
3
4# --- Training vs Evaluation Mode ---
5model = nn.Sequential(
6    nn.Linear(784, 256),
7    nn.BatchNorm1d(256),
8    nn.ReLU(),
9    nn.Linear(256, 10)
10)
11
12# Training mode (default)
13model.train()
14# BatchNorm uses batch statistics
15# Updates running_mean and running_var
16
17# Evaluation mode
18model.eval()
19# BatchNorm uses running statistics
20# Does NOT update running_mean/running_var
21
22# Example training loop
23for batch in dataloader:
24    model.train()
25    outputs = model(batch)
26    loss = criterion(outputs, targets)
27    loss.backward()
28    optimizer.step()
29
30# Inference
31model.eval()
32with torch.no_grad():
33    predictions = model(test_data)
34
35# Key: Always call model.train() before training
36#      Always call model.eval() before inference
37# Forgetting this is a common bug!

Choosing the Right Normalization

With so many options, how do you choose? Here's a practical guide:

Decision Tree

  1. Building a Transformer or RNN? → Use LayerNorm (or RMSNorm for efficiency)
  2. Building a CNN with large batch sizes? → Use BatchNorm
  3. Building a CNN with small batch sizes? → Use GroupNorm
  4. Doing style transfer or image-to-image? → Use InstanceNorm
  5. Building a modern LLM? → Use RMSNorm
ArchitectureRecommended NormWhy
ResNet, VGG, etc.BatchNormStandard for CNNs, enables higher LR
Transformers (BERT, GPT)LayerNorm / RMSNormBatch-independent, same train/eval
LLaMA, Mistral, GemmaRMSNorm15% faster, comparable performance
YOLO, Faster R-CNNGroupNormSmall batch training common
StyleGAN, CycleGANInstanceNormRemoves style info for transfer
U-Net for segmentationGroupNormOften trained with small batches
GAN discriminatorsSpectralNorm + BatchNormStabilizes training, 1-Lipschitz
Multi-GPU detection/segSyncBatchNormAggregates stats across GPUs

When in Doubt

For CNNs, start with BatchNorm. For Transformers, start with LayerNorm. These are the established defaults with extensive empirical validation. Only switch if you have a specific reason (small batches, style transfer, etc.).

Test Your Understanding

Test your knowledge of normalization layers with this comprehensive quiz.

Loading interactive demo...


Summary

Key Takeaways

  1. Normalization addresses Internal Covariate Shift—the change in layer input distributions during training that slows convergence
  2. Batch Normalization normalizes over the batch dimension, requires batch > 1, and uses running statistics at inference
  3. Layer Normalization normalizes over features, works with any batch size, and is identical at train and test time
  4. Instance Normalization normalizes each sample and channel independently, useful for style transfer
  5. Group Normalization is a flexible hybrid that works well with small batches
  6. RMS Normalization simplifies LayerNorm for efficiency in modern LLMs
  7. Always switch between train() and eval() modes when using BatchNorm!
NormalizationNormalizes OverBest For
BatchNorm(N, H, W) per CCNNs, large batches
LayerNorm(C, H, W) per NTransformers, RNNs
InstanceNorm(H, W) per N, CStyle transfer, GANs
GroupNorm(C/G, H, W) per N, GSmall batch CNNs
RMSNormLike LayerNorm, no meanModern LLMs
SpectralNormWeights (by spectral norm)GAN discriminators
SyncBatchNormSame as BatchNorm, cross-GPUMulti-GPU small batches

Exercises

Conceptual Questions

  1. Explain why BatchNorm acts as a regularizer during training. What happens to the regularization effect at inference time?
  2. Why does LayerNorm work better than BatchNorm for autoregressive language models (where you generate one token at a time)?
  3. In RMSNorm, why is it acceptable to skip the mean subtraction step? Under what conditions might this be problematic?

Coding Exercises

  1. Implement BatchNorm from Scratch: Write a class that implements BatchNorm2d with training/eval modes and running statistics. Verify against PyTorch.
  2. Compare Normalization Methods: Train the same CNN architecture (e.g., ResNet-18) on CIFAR-10 with BatchNorm, LayerNorm, GroupNorm (G=32), and no normalization. Plot training curves and compare final accuracy.
  3. Small Batch Experiment: Train a model with batch sizes of 4, 8, 16, 32 using BatchNorm and GroupNorm. Compare stability and final performance.
  4. RMSNorm Implementation: Implement RMSNorm and replace LayerNorm in a small Transformer. Measure the speed difference and compare training dynamics.

Challenge: Adaptive Normalization

Implement Conditional Batch Normalization where the γ and β parameters are predicted by another network based on some conditioning input (e.g., a class label or style embedding). This is used in conditional image generation:

  • Take a base BatchNorm layer
  • Add a small MLP that takes a conditioning vector and predicts γ and β
  • Use these predicted parameters instead of learned parameters
  • Train on a conditional generation task

Hint

Look up "Adaptive Instance Normalization (AdaIN)" and "Conditional Batch Normalization" papers for inspiration. These techniques are widely used in GANs and diffusion models.

Coming Up in Chapter 9: While this section covers what normalization layers are and how they work, Chapter 9 (Training Neural Networks) explores the practical aspects: how normalization interacts with learning rate selection, combining normalization with other regularization techniques, and debugging normalization-related training issues.

In the next section, we'll explore Dropout and Regularization—techniques for preventing overfitting and improving generalization in neural networks.