Chapter 6
12 min read
Section 30 of 104

Layer Normalization Integration

Bidirectional LSTM Encoder

Learning Objectives

By the end of this section, you will:

  1. Distinguish batch normalization from layer normalization and when to use each
  2. Master the layer normalization formula and its application
  3. Understand why layer norm is preferred for RNNs
  4. Integrate layer normalization with the BiLSTM output
  5. Configure layer norm parameters in PyTorch
Why This Matters: Batch normalization transformed CNN training, but it has limitations for recurrent networks. Layer normalization provides similar benefits without the batch dependency, making it the preferred choice for LSTM architectures.

Batch Norm vs Layer Norm

Both normalization techniques stabilize training, but they normalize across different dimensions.

Batch Normalization (Recap)

For a batch of activations, batch norm normalizes across the batch dimension:

BatchNorm:normalize across N (batch samples)\text{BatchNorm}: \text{normalize across } N \text{ (batch samples)}
📝text
1Input: (N, T, H)  ← Batch, Time, Hidden
2
3For each feature h:
4  μ_h = mean across all N samples and T timesteps
5  σ_h = std across all N samples and T timesteps
6  normalize: x̂_{n,t,h} = (x_{n,t,h} - μ_h) / σ_h

Layer Normalization

Layer norm normalizes across the feature dimension within each sample:

LayerNorm:normalize across H (features)\text{LayerNorm}: \text{normalize across } H \text{ (features)}
📝text
1Input: (N, T, H)  ← Batch, Time, Hidden
2
3For each sample n and timestep t:
4  μ_{n,t} = mean across all H features
5  σ_{n,t} = std across all H features
6  normalize: x̂_{n,t,h} = (x_{n,t,h} - μ_{n,t}) / σ_{n,t}

Visual Comparison

📝text
1Tensor shape: (N, T, H) = (Batch, Time, Hidden)
2
3BatchNorm:                    LayerNorm:
4    H ──────────►                 H ──────────►
5  ┌─────────────┐               ┌─────────────┐
6T │ ■ ■ ■ ■ ■ ■ │             T │ ▒ ▒ ▒ ▒ ▒ ▒ │  ← normalize this row
7  │ ■ ■ ■ ■ ■ ■ │               │ ░ ░ ░ ░ ░ ░ │  ← normalize this row
8  │ ■ ■ ■ ■ ■ ■ │               │ ▓ ▓ ▓ ▓ ▓ ▓ │  ← normalize this row
9  └─────────────┘               └─────────────┘
10        ↑                             ↑
11    normalize                   each row normalized
12    this column                  independently
13  across samples

Key Differences

AspectBatchNormLayerNorm
Normalizes acrossBatch dimension (N)Feature dimension (H)
Batch dependencyYes (needs batch stats)No (per-sample)
Inference behaviorUses running statsSame as training
Variable sequencesProblematicNatural handling
Best forCNNs with large batchesRNNs, Transformers

Layer Normalization Formula

Layer normalization normalizes each sample independently across features.

Mathematical Formulation

For a single sample at timestep t with hidden dimension H:

μt=1Hh=1Hxt,h\mu_t = \frac{1}{H} \sum_{h=1}^{H} x_{t,h}
σt2=1Hh=1H(xt,hμt)2\sigma_t^2 = \frac{1}{H} \sum_{h=1}^{H} (x_{t,h} - \mu_t)^2
x^t,h=xt,hμtσt2+ϵ\hat{x}_{t,h} = \frac{x_{t,h} - \mu_t}{\sqrt{\sigma_t^2 + \epsilon}}
yt,h=γhx^t,h+βhy_{t,h} = \gamma_h \hat{x}_{t,h} + \beta_h

Where:

  • μt\mu_t: Mean across features at timestep t
  • σt2\sigma_t^2: Variance across features at timestep t
  • ϵ\epsilon: Small constant (1e-5) for numerical stability
  • γh,βh\gamma_h, \beta_h: Learnable scale and shift per feature

Why Layer Norm for RNNs?

Several properties make layer normalization superior to batch normalization for recurrent networks.

Problem 1: Variable Sequence Lengths

RNNs often process sequences of different lengths. With batch norm:

📝text
1Batch with varying lengths:
2  Sequence 1: [x₁, x₂, x₃, x₄, x₅]  (length 5)
3  Sequence 2: [x₁, x₂, x₃, PAD, PAD]  (length 3 + padding)
4
5BatchNorm at position 4:
6  Only uses Sequence 1's value (Sequence 2 is padding)
7  Statistics become unreliable!

Layer norm avoids this—each position is normalized independently.

Problem 2: Sequential Dependency

RNN hidden states at different timesteps are correlated. Batch normalizing across time can damage this structure:

  • State at t depends on state at t-1
  • Normalizing all t positions together ignores temporal order
  • Layer norm respects the sequential structure

Problem 3: Small Batch Sizes

RNN training often uses smaller batches (memory constraints). Batch norm statistics become noisy with small N:

Batch SizeBatchNorm QualityLayerNorm Quality
N = 128GoodGood
N = 32AcceptableGood
N = 8NoisyGood
N = 1FailsGood

Inference Consistency

Layer norm computes the same statistics during training and inference—no running mean/variance to track. This eliminates a common source of train-test discrepancy in RNN models.


Integration with BiLSTM

We apply layer normalization to the BiLSTM output before passing to attention.

Placement

📝text
1BiLSTM Output: (B, 30, 256)
23   LayerNorm(256)  ← Normalize across 256 features
45Normalized: (B, 30, 256)
67   Attention Layer

Why After BiLSTM?

  • Stabilizes attention inputs: Prevents attention scores from exploding/vanishing
  • Consistent scale: All 30 timesteps have comparable magnitudes
  • Faster convergence: Better-conditioned optimization landscape

Parameter Count

LayerNorm for 256 features:

  • γR256\gamma \in \mathbb{R}^{256}: Scale parameters
  • βR256\beta \in \mathbb{R}^{256}: Shift parameters
  • Total: 512 learnable parameters

PyTorch Configuration

🐍python
1import torch.nn as nn
2
3# Layer normalization after BiLSTM
4self.layer_norm = nn.LayerNorm(
5    normalized_shape=256,  # Hidden dimension (2 × 128)
6    eps=1e-5,              # Numerical stability
7    elementwise_affine=True  # Learnable γ and β
8)
9
10# In forward pass
11lstm_output, _ = self.bilstm(cnn_features)  # (B, 30, 256)
12normalized = self.layer_norm(lstm_output)    # (B, 30, 256)

Pre-LayerNorm vs Post-LayerNorm

Our architecture applies layer norm after the BiLSTM (post-norm). Pre-norm variants apply it before each sub-layer. Post-norm is the original formulation and works well for our depth.


Summary

In this section, we integrated layer normalization with the BiLSTM:

  1. Layer vs Batch norm: Layer norm is per-sample, batch-independent
  2. Formula: Normalize across features, then scale and shift
  3. RNN advantages: Handles variable lengths, small batches, sequential structure
  4. Placement: After BiLSTM output, before attention
  5. Parameters: 512 (γ and β for 256 dimensions)
PropertyValue
Normalized dimension256 (BiLSTM output)
Normalization typePer-sample, per-timestep
Learnable parameters512 (γ: 256, β: 256)
Epsilon1e-5
PlacementAfter BiLSTM, before attention
Looking Ahead: We have designed all components of the BiLSTM encoder: two bidirectional layers with dropout and layer normalization. The next section brings everything together with the complete PyTorch implementation.

With all BiLSTM components designed, we now implement the complete encoder module.