Introduction
Now we combine everything we've learned into a complete, production-ready MultiHeadAttention module. This implementation follows best practices and includes comprehensive documentation for every operation.
Module Architecture
Component Overview
MultiHeadAttention Component Structure
📝text
Explanation(0)
Code(7)
7 lines without explanation
1MultiHeadAttention
2├── Input projections (W_Q, W_K, W_V)
3├── Head splitting
4├── Scaled dot-product attention
5├── Head combining
6├── Output projection (W_O)
7└── DropoutData Flow
Data Flow Through MultiHeadAttention
📝text
Explanation(0)
Code(5)
5 lines without explanation
1query_input ─┐
2 ├─→ W_Q ─→ Q ─┐
3key_input ───┼─→ W_K ─→ K ─┼─→ split_heads ─→ attention ─→ combine_heads ─→ W_O ─→ output
4 ├─→ W_V ─→ V ─┘
5value_input ─┘Complete Implementation
Complete MultiHeadAttention Module
🐍multi_head_attention.py
Explanation(0)
Code(233)
233 lines without explanation
1"""
2Multi-Head Attention Implementation
3===================================
4
5A complete, production-ready implementation of multi-head attention
6as described in "Attention Is All You Need" (Vaswani et al., 2017).
7"""
8
9import torch
10import torch.nn as nn
11import torch.nn.functional as F
12import math
13from typing import Optional, Tuple
14
15
16class MultiHeadAttention(nn.Module):
17 """
18 Multi-Head Attention module.
19
20 Computes attention over multiple heads in parallel, allowing the model
21 to jointly attend to information from different representation subspaces.
22
23 Args:
24 d_model: Total dimension of the model (must be divisible by num_heads)
25 num_heads: Number of parallel attention heads
26 dropout: Dropout probability for attention weights and output
27 bias: Whether to include bias in linear projections
28
29 Attributes:
30 d_k: Dimension per head (d_model // num_heads)
31 W_Q: Query projection layer
32 W_K: Key projection layer
33 W_V: Value projection layer
34 W_O: Output projection layer
35
36 Example:
37 >>> mha = MultiHeadAttention(d_model=512, num_heads=8)
38 >>> x = torch.randn(2, 10, 512) # [batch, seq_len, d_model]
39 >>> output, attn_weights = mha(x, x, x)
40 >>> output.shape
41 torch.Size([2, 10, 512])
42 >>> attn_weights.shape
43 torch.Size([2, 8, 10, 10])
44 """
45
46 def __init__(
47 self,
48 d_model: int,
49 num_heads: int,
50 dropout: float = 0.0,
51 bias: bool = True
52 ):
53 super().__init__()
54
55 # Validate inputs
56 assert d_model % num_heads == 0, \
57 f"d_model ({d_model}) must be divisible by num_heads ({num_heads})"
58
59 self.d_model = d_model
60 self.num_heads = num_heads
61 self.d_k = d_model // num_heads
62 self.scale = math.sqrt(self.d_k)
63
64 # Linear projections
65 self.W_Q = nn.Linear(d_model, d_model, bias=bias)
66 self.W_K = nn.Linear(d_model, d_model, bias=bias)
67 self.W_V = nn.Linear(d_model, d_model, bias=bias)
68 self.W_O = nn.Linear(d_model, d_model, bias=bias)
69
70 # Dropout
71 self.dropout = nn.Dropout(dropout)
72
73 # Initialize parameters
74 self._reset_parameters()
75
76 def _reset_parameters(self):
77 """Initialize parameters with Xavier uniform distribution."""
78 nn.init.xavier_uniform_(self.W_Q.weight)
79 nn.init.xavier_uniform_(self.W_K.weight)
80 nn.init.xavier_uniform_(self.W_V.weight)
81 nn.init.xavier_uniform_(self.W_O.weight)
82
83 if self.W_Q.bias is not None:
84 nn.init.zeros_(self.W_Q.bias)
85 nn.init.zeros_(self.W_K.bias)
86 nn.init.zeros_(self.W_V.bias)
87 nn.init.zeros_(self.W_O.bias)
88
89 def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
90 """
91 Split the last dimension into (num_heads, d_k).
92
93 Args:
94 x: [batch, seq_len, d_model]
95
96 Returns:
97 [batch, num_heads, seq_len, d_k]
98 """
99 batch_size, seq_len, _ = x.shape
100
101 # [batch, seq_len, d_model] -> [batch, seq_len, num_heads, d_k]
102 x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
103
104 # [batch, seq_len, num_heads, d_k] -> [batch, num_heads, seq_len, d_k]
105 return x.transpose(1, 2)
106
107 def _combine_heads(self, x: torch.Tensor) -> torch.Tensor:
108 """
109 Combine head outputs back into single tensor.
110
111 Args:
112 x: [batch, num_heads, seq_len, d_k]
113
114 Returns:
115 [batch, seq_len, d_model]
116 """
117 batch_size, _, seq_len, _ = x.shape
118
119 # [batch, num_heads, seq_len, d_k] -> [batch, seq_len, num_heads, d_k]
120 x = x.transpose(1, 2)
121
122 # [batch, seq_len, num_heads, d_k] -> [batch, seq_len, d_model]
123 return x.contiguous().view(batch_size, seq_len, self.d_model)
124
125 def _scaled_dot_product_attention(
126 self,
127 Q: torch.Tensor,
128 K: torch.Tensor,
129 V: torch.Tensor,
130 mask: Optional[torch.Tensor] = None
131 ) -> Tuple[torch.Tensor, torch.Tensor]:
132 """
133 Compute scaled dot-product attention.
134
135 Args:
136 Q: [batch, num_heads, seq_len_q, d_k]
137 K: [batch, num_heads, seq_len_k, d_k]
138 V: [batch, num_heads, seq_len_k, d_k]
139 mask: [batch, 1, 1, seq_len_k] or [batch, 1, seq_len_q, seq_len_k]
140
141 Returns:
142 output: [batch, num_heads, seq_len_q, d_k]
143 attention_weights: [batch, num_heads, seq_len_q, seq_len_k]
144 """
145 # Compute attention scores
146 # Q @ K^T: [batch, heads, seq_q, d_k] @ [batch, heads, d_k, seq_k]
147 # = [batch, heads, seq_q, seq_k]
148 scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
149
150 # Apply mask
151 if mask is not None:
152 scores = scores.masked_fill(mask == 0, float('-inf'))
153
154 # Softmax over keys
155 attention_weights = F.softmax(scores, dim=-1)
156
157 # Handle all-masked rows (prevents NaN)
158 attention_weights = torch.nan_to_num(attention_weights, nan=0.0)
159
160 # Apply dropout to attention weights
161 attention_weights = self.dropout(attention_weights)
162
163 # Compute output
164 # weights @ V: [batch, heads, seq_q, seq_k] @ [batch, heads, seq_k, d_k]
165 # = [batch, heads, seq_q, d_k]
166 output = torch.matmul(attention_weights, V)
167
168 return output, attention_weights
169
170 def forward(
171 self,
172 query: torch.Tensor,
173 key: torch.Tensor,
174 value: torch.Tensor,
175 mask: Optional[torch.Tensor] = None,
176 return_attention: bool = True
177 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
178 """
179 Forward pass for multi-head attention.
180
181 Args:
182 query: Query tensor [batch, seq_len_q, d_model]
183 key: Key tensor [batch, seq_len_k, d_model]
184 value: Value tensor [batch, seq_len_k, d_model]
185 mask: Optional attention mask [batch, 1, seq_len_q, seq_len_k]
186 or [batch, 1, 1, seq_len_k]
187 0 = masked (ignore), 1 = attend
188 return_attention: Whether to return attention weights
189
190 Returns:
191 output: [batch, seq_len_q, d_model]
192 attention_weights: [batch, num_heads, seq_len_q, seq_len_k] (if return_attention)
193 """
194 batch_size = query.size(0)
195
196 # 1. Linear projections
197 # [batch, seq_len, d_model] -> [batch, seq_len, d_model]
198 Q = self.W_Q(query)
199 K = self.W_K(key)
200 V = self.W_V(value)
201
202 # 2. Split into heads
203 # [batch, seq_len, d_model] -> [batch, num_heads, seq_len, d_k]
204 Q = self._split_heads(Q)
205 K = self._split_heads(K)
206 V = self._split_heads(V)
207
208 # 3. Scaled dot-product attention
209 # Q, K, V: [batch, num_heads, seq_len, d_k]
210 # attention_output: [batch, num_heads, seq_len_q, d_k]
211 # attention_weights: [batch, num_heads, seq_len_q, seq_len_k]
212 attention_output, attention_weights = self._scaled_dot_product_attention(
213 Q, K, V, mask
214 )
215
216 # 4. Combine heads
217 # [batch, num_heads, seq_len, d_k] -> [batch, seq_len, d_model]
218 combined = self._combine_heads(attention_output)
219
220 # 5. Output projection
221 # [batch, seq_len, d_model] -> [batch, seq_len, d_model]
222 output = self.W_O(combined)
223
224 # Apply dropout to output
225 output = self.dropout(output)
226
227 if return_attention:
228 return output, attention_weights
229 return output, None
230
231 def extra_repr(self) -> str:
232 """String representation for printing."""
233 return f'd_model={self.d_model}, num_heads={self.num_heads}, d_k={self.d_k}'Testing the Implementation
Basic Functionality Tests
Comprehensive MultiHeadAttention Tests
🐍test_attention.py
Explanation(0)
Code(99)
99 lines without explanation
1def test_multi_head_attention():
2 """Comprehensive tests for MultiHeadAttention."""
3
4 print("Testing MultiHeadAttention...")
5 print("-" * 50)
6
7 # Configuration
8 d_model = 512
9 num_heads = 8
10 batch_size = 2
11 seq_len = 10
12
13 # Create module
14 mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout=0.0)
15
16 # Test 1: Self-attention shapes
17 print("Test 1: Self-attention shapes")
18 x = torch.randn(batch_size, seq_len, d_model)
19 output, weights = mha(x, x, x)
20
21 assert output.shape == (batch_size, seq_len, d_model), \
22 f"Expected output shape {(batch_size, seq_len, d_model)}, got {output.shape}"
23 assert weights.shape == (batch_size, num_heads, seq_len, seq_len), \
24 f"Expected weights shape {(batch_size, num_heads, seq_len, seq_len)}, got {weights.shape}"
25 print(f" Output shape: {output.shape} ✓")
26 print(f" Weights shape: {weights.shape} ✓")
27
28 # Test 2: Cross-attention shapes
29 print("\nTest 2: Cross-attention shapes")
30 seq_len_q = 8
31 seq_len_k = 12
32 query = torch.randn(batch_size, seq_len_q, d_model)
33 key = torch.randn(batch_size, seq_len_k, d_model)
34 value = torch.randn(batch_size, seq_len_k, d_model)
35
36 output, weights = mha(query, key, value)
37
38 assert output.shape == (batch_size, seq_len_q, d_model), \
39 f"Expected output shape {(batch_size, seq_len_q, d_model)}, got {output.shape}"
40 assert weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k), \
41 f"Expected weights shape {(batch_size, num_heads, seq_len_q, seq_len_k)}, got {weights.shape}"
42 print(f" Query shape: {query.shape}")
43 print(f" Key shape: {key.shape}")
44 print(f" Output shape: {output.shape} ✓")
45 print(f" Weights shape: {weights.shape} ✓")
46
47 # Test 3: Attention weights sum to 1
48 print("\nTest 3: Attention weights sum to 1")
49 x = torch.randn(batch_size, seq_len, d_model)
50 _, weights = mha(x, x, x)
51
52 weight_sums = weights.sum(dim=-1)
53 assert torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5), \
54 "Attention weights don't sum to 1"
55 print(f" Weight sums: all ≈ 1.0 ✓")
56
57 # Test 4: Masking
58 print("\nTest 4: Masking")
59 x = torch.randn(batch_size, seq_len, d_model)
60
61 # Create causal mask
62 mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)
63 # [1, 1, seq_len, seq_len]
64
65 _, weights = mha(x, x, x, mask=mask)
66
67 # Check upper triangle is zero (masked)
68 for b in range(batch_size):
69 for h in range(num_heads):
70 for i in range(seq_len):
71 for j in range(i + 1, seq_len):
72 assert weights[b, h, i, j] < 1e-5, \
73 f"Masked position ({i},{j}) has non-zero weight: {weights[b, h, i, j]}"
74 print(f" Causal mask applied correctly ✓")
75
76 # Test 5: Gradient flow
77 print("\nTest 5: Gradient flow")
78 x = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
79 output, _ = mha(x, x, x)
80 loss = output.sum()
81 loss.backward()
82
83 assert x.grad is not None, "Gradients not computed"
84 assert not torch.isnan(x.grad).any(), "NaN in gradients"
85 print(f" Gradients flow correctly ✓")
86
87 # Test 6: Parameter count
88 print("\nTest 6: Parameter count")
89 total_params = sum(p.numel() for p in mha.parameters())
90 expected_params = 4 * d_model * d_model + 4 * d_model # 4 projections with bias
91 print(f" Total parameters: {total_params:,}")
92 print(f" Expected (with bias): {expected_params:,}")
93
94 print("\n" + "-" * 50)
95 print("All tests passed! ✓")
96
97
98# Run tests
99test_multi_head_attention()Usage Examples
Self-Attention (Encoder)
Self-Attention Example
🐍usage_examples.py
Explanation(0)
Code(12)
12 lines without explanation
1# Encoder self-attention: query, key, value are all the same
2mha = MultiHeadAttention(d_model=512, num_heads=8)
3
4# Input sequence
5x = torch.randn(2, 20, 512) # [batch, seq_len, d_model]
6
7# Self-attention
8output, weights = mha(x, x, x)
9
10print(f"Input: {x.shape}")
11print(f"Output: {output.shape}")
12print(f"Weights: {weights.shape}")Cross-Attention (Decoder)
Cross-Attention Example
🐍usage_examples.py
Explanation(0)
Code(20)
20 lines without explanation
1# Cross-attention: query from decoder, key/value from encoder
2mha = MultiHeadAttention(d_model=512, num_heads=8)
3
4# Encoder output
5encoder_output = torch.randn(2, 30, 512) # [batch, src_len, d_model]
6
7# Decoder state
8decoder_state = torch.randn(2, 20, 512) # [batch, tgt_len, d_model]
9
10# Cross-attention
11output, weights = mha(
12 query=decoder_state,
13 key=encoder_output,
14 value=encoder_output
15)
16
17print(f"Encoder output: {encoder_output.shape}")
18print(f"Decoder state: {decoder_state.shape}")
19print(f"Cross-attention output: {output.shape}")
20print(f"Cross-attention weights: {weights.shape}") # [batch, heads, tgt_len, src_len]With Masking
Masked Self-Attention Example
🐍usage_examples.py
Explanation(0)
Code(18)
18 lines without explanation
1def create_causal_mask(seq_len: int) -> torch.Tensor:
2 """Create causal mask for decoder self-attention."""
3 mask = torch.tril(torch.ones(seq_len, seq_len))
4 return mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
5
6
7# Masked self-attention
8mha = MultiHeadAttention(d_model=512, num_heads=8)
9x = torch.randn(2, 10, 512)
10
11mask = create_causal_mask(10)
12output, weights = mha(x, x, x, mask=mask)
13
14# Verify masking
15print("Attention weights for position 0:")
16print(weights[0, 0, 0, :]) # Should only attend to position 0
17print("\nAttention weights for position 5:")
18print(weights[0, 0, 5, :]) # Should only attend to positions 0-5Integration with PyTorch's Built-in
Comparison
Comparing with PyTorch Built-in
🐍pytorch_comparison.py
Explanation(0)
Code(34)
34 lines without explanation
1def compare_with_pytorch():
2 """Compare our implementation with PyTorch's nn.MultiheadAttention."""
3
4 d_model = 512
5 num_heads = 8
6 batch_size = 2
7 seq_len = 10
8
9 # Our implementation
10 our_mha = MultiHeadAttention(d_model, num_heads, dropout=0.0, bias=True)
11
12 # PyTorch's implementation
13 # Note: PyTorch expects [seq_len, batch, d_model] by default
14 pytorch_mha = nn.MultiheadAttention(d_model, num_heads, dropout=0.0, batch_first=True)
15
16 # Input
17 x = torch.randn(batch_size, seq_len, d_model)
18
19 # Our output
20 our_output, our_weights = our_mha(x, x, x)
21
22 # PyTorch output
23 pytorch_output, pytorch_weights = pytorch_mha(x, x, x)
24
25 print(f"Our output shape: {our_output.shape}")
26 print(f"PyTorch output shape: {pytorch_output.shape}")
27 print(f"Our weights shape: {our_weights.shape}")
28 print(f"PyTorch weights shape: {pytorch_weights.shape}")
29
30 # Note: Outputs will differ due to different random initialization
31 # But shapes should match
32
33
34compare_with_pytorch()Performance Optimization
Fused Projection
Efficient MultiHeadAttention with Fused QKV
🐍efficient_attention.py
Explanation(0)
Code(41)
41 lines without explanation
1class EfficientMultiHeadAttention(nn.Module):
2 """
3 Efficient multi-head attention with fused QKV projection.
4 """
5
6 def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
7 super().__init__()
8
9 self.d_model = d_model
10 self.num_heads = num_heads
11 self.d_k = d_model // num_heads
12 self.scale = math.sqrt(self.d_k)
13
14 # Fused QKV projection (more efficient)
15 self.qkv_proj = nn.Linear(d_model, 3 * d_model)
16 self.out_proj = nn.Linear(d_model, d_model)
17 self.dropout = nn.Dropout(dropout)
18
19 def forward(self, x, mask=None):
20 """Forward pass for self-attention only."""
21 batch_size, seq_len, _ = x.shape
22
23 # Fused QKV projection
24 qkv = self.qkv_proj(x) # [batch, seq, 3*d_model]
25 qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.d_k)
26 qkv = qkv.permute(2, 0, 3, 1, 4) # [3, batch, heads, seq, d_k]
27 Q, K, V = qkv[0], qkv[1], qkv[2]
28
29 # Attention
30 scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
31 if mask is not None:
32 scores = scores.masked_fill(mask == 0, float('-inf'))
33 weights = F.softmax(scores, dim=-1)
34 weights = self.dropout(weights)
35
36 # Output
37 out = torch.matmul(weights, V)
38 out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
39 out = self.out_proj(out)
40
41 return out, weightsSummary
Module Structure
| Component | Shape Transformation |
|---|---|
| Input | [batch, seq, d_model] |
| W_Q/W_K/W_V | [batch, seq, d_model] → [batch, seq, d_model] |
| split_heads | [batch, seq, d_model] → [batch, heads, seq, d_k] |
| attention | [batch, heads, seq, d_k] → [batch, heads, seq, d_k] |
| combine_heads | [batch, heads, seq, d_k] → [batch, seq, d_model] |
| W_O | [batch, seq, d_model] → [batch, seq, d_model] |
Key Implementation Points
- Validate dimensions: d_model must be divisible by num_heads
- Initialize properly: Xavier uniform for attention weights
- Handle masks: Apply before softmax, use -inf for masked positions
- Memory layout: Use contiguous() before view after transpose
- Return weights: Useful for visualization and debugging
Exercises
Implementation Exercises
- Add a
need_weightsparameter that skips storing attention weights when False (for memory efficiency). - Implement rotary position embeddings (RoPE) inside the attention computation.
- Add flash attention support using PyTorch 2.0's
F.scaled_dot_product_attention.
Testing Exercises
- Write a test that verifies the module is equivariant to permutation of heads.
- Benchmark memory usage as sequence length increases from 128 to 4096.
- Test numerical stability with very long sequences and FP16.