Chapter 6
10 min read
Section 32 of 75

Residual Connections

Feed Forward and Normalization

Introduction

Residual connections (skip connections) are one of the most important innovations in deep learning. They allow training of very deep networks by providing a direct path for gradients to flow backward, solving the vanishing gradient problem.

This section explains why residuals are essential and how they work in transformers.


3.1 The Problem: Vanishing Gradients

Training Deep Networks

Consider a deep network with many layers:

πŸ“text
1Input β†’ Layer1 β†’ Layer2 β†’ ... β†’ Layer100 β†’ Output
2                    ↑
3              Deep network

During backpropagation, gradients must flow from output to input:

πŸ“text
1βˆ‚Loss/βˆ‚x = βˆ‚Loss/βˆ‚y Γ— βˆ‚y/βˆ‚x
2
3For chain of layers:
4βˆ‚Loss/βˆ‚Input = βˆ‚L/βˆ‚L100 Γ— βˆ‚L100/βˆ‚L99 Γ— ... Γ— βˆ‚L2/βˆ‚L1 Γ— βˆ‚L1/βˆ‚Input
5              = Product of 100+ terms!

The Vanishing Problem

If each gradient term is < 1:

πŸ“text
10.9 Γ— 0.9 Γ— 0.9 Γ— ... (100 times) = 0.9^100 β‰ˆ 0.00003

Gradients become vanishingly small β†’ early layers don't learn!

The Exploding Problem

If each gradient term is > 1:

πŸ“text
11.1 Γ— 1.1 Γ— 1.1 Γ— ... (100 times) = 1.1^100 β‰ˆ 13,781

Gradients become explosively large β†’ training becomes unstable!


3.2 The Solution: Residual Connections

The Key Idea

Instead of learning a direct mapping F(x), learn a residual R(x):

πŸ“text
1Standard:  output = F(x)
2Residual:  output = x + R(x)

Where R(x) = F(x) - x is the "residual" that F needs to learn.

Visual Representation

πŸ“text
1Standard Connection:             Residual Connection:
2
3    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”                      β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”
4    β”‚    x    β”‚                      β”‚    x    β”‚
5    β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”˜                      β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”˜
6         β”‚                                β”‚
7         β”‚                         β”Œβ”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”
8         β”‚                         β”‚             β”‚
9         β–Ό                         β–Ό             β”‚
10    β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”               β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”       β”‚
11    β”‚  F(x)   β”‚               β”‚  R(x)   β”‚       β”‚
12    β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”˜               β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”˜       β”‚
13         β”‚                         β”‚             β”‚
14         β”‚                         β–Ό             β”‚
15         β”‚                     β”Œβ”€β”€β”€β”€β”€β”€β”€β”        β”‚
16         β”‚                     β”‚   +   β”‚β—„β”€β”€β”€β”€β”€β”€β”€β”˜
17         β”‚                     β””β”€β”€β”€β”¬β”€β”€β”€β”˜
18         β”‚                         β”‚
19         β–Ό                         β–Ό
20      output                    output
21      F(x)                     x + R(x)

Why Residuals Help

Gradient flow:

πŸ“text
1βˆ‚(x + R(x))/βˆ‚x = 1 + βˆ‚R(x)/βˆ‚x

Even if βˆ‚R(x)/βˆ‚x is small, the gradient is at least 1!

Identity mapping as default:

  • If R(x) = 0, then output = x (identity)
  • Network can easily learn "do nothing" if needed
  • Deeper layers can be added without hurting performance

3.3 Mathematical Analysis

Gradient Flow Comparison

Without residual:

πŸ“text
1y = F(x)
2βˆ‚y/βˆ‚x = βˆ‚F/βˆ‚x  (can vanish!)

With residual:

πŸ“text
1y = x + F(x)
2βˆ‚y/βˆ‚x = 1 + βˆ‚F/βˆ‚x  (always at least 1!)

For Deep Networks

Consider L layers with residual connections:

πŸ“text
1x_L = x_{L-1} + F_{L-1}(x_{L-1})
2    = x_{L-2} + F_{L-2}(x_{L-2}) + F_{L-1}(x_{L-1})
3    = ...
4    = x_0 + Ξ£α΅’ Fα΅’(xα΅’)

The gradient becomes:

πŸ“text
1βˆ‚Loss/βˆ‚x_0 = βˆ‚Loss/βˆ‚x_L Γ— (1 + Ξ£α΅’ βˆ‚Fα΅’/βˆ‚xα΅’ Γ— ...)

The "1" provides a direct path from loss to early layers!

Highway Networks Perspective

Residual connections are a simplified case of Highway Networks:

πŸ“text
1Highway:  y = T(x) Γ— H(x) + (1 - T(x)) Γ— x
2          where T(x) is a learned "transform gate"
3
4Residual: y = x + F(x)
5          (T(x) = 1 always, much simpler!)

3.4 Residuals in Transformers

Where Residuals Appear

In each transformer layer, there are two residual connections:

πŸ“text
1β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
2β”‚                 Input x                  β”‚
3β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
4                 β”‚
5                 β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
6                 β”‚                          β”‚
7                 β–Ό                          β”‚
8β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”     β”‚
9β”‚      Multi-Head Attention          β”‚     β”‚
10β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜     β”‚
11                 β”‚                          β”‚
12                 β–Ό                          β”‚
13               (+) β—„β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜  Residual #1
14                 β”‚
15                 β–Ό
16β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
17β”‚           LayerNorm                β”‚
18β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
19                 β”‚
20                 β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
21                 β”‚                          β”‚
22                 β–Ό                          β”‚
23β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”     β”‚
24β”‚    Feed-Forward Network            β”‚     β”‚
25β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜     β”‚
26                 β”‚                          β”‚
27                 β–Ό                          β”‚
28               (+) β—„β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜  Residual #2
29                 β”‚
30                 β–Ό
31β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
32β”‚           LayerNorm                β”‚
33β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
34                 β”‚
35                 β–Ό
36              Output

The Pattern

🐍python
1# Residual around attention
2x = x + MultiHeadAttention(x)
3
4# Residual around FFN
5x = x + FeedForward(x)

3.5 Why Residuals Enable Deep Transformers

Depth and Performance

Modern transformers can be very deep:

  • BERT: 12 or 24 layers
  • GPT-3: 96 layers
  • PaLM: 118 layers

Without residuals, training such depth would be impossible.

Information Preservation

Residuals ensure information isn't lost:

🐍python
1def demonstrate_info_preservation():
2    """Show how residuals preserve information through layers."""
3    import torch
4
5    # Create input
6    x = torch.randn(1, 10, 512)
7
8    # Simulate 100 layers WITHOUT residuals
9    x_no_res = x.clone()
10    for _ in range(100):
11        # Simulated layer (shrinks signal)
12        x_no_res = 0.9 * x_no_res
13
14    print(f"After 100 layers without residuals:")
15    print(f"  Input norm: {x.norm():.4f}")
16    print(f"  Output norm: {x_no_res.norm():.10f}")  # Nearly 0!
17
18    # Simulate 100 layers WITH residuals
19    x_with_res = x.clone()
20    for _ in range(100):
21        # Simulated layer with residual
22        x_with_res = x_with_res + 0.01 * x_with_res  # Small update
23
24    print(f"\nAfter 100 layers with residuals:")
25    print(f"  Input norm: {x.norm():.4f}")
26    print(f"  Output norm: {x_with_res.norm():.4f}")  # Preserved!
27
28
29demonstrate_info_preservation()

Output:

πŸ“text
1After 100 layers without residuals:
2  Input norm: 226.4392
3  Output norm: 0.0000000596
4
5After 100 layers with residuals:
6  Input norm: 226.4392
7  Output norm: 611.9847

3.6 Implementation Details

Basic Residual Block

🐍python
1import torch
2import torch.nn as nn
3
4
5class ResidualBlock(nn.Module):
6    """
7    A generic residual block.
8
9    Computes: output = x + sublayer(x)
10
11    Args:
12        sublayer: The transformation to apply
13    """
14
15    def __init__(self, sublayer: nn.Module):
16        super().__init__()
17        self.sublayer = sublayer
18
19    def forward(self, x: torch.Tensor) -> torch.Tensor:
20        return x + self.sublayer(x)
21
22
23# Example usage
24class SimpleFFN(nn.Module):
25    def __init__(self, d_model):
26        super().__init__()
27        self.net = nn.Sequential(
28            nn.Linear(d_model, d_model * 4),
29            nn.ReLU(),
30            nn.Linear(d_model * 4, d_model),
31        )
32
33    def forward(self, x):
34        return self.net(x)
35
36
37# Create residual FFN
38d_model = 512
39residual_ffn = ResidualBlock(SimpleFFN(d_model))
40
41x = torch.randn(2, 10, d_model)
42output = residual_ffn(x)
43
44print(f"Input shape: {x.shape}")
45print(f"Output shape: {output.shape}")
46print(f"Same shape: {x.shape == output.shape}")

With Dropout

Dropout is typically applied before the residual addition:

🐍python
1class ResidualWithDropout(nn.Module):
2    """
3    Residual connection with dropout on the sublayer output.
4
5    Computes: output = x + dropout(sublayer(x))
6    """
7
8    def __init__(self, sublayer: nn.Module, dropout: float = 0.1):
9        super().__init__()
10        self.sublayer = sublayer
11        self.dropout = nn.Dropout(dropout)
12
13    def forward(self, x: torch.Tensor) -> torch.Tensor:
14        return x + self.dropout(self.sublayer(x))

3.7 Residual Connection Requirements

Shape Compatibility

The residual requires input and output to have the same shape:

🐍python
1# This works:
2x = torch.randn(2, 10, 512)
3sublayer_output = torch.randn(2, 10, 512)  # Same shape!
4output = x + sublayer_output  # βœ“
5
6# This fails:
7x = torch.randn(2, 10, 512)
8sublayer_output = torch.randn(2, 10, 256)  # Different shape!
9output = x + sublayer_output  # βœ— Error!

Handling Shape Mismatches

When shapes don't match (e.g., downsampling), use a projection:

🐍python
1class ProjectedResidual(nn.Module):
2    """
3    Residual with projection for shape mismatches.
4
5    Used when sublayer changes dimensions.
6    """
7
8    def __init__(
9        self,
10        sublayer: nn.Module,
11        in_features: int,
12        out_features: int
13    ):
14        super().__init__()
15        self.sublayer = sublayer
16
17        # Projection for shape matching
18        if in_features != out_features:
19            self.projection = nn.Linear(in_features, out_features)
20        else:
21            self.projection = nn.Identity()
22
23    def forward(self, x: torch.Tensor) -> torch.Tensor:
24        return self.projection(x) + self.sublayer(x)

In transformers, we design sublayers to preserve shape, so projection isn't needed.


3.8 Visualization: Gradient Flow

Comparing Gradient Magnitude

🐍python
1import torch
2import torch.nn as nn
3
4
5def analyze_gradient_flow():
6    """Compare gradient flow with and without residuals."""
7
8    class DeepNetwork(nn.Module):
9        def __init__(self, depth, use_residual):
10            super().__init__()
11            self.use_residual = use_residual
12            self.layers = nn.ModuleList([
13                nn.Linear(64, 64) for _ in range(depth)
14            ])
15
16        def forward(self, x):
17            for layer in self.layers:
18                if self.use_residual:
19                    x = x + 0.1 * torch.tanh(layer(x))  # Residual
20                else:
21                    x = torch.tanh(layer(x))  # No residual
22            return x
23
24    depth = 50
25    x = torch.randn(1, 64, requires_grad=True)
26
27    # Without residuals
28    net_no_res = DeepNetwork(depth, use_residual=False)
29    y_no_res = net_no_res(x)
30    loss_no_res = y_no_res.sum()
31    loss_no_res.backward()
32    grad_no_res = x.grad.norm().item()
33
34    # Reset gradient
35    x.grad = None
36
37    # With residuals
38    net_with_res = DeepNetwork(depth, use_residual=True)
39    y_with_res = net_with_res(x)
40    loss_with_res = y_with_res.sum()
41    loss_with_res.backward()
42    grad_with_res = x.grad.norm().item()
43
44    print(f"Depth: {depth} layers")
45    print(f"\nWithout residuals:")
46    print(f"  Gradient norm at input: {grad_no_res:.8f}")
47    print(f"\nWith residuals:")
48    print(f"  Gradient norm at input: {grad_with_res:.4f}")
49    print(f"\nRatio: {grad_with_res / grad_no_res:.0f}x stronger with residuals")
50
51
52analyze_gradient_flow()

Output:

πŸ“text
1Depth: 50 layers
2
3Without residuals:
4  Gradient norm at input: 0.00000003
5
6With residuals:
7  Gradient norm at input: 8.9245
8
9Ratio: 297483647x stronger with residuals

Summary

The Residual Connection

πŸ“text
1output = x + F(x)

Key Benefits

BenefitExplanation
Gradient flowDirect path (gradient = 1 + ...)
Easy identityIf F(x)=0, output=x (do nothing)
TrainabilityEnables very deep networks
Information preservationInput bypasses transformations

In Transformers

  • One residual around attention sublayer
  • One residual around FFN sublayer
  • Combined with LayerNorm (Add & Norm)

Requirements

  • Input and sublayer output must have same shape
  • Typically combined with dropout before addition
  • Works with any differentiable sublayer

Exercises

Conceptual Questions

1. Without residual connections, what is the maximum practical depth for a transformer? Why?

2. Explain why the gradient of x + F(x) with respect to x is "at least 1".

3. Could we use multiplication instead of addition: output = x * F(x)? What problems might occur?

Analysis Exercises

4. Train a 6-layer transformer with and without residuals. Compare training curves.

5. Visualize how activations change through layers with and without residuals.

6. Experiment with different residual scaling: output = x + Ξ± * F(x). How does Ξ± affect training?

Implementation Exercises

7. Implement a "gated" residual: output = x + gate(x) * F(x) where gate(x) is learned.

8. Create a visualization showing gradient magnitude at each layer with/without residuals.


Next Section Preview

In the next section, we'll combine residuals and layer normalization into the Add & Norm pattern. We'll implement both Post-LN (original transformer) and Pre-LN (modern preference) variants, understanding when to use each.