Chapter 5
15 min read
Section 24 of 104

Batch Normalization for Training Stability

CNN Feature Extractor

Learning Objectives

By the end of this section, you will:

  1. Understand internal covariate shift and why it hinders training
  2. Master the batch normalization formula including learnable parameters
  3. Distinguish training from inference behavior with running statistics
  4. Apply BatchNorm1d correctly to time series convolutional layers
  5. Position batch normalization optimally within CNN blocks
Why This Matters: Batch normalization is one of the most important innovations in deep learning. It stabilizes training, allows higher learning rates, and provides mild regularization. Understanding its mechanics is essential for training deep networks effectively.

Internal Covariate Shift

Internal covariate shift refers to the change in the distribution of layer inputs during training, as the parameters of preceding layers change.

The Problem

Consider a deep network where each layer transforms its input. As training progresses:

  1. Layer 1 parameters update → its output distribution shifts
  2. Layer 2 must adapt to this new input distribution
  3. This adaptation is undone when Layer 1 updates again
  4. The cycle continues, slowing convergence
📝text
1Training dynamics without normalization:
2
3Epoch 1:  Layer 1 output ~ N(μ₁, σ₁)  →  Layer 2 adapts
4Epoch 2:  Layer 1 output ~ N(μ₂, σ₂)  →  Layer 2 re-adapts
5Epoch 3:  Layer 1 output ~ N(μ₃, σ₃)  →  Layer 2 re-adapts again
6
7Each layer is "chasing" a moving target!

Consequences

EffectImpact on Training
Slower convergenceMore epochs needed to reach optimum
Lower learning rates requiredLarge steps cause instability
Gradient vanishing/explodingActivations drift to saturation regions
Sensitive initializationPoor initialization causes divergence

The Solution: Normalize Inputs

Batch normalization addresses this by normalizing layer inputs to have zero mean and unit variance, making each layer's job easier.

x^=xμσ    x^N(0,1)\hat{x} = \frac{x - \mu}{\sigma} \implies \hat{x} \sim N(0, 1)

Batch Normalization Formulation

Batch normalization normalizes activations across the batch dimension, then applies a learnable affine transformation.

Step 1: Compute Batch Statistics

For a mini-batch B={x1,...,xm}\mathcal{B} = \{x_1, ..., x_m\}:

μB=1mi=1mxi\mu_{\mathcal{B}} = \frac{1}{m} \sum_{i=1}^{m} x_i
σB2=1mi=1m(xiμB)2\sigma_{\mathcal{B}}^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_{\mathcal{B}})^2

Step 2: Normalize

x^i=xiμBσB2+ϵ\hat{x}_i = \frac{x_i - \mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2 + \epsilon}}

Where ϵ\epsilon (typically 10⁻⁵) is a small constant for numerical stability.

Step 3: Scale and Shift

yi=γx^i+βy_i = \gamma \hat{x}_i + \beta

Where:

  • γ\gamma: Learnable scale parameter (initialized to 1)
  • β\beta: Learnable shift parameter (initialized to 0)

Why Scale and Shift?

Pure normalization to N(0, 1) would limit the network's representational power. The learnable γ\gamma and β\beta allow the network to recover any mean and variance if that's optimal—including undoing the normalization entirely.


Training vs Inference

Batch normalization behaves differently during training and inference.

During Training

  • Use batch statistics (μ_B, σ_B²) for normalization
  • Update running estimates with exponential moving average (EMA)
μrunning(1m)μrunning+mμB\mu_{\text{running}} \leftarrow (1 - m) \cdot \mu_{\text{running}} + m \cdot \mu_{\mathcal{B}}
σrunning2(1m)σrunning2+mσB2\sigma^2_{\text{running}} \leftarrow (1 - m) \cdot \sigma^2_{\text{running}} + m \cdot \sigma^2_{\mathcal{B}}

Where mm is the momentum (typically 0.1 in PyTorch).

During Inference

  • Use running statistics (population estimates)
  • No batch statistics computed—inference is deterministic
  • Single sample inference is well-defined
x^=xμrunningσrunning2+ϵ\hat{x} = \frac{x - \mu_{\text{running}}}{\sqrt{\sigma^2_{\text{running}} + \epsilon}}

PyTorch Behavior

🐍python
1# Training mode: uses batch statistics, updates running stats
2model.train()
3output = model(batch)  # Uses μ_B, σ_B²
4
5# Evaluation mode: uses running statistics, frozen
6model.eval()
7output = model(single_sample)  # Uses μ_running, σ²_running

Always Set Correct Mode

Forgetting to call model.eval() before inference is a common bug. With batch size 1, batch statistics become meaningless (variance is 0), causing incorrect predictions.

Training vs Inference Comparison

AspectTrainingInference
Statistics usedBatch (μ_B, σ_B²)Running (μ_run, σ²_run)
Running statsUpdated via EMAFrozen
Batch dependencyYes (needs mini-batch)No (single sample OK)
DeterministicNo (varies with batch)Yes
PyTorch modemodel.train()model.eval()

BatchNorm1d for Time Series

For 1D convolutional layers processing time series, we use BatchNorm1d which normalizes across the batch and time dimensions for each channel.

Input Shape Convention

PyTorch's BatchNorm1d expects input shape (N,C,L)(N, C, L):

  • NN: Batch size
  • CC: Number of channels (features)
  • LL: Sequence length (time)

Normalization Dimension

For each channel cc, statistics are computed over all batch samples and all time positions:

μc=1NLn=1Nt=1Lxn,c,t\mu_c = \frac{1}{N \cdot L} \sum_{n=1}^{N} \sum_{t=1}^{L} x_{n,c,t}
σc2=1NLn=1Nt=1L(xn,c,tμc)2\sigma_c^2 = \frac{1}{N \cdot L} \sum_{n=1}^{N} \sum_{t=1}^{L} (x_{n,c,t} - \mu_c)^2

Parameter Count

For CC channels:

ParameterCountPurpose
γ (gamma)CLearnable scale
β (beta)CLearnable shift
running_meanCPopulation mean estimate
running_varCPopulation variance estimate

Learnable parameters: 2C2C. For our layer with 64 channels: 128 learnable parameters.


Placement in CNN Blocks

The placement of batch normalization within the block affects training dynamics.

Standard Placement: After Convolution, Before Activation

📝text
1Recommended order (what we use):
2  Conv1D → BatchNorm1d → ReLU → Dropout
3
4The convolution output is normalized before the non-linearity.

Alternative: After Activation

📝text
1Alternative order:
2  Conv1D → ReLU → BatchNorm1d → Dropout
3
4Normalizes after non-linearity. Less common.

Comparison

PlacementProsCons
Before ReLUControls pre-activation scale, standard practiceNormalized values may be clipped by ReLU
After ReLUNormalizes actual activationsReLU zeros may skew statistics

Our choice: Before ReLU. This is the original formulation and works well in practice.

Complete Block Order

📝text
1Input
23Conv1D(in_channels, out_channels, kernel_size=3, padding=1)
45BatchNorm1d(out_channels)  ← Normalize here
67ReLU()
89Dropout(p=0.2)
1011Output

Bias in Convolution

When BatchNorm follows Conv1D, the convolution bias is redundant—the batch norm's β parameter serves the same purpose. Many implementations use bias=False in the convolution to save parameters.


Summary

In this section, we explained batch normalization for CNN training:

  1. Internal covariate shift: Changing input distributions slow training
  2. Batch normalization: Normalizes, then applies learnable γ and β
  3. Training vs inference: Batch stats vs running stats
  4. BatchNorm1d: Normalizes over batch and time for each channel
  5. Placement: After convolution, before activation
PropertyValue
Input shape(batch, channels, time)
Statistics computed overBatch and time dimensions
Learnable parameters2 × channels (γ and β)
Running statsMean and variance per channel
Momentum (PyTorch)0.1 (for EMA)
Epsilon1e-5 (numerical stability)
Looking Ahead: Batch normalization helps training stability but doesn't prevent overfitting. For that, we need regularization. The next section introduces dropout—randomly zeroing activations during training to prevent co-adaptation and improve generalization.

With batch normalization understood, we now examine dropout strategies for regularization.