Introduction
Deep neural networks suffer from internal covariate shiftβthe distribution of layer inputs changes during training, making optimization difficult. Layer Normalization (LayerNorm) solves this by normalizing across features for each sample independently.
This section explains why LayerNorm is essential for transformers and implements it from scratch.
2.1 The Need for Normalization
The Problem: Internal Covariate Shift
As training progresses, each layer's input distribution changes because earlier layers are updating:
1Epoch 1: Layer 3 sees inputs with mean=0.5, std=1.0
2Epoch 10: Layer 3 sees inputs with mean=2.0, std=3.0
3Epoch 100: Layer 3 sees inputs with mean=-1.0, std=0.5
4
5Layer 3 must constantly readjust to changing input statistics!Why This Hurts Training
- Slow convergence: Network constantly chasing moving target
- Careful initialization: Must prevent gradients from exploding/vanishing
- Small learning rates: Large updates cause instability
- Training deep networks: Problem compounds with depth
The Solution: Normalization
Normalize activations to have consistent statistics:
1Before: x = [varying mean, varying std]
2After: x_norm = (x - ΞΌ) / Ο β [mean β 0, std β 1]2.2 Types of Normalization
Batch Normalization
Normalizes across the batch dimension:
1Input: [batch, features]
2 β
3Compute mean and std across batch (for each feature)
4 β
5Normalize each feature to mean=0, std=11# For input x of shape [batch, features]
2mean = x.mean(dim=0) # [features]
3std = x.std(dim=0) # [features]
4x_norm = (x - mean) / std # [batch, features]Problems for transformers:
- Variable sequence lengths in batch
- Batch statistics unreliable for small batches
- Different behavior during training vs inference
Layer Normalization
Normalizes across the feature dimension:
1Input: [batch, seq_len, features]
2 β
3Compute mean and std across features (for each position in each sample)
4 β
5Normalize each position independently1# For input x of shape [batch, seq_len, features]
2mean = x.mean(dim=-1, keepdim=True) # [batch, seq_len, 1]
3std = x.std(dim=-1, keepdim=True) # [batch, seq_len, 1]
4x_norm = (x - mean) / std # [batch, seq_len, features]Why better for transformers:
- Works with variable sequence lengths
- No batch size dependency
- Same behavior train/inference
- Each token normalized independently
Visual Comparison
1Batch Norm (normalize vertically):
2βββββββββββββββββββββββββββ
3β Sample 1: [f1, f2, f3] β Compute mean/std of f1 across all samples
4β Sample 2: [f1, f2, f3] β Compute mean/std of f2 across all samples
5β Sample 3: [f1, f2, f3] β Compute mean/std of f3 across all samples
6βββββββββββββββββββββββββββ
7
8Layer Norm (normalize horizontally):
9βββββββββββββββββββββββββββ
10β Sample 1: [f1, f2, f3] β β Compute mean/std across [f1, f2, f3]
11β Sample 2: [f1, f2, f3] β β Compute mean/std across [f1, f2, f3]
12β Sample 3: [f1, f2, f3] β β Compute mean/std across [f1, f2, f3]
13βββββββββββββββββββββββββββ2.3 LayerNorm Mathematics
The Formula
1LayerNorm(x) = Ξ³ Γ (x - ΞΌ) / β(ΟΒ² + Ξ΅) + Ξ²Where:
- ΞΌ: Mean across features (last dimension)
- ΟΒ²: Variance across features
- Ξ΅: Small constant for numerical stability (e.g., 1e-6)
- Ξ³ (gamma): Learnable scale parameter
- Ξ² (beta): Learnable shift parameter
Step-by-Step Computation
1Input x: [batch, seq_len, d_model]
2
31. Compute mean:
4 ΞΌ = mean(x, dim=-1)
5 Shape: [batch, seq_len, 1]
6
72. Compute variance:
8 ΟΒ² = var(x, dim=-1)
9 Shape: [batch, seq_len, 1]
10
113. Normalize:
12 x_norm = (x - ΞΌ) / β(ΟΒ² + Ξ΅)
13 Shape: [batch, seq_len, d_model]
14
154. Scale and shift:
16 output = Ξ³ Γ x_norm + Ξ²
17 Shape: [batch, seq_len, d_model]Why Ξ³ and Ξ²?
Pure normalization (mean=0, std=1) might be too restrictive:
1Without Ξ³, Ξ²:
2- Output always centered at 0
3- Output always has std of 1
4- Reduces model expressiveness
5
6With Ξ³, Ξ²:
7- Model can learn to "undo" normalization if needed
8- Each feature can have its own scale and shift
9- Maintains representational power2.4 Implementation from Scratch
Basic LayerNorm
1import torch
2import torch.nn as nn
3
4
5class LayerNorm(nn.Module):
6 """
7 Layer Normalization implemented from scratch.
8
9 Normalizes the input across the last dimension (features).
10
11 Args:
12 d_model: The dimension to normalize (feature dimension)
13 eps: Small constant for numerical stability
14
15 Example:
16 >>> ln = LayerNorm(d_model=512)
17 >>> x = torch.randn(2, 10, 512)
18 >>> output = ln(x) # Same shape, normalized
19 """
20
21 def __init__(self, d_model: int, eps: float = 1e-6):
22 super().__init__()
23
24 self.d_model = d_model
25 self.eps = eps
26
27 # Learnable parameters
28 self.gamma = nn.Parameter(torch.ones(d_model)) # Scale
29 self.beta = nn.Parameter(torch.zeros(d_model)) # Shift
30
31 def forward(self, x: torch.Tensor) -> torch.Tensor:
32 """
33 Apply layer normalization.
34
35 Args:
36 x: Input tensor [..., d_model]
37
38 Returns:
39 Normalized tensor [..., d_model]
40 """
41 # Compute mean across last dimension
42 mean = x.mean(dim=-1, keepdim=True)
43
44 # Compute variance across last dimension
45 var = x.var(dim=-1, keepdim=True, unbiased=False)
46
47 # Normalize
48 x_norm = (x - mean) / torch.sqrt(var + self.eps)
49
50 # Scale and shift
51 output = self.gamma * x_norm + self.beta
52
53 return output
54
55 def extra_repr(self) -> str:
56 return f"{self.d_model}, eps={self.eps}"
57
58
59# Test
60def test_layer_norm():
61 d_model = 512
62 batch_size = 2
63 seq_len = 10
64
65 ln = LayerNorm(d_model)
66 x = torch.randn(batch_size, seq_len, d_model) * 5 + 3 # Non-standard stats
67
68 output = ln(x)
69
70 print(f"Input shape: {x.shape}")
71 print(f"Output shape: {output.shape}")
72
73 # Check statistics before/after
74 print(f"\nBefore LayerNorm:")
75 print(f" Mean: {x.mean(dim=-1).mean():.4f}")
76 print(f" Std: {x.std(dim=-1).mean():.4f}")
77
78 print(f"\nAfter LayerNorm:")
79 print(f" Mean: {output.mean(dim=-1).mean():.4f} (should be β 0)")
80 print(f" Std: {output.std(dim=-1).mean():.4f} (should be β 1)")
81
82 # Verify against PyTorch implementation
83 pytorch_ln = nn.LayerNorm(d_model)
84 pytorch_output = pytorch_ln(x)
85
86 # Our output should be similar (different initialization of gamma/beta)
87 print(f"\nParameter count: {sum(p.numel() for p in ln.parameters())}")
88
89 print("\nβ LayerNorm test passed!")
90
91
92test_layer_norm()Output:
1Input shape: torch.Size([2, 10, 512])
2Output shape: torch.Size([2, 10, 512])
3
4Before LayerNorm:
5 Mean: 3.0012
6 Std: 4.9987
7
8After LayerNorm:
9 Mean: 0.0000 (should be β 0)
10 Std: 1.0010 (should be β 1)
11
12Parameter count: 1024
13
14β LayerNorm test passed!2.5 RMSNorm: A Simplified Alternative
Root Mean Square Normalization
RMSNorm (used in LLaMA, T5) skips the mean centering:
1RMSNorm(x) = Ξ³ Γ x / RMS(x)
2
3Where RMS(x) = β(mean(xΒ²))Why RMSNorm?
- Faster: No mean computation
- Similar performance: Often matches LayerNorm
- Simpler gradient: Easier to analyze mathematically
Implementation
1class RMSNorm(nn.Module):
2 """
3 Root Mean Square Layer Normalization.
4
5 Simplifies LayerNorm by removing mean centering.
6 Used by LLaMA and other modern models.
7
8 Args:
9 d_model: Feature dimension
10 eps: Numerical stability constant
11 """
12
13 def __init__(self, d_model: int, eps: float = 1e-6):
14 super().__init__()
15
16 self.d_model = d_model
17 self.eps = eps
18 self.weight = nn.Parameter(torch.ones(d_model))
19
20 def forward(self, x: torch.Tensor) -> torch.Tensor:
21 """
22 Apply RMS normalization.
23
24 Args:
25 x: Input tensor [..., d_model]
26
27 Returns:
28 Normalized tensor [..., d_model]
29 """
30 # Compute RMS
31 rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
32
33 # Normalize and scale
34 return self.weight * (x / rms)
35
36
37# Test RMSNorm
38def test_rms_norm():
39 d_model = 512
40 rms_norm = RMSNorm(d_model)
41
42 x = torch.randn(2, 10, d_model)
43 output = rms_norm(x)
44
45 print(f"Input shape: {x.shape}")
46 print(f"Output shape: {output.shape}")
47
48 # RMSNorm doesn't center at 0, but normalizes magnitude
49 print(f"\nInput RMS: {torch.sqrt((x**2).mean(dim=-1)).mean():.4f}")
50 print(f"Output RMS: {torch.sqrt((output**2).mean(dim=-1)).mean():.4f}")
51
52 print("\nβ RMSNorm test passed!")
53
54
55test_rms_norm()2.6 Comparison: LayerNorm vs BatchNorm
Detailed Comparison
1def compare_normalizations():
2 """Compare LayerNorm and BatchNorm behavior."""
3
4 batch_size = 4
5 seq_len = 10
6 d_model = 8
7
8 # Create input
9 x = torch.randn(batch_size, seq_len, d_model) * 3 + 2
10
11 # Layer Norm
12 layer_norm = nn.LayerNorm(d_model)
13 ln_output = layer_norm(x)
14
15 # Batch Norm (requires reshaping for sequences)
16 batch_norm = nn.BatchNorm1d(d_model)
17 # BatchNorm expects [batch, features, seq_len]
18 x_transposed = x.transpose(1, 2)
19 bn_output = batch_norm(x_transposed).transpose(1, 2)
20
21 print("Comparison of Normalizations")
22 print("=" * 50)
23
24 print("\nLayerNorm statistics (per position):")
25 for b in range(min(2, batch_size)):
26 for s in range(min(3, seq_len)):
27 mean = ln_output[b, s].mean().item()
28 std = ln_output[b, s].std().item()
29 print(f" Batch {b}, Pos {s}: mean={mean:.4f}, std={std:.4f}")
30
31 print("\nBatchNorm statistics (per feature across batch):")
32 for f in range(min(4, d_model)):
33 mean = bn_output[:, :, f].mean().item()
34 std = bn_output[:, :, f].std().item()
35 print(f" Feature {f}: mean={mean:.4f}, std={std:.4f}")
36
37
38compare_normalizations()When to Use Each
| Scenario | Best Choice |
|---|---|
| Transformers | LayerNorm |
| CNNs | BatchNorm |
| RNNs | LayerNorm |
| Small batches | LayerNorm |
| Variable sequence lengths | LayerNorm |
| Image data | BatchNorm |
2.7 LayerNorm Placement
Where Does LayerNorm Go?
In a transformer layer, there are two common placements:
Post-LN (Original Transformer):
1output = LayerNorm(x + Sublayer(x))Pre-LN (Modern Preference):
1output = x + Sublayer(LayerNorm(x))Visual Comparison
1Post-LN: Pre-LN:
2ββββββββββββ ββββββββββββ
3β x β β x β
4ββββββ¬ββββββ ββββββ¬ββββββ
5 β β
6 ββββββββββββββ ββββββββββββββ
7 β β β β
8 βΌ β βΌ β
9ββββββββββββ β ββββββββββββ β
10β Sublayer β β βLayerNorm β β
11ββββββ¬ββββββ β ββββββ¬ββββββ β
12 β β β β
13 βΌ β βΌ β
14 Add βββββββββββ ββββββββββββ β
15 β β Sublayer β β
16 βΌ ββββββ¬ββββββ β
17ββββββββββββ β β
18βLayerNorm β βΌ β
19ββββββ¬ββββββ Add βββββββββββ
20 β β
21 βΌ βΌ
22 Output OutputWhy Pre-LN is Often Better
- Gradient flow: Gradients flow more directly through residuals
- Training stability: Less sensitive to hyperparameters
- Deeper networks: Enables training very deep transformers
- No warmup: Often doesn't need learning rate warmup
We'll explore this more in Section 4.
2.8 Numerical Stability
The Epsilon Parameter
Without epsilon, we risk division by zero:
1# Dangerous: If variance is 0, we get inf/nan
2x_norm = (x - mean) / torch.sqrt(var)
3
4# Safe: eps prevents division by zero
5x_norm = (x - mean) / torch.sqrt(var + eps)When Variance is Near Zero
1def demonstrate_eps_importance():
2 """Show why eps matters."""
3
4 # Create input with nearly identical values
5 x = torch.full((2, 4), 1.0)
6 x[0, 0] = 1.001 # Tiny variation
7
8 print(f"Input:\n{x}")
9 print(f"\nVariance: {x.var(dim=-1)}")
10
11 # Without eps (dangerous)
12 mean = x.mean(dim=-1, keepdim=True)
13 var = x.var(dim=-1, keepdim=True, unbiased=False)
14
15 print(f"\nWithout eps (var can be 0):")
16 try:
17 x_norm_unsafe = (x - mean) / torch.sqrt(var)
18 print(f" Result: {x_norm_unsafe[1]}") # May have inf
19 except Exception as e:
20 print(f" Error: {e}")
21
22 # With eps (safe)
23 eps = 1e-6
24 x_norm_safe = (x - mean) / torch.sqrt(var + eps)
25 print(f"\nWith eps={eps}:")
26 print(f" Result: {x_norm_safe[1]}") # Stable
27
28
29demonstrate_eps_importance()Summary
LayerNorm Formula
1LayerNorm(x) = Ξ³ Γ (x - ΞΌ) / β(ΟΒ² + Ξ΅) + Ξ²Key Properties
| Property | Value |
|---|---|
| Normalization axis | Last dimension (features) |
| Learnable parameters | Ξ³ (scale), Ξ² (shift) |
| Typical eps | 1e-6 or 1e-5 |
| Batch dependency | None (works with batch_size=1) |
Implementation Checklist
- Compute mean across features (dim=-1)
- Compute variance across features
- Add epsilon before square root
- Initialize gamma to 1, beta to 0
- Apply gamma * normalized + beta
Exercises
Implementation Exercises
1. Implement LayerNorm that normalizes over multiple dimensions (e.g., last 2 dims).
2. Create a "conditional" LayerNorm where gamma/beta depend on some condition.
3. Implement GroupNorm (normalize within groups of channels).
Analysis Exercises
4. Compare training curves with LayerNorm vs BatchNorm on a transformer.
5. Visualize how the learned gamma/beta parameters evolve during training.
6. Test the effect of different epsilon values on numerical stability.
Conceptual Questions
7. Why does LayerNorm have learnable parameters? What would happen without them?
8. Could we use Instance Normalization for transformers? Why or why not?
Next Section Preview
In the next section, we'll explore Residual Connectionsβthe skip connections that allow gradients to flow through deep networks. Combined with LayerNorm, they form the "Add & Norm" pattern essential to transformer architecture.