Introduction
Now we implement the encoder layer by combining all our previously built components: Multi-Head Attention, Feed-Forward Network, Layer Normalization, and residual connections.
This section provides a complete, well-documented implementation.
Component Imports and Setup
Required Components
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5from typing import Optional
6
7
8# We'll use our implementations from previous chapters
9# For clarity, we'll include simplified versions hereMulti-Head Attention (From Chapter 3)
1class MultiHeadAttention(nn.Module):
2 """
3 Multi-Head Attention mechanism.
4
5 Args:
6 d_model: Model dimension
7 num_heads: Number of attention heads
8 dropout: Dropout probability
9 """
10
11 def __init__(
12 self,
13 d_model: int,
14 num_heads: int,
15 dropout: float = 0.1
16 ):
17 super().__init__()
18
19 assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
20
21 self.d_model = d_model
22 self.num_heads = num_heads
23 self.d_k = d_model // num_heads
24
25 # Linear projections
26 self.W_q = nn.Linear(d_model, d_model)
27 self.W_k = nn.Linear(d_model, d_model)
28 self.W_v = nn.Linear(d_model, d_model)
29 self.W_o = nn.Linear(d_model, d_model)
30
31 self.dropout = nn.Dropout(dropout)
32
33 def forward(
34 self,
35 query: torch.Tensor,
36 key: torch.Tensor,
37 value: torch.Tensor,
38 mask: Optional[torch.Tensor] = None
39 ) -> torch.Tensor:
40 """
41 Args:
42 query: [batch, seq_len, d_model]
43 key: [batch, seq_len, d_model]
44 value: [batch, seq_len, d_model]
45 mask: [batch, 1, 1, seq_len] or [batch, 1, seq_len, seq_len]
46
47 Returns:
48 output: [batch, seq_len, d_model]
49 """
50 batch_size = query.size(0)
51
52 # Linear projections and reshape for multi-head
53 # [batch, seq_len, d_model] -> [batch, num_heads, seq_len, d_k]
54 Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
55 K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
56 V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
57
58 # Scaled dot-product attention
59 scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
60
61 if mask is not None:
62 scores = scores.masked_fill(mask == 0, float('-inf'))
63
64 attn_weights = F.softmax(scores, dim=-1)
65 attn_weights = self.dropout(attn_weights)
66
67 # Apply attention to values
68 context = torch.matmul(attn_weights, V)
69
70 # Reshape back: [batch, num_heads, seq_len, d_k] -> [batch, seq_len, d_model]
71 context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
72
73 # Final linear projection
74 output = self.W_o(context)
75
76 return outputPosition-wise Feed-Forward (From Chapter 6)
1class PositionwiseFeedForward(nn.Module):
2 """
3 Position-wise Feed-Forward Network.
4
5 Args:
6 d_model: Input/output dimension
7 d_ff: Hidden layer dimension
8 dropout: Dropout probability
9 """
10
11 def __init__(
12 self,
13 d_model: int,
14 d_ff: int,
15 dropout: float = 0.1
16 ):
17 super().__init__()
18
19 self.linear1 = nn.Linear(d_model, d_ff)
20 self.linear2 = nn.Linear(d_ff, d_model)
21 self.dropout = nn.Dropout(dropout)
22
23 def forward(self, x: torch.Tensor) -> torch.Tensor:
24 """
25 Args:
26 x: [batch, seq_len, d_model]
27
28 Returns:
29 output: [batch, seq_len, d_model]
30 """
31 return self.linear2(self.dropout(F.gelu(self.linear1(x))))Encoder Layer Implementation
TransformerEncoderLayer Class
1class TransformerEncoderLayer(nn.Module):
2 """
3 Single Transformer Encoder Layer.
4
5 Architecture (Pre-LN):
6 x → LayerNorm → MultiHeadAttention → Dropout → (+x) →
7 → LayerNorm → FeedForward → Dropout → (+) → output
8
9 Args:
10 d_model: Model dimension (default: 512)
11 num_heads: Number of attention heads (default: 8)
12 d_ff: Feed-forward dimension (default: 2048)
13 dropout: Dropout probability (default: 0.1)
14 pre_norm: Use Pre-LN if True, Post-LN if False (default: True)
15
16 Shape:
17 Input: [batch, seq_len, d_model]
18 Output: [batch, seq_len, d_model]
19
20 Example:
21 >>> layer = TransformerEncoderLayer(d_model=512, num_heads=8)
22 >>> x = torch.randn(2, 10, 512)
23 >>> output = layer(x) # [2, 10, 512]
24 """
25
26 def __init__(
27 self,
28 d_model: int = 512,
29 num_heads: int = 8,
30 d_ff: int = 2048,
31 dropout: float = 0.1,
32 pre_norm: bool = True
33 ):
34 super().__init__()
35
36 self.d_model = d_model
37 self.pre_norm = pre_norm
38
39 # Multi-Head Self-Attention
40 self.self_attention = MultiHeadAttention(
41 d_model=d_model,
42 num_heads=num_heads,
43 dropout=dropout
44 )
45
46 # Position-wise Feed-Forward
47 self.feed_forward = PositionwiseFeedForward(
48 d_model=d_model,
49 d_ff=d_ff,
50 dropout=dropout
51 )
52
53 # Layer Normalization
54 self.norm1 = nn.LayerNorm(d_model)
55 self.norm2 = nn.LayerNorm(d_model)
56
57 # Dropout for residual connections
58 self.dropout1 = nn.Dropout(dropout)
59 self.dropout2 = nn.Dropout(dropout)
60
61 def forward(
62 self,
63 x: torch.Tensor,
64 src_mask: Optional[torch.Tensor] = None
65 ) -> torch.Tensor:
66 """
67 Forward pass through encoder layer.
68
69 Args:
70 x: Input tensor [batch, seq_len, d_model]
71 src_mask: Padding mask [batch, 1, 1, seq_len]
72 1 for valid positions, 0 for padding
73
74 Returns:
75 Output tensor [batch, seq_len, d_model]
76 """
77 if self.pre_norm:
78 # Pre-LN: Normalize before sublayer
79 x = self._forward_pre_norm(x, src_mask)
80 else:
81 # Post-LN: Normalize after sublayer
82 x = self._forward_post_norm(x, src_mask)
83
84 return x
85
86 def _forward_pre_norm(
87 self,
88 x: torch.Tensor,
89 src_mask: Optional[torch.Tensor]
90 ) -> torch.Tensor:
91 """Pre-LayerNorm forward pass."""
92
93 # Self-attention block
94 # x: [batch, seq_len, d_model]
95 residual = x
96 x = self.norm1(x) # Normalize BEFORE attention
97 x = self.self_attention(x, x, x, mask=src_mask) # Q=K=V for self-attention
98 x = self.dropout1(x)
99 x = residual + x # Residual connection
100
101 # Feed-forward block
102 residual = x
103 x = self.norm2(x) # Normalize BEFORE FFN
104 x = self.feed_forward(x)
105 x = self.dropout2(x)
106 x = residual + x # Residual connection
107
108 return x
109
110 def _forward_post_norm(
111 self,
112 x: torch.Tensor,
113 src_mask: Optional[torch.Tensor]
114 ) -> torch.Tensor:
115 """Post-LayerNorm forward pass."""
116
117 # Self-attention block
118 residual = x
119 x = self.self_attention(x, x, x, mask=src_mask)
120 x = self.dropout1(x)
121 x = self.norm1(residual + x) # Normalize AFTER residual
122
123 # Feed-forward block
124 residual = x
125 x = self.feed_forward(x)
126 x = self.dropout2(x)
127 x = self.norm2(residual + x) # Normalize AFTER residual
128
129 return x
130
131 def extra_repr(self) -> str:
132 return f"d_model={self.d_model}, pre_norm={self.pre_norm}"Testing the Encoder Layer
Comprehensive Tests
1def test_encoder_layer():
2 """Test TransformerEncoderLayer with various inputs."""
3
4 # Configuration
5 d_model = 512
6 num_heads = 8
7 d_ff = 2048
8 batch_size = 2
9 seq_len = 10
10
11 # Create layer
12 layer = TransformerEncoderLayer(
13 d_model=d_model,
14 num_heads=num_heads,
15 d_ff=d_ff,
16 dropout=0.1,
17 pre_norm=True
18 )
19
20 print("TransformerEncoderLayer Test")
21 print("=" * 50)
22
23 # Test 1: Basic forward pass
24 x = torch.randn(batch_size, seq_len, d_model)
25 output = layer(x)
26
27 print(f"\nTest 1: Basic forward pass")
28 print(f" Input shape: {x.shape}")
29 print(f" Output shape: {output.shape}")
30 assert output.shape == x.shape, "Shape mismatch!"
31 print(" ✓ Passed")
32
33 # Test 2: With padding mask
34 # Create mask: first sentence has 8 real tokens, second has 6
35 mask = torch.ones(batch_size, 1, 1, seq_len)
36 mask[0, 0, 0, 8:] = 0 # Mask last 2 positions
37 mask[1, 0, 0, 6:] = 0 # Mask last 4 positions
38
39 output_masked = layer(x, src_mask=mask)
40
41 print(f"\nTest 2: With padding mask")
42 print(f" Mask shape: {mask.shape}")
43 print(f" Output shape: {output_masked.shape}")
44 assert output_masked.shape == x.shape, "Shape mismatch with mask!"
45 print(" ✓ Passed")
46
47 # Test 3: Gradient flow
48 x = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
49 output = layer(x)
50 loss = output.sum()
51 loss.backward()
52
53 print(f"\nTest 3: Gradient flow")
54 print(f" Input gradient shape: {x.grad.shape}")
55 print(f" Gradient norm: {x.grad.norm():.4f}")
56 assert x.grad is not None, "No gradient computed!"
57 print(" ✓ Passed")
58
59 # Test 4: Parameter count
60 total_params = sum(p.numel() for p in layer.parameters())
61 trainable_params = sum(p.numel() for p in layer.parameters() if p.requires_grad)
62
63 print(f"\nTest 4: Parameter count")
64 print(f" Total parameters: {total_params:,}")
65 print(f" Trainable parameters: {trainable_params:,}")
66 print(" ✓ Passed")
67
68 # Test 5: Different sequence lengths
69 for test_seq_len in [5, 20, 100]:
70 x_test = torch.randn(1, test_seq_len, d_model)
71 output_test = layer(x_test)
72 assert output_test.shape == x_test.shape
73 print(f"\nTest 5: seq_len={test_seq_len} ✓")
74
75 print("\n" + "=" * 50)
76 print("All tests passed! ✓")
77
78 return layer
79
80
81# Run tests
82layer = test_encoder_layer()Detailed Shape Analysis
Shape Tracing Through Layer
1def trace_shapes(layer, x, mask=None):
2 """
3 Trace tensor shapes through encoder layer.
4
5 This is helpful for debugging shape mismatches.
6 """
7 print("Shape Trace Through Encoder Layer")
8 print("=" * 50)
9 print(f"\nInput: {x.shape}")
10
11 # For Pre-LN
12 if layer.pre_norm:
13 # Attention block
14 normed = layer.norm1(x)
15 print(f"After norm1: {normed.shape}")
16
17 attn_out = layer.self_attention(normed, normed, normed, mask=mask)
18 print(f"After self_attention: {attn_out.shape}")
19
20 attn_out = layer.dropout1(attn_out)
21 print(f"After dropout1: {attn_out.shape}")
22
23 x = x + attn_out
24 print(f"After residual 1: {x.shape}")
25
26 # FFN block
27 normed = layer.norm2(x)
28 print(f"After norm2: {normed.shape}")
29
30 ff_out = layer.feed_forward(normed)
31 print(f"After feed_forward: {ff_out.shape}")
32
33 ff_out = layer.dropout2(ff_out)
34 print(f"After dropout2: {ff_out.shape}")
35
36 x = x + ff_out
37 print(f"After residual 2: {x.shape}")
38
39 print(f"\nFinal output: {x.shape}")
40 return x
41
42
43# Trace
44layer = TransformerEncoderLayer(d_model=512, num_heads=8)
45x = torch.randn(2, 10, 512)
46output = trace_shapes(layer, x)Output:
1Shape Trace Through Encoder Layer
2==================================================
3
4Input: torch.Size([2, 10, 512])
5After norm1: torch.Size([2, 10, 512])
6After self_attention: torch.Size([2, 10, 512])
7After dropout1: torch.Size([2, 10, 512])
8After residual 1: torch.Size([2, 10, 512])
9After norm2: torch.Size([2, 10, 512])
10After feed_forward: torch.Size([2, 10, 512])
11After dropout2: torch.Size([2, 10, 512])
12After residual 2: torch.Size([2, 10, 512])
13
14Final output: torch.Size([2, 10, 512])Memory Efficiency Considerations
Memory Usage Per Layer
1def estimate_memory(batch_size, seq_len, d_model, num_heads, d_ff):
2 """Estimate memory usage for one encoder layer forward pass."""
3
4 d_k = d_model // num_heads
5 bytes_per_float = 4 # float32
6
7 memory = {}
8
9 # Activations saved for backward pass
10 memory['input'] = batch_size * seq_len * d_model * bytes_per_float
11
12 # Attention scores: [batch, heads, seq, seq]
13 memory['attn_scores'] = batch_size * num_heads * seq_len * seq_len * bytes_per_float
14
15 # Q, K, V projections
16 memory['qkv'] = 3 * batch_size * seq_len * d_model * bytes_per_float
17
18 # FFN hidden: [batch, seq, d_ff]
19 memory['ffn_hidden'] = batch_size * seq_len * d_ff * bytes_per_float
20
21 # Total
22 total = sum(memory.values())
23
24 print("Memory Estimation (Single Layer Forward Pass)")
25 print("=" * 50)
26 print(f"Configuration:")
27 print(f" batch_size={batch_size}, seq_len={seq_len}")
28 print(f" d_model={d_model}, num_heads={num_heads}, d_ff={d_ff}")
29 print(f"\nMemory breakdown:")
30 for name, mem in memory.items():
31 mb = mem / 1024**2
32 total += mb
33 print(f" {name:15s}: {mb:.2f} MB")
34 print(f"\nTotal: {total / 1024**2:.2f} MB")
35
36 return total
37
38
39# Example estimation
40estimate_memory(
41 batch_size=32,
42 seq_len=512,
43 d_model=512,
44 num_heads=8,
45 d_ff=2048
46)Weight Initialization
Proper Initialization
1def initialize_encoder_layer(layer: TransformerEncoderLayer):
2 """
3 Initialize weights for encoder layer.
4
5 Following common practices:
6 - Linear layers: Xavier/Glorot uniform
7 - LayerNorm: ones for weight, zeros for bias
8 """
9 for name, param in layer.named_parameters():
10 if 'weight' in name and param.dim() > 1:
11 # Xavier initialization for weight matrices
12 nn.init.xavier_uniform_(param)
13 elif 'bias' in name:
14 # Zero initialization for biases
15 nn.init.zeros_(param)
16 elif 'norm' in name and 'weight' in name:
17 # Ones for LayerNorm scale
18 nn.init.ones_(param)
19
20 print("Layer initialized with Xavier uniform weights")
21 return layer
22
23
24# Initialize
25layer = TransformerEncoderLayer(d_model=512, num_heads=8)
26layer = initialize_encoder_layer(layer)Summary
TransformerEncoderLayer Components
1class TransformerEncoderLayer(nn.Module):
2 def __init__(self, d_model, num_heads, d_ff, dropout, pre_norm):
3 # Self-attention
4 self.self_attention = MultiHeadAttention(d_model, num_heads, dropout)
5
6 # Feed-forward
7 self.feed_forward = PositionwiseFeedForward(d_model, d_ff, dropout)
8
9 # Layer norms
10 self.norm1 = nn.LayerNorm(d_model)
11 self.norm2 = nn.LayerNorm(d_model)
12
13 # Dropout
14 self.dropout1 = nn.Dropout(dropout)
15 self.dropout2 = nn.Dropout(dropout)
16
17 def forward(self, x, src_mask=None):
18 # Pre-LN style
19 x = x + self.dropout1(self.self_attention(self.norm1(x), ...))
20 x = x + self.dropout2(self.feed_forward(self.norm2(x)))
21 return xKey Implementation Details
| Aspect | Detail |
|---|---|
| Self-attention | Q=K=V for encoder |
| Mask | Padding mask only |
| Residual | Before dropout in Pre-LN |
| Shape | Preserved throughout |
| Initialization | Xavier uniform recommended |
Exercises
Implementation Exercises
1. Add an option to return attention weights from the encoder layer.
2. Implement gradient checkpointing to reduce memory usage.
3. Create a version with RMSNorm instead of LayerNorm.
Analysis Exercises
4. Profile the forward pass and identify the most time-consuming operation.
5. Compare memory usage between Pre-LN and Post-LN variants.
6. Experiment with different dropout placements.
In the next section, we'll stack multiple encoder layers to create the complete TransformerEncoder module, learning how to properly use nn.ModuleList and handle the final layer normalization.