Chapter 6
15 min read
Section 31 of 75

Layer Normalization Deep Dive

Feed Forward and Normalization

Introduction

Deep neural networks suffer from internal covariate shiftβ€”the distribution of layer inputs changes during training, making optimization difficult. Layer Normalization (LayerNorm) solves this by normalizing across features for each sample independently.

This section explains why LayerNorm is essential for transformers and implements it from scratch.


2.1 The Need for Normalization

The Problem: Internal Covariate Shift

As training progresses, each layer's input distribution changes because earlier layers are updating:

πŸ“text
1Epoch 1: Layer 3 sees inputs with mean=0.5, std=1.0
2Epoch 10: Layer 3 sees inputs with mean=2.0, std=3.0
3Epoch 100: Layer 3 sees inputs with mean=-1.0, std=0.5
4
5Layer 3 must constantly readjust to changing input statistics!

Why This Hurts Training

  • Slow convergence: Network constantly chasing moving target
  • Careful initialization: Must prevent gradients from exploding/vanishing
  • Small learning rates: Large updates cause instability
  • Training deep networks: Problem compounds with depth

The Solution: Normalization

Normalize activations to have consistent statistics:

πŸ“text
1Before: x = [varying mean, varying std]
2After:  x_norm = (x - ΞΌ) / Οƒ β†’ [mean β‰ˆ 0, std β‰ˆ 1]

2.2 Types of Normalization

Batch Normalization

Normalizes across the batch dimension:

πŸ“text
1Input: [batch, features]
2       ↓
3Compute mean and std across batch (for each feature)
4       ↓
5Normalize each feature to mean=0, std=1
🐍python
1# For input x of shape [batch, features]
2mean = x.mean(dim=0)      # [features]
3std = x.std(dim=0)        # [features]
4x_norm = (x - mean) / std  # [batch, features]

Problems for transformers:

  • Variable sequence lengths in batch
  • Batch statistics unreliable for small batches
  • Different behavior during training vs inference

Layer Normalization

Normalizes across the feature dimension:

πŸ“text
1Input: [batch, seq_len, features]
2       ↓
3Compute mean and std across features (for each position in each sample)
4       ↓
5Normalize each position independently
🐍python
1# For input x of shape [batch, seq_len, features]
2mean = x.mean(dim=-1, keepdim=True)   # [batch, seq_len, 1]
3std = x.std(dim=-1, keepdim=True)      # [batch, seq_len, 1]
4x_norm = (x - mean) / std              # [batch, seq_len, features]

Why better for transformers:

  • Works with variable sequence lengths
  • No batch size dependency
  • Same behavior train/inference
  • Each token normalized independently

Visual Comparison

πŸ“text
1Batch Norm (normalize vertically):
2β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
3β”‚ Sample 1: [f1, f2, f3]  β”‚  Compute mean/std of f1 across all samples
4β”‚ Sample 2: [f1, f2, f3]  β”‚  Compute mean/std of f2 across all samples
5β”‚ Sample 3: [f1, f2, f3]  β”‚  Compute mean/std of f3 across all samples
6β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
7
8Layer Norm (normalize horizontally):
9β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
10β”‚ Sample 1: [f1, f2, f3]  β”‚ β†’ Compute mean/std across [f1, f2, f3]
11β”‚ Sample 2: [f1, f2, f3]  β”‚ β†’ Compute mean/std across [f1, f2, f3]
12β”‚ Sample 3: [f1, f2, f3]  β”‚ β†’ Compute mean/std across [f1, f2, f3]
13β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

2.3 LayerNorm Mathematics

The Formula

πŸ“text
1LayerNorm(x) = Ξ³ Γ— (x - ΞΌ) / √(σ² + Ξ΅) + Ξ²

Where:

  • ΞΌ: Mean across features (last dimension)
  • σ²: Variance across features
  • Ξ΅: Small constant for numerical stability (e.g., 1e-6)
  • Ξ³ (gamma): Learnable scale parameter
  • Ξ² (beta): Learnable shift parameter

Step-by-Step Computation

πŸ“text
1Input x: [batch, seq_len, d_model]
2
31. Compute mean:
4   ΞΌ = mean(x, dim=-1)
5   Shape: [batch, seq_len, 1]
6
72. Compute variance:
8   σ² = var(x, dim=-1)
9   Shape: [batch, seq_len, 1]
10
113. Normalize:
12   x_norm = (x - ΞΌ) / √(σ² + Ξ΅)
13   Shape: [batch, seq_len, d_model]
14
154. Scale and shift:
16   output = Ξ³ Γ— x_norm + Ξ²
17   Shape: [batch, seq_len, d_model]

Why Ξ³ and Ξ²?

Pure normalization (mean=0, std=1) might be too restrictive:

πŸ“text
1Without Ξ³, Ξ²:
2- Output always centered at 0
3- Output always has std of 1
4- Reduces model expressiveness
5
6With Ξ³, Ξ²:
7- Model can learn to "undo" normalization if needed
8- Each feature can have its own scale and shift
9- Maintains representational power

2.4 Implementation from Scratch

Basic LayerNorm

🐍python
1import torch
2import torch.nn as nn
3
4
5class LayerNorm(nn.Module):
6    """
7    Layer Normalization implemented from scratch.
8
9    Normalizes the input across the last dimension (features).
10
11    Args:
12        d_model: The dimension to normalize (feature dimension)
13        eps: Small constant for numerical stability
14
15    Example:
16        >>> ln = LayerNorm(d_model=512)
17        >>> x = torch.randn(2, 10, 512)
18        >>> output = ln(x)  # Same shape, normalized
19    """
20
21    def __init__(self, d_model: int, eps: float = 1e-6):
22        super().__init__()
23
24        self.d_model = d_model
25        self.eps = eps
26
27        # Learnable parameters
28        self.gamma = nn.Parameter(torch.ones(d_model))   # Scale
29        self.beta = nn.Parameter(torch.zeros(d_model))   # Shift
30
31    def forward(self, x: torch.Tensor) -> torch.Tensor:
32        """
33        Apply layer normalization.
34
35        Args:
36            x: Input tensor [..., d_model]
37
38        Returns:
39            Normalized tensor [..., d_model]
40        """
41        # Compute mean across last dimension
42        mean = x.mean(dim=-1, keepdim=True)
43
44        # Compute variance across last dimension
45        var = x.var(dim=-1, keepdim=True, unbiased=False)
46
47        # Normalize
48        x_norm = (x - mean) / torch.sqrt(var + self.eps)
49
50        # Scale and shift
51        output = self.gamma * x_norm + self.beta
52
53        return output
54
55    def extra_repr(self) -> str:
56        return f"{self.d_model}, eps={self.eps}"
57
58
59# Test
60def test_layer_norm():
61    d_model = 512
62    batch_size = 2
63    seq_len = 10
64
65    ln = LayerNorm(d_model)
66    x = torch.randn(batch_size, seq_len, d_model) * 5 + 3  # Non-standard stats
67
68    output = ln(x)
69
70    print(f"Input shape: {x.shape}")
71    print(f"Output shape: {output.shape}")
72
73    # Check statistics before/after
74    print(f"\nBefore LayerNorm:")
75    print(f"  Mean: {x.mean(dim=-1).mean():.4f}")
76    print(f"  Std:  {x.std(dim=-1).mean():.4f}")
77
78    print(f"\nAfter LayerNorm:")
79    print(f"  Mean: {output.mean(dim=-1).mean():.4f} (should be β‰ˆ 0)")
80    print(f"  Std:  {output.std(dim=-1).mean():.4f} (should be β‰ˆ 1)")
81
82    # Verify against PyTorch implementation
83    pytorch_ln = nn.LayerNorm(d_model)
84    pytorch_output = pytorch_ln(x)
85
86    # Our output should be similar (different initialization of gamma/beta)
87    print(f"\nParameter count: {sum(p.numel() for p in ln.parameters())}")
88
89    print("\nβœ“ LayerNorm test passed!")
90
91
92test_layer_norm()

Output:

πŸ“text
1Input shape: torch.Size([2, 10, 512])
2Output shape: torch.Size([2, 10, 512])
3
4Before LayerNorm:
5  Mean: 3.0012
6  Std:  4.9987
7
8After LayerNorm:
9  Mean: 0.0000 (should be β‰ˆ 0)
10  Std:  1.0010 (should be β‰ˆ 1)
11
12Parameter count: 1024
13
14βœ“ LayerNorm test passed!

2.5 RMSNorm: A Simplified Alternative

Root Mean Square Normalization

RMSNorm (used in LLaMA, T5) skips the mean centering:

πŸ“text
1RMSNorm(x) = Ξ³ Γ— x / RMS(x)
2
3Where RMS(x) = √(mean(x²))

Why RMSNorm?

  • Faster: No mean computation
  • Similar performance: Often matches LayerNorm
  • Simpler gradient: Easier to analyze mathematically

Implementation

🐍python
1class RMSNorm(nn.Module):
2    """
3    Root Mean Square Layer Normalization.
4
5    Simplifies LayerNorm by removing mean centering.
6    Used by LLaMA and other modern models.
7
8    Args:
9        d_model: Feature dimension
10        eps: Numerical stability constant
11    """
12
13    def __init__(self, d_model: int, eps: float = 1e-6):
14        super().__init__()
15
16        self.d_model = d_model
17        self.eps = eps
18        self.weight = nn.Parameter(torch.ones(d_model))
19
20    def forward(self, x: torch.Tensor) -> torch.Tensor:
21        """
22        Apply RMS normalization.
23
24        Args:
25            x: Input tensor [..., d_model]
26
27        Returns:
28            Normalized tensor [..., d_model]
29        """
30        # Compute RMS
31        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
32
33        # Normalize and scale
34        return self.weight * (x / rms)
35
36
37# Test RMSNorm
38def test_rms_norm():
39    d_model = 512
40    rms_norm = RMSNorm(d_model)
41
42    x = torch.randn(2, 10, d_model)
43    output = rms_norm(x)
44
45    print(f"Input shape: {x.shape}")
46    print(f"Output shape: {output.shape}")
47
48    # RMSNorm doesn't center at 0, but normalizes magnitude
49    print(f"\nInput RMS: {torch.sqrt((x**2).mean(dim=-1)).mean():.4f}")
50    print(f"Output RMS: {torch.sqrt((output**2).mean(dim=-1)).mean():.4f}")
51
52    print("\nβœ“ RMSNorm test passed!")
53
54
55test_rms_norm()

2.6 Comparison: LayerNorm vs BatchNorm

Detailed Comparison

🐍python
1def compare_normalizations():
2    """Compare LayerNorm and BatchNorm behavior."""
3
4    batch_size = 4
5    seq_len = 10
6    d_model = 8
7
8    # Create input
9    x = torch.randn(batch_size, seq_len, d_model) * 3 + 2
10
11    # Layer Norm
12    layer_norm = nn.LayerNorm(d_model)
13    ln_output = layer_norm(x)
14
15    # Batch Norm (requires reshaping for sequences)
16    batch_norm = nn.BatchNorm1d(d_model)
17    # BatchNorm expects [batch, features, seq_len]
18    x_transposed = x.transpose(1, 2)
19    bn_output = batch_norm(x_transposed).transpose(1, 2)
20
21    print("Comparison of Normalizations")
22    print("=" * 50)
23
24    print("\nLayerNorm statistics (per position):")
25    for b in range(min(2, batch_size)):
26        for s in range(min(3, seq_len)):
27            mean = ln_output[b, s].mean().item()
28            std = ln_output[b, s].std().item()
29            print(f"  Batch {b}, Pos {s}: mean={mean:.4f}, std={std:.4f}")
30
31    print("\nBatchNorm statistics (per feature across batch):")
32    for f in range(min(4, d_model)):
33        mean = bn_output[:, :, f].mean().item()
34        std = bn_output[:, :, f].std().item()
35        print(f"  Feature {f}: mean={mean:.4f}, std={std:.4f}")
36
37
38compare_normalizations()

When to Use Each

ScenarioBest Choice
TransformersLayerNorm
CNNsBatchNorm
RNNsLayerNorm
Small batchesLayerNorm
Variable sequence lengthsLayerNorm
Image dataBatchNorm

2.7 LayerNorm Placement

Where Does LayerNorm Go?

In a transformer layer, there are two common placements:

Post-LN (Original Transformer):

πŸ“text
1output = LayerNorm(x + Sublayer(x))

Pre-LN (Modern Preference):

πŸ“text
1output = x + Sublayer(LayerNorm(x))

Visual Comparison

πŸ“text
1Post-LN:                        Pre-LN:
2β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
3β”‚    x     β”‚                    β”‚    x     β”‚
4β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜                    β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜
5     β”‚                               β”‚
6     β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                  β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
7     β”‚            β”‚                  β”‚            β”‚
8     β–Ό            β”‚                  β–Ό            β”‚
9β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”      β”‚             β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”     β”‚
10β”‚ Sublayer β”‚      β”‚             β”‚LayerNorm β”‚     β”‚
11β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜      β”‚             β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜     β”‚
12     β”‚            β”‚                  β”‚            β”‚
13     β–Ό            β”‚                  β–Ό            β”‚
14    Add β—„β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜             β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”     β”‚
15     β”‚                          β”‚ Sublayer β”‚     β”‚
16     β–Ό                          β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜     β”‚
17β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”                         β”‚            β”‚
18β”‚LayerNorm β”‚                         β–Ό            β”‚
19β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜                        Add β—„β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
20     β”‚                               β”‚
21     β–Ό                               β–Ό
22  Output                          Output

Why Pre-LN is Often Better

  • Gradient flow: Gradients flow more directly through residuals
  • Training stability: Less sensitive to hyperparameters
  • Deeper networks: Enables training very deep transformers
  • No warmup: Often doesn't need learning rate warmup

We'll explore this more in Section 4.


2.8 Numerical Stability

The Epsilon Parameter

Without epsilon, we risk division by zero:

🐍python
1# Dangerous: If variance is 0, we get inf/nan
2x_norm = (x - mean) / torch.sqrt(var)
3
4# Safe: eps prevents division by zero
5x_norm = (x - mean) / torch.sqrt(var + eps)

When Variance is Near Zero

🐍python
1def demonstrate_eps_importance():
2    """Show why eps matters."""
3
4    # Create input with nearly identical values
5    x = torch.full((2, 4), 1.0)
6    x[0, 0] = 1.001  # Tiny variation
7
8    print(f"Input:\n{x}")
9    print(f"\nVariance: {x.var(dim=-1)}")
10
11    # Without eps (dangerous)
12    mean = x.mean(dim=-1, keepdim=True)
13    var = x.var(dim=-1, keepdim=True, unbiased=False)
14
15    print(f"\nWithout eps (var can be 0):")
16    try:
17        x_norm_unsafe = (x - mean) / torch.sqrt(var)
18        print(f"  Result: {x_norm_unsafe[1]}")  # May have inf
19    except Exception as e:
20        print(f"  Error: {e}")
21
22    # With eps (safe)
23    eps = 1e-6
24    x_norm_safe = (x - mean) / torch.sqrt(var + eps)
25    print(f"\nWith eps={eps}:")
26    print(f"  Result: {x_norm_safe[1]}")  # Stable
27
28
29demonstrate_eps_importance()

Summary

LayerNorm Formula

πŸ“text
1LayerNorm(x) = Ξ³ Γ— (x - ΞΌ) / √(σ² + Ξ΅) + Ξ²

Key Properties

PropertyValue
Normalization axisLast dimension (features)
Learnable parametersΞ³ (scale), Ξ² (shift)
Typical eps1e-6 or 1e-5
Batch dependencyNone (works with batch_size=1)

Implementation Checklist

  • Compute mean across features (dim=-1)
  • Compute variance across features
  • Add epsilon before square root
  • Initialize gamma to 1, beta to 0
  • Apply gamma * normalized + beta

Exercises

Implementation Exercises

1. Implement LayerNorm that normalizes over multiple dimensions (e.g., last 2 dims).

2. Create a "conditional" LayerNorm where gamma/beta depend on some condition.

3. Implement GroupNorm (normalize within groups of channels).

Analysis Exercises

4. Compare training curves with LayerNorm vs BatchNorm on a transformer.

5. Visualize how the learned gamma/beta parameters evolve during training.

6. Test the effect of different epsilon values on numerical stability.

Conceptual Questions

7. Why does LayerNorm have learnable parameters? What would happen without them?

8. Could we use Instance Normalization for transformers? Why or why not?


Next Section Preview

In the next section, we'll explore Residual Connectionsβ€”the skip connections that allow gradients to flow through deep networks. Combined with LayerNorm, they form the "Add & Norm" pattern essential to transformer architecture.