Chapter 7
12 min read
Section 36 of 75

Stacking Encoder Layers

Transformer Encoder

Introduction

A single encoder layer provides one round of self-attention and feed-forward processing. To build powerful representations, we stack N identical layers (typically N=6 for the base transformer).

This section covers building the complete TransformerEncoder by stacking encoder layers properly.


Why Stack Layers?

Hierarchical Representation Learning

Each layer builds on the previous one:

πŸ“text
1Layer 1: Low-level patterns
2  - Word types (noun, verb, adjective)
3  - Simple local dependencies
4
5Layer 2-3: Mid-level patterns
6  - Phrase structure
7  - Subject-verb agreement
8  - Modifier relationships
9
10Layer 4-6: High-level patterns
11  - Semantic relationships
12  - Long-range dependencies
13  - Abstract concepts

Information Refinement

πŸ“text
1Input: "The quick brown fox jumps over the lazy dog"
2
3After Layer 1:
4  - "fox" knows it's near "quick" and "brown"
5
6After Layer 3:
7  - "fox" understands it's the subject of "jumps"
8  - "jumps" knows its object is related to "dog"
9
10After Layer 6:
11  - Full semantic understanding
12  - "fox" has rich representation capturing all relationships

Using nn.ModuleList

Why Not a Python List?

🐍python
1# WRONG: Python list - parameters not registered
2class BadEncoder(nn.Module):
3    def __init__(self):
4        self.layers = [EncoderLayer() for _ in range(6)]
5        # These layers won't be saved, won't get gradients properly!
6
7# CORRECT: nn.ModuleList - parameters registered
8class GoodEncoder(nn.Module):
9    def __init__(self):
10        self.layers = nn.ModuleList([EncoderLayer() for _ in range(6)])
11        # All layer parameters properly registered

Benefits of nn.ModuleList

1. Parameter registration: All parameters appear in model.parameters()

2. Device movement: model.to(device) moves all layers

3. Serialization: model.state_dict() includes all layers

4. Iteration: Can iterate like a list

🐍python
1# Check parameters
2encoder = GoodEncoder()
3print(f"Total parameters: {sum(p.numel() for p in encoder.parameters())}")
4
5# Move to GPU
6encoder = encoder.to('cuda')
7
8# Save/load works
9torch.save(encoder.state_dict(), 'encoder.pt')

TransformerEncoder Implementation

Complete Implementation

🐍python
1import torch
2import torch.nn as nn
3from typing import Optional
4
5
6class TransformerEncoder(nn.Module):
7    """
8    Complete Transformer Encoder.
9
10    Stacks N encoder layers and applies final layer normalization.
11
12    Args:
13        d_model: Model dimension (default: 512)
14        num_heads: Number of attention heads (default: 8)
15        num_layers: Number of encoder layers (default: 6)
16        d_ff: Feed-forward dimension (default: 2048)
17        dropout: Dropout probability (default: 0.1)
18        pre_norm: Use Pre-LN if True (default: True)
19
20    Shape:
21        Input: [batch, seq_len, d_model]
22        Output: [batch, seq_len, d_model]
23
24    Example:
25        >>> encoder = TransformerEncoder(d_model=512, num_layers=6)
26        >>> x = torch.randn(2, 10, 512)
27        >>> output = encoder(x)  # [2, 10, 512]
28    """
29
30    def __init__(
31        self,
32        d_model: int = 512,
33        num_heads: int = 8,
34        num_layers: int = 6,
35        d_ff: int = 2048,
36        dropout: float = 0.1,
37        pre_norm: bool = True
38    ):
39        super().__init__()
40
41        self.d_model = d_model
42        self.num_layers = num_layers
43        self.pre_norm = pre_norm
44
45        # Stack of encoder layers
46        self.layers = nn.ModuleList([
47            TransformerEncoderLayer(
48                d_model=d_model,
49                num_heads=num_heads,
50                d_ff=d_ff,
51                dropout=dropout,
52                pre_norm=pre_norm
53            )
54            for _ in range(num_layers)
55        ])
56
57        # Final layer normalization (needed for Pre-LN)
58        self.norm = nn.LayerNorm(d_model) if pre_norm else nn.Identity()
59
60        # Initialize weights
61        self._init_weights()
62
63    def _init_weights(self):
64        """Initialize weights using Xavier initialization."""
65        for p in self.parameters():
66            if p.dim() > 1:
67                nn.init.xavier_uniform_(p)
68
69    def forward(
70        self,
71        x: torch.Tensor,
72        src_mask: Optional[torch.Tensor] = None
73    ) -> torch.Tensor:
74        """
75        Forward pass through all encoder layers.
76
77        Args:
78            x: Input tensor [batch, seq_len, d_model]
79            src_mask: Padding mask [batch, 1, 1, seq_len]
80
81        Returns:
82            Encoded output [batch, seq_len, d_model]
83        """
84        # Pass through each encoder layer
85        for layer in self.layers:
86            x = layer(x, src_mask=src_mask)
87
88        # Final normalization (for Pre-LN)
89        x = self.norm(x)
90
91        return x
92
93    def extra_repr(self) -> str:
94        return f"num_layers={self.num_layers}, d_model={self.d_model}"
95
96
97# Include the encoder layer class for completeness
98class TransformerEncoderLayer(nn.Module):
99    """Single encoder layer (from Section 2)."""
100
101    def __init__(
102        self,
103        d_model: int = 512,
104        num_heads: int = 8,
105        d_ff: int = 2048,
106        dropout: float = 0.1,
107        pre_norm: bool = True
108    ):
109        super().__init__()
110
111        self.pre_norm = pre_norm
112
113        # Self-attention
114        self.self_attention = nn.MultiheadAttention(
115            embed_dim=d_model,
116            num_heads=num_heads,
117            dropout=dropout,
118            batch_first=True
119        )
120
121        # Feed-forward
122        self.feed_forward = nn.Sequential(
123            nn.Linear(d_model, d_ff),
124            nn.GELU(),
125            nn.Dropout(dropout),
126            nn.Linear(d_ff, d_model),
127        )
128
129        # Layer norms
130        self.norm1 = nn.LayerNorm(d_model)
131        self.norm2 = nn.LayerNorm(d_model)
132
133        # Dropout
134        self.dropout1 = nn.Dropout(dropout)
135        self.dropout2 = nn.Dropout(dropout)
136
137    def forward(
138        self,
139        x: torch.Tensor,
140        src_mask: Optional[torch.Tensor] = None
141    ) -> torch.Tensor:
142        """Forward pass."""
143        if self.pre_norm:
144            # Pre-LN
145            normed = self.norm1(x)
146            attn_out, _ = self.self_attention(normed, normed, normed, key_padding_mask=src_mask)
147            x = x + self.dropout1(attn_out)
148
149            normed = self.norm2(x)
150            ff_out = self.feed_forward(normed)
151            x = x + self.dropout2(ff_out)
152        else:
153            # Post-LN
154            attn_out, _ = self.self_attention(x, x, x, key_padding_mask=src_mask)
155            x = self.norm1(x + self.dropout1(attn_out))
156
157            ff_out = self.feed_forward(x)
158            x = self.norm2(x + self.dropout2(ff_out))
159
160        return x

Testing the Complete Encoder

Comprehensive Tests

🐍python
1def test_transformer_encoder():
2    """Test the complete TransformerEncoder."""
3
4    print("TransformerEncoder Test")
5    print("=" * 60)
6
7    # Configuration
8    d_model = 512
9    num_heads = 8
10    num_layers = 6
11    d_ff = 2048
12    batch_size = 4
13    seq_len = 50
14
15    # Create encoder
16    encoder = TransformerEncoder(
17        d_model=d_model,
18        num_heads=num_heads,
19        num_layers=num_layers,
20        d_ff=d_ff,
21        dropout=0.1,
22        pre_norm=True
23    )
24
25    # Test 1: Basic forward pass
26    x = torch.randn(batch_size, seq_len, d_model)
27    output = encoder(x)
28
29    print(f"\nTest 1: Basic forward pass")
30    print(f"  Input:  {x.shape}")
31    print(f"  Output: {output.shape}")
32    assert output.shape == x.shape
33    print("  βœ“ Passed")
34
35    # Test 2: With padding mask
36    # Mask format for nn.MultiheadAttention: True = ignore
37    padding_lengths = [50, 40, 30, 20]  # Real token lengths
38    mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
39    for i, length in enumerate(padding_lengths):
40        mask[i, length:] = True  # Mark padding positions
41
42    output_masked = encoder(x, src_mask=mask)
43
44    print(f"\nTest 2: With padding mask")
45    print(f"  Mask shape: {mask.shape}")
46    assert output_masked.shape == x.shape
47    print("  βœ“ Passed")
48
49    # Test 3: Parameter count
50    total_params = sum(p.numel() for p in encoder.parameters())
51    trainable_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
52
53    print(f"\nTest 3: Parameter count")
54    print(f"  Total: {total_params:,}")
55    print(f"  Trainable: {trainable_params:,}")
56    print(f"  Per layer: ~{total_params // num_layers:,}")
57
58    # Test 4: Check layer registration
59    print(f"\nTest 4: Module structure")
60    print(f"  Number of layers: {len(encoder.layers)}")
61    print(f"  Final norm: {encoder.norm}")
62    print("  βœ“ Passed")
63
64    # Test 5: Gradient flow through all layers
65    x = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
66    output = encoder(x)
67    loss = output.sum()
68    loss.backward()
69
70    print(f"\nTest 5: Gradient flow")
71    print(f"  Input gradient norm: {x.grad.norm():.4f}")
72
73    # Check gradients in first and last layers
74    first_layer_grad = list(encoder.layers[0].parameters())[0].grad
75    last_layer_grad = list(encoder.layers[-1].parameters())[0].grad
76
77    print(f"  First layer gradient norm: {first_layer_grad.norm():.4f}")
78    print(f"  Last layer gradient norm: {last_layer_grad.norm():.4f}")
79    print("  βœ“ Gradients flowing through all layers")
80
81    # Test 6: Different sequence lengths
82    print(f"\nTest 6: Variable sequence lengths")
83    for test_len in [10, 100, 200]:
84        x_test = torch.randn(2, test_len, d_model)
85        out_test = encoder(x_test)
86        assert out_test.shape == (2, test_len, d_model)
87        print(f"  seq_len={test_len}: βœ“")
88
89    print("\n" + "=" * 60)
90    print("All encoder tests passed! βœ“")
91
92    return encoder
93
94
95encoder = test_transformer_encoder()

Visualizing Layer Outputs

Tracking Representations Through Layers

🐍python
1class TransformerEncoderWithHooks(TransformerEncoder):
2    """Encoder that can capture intermediate layer outputs."""
3
4    def forward(
5        self,
6        x: torch.Tensor,
7        src_mask: Optional[torch.Tensor] = None,
8        return_all_layers: bool = False
9    ) -> torch.Tensor:
10        """
11        Args:
12            return_all_layers: If True, return outputs from all layers
13
14        Returns:
15            If return_all_layers=False: Final output [batch, seq_len, d_model]
16            If return_all_layers=True: List of outputs from each layer
17        """
18        all_outputs = [x]  # Include input as "layer 0"
19
20        for layer in self.layers:
21            x = layer(x, src_mask=src_mask)
22            all_outputs.append(x)
23
24        x = self.norm(x)
25        all_outputs.append(x)  # Final normalized output
26
27        if return_all_layers:
28            return all_outputs
29        return x
30
31
32def analyze_layer_outputs():
33    """Analyze how representations change through layers."""
34
35    encoder = TransformerEncoderWithHooks(
36        d_model=512, num_layers=6, pre_norm=True
37    )
38    encoder.eval()  # Disable dropout
39
40    # Create input
41    x = torch.randn(1, 20, 512)
42
43    # Get all layer outputs
44    with torch.no_grad():
45        all_outputs = encoder(x, return_all_layers=True)
46
47    print("Layer-by-Layer Analysis")
48    print("=" * 50)
49
50    for i, output in enumerate(all_outputs):
51        layer_name = f"Layer {i}" if i > 0 else "Input"
52        if i == len(all_outputs) - 1:
53            layer_name = "Final (normalized)"
54
55        mean = output.mean().item()
56        std = output.std().item()
57        norm = output.norm().item()
58
59        print(f"{layer_name:20s}: mean={mean:+.4f}, std={std:.4f}, norm={norm:.2f}")
60
61    # Compute similarity between layers
62    print("\nLayer Similarity (cosine):")
63    for i in range(len(all_outputs) - 1):
64        curr = all_outputs[i].view(-1)
65        next_layer = all_outputs[i + 1].view(-1)
66        similarity = F.cosine_similarity(curr.unsqueeze(0), next_layer.unsqueeze(0))
67        print(f"  Layer {i} -> Layer {i+1}: {similarity.item():.4f}")
68
69
70analyze_layer_outputs()

Weight Sharing Options

Shared Weights Across Layers (Optional)

Some models share weights between layers to reduce parameters:

🐍python
1class SharedWeightEncoder(nn.Module):
2    """
3    Encoder with shared weights across layers.
4
5    Uses the same layer N times instead of N different layers.
6    Reduces parameters significantly but may reduce capacity.
7    """
8
9    def __init__(
10        self,
11        d_model: int = 512,
12        num_heads: int = 8,
13        num_layers: int = 6,
14        d_ff: int = 2048,
15        dropout: float = 0.1
16    ):
17        super().__init__()
18
19        self.num_layers = num_layers
20
21        # Single layer shared across all positions
22        self.shared_layer = TransformerEncoderLayer(
23            d_model=d_model,
24            num_heads=num_heads,
25            d_ff=d_ff,
26            dropout=dropout,
27            pre_norm=True
28        )
29
30        self.norm = nn.LayerNorm(d_model)
31
32    def forward(self, x, src_mask=None):
33        # Apply the same layer N times
34        for _ in range(self.num_layers):
35            x = self.shared_layer(x, src_mask=src_mask)
36
37        return self.norm(x)
38
39
40# Compare parameter counts
41regular = TransformerEncoder(d_model=512, num_layers=6)
42shared = SharedWeightEncoder(d_model=512, num_layers=6)
43
44print(f"Regular encoder params: {sum(p.numel() for p in regular.parameters()):,}")
45print(f"Shared encoder params:  {sum(p.numel() for p in shared.parameters()):,}")

Output:

πŸ“text
1Regular encoder params: 18,938,368
2Shared encoder params:  3,156,480

Memory Optimization: Gradient Checkpointing

For Large Models

🐍python
1from torch.utils.checkpoint import checkpoint
2
3
4class MemoryEfficientEncoder(nn.Module):
5    """
6    Encoder using gradient checkpointing to reduce memory.
7
8    Trades compute for memory: recomputes forward pass during backward.
9    Useful for training large models with limited GPU memory.
10    """
11
12    def __init__(
13        self,
14        d_model: int = 512,
15        num_heads: int = 8,
16        num_layers: int = 6,
17        d_ff: int = 2048,
18        dropout: float = 0.1,
19        checkpoint_layers: bool = True
20    ):
21        super().__init__()
22
23        self.checkpoint_layers = checkpoint_layers
24
25        self.layers = nn.ModuleList([
26            TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
27            for _ in range(num_layers)
28        ])
29
30        self.norm = nn.LayerNorm(d_model)
31
32    def forward(self, x, src_mask=None):
33        for layer in self.layers:
34            if self.checkpoint_layers and self.training:
35                # Checkpoint this layer (saves memory, costs compute)
36                x = checkpoint(layer, x, src_mask, use_reentrant=False)
37            else:
38                x = layer(x, src_mask=src_mask)
39
40        return self.norm(x)

Summary

TransformerEncoder Structure

πŸ“text
1TransformerEncoder
2β”œβ”€β”€ layers: ModuleList
3β”‚   β”œβ”€β”€ EncoderLayer 0
4β”‚   β”œβ”€β”€ EncoderLayer 1
5β”‚   β”œβ”€β”€ ...
6β”‚   └── EncoderLayer N-1
7└── norm: LayerNorm (for Pre-LN)

Key Implementation Points

AspectImplementation
Layer storagenn.ModuleList
IterationSimple for loop
Final normRequired for Pre-LN
Mask propagationPass to each layer

Parameter Count Formula

πŸ“text
1Per layer:
2- Attention: 4 Γ— d_modelΒ²
3- FFN: 2 Γ— d_model Γ— d_ff + d_model + d_ff
4- LayerNorms: 4 Γ— d_model
5
6Total for N layers:
7N Γ— (4dΒ² + 8dΓ—d_ff + 4d) + 2d (final norm)
8
9Example (d=512, d_ff=2048, N=6):
10β‰ˆ 18.9M parameters

Exercises

Implementation Exercises

1. Implement an encoder that returns attention weights from all layers for visualization.

2. Create an encoder with "layer dropout" that randomly skips layers during training.

3. Implement a progressive encoder that starts with 2 layers and gradually adds more.

Analysis Exercises

4. Compare training speed and memory usage between regular and checkpointed encoders.

5. Visualize how attention patterns change from layer 1 to layer 6.

6. Experiment with different numbers of layers (4, 6, 8, 12). What's the tradeoff?


In the next section, we'll do a complete forward pass walkthrough with specific dimensions, tracing exact shapes at every stage of a 6-layer encoder.