Learning Objectives
By the end of this section, you will:
- Distinguish batch normalization from layer normalization and when to use each
- Master the layer normalization formula and its application
- Understand why layer norm is preferred for RNNs
- Integrate layer normalization with the BiLSTM output
- Configure layer norm parameters in PyTorch
Why This Matters: Batch normalization transformed CNN training, but it has limitations for recurrent networks. Layer normalization provides similar benefits without the batch dependency, making it the preferred choice for LSTM architectures.
Batch Norm vs Layer Norm
Both normalization techniques stabilize training, but they normalize across different dimensions.
Batch Normalization (Recap)
For a batch of activations, batch norm normalizes across the batch dimension:
1Input: (N, T, H) ← Batch, Time, Hidden
2
3For each feature h:
4 μ_h = mean across all N samples and T timesteps
5 σ_h = std across all N samples and T timesteps
6 normalize: x̂_{n,t,h} = (x_{n,t,h} - μ_h) / σ_hLayer Normalization
Layer norm normalizes across the feature dimension within each sample:
1Input: (N, T, H) ← Batch, Time, Hidden
2
3For each sample n and timestep t:
4 μ_{n,t} = mean across all H features
5 σ_{n,t} = std across all H features
6 normalize: x̂_{n,t,h} = (x_{n,t,h} - μ_{n,t}) / σ_{n,t}Visual Comparison
1Tensor shape: (N, T, H) = (Batch, Time, Hidden)
2
3BatchNorm: LayerNorm:
4 H ──────────► H ──────────►
5 ┌─────────────┐ ┌─────────────┐
6T │ ■ ■ ■ ■ ■ ■ │ T │ ▒ ▒ ▒ ▒ ▒ ▒ │ ← normalize this row
7 │ ■ ■ ■ ■ ■ ■ │ │ ░ ░ ░ ░ ░ ░ │ ← normalize this row
8 │ ■ ■ ■ ■ ■ ■ │ │ ▓ ▓ ▓ ▓ ▓ ▓ │ ← normalize this row
9 └─────────────┘ └─────────────┘
10 ↑ ↑
11 normalize each row normalized
12 this column independently
13 across samplesKey Differences
| Aspect | BatchNorm | LayerNorm |
|---|---|---|
| Normalizes across | Batch dimension (N) | Feature dimension (H) |
| Batch dependency | Yes (needs batch stats) | No (per-sample) |
| Inference behavior | Uses running stats | Same as training |
| Variable sequences | Problematic | Natural handling |
| Best for | CNNs with large batches | RNNs, Transformers |
Layer Normalization Formula
Layer normalization normalizes each sample independently across features.
Mathematical Formulation
For a single sample at timestep t with hidden dimension H:
Where:
- : Mean across features at timestep t
- : Variance across features at timestep t
- : Small constant (1e-5) for numerical stability
- : Learnable scale and shift per feature
Why Layer Norm for RNNs?
Several properties make layer normalization superior to batch normalization for recurrent networks.
Problem 1: Variable Sequence Lengths
RNNs often process sequences of different lengths. With batch norm:
1Batch with varying lengths:
2 Sequence 1: [x₁, x₂, x₃, x₄, x₅] (length 5)
3 Sequence 2: [x₁, x₂, x₃, PAD, PAD] (length 3 + padding)
4
5BatchNorm at position 4:
6 Only uses Sequence 1's value (Sequence 2 is padding)
7 Statistics become unreliable!Layer norm avoids this—each position is normalized independently.
Problem 2: Sequential Dependency
RNN hidden states at different timesteps are correlated. Batch normalizing across time can damage this structure:
- State at t depends on state at t-1
- Normalizing all t positions together ignores temporal order
- Layer norm respects the sequential structure
Problem 3: Small Batch Sizes
RNN training often uses smaller batches (memory constraints). Batch norm statistics become noisy with small N:
| Batch Size | BatchNorm Quality | LayerNorm Quality |
|---|---|---|
| N = 128 | Good | Good |
| N = 32 | Acceptable | Good |
| N = 8 | Noisy | Good |
| N = 1 | Fails | Good |
Inference Consistency
Layer norm computes the same statistics during training and inference—no running mean/variance to track. This eliminates a common source of train-test discrepancy in RNN models.
Integration with BiLSTM
We apply layer normalization to the BiLSTM output before passing to attention.
Placement
1BiLSTM Output: (B, 30, 256)
2 ↓
3 LayerNorm(256) ← Normalize across 256 features
4 ↓
5Normalized: (B, 30, 256)
6 ↓
7 Attention LayerWhy After BiLSTM?
- Stabilizes attention inputs: Prevents attention scores from exploding/vanishing
- Consistent scale: All 30 timesteps have comparable magnitudes
- Faster convergence: Better-conditioned optimization landscape
Parameter Count
LayerNorm for 256 features:
- : Scale parameters
- : Shift parameters
- Total: 512 learnable parameters
PyTorch Configuration
1import torch.nn as nn
2
3# Layer normalization after BiLSTM
4self.layer_norm = nn.LayerNorm(
5 normalized_shape=256, # Hidden dimension (2 × 128)
6 eps=1e-5, # Numerical stability
7 elementwise_affine=True # Learnable γ and β
8)
9
10# In forward pass
11lstm_output, _ = self.bilstm(cnn_features) # (B, 30, 256)
12normalized = self.layer_norm(lstm_output) # (B, 30, 256)Pre-LayerNorm vs Post-LayerNorm
Our architecture applies layer norm after the BiLSTM (post-norm). Pre-norm variants apply it before each sub-layer. Post-norm is the original formulation and works well for our depth.
Summary
In this section, we integrated layer normalization with the BiLSTM:
- Layer vs Batch norm: Layer norm is per-sample, batch-independent
- Formula: Normalize across features, then scale and shift
- RNN advantages: Handles variable lengths, small batches, sequential structure
- Placement: After BiLSTM output, before attention
- Parameters: 512 (γ and β for 256 dimensions)
| Property | Value |
|---|---|
| Normalized dimension | 256 (BiLSTM output) |
| Normalization type | Per-sample, per-timestep |
| Learnable parameters | 512 (γ: 256, β: 256) |
| Epsilon | 1e-5 |
| Placement | After BiLSTM, before attention |
Looking Ahead: We have designed all components of the BiLSTM encoder: two bidirectional layers with dropout and layer normalization. The next section brings everything together with the complete PyTorch implementation.
With all BiLSTM components designed, we now implement the complete encoder module.