Introduction
Residual connections (skip connections) are one of the most important innovations in deep learning. They allow training of very deep networks by providing a direct path for gradients to flow backward, solving the vanishing gradient problem.
This section explains why residuals are essential and how they work in transformers.
3.1 The Problem: Vanishing Gradients
Training Deep Networks
Consider a deep network with many layers:
1Input β Layer1 β Layer2 β ... β Layer100 β Output
2 β
3 Deep networkDuring backpropagation, gradients must flow from output to input:
1βLoss/βx = βLoss/βy Γ βy/βx
2
3For chain of layers:
4βLoss/βInput = βL/βL100 Γ βL100/βL99 Γ ... Γ βL2/βL1 Γ βL1/βInput
5 = Product of 100+ terms!The Vanishing Problem
If each gradient term is < 1:
10.9 Γ 0.9 Γ 0.9 Γ ... (100 times) = 0.9^100 β 0.00003Gradients become vanishingly small β early layers don't learn!
The Exploding Problem
If each gradient term is > 1:
11.1 Γ 1.1 Γ 1.1 Γ ... (100 times) = 1.1^100 β 13,781Gradients become explosively large β training becomes unstable!
3.2 The Solution: Residual Connections
The Key Idea
Instead of learning a direct mapping F(x), learn a residual R(x):
1Standard: output = F(x)
2Residual: output = x + R(x)Where R(x) = F(x) - x is the "residual" that F needs to learn.
Visual Representation
1Standard Connection: Residual Connection:
2
3 βββββββββββ βββββββββββ
4 β x β β x β
5 ββββββ¬βββββ ββββββ¬βββββ
6 β β
7 β ββββββββ΄βββββββ
8 β β β
9 βΌ βΌ β
10 βββββββββββ βββββββββββ β
11 β F(x) β β R(x) β β
12 ββββββ¬βββββ ββββββ¬βββββ β
13 β β β
14 β βΌ β
15 β βββββββββ β
16 β β + ββββββββββ
17 β βββββ¬ββββ
18 β β
19 βΌ βΌ
20 output output
21 F(x) x + R(x)Why Residuals Help
Gradient flow:
1β(x + R(x))/βx = 1 + βR(x)/βxEven if βR(x)/βx is small, the gradient is at least 1!
Identity mapping as default:
- If R(x) = 0, then output = x (identity)
- Network can easily learn "do nothing" if needed
- Deeper layers can be added without hurting performance
3.3 Mathematical Analysis
Gradient Flow Comparison
Without residual:
1y = F(x)
2βy/βx = βF/βx (can vanish!)With residual:
1y = x + F(x)
2βy/βx = 1 + βF/βx (always at least 1!)For Deep Networks
Consider L layers with residual connections:
1x_L = x_{L-1} + F_{L-1}(x_{L-1})
2 = x_{L-2} + F_{L-2}(x_{L-2}) + F_{L-1}(x_{L-1})
3 = ...
4 = x_0 + Ξ£α΅’ Fα΅’(xα΅’)The gradient becomes:
1βLoss/βx_0 = βLoss/βx_L Γ (1 + Ξ£α΅’ βFα΅’/βxα΅’ Γ ...)The "1" provides a direct path from loss to early layers!
Highway Networks Perspective
Residual connections are a simplified case of Highway Networks:
1Highway: y = T(x) Γ H(x) + (1 - T(x)) Γ x
2 where T(x) is a learned "transform gate"
3
4Residual: y = x + F(x)
5 (T(x) = 1 always, much simpler!)3.4 Residuals in Transformers
Where Residuals Appear
In each transformer layer, there are two residual connections:
1βββββββββββββββββββββββββββββββββββββββββββ
2β Input x β
3ββββββββββββββββββ¬βββββββββββββββββββββββββ
4 β
5 ββββββββββββββββββββββββββββ
6 β β
7 βΌ β
8ββββββββββββββββββββββββββββββββββββββ β
9β Multi-Head Attention β β
10ββββββββββββββββββ¬ββββββββββββββββββββ β
11 β β
12 βΌ β
13 (+) ββββββββββββββββββββββββββ Residual #1
14 β
15 βΌ
16ββββββββββββββββββββββββββββββββββββββ
17β LayerNorm β
18ββββββββββββββββββ¬ββββββββββββββββββββ
19 β
20 ββββββββββββββββββββββββββββ
21 β β
22 βΌ β
23ββββββββββββββββββββββββββββββββββββββ β
24β Feed-Forward Network β β
25ββββββββββββββββββ¬ββββββββββββββββββββ β
26 β β
27 βΌ β
28 (+) ββββββββββββββββββββββββββ Residual #2
29 β
30 βΌ
31ββββββββββββββββββββββββββββββββββββββ
32β LayerNorm β
33ββββββββββββββββββ¬ββββββββββββββββββββ
34 β
35 βΌ
36 OutputThe Pattern
1# Residual around attention
2x = x + MultiHeadAttention(x)
3
4# Residual around FFN
5x = x + FeedForward(x)3.5 Why Residuals Enable Deep Transformers
Depth and Performance
Modern transformers can be very deep:
- BERT: 12 or 24 layers
- GPT-3: 96 layers
- PaLM: 118 layers
Without residuals, training such depth would be impossible.
Information Preservation
Residuals ensure information isn't lost:
1def demonstrate_info_preservation():
2 """Show how residuals preserve information through layers."""
3 import torch
4
5 # Create input
6 x = torch.randn(1, 10, 512)
7
8 # Simulate 100 layers WITHOUT residuals
9 x_no_res = x.clone()
10 for _ in range(100):
11 # Simulated layer (shrinks signal)
12 x_no_res = 0.9 * x_no_res
13
14 print(f"After 100 layers without residuals:")
15 print(f" Input norm: {x.norm():.4f}")
16 print(f" Output norm: {x_no_res.norm():.10f}") # Nearly 0!
17
18 # Simulate 100 layers WITH residuals
19 x_with_res = x.clone()
20 for _ in range(100):
21 # Simulated layer with residual
22 x_with_res = x_with_res + 0.01 * x_with_res # Small update
23
24 print(f"\nAfter 100 layers with residuals:")
25 print(f" Input norm: {x.norm():.4f}")
26 print(f" Output norm: {x_with_res.norm():.4f}") # Preserved!
27
28
29demonstrate_info_preservation()Output:
1After 100 layers without residuals:
2 Input norm: 226.4392
3 Output norm: 0.0000000596
4
5After 100 layers with residuals:
6 Input norm: 226.4392
7 Output norm: 611.98473.6 Implementation Details
Basic Residual Block
1import torch
2import torch.nn as nn
3
4
5class ResidualBlock(nn.Module):
6 """
7 A generic residual block.
8
9 Computes: output = x + sublayer(x)
10
11 Args:
12 sublayer: The transformation to apply
13 """
14
15 def __init__(self, sublayer: nn.Module):
16 super().__init__()
17 self.sublayer = sublayer
18
19 def forward(self, x: torch.Tensor) -> torch.Tensor:
20 return x + self.sublayer(x)
21
22
23# Example usage
24class SimpleFFN(nn.Module):
25 def __init__(self, d_model):
26 super().__init__()
27 self.net = nn.Sequential(
28 nn.Linear(d_model, d_model * 4),
29 nn.ReLU(),
30 nn.Linear(d_model * 4, d_model),
31 )
32
33 def forward(self, x):
34 return self.net(x)
35
36
37# Create residual FFN
38d_model = 512
39residual_ffn = ResidualBlock(SimpleFFN(d_model))
40
41x = torch.randn(2, 10, d_model)
42output = residual_ffn(x)
43
44print(f"Input shape: {x.shape}")
45print(f"Output shape: {output.shape}")
46print(f"Same shape: {x.shape == output.shape}")With Dropout
Dropout is typically applied before the residual addition:
1class ResidualWithDropout(nn.Module):
2 """
3 Residual connection with dropout on the sublayer output.
4
5 Computes: output = x + dropout(sublayer(x))
6 """
7
8 def __init__(self, sublayer: nn.Module, dropout: float = 0.1):
9 super().__init__()
10 self.sublayer = sublayer
11 self.dropout = nn.Dropout(dropout)
12
13 def forward(self, x: torch.Tensor) -> torch.Tensor:
14 return x + self.dropout(self.sublayer(x))3.7 Residual Connection Requirements
Shape Compatibility
The residual requires input and output to have the same shape:
1# This works:
2x = torch.randn(2, 10, 512)
3sublayer_output = torch.randn(2, 10, 512) # Same shape!
4output = x + sublayer_output # β
5
6# This fails:
7x = torch.randn(2, 10, 512)
8sublayer_output = torch.randn(2, 10, 256) # Different shape!
9output = x + sublayer_output # β Error!Handling Shape Mismatches
When shapes don't match (e.g., downsampling), use a projection:
1class ProjectedResidual(nn.Module):
2 """
3 Residual with projection for shape mismatches.
4
5 Used when sublayer changes dimensions.
6 """
7
8 def __init__(
9 self,
10 sublayer: nn.Module,
11 in_features: int,
12 out_features: int
13 ):
14 super().__init__()
15 self.sublayer = sublayer
16
17 # Projection for shape matching
18 if in_features != out_features:
19 self.projection = nn.Linear(in_features, out_features)
20 else:
21 self.projection = nn.Identity()
22
23 def forward(self, x: torch.Tensor) -> torch.Tensor:
24 return self.projection(x) + self.sublayer(x)In transformers, we design sublayers to preserve shape, so projection isn't needed.
3.8 Visualization: Gradient Flow
Comparing Gradient Magnitude
1import torch
2import torch.nn as nn
3
4
5def analyze_gradient_flow():
6 """Compare gradient flow with and without residuals."""
7
8 class DeepNetwork(nn.Module):
9 def __init__(self, depth, use_residual):
10 super().__init__()
11 self.use_residual = use_residual
12 self.layers = nn.ModuleList([
13 nn.Linear(64, 64) for _ in range(depth)
14 ])
15
16 def forward(self, x):
17 for layer in self.layers:
18 if self.use_residual:
19 x = x + 0.1 * torch.tanh(layer(x)) # Residual
20 else:
21 x = torch.tanh(layer(x)) # No residual
22 return x
23
24 depth = 50
25 x = torch.randn(1, 64, requires_grad=True)
26
27 # Without residuals
28 net_no_res = DeepNetwork(depth, use_residual=False)
29 y_no_res = net_no_res(x)
30 loss_no_res = y_no_res.sum()
31 loss_no_res.backward()
32 grad_no_res = x.grad.norm().item()
33
34 # Reset gradient
35 x.grad = None
36
37 # With residuals
38 net_with_res = DeepNetwork(depth, use_residual=True)
39 y_with_res = net_with_res(x)
40 loss_with_res = y_with_res.sum()
41 loss_with_res.backward()
42 grad_with_res = x.grad.norm().item()
43
44 print(f"Depth: {depth} layers")
45 print(f"\nWithout residuals:")
46 print(f" Gradient norm at input: {grad_no_res:.8f}")
47 print(f"\nWith residuals:")
48 print(f" Gradient norm at input: {grad_with_res:.4f}")
49 print(f"\nRatio: {grad_with_res / grad_no_res:.0f}x stronger with residuals")
50
51
52analyze_gradient_flow()Output:
1Depth: 50 layers
2
3Without residuals:
4 Gradient norm at input: 0.00000003
5
6With residuals:
7 Gradient norm at input: 8.9245
8
9Ratio: 297483647x stronger with residualsSummary
The Residual Connection
1output = x + F(x)Key Benefits
| Benefit | Explanation |
|---|---|
| Gradient flow | Direct path (gradient = 1 + ...) |
| Easy identity | If F(x)=0, output=x (do nothing) |
| Trainability | Enables very deep networks |
| Information preservation | Input bypasses transformations |
In Transformers
- One residual around attention sublayer
- One residual around FFN sublayer
- Combined with LayerNorm (Add & Norm)
Requirements
- Input and sublayer output must have same shape
- Typically combined with dropout before addition
- Works with any differentiable sublayer
Exercises
Conceptual Questions
1. Without residual connections, what is the maximum practical depth for a transformer? Why?
2. Explain why the gradient of x + F(x) with respect to x is "at least 1".
3. Could we use multiplication instead of addition: output = x * F(x)? What problems might occur?
Analysis Exercises
4. Train a 6-layer transformer with and without residuals. Compare training curves.
5. Visualize how activations change through layers with and without residuals.
6. Experiment with different residual scaling: output = x + Ξ± * F(x). How does Ξ± affect training?
Implementation Exercises
7. Implement a "gated" residual: output = x + gate(x) * F(x) where gate(x) is learned.
8. Create a visualization showing gradient magnitude at each layer with/without residuals.
Next Section Preview
In the next section, we'll combine residuals and layer normalization into the Add & Norm pattern. We'll implement both Post-LN (original transformer) and Pre-LN (modern preference) variants, understanding when to use each.