Introduction
A single encoder layer provides one round of self-attention and feed-forward processing. To build powerful representations, we stack N identical layers (typically N=6 for the base transformer).
This section covers building the complete TransformerEncoder by stacking encoder layers properly.
Why Stack Layers?
Hierarchical Representation Learning
Each layer builds on the previous one:
1Layer 1: Low-level patterns
2 - Word types (noun, verb, adjective)
3 - Simple local dependencies
4
5Layer 2-3: Mid-level patterns
6 - Phrase structure
7 - Subject-verb agreement
8 - Modifier relationships
9
10Layer 4-6: High-level patterns
11 - Semantic relationships
12 - Long-range dependencies
13 - Abstract conceptsInformation Refinement
1Input: "The quick brown fox jumps over the lazy dog"
2
3After Layer 1:
4 - "fox" knows it's near "quick" and "brown"
5
6After Layer 3:
7 - "fox" understands it's the subject of "jumps"
8 - "jumps" knows its object is related to "dog"
9
10After Layer 6:
11 - Full semantic understanding
12 - "fox" has rich representation capturing all relationshipsUsing nn.ModuleList
Why Not a Python List?
1# WRONG: Python list - parameters not registered
2class BadEncoder(nn.Module):
3 def __init__(self):
4 self.layers = [EncoderLayer() for _ in range(6)]
5 # These layers won't be saved, won't get gradients properly!
6
7# CORRECT: nn.ModuleList - parameters registered
8class GoodEncoder(nn.Module):
9 def __init__(self):
10 self.layers = nn.ModuleList([EncoderLayer() for _ in range(6)])
11 # All layer parameters properly registeredBenefits of nn.ModuleList
1. Parameter registration: All parameters appear in model.parameters()
2. Device movement: model.to(device) moves all layers
3. Serialization: model.state_dict() includes all layers
4. Iteration: Can iterate like a list
1# Check parameters
2encoder = GoodEncoder()
3print(f"Total parameters: {sum(p.numel() for p in encoder.parameters())}")
4
5# Move to GPU
6encoder = encoder.to('cuda')
7
8# Save/load works
9torch.save(encoder.state_dict(), 'encoder.pt')TransformerEncoder Implementation
Complete Implementation
1import torch
2import torch.nn as nn
3from typing import Optional
4
5
6class TransformerEncoder(nn.Module):
7 """
8 Complete Transformer Encoder.
9
10 Stacks N encoder layers and applies final layer normalization.
11
12 Args:
13 d_model: Model dimension (default: 512)
14 num_heads: Number of attention heads (default: 8)
15 num_layers: Number of encoder layers (default: 6)
16 d_ff: Feed-forward dimension (default: 2048)
17 dropout: Dropout probability (default: 0.1)
18 pre_norm: Use Pre-LN if True (default: True)
19
20 Shape:
21 Input: [batch, seq_len, d_model]
22 Output: [batch, seq_len, d_model]
23
24 Example:
25 >>> encoder = TransformerEncoder(d_model=512, num_layers=6)
26 >>> x = torch.randn(2, 10, 512)
27 >>> output = encoder(x) # [2, 10, 512]
28 """
29
30 def __init__(
31 self,
32 d_model: int = 512,
33 num_heads: int = 8,
34 num_layers: int = 6,
35 d_ff: int = 2048,
36 dropout: float = 0.1,
37 pre_norm: bool = True
38 ):
39 super().__init__()
40
41 self.d_model = d_model
42 self.num_layers = num_layers
43 self.pre_norm = pre_norm
44
45 # Stack of encoder layers
46 self.layers = nn.ModuleList([
47 TransformerEncoderLayer(
48 d_model=d_model,
49 num_heads=num_heads,
50 d_ff=d_ff,
51 dropout=dropout,
52 pre_norm=pre_norm
53 )
54 for _ in range(num_layers)
55 ])
56
57 # Final layer normalization (needed for Pre-LN)
58 self.norm = nn.LayerNorm(d_model) if pre_norm else nn.Identity()
59
60 # Initialize weights
61 self._init_weights()
62
63 def _init_weights(self):
64 """Initialize weights using Xavier initialization."""
65 for p in self.parameters():
66 if p.dim() > 1:
67 nn.init.xavier_uniform_(p)
68
69 def forward(
70 self,
71 x: torch.Tensor,
72 src_mask: Optional[torch.Tensor] = None
73 ) -> torch.Tensor:
74 """
75 Forward pass through all encoder layers.
76
77 Args:
78 x: Input tensor [batch, seq_len, d_model]
79 src_mask: Padding mask [batch, 1, 1, seq_len]
80
81 Returns:
82 Encoded output [batch, seq_len, d_model]
83 """
84 # Pass through each encoder layer
85 for layer in self.layers:
86 x = layer(x, src_mask=src_mask)
87
88 # Final normalization (for Pre-LN)
89 x = self.norm(x)
90
91 return x
92
93 def extra_repr(self) -> str:
94 return f"num_layers={self.num_layers}, d_model={self.d_model}"
95
96
97# Include the encoder layer class for completeness
98class TransformerEncoderLayer(nn.Module):
99 """Single encoder layer (from Section 2)."""
100
101 def __init__(
102 self,
103 d_model: int = 512,
104 num_heads: int = 8,
105 d_ff: int = 2048,
106 dropout: float = 0.1,
107 pre_norm: bool = True
108 ):
109 super().__init__()
110
111 self.pre_norm = pre_norm
112
113 # Self-attention
114 self.self_attention = nn.MultiheadAttention(
115 embed_dim=d_model,
116 num_heads=num_heads,
117 dropout=dropout,
118 batch_first=True
119 )
120
121 # Feed-forward
122 self.feed_forward = nn.Sequential(
123 nn.Linear(d_model, d_ff),
124 nn.GELU(),
125 nn.Dropout(dropout),
126 nn.Linear(d_ff, d_model),
127 )
128
129 # Layer norms
130 self.norm1 = nn.LayerNorm(d_model)
131 self.norm2 = nn.LayerNorm(d_model)
132
133 # Dropout
134 self.dropout1 = nn.Dropout(dropout)
135 self.dropout2 = nn.Dropout(dropout)
136
137 def forward(
138 self,
139 x: torch.Tensor,
140 src_mask: Optional[torch.Tensor] = None
141 ) -> torch.Tensor:
142 """Forward pass."""
143 if self.pre_norm:
144 # Pre-LN
145 normed = self.norm1(x)
146 attn_out, _ = self.self_attention(normed, normed, normed, key_padding_mask=src_mask)
147 x = x + self.dropout1(attn_out)
148
149 normed = self.norm2(x)
150 ff_out = self.feed_forward(normed)
151 x = x + self.dropout2(ff_out)
152 else:
153 # Post-LN
154 attn_out, _ = self.self_attention(x, x, x, key_padding_mask=src_mask)
155 x = self.norm1(x + self.dropout1(attn_out))
156
157 ff_out = self.feed_forward(x)
158 x = self.norm2(x + self.dropout2(ff_out))
159
160 return xTesting the Complete Encoder
Comprehensive Tests
1def test_transformer_encoder():
2 """Test the complete TransformerEncoder."""
3
4 print("TransformerEncoder Test")
5 print("=" * 60)
6
7 # Configuration
8 d_model = 512
9 num_heads = 8
10 num_layers = 6
11 d_ff = 2048
12 batch_size = 4
13 seq_len = 50
14
15 # Create encoder
16 encoder = TransformerEncoder(
17 d_model=d_model,
18 num_heads=num_heads,
19 num_layers=num_layers,
20 d_ff=d_ff,
21 dropout=0.1,
22 pre_norm=True
23 )
24
25 # Test 1: Basic forward pass
26 x = torch.randn(batch_size, seq_len, d_model)
27 output = encoder(x)
28
29 print(f"\nTest 1: Basic forward pass")
30 print(f" Input: {x.shape}")
31 print(f" Output: {output.shape}")
32 assert output.shape == x.shape
33 print(" β Passed")
34
35 # Test 2: With padding mask
36 # Mask format for nn.MultiheadAttention: True = ignore
37 padding_lengths = [50, 40, 30, 20] # Real token lengths
38 mask = torch.zeros(batch_size, seq_len, dtype=torch.bool)
39 for i, length in enumerate(padding_lengths):
40 mask[i, length:] = True # Mark padding positions
41
42 output_masked = encoder(x, src_mask=mask)
43
44 print(f"\nTest 2: With padding mask")
45 print(f" Mask shape: {mask.shape}")
46 assert output_masked.shape == x.shape
47 print(" β Passed")
48
49 # Test 3: Parameter count
50 total_params = sum(p.numel() for p in encoder.parameters())
51 trainable_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
52
53 print(f"\nTest 3: Parameter count")
54 print(f" Total: {total_params:,}")
55 print(f" Trainable: {trainable_params:,}")
56 print(f" Per layer: ~{total_params // num_layers:,}")
57
58 # Test 4: Check layer registration
59 print(f"\nTest 4: Module structure")
60 print(f" Number of layers: {len(encoder.layers)}")
61 print(f" Final norm: {encoder.norm}")
62 print(" β Passed")
63
64 # Test 5: Gradient flow through all layers
65 x = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
66 output = encoder(x)
67 loss = output.sum()
68 loss.backward()
69
70 print(f"\nTest 5: Gradient flow")
71 print(f" Input gradient norm: {x.grad.norm():.4f}")
72
73 # Check gradients in first and last layers
74 first_layer_grad = list(encoder.layers[0].parameters())[0].grad
75 last_layer_grad = list(encoder.layers[-1].parameters())[0].grad
76
77 print(f" First layer gradient norm: {first_layer_grad.norm():.4f}")
78 print(f" Last layer gradient norm: {last_layer_grad.norm():.4f}")
79 print(" β Gradients flowing through all layers")
80
81 # Test 6: Different sequence lengths
82 print(f"\nTest 6: Variable sequence lengths")
83 for test_len in [10, 100, 200]:
84 x_test = torch.randn(2, test_len, d_model)
85 out_test = encoder(x_test)
86 assert out_test.shape == (2, test_len, d_model)
87 print(f" seq_len={test_len}: β")
88
89 print("\n" + "=" * 60)
90 print("All encoder tests passed! β")
91
92 return encoder
93
94
95encoder = test_transformer_encoder()Visualizing Layer Outputs
Tracking Representations Through Layers
1class TransformerEncoderWithHooks(TransformerEncoder):
2 """Encoder that can capture intermediate layer outputs."""
3
4 def forward(
5 self,
6 x: torch.Tensor,
7 src_mask: Optional[torch.Tensor] = None,
8 return_all_layers: bool = False
9 ) -> torch.Tensor:
10 """
11 Args:
12 return_all_layers: If True, return outputs from all layers
13
14 Returns:
15 If return_all_layers=False: Final output [batch, seq_len, d_model]
16 If return_all_layers=True: List of outputs from each layer
17 """
18 all_outputs = [x] # Include input as "layer 0"
19
20 for layer in self.layers:
21 x = layer(x, src_mask=src_mask)
22 all_outputs.append(x)
23
24 x = self.norm(x)
25 all_outputs.append(x) # Final normalized output
26
27 if return_all_layers:
28 return all_outputs
29 return x
30
31
32def analyze_layer_outputs():
33 """Analyze how representations change through layers."""
34
35 encoder = TransformerEncoderWithHooks(
36 d_model=512, num_layers=6, pre_norm=True
37 )
38 encoder.eval() # Disable dropout
39
40 # Create input
41 x = torch.randn(1, 20, 512)
42
43 # Get all layer outputs
44 with torch.no_grad():
45 all_outputs = encoder(x, return_all_layers=True)
46
47 print("Layer-by-Layer Analysis")
48 print("=" * 50)
49
50 for i, output in enumerate(all_outputs):
51 layer_name = f"Layer {i}" if i > 0 else "Input"
52 if i == len(all_outputs) - 1:
53 layer_name = "Final (normalized)"
54
55 mean = output.mean().item()
56 std = output.std().item()
57 norm = output.norm().item()
58
59 print(f"{layer_name:20s}: mean={mean:+.4f}, std={std:.4f}, norm={norm:.2f}")
60
61 # Compute similarity between layers
62 print("\nLayer Similarity (cosine):")
63 for i in range(len(all_outputs) - 1):
64 curr = all_outputs[i].view(-1)
65 next_layer = all_outputs[i + 1].view(-1)
66 similarity = F.cosine_similarity(curr.unsqueeze(0), next_layer.unsqueeze(0))
67 print(f" Layer {i} -> Layer {i+1}: {similarity.item():.4f}")
68
69
70analyze_layer_outputs()Weight Sharing Options
Shared Weights Across Layers (Optional)
Some models share weights between layers to reduce parameters:
1class SharedWeightEncoder(nn.Module):
2 """
3 Encoder with shared weights across layers.
4
5 Uses the same layer N times instead of N different layers.
6 Reduces parameters significantly but may reduce capacity.
7 """
8
9 def __init__(
10 self,
11 d_model: int = 512,
12 num_heads: int = 8,
13 num_layers: int = 6,
14 d_ff: int = 2048,
15 dropout: float = 0.1
16 ):
17 super().__init__()
18
19 self.num_layers = num_layers
20
21 # Single layer shared across all positions
22 self.shared_layer = TransformerEncoderLayer(
23 d_model=d_model,
24 num_heads=num_heads,
25 d_ff=d_ff,
26 dropout=dropout,
27 pre_norm=True
28 )
29
30 self.norm = nn.LayerNorm(d_model)
31
32 def forward(self, x, src_mask=None):
33 # Apply the same layer N times
34 for _ in range(self.num_layers):
35 x = self.shared_layer(x, src_mask=src_mask)
36
37 return self.norm(x)
38
39
40# Compare parameter counts
41regular = TransformerEncoder(d_model=512, num_layers=6)
42shared = SharedWeightEncoder(d_model=512, num_layers=6)
43
44print(f"Regular encoder params: {sum(p.numel() for p in regular.parameters()):,}")
45print(f"Shared encoder params: {sum(p.numel() for p in shared.parameters()):,}")Output:
1Regular encoder params: 18,938,368
2Shared encoder params: 3,156,480Memory Optimization: Gradient Checkpointing
For Large Models
1from torch.utils.checkpoint import checkpoint
2
3
4class MemoryEfficientEncoder(nn.Module):
5 """
6 Encoder using gradient checkpointing to reduce memory.
7
8 Trades compute for memory: recomputes forward pass during backward.
9 Useful for training large models with limited GPU memory.
10 """
11
12 def __init__(
13 self,
14 d_model: int = 512,
15 num_heads: int = 8,
16 num_layers: int = 6,
17 d_ff: int = 2048,
18 dropout: float = 0.1,
19 checkpoint_layers: bool = True
20 ):
21 super().__init__()
22
23 self.checkpoint_layers = checkpoint_layers
24
25 self.layers = nn.ModuleList([
26 TransformerEncoderLayer(d_model, num_heads, d_ff, dropout)
27 for _ in range(num_layers)
28 ])
29
30 self.norm = nn.LayerNorm(d_model)
31
32 def forward(self, x, src_mask=None):
33 for layer in self.layers:
34 if self.checkpoint_layers and self.training:
35 # Checkpoint this layer (saves memory, costs compute)
36 x = checkpoint(layer, x, src_mask, use_reentrant=False)
37 else:
38 x = layer(x, src_mask=src_mask)
39
40 return self.norm(x)Summary
TransformerEncoder Structure
1TransformerEncoder
2βββ layers: ModuleList
3β βββ EncoderLayer 0
4β βββ EncoderLayer 1
5β βββ ...
6β βββ EncoderLayer N-1
7βββ norm: LayerNorm (for Pre-LN)Key Implementation Points
| Aspect | Implementation |
|---|---|
| Layer storage | nn.ModuleList |
| Iteration | Simple for loop |
| Final norm | Required for Pre-LN |
| Mask propagation | Pass to each layer |
Parameter Count Formula
1Per layer:
2- Attention: 4 Γ d_modelΒ²
3- FFN: 2 Γ d_model Γ d_ff + d_model + d_ff
4- LayerNorms: 4 Γ d_model
5
6Total for N layers:
7N Γ (4dΒ² + 8dΓd_ff + 4d) + 2d (final norm)
8
9Example (d=512, d_ff=2048, N=6):
10β 18.9M parametersExercises
Implementation Exercises
1. Implement an encoder that returns attention weights from all layers for visualization.
2. Create an encoder with "layer dropout" that randomly skips layers during training.
3. Implement a progressive encoder that starts with 2 layers and gradually adds more.
Analysis Exercises
4. Compare training speed and memory usage between regular and checkpointed encoders.
5. Visualize how attention patterns change from layer 1 to layer 6.
6. Experiment with different numbers of layers (4, 6, 8, 12). What's the tradeoff?
In the next section, we'll do a complete forward pass walkthrough with specific dimensions, tracing exact shapes at every stage of a 6-layer encoder.