Chapter 3
20 min read
Section 17 of 75

Implementing MultiHeadAttention as nn.Module

Multi-Head Attention

Introduction

Now we combine everything we've learned into a complete, production-ready MultiHeadAttention module. This implementation follows best practices and includes comprehensive documentation for every operation.


Module Architecture

Component Overview

MultiHeadAttention Component Structure
📝text
7 lines without explanation
1MultiHeadAttention
2├── Input projections (W_Q, W_K, W_V)
3├── Head splitting
4├── Scaled dot-product attention
5├── Head combining
6├── Output projection (W_O)
7└── Dropout

Data Flow

Data Flow Through MultiHeadAttention
📝text
5 lines without explanation
1query_input ─┐
2             ├─→ W_Q ─→ Q ─┐
3key_input ───┼─→ W_K ─→ K ─┼─→ split_heads ─→ attention ─→ combine_heads ─→ W_O ─→ output
4             ├─→ W_V ─→ V ─┘
5value_input ─┘

Complete Implementation

Complete MultiHeadAttention Module
🐍multi_head_attention.py
233 lines without explanation
1"""
2Multi-Head Attention Implementation
3===================================
4
5A complete, production-ready implementation of multi-head attention
6as described in "Attention Is All You Need" (Vaswani et al., 2017).
7"""
8
9import torch
10import torch.nn as nn
11import torch.nn.functional as F
12import math
13from typing import Optional, Tuple
14
15
16class MultiHeadAttention(nn.Module):
17    """
18    Multi-Head Attention module.
19
20    Computes attention over multiple heads in parallel, allowing the model
21    to jointly attend to information from different representation subspaces.
22
23    Args:
24        d_model: Total dimension of the model (must be divisible by num_heads)
25        num_heads: Number of parallel attention heads
26        dropout: Dropout probability for attention weights and output
27        bias: Whether to include bias in linear projections
28
29    Attributes:
30        d_k: Dimension per head (d_model // num_heads)
31        W_Q: Query projection layer
32        W_K: Key projection layer
33        W_V: Value projection layer
34        W_O: Output projection layer
35
36    Example:
37        >>> mha = MultiHeadAttention(d_model=512, num_heads=8)
38        >>> x = torch.randn(2, 10, 512)  # [batch, seq_len, d_model]
39        >>> output, attn_weights = mha(x, x, x)
40        >>> output.shape
41        torch.Size([2, 10, 512])
42        >>> attn_weights.shape
43        torch.Size([2, 8, 10, 10])
44    """
45
46    def __init__(
47        self,
48        d_model: int,
49        num_heads: int,
50        dropout: float = 0.0,
51        bias: bool = True
52    ):
53        super().__init__()
54
55        # Validate inputs
56        assert d_model % num_heads == 0, \
57            f"d_model ({d_model}) must be divisible by num_heads ({num_heads})"
58
59        self.d_model = d_model
60        self.num_heads = num_heads
61        self.d_k = d_model // num_heads
62        self.scale = math.sqrt(self.d_k)
63
64        # Linear projections
65        self.W_Q = nn.Linear(d_model, d_model, bias=bias)
66        self.W_K = nn.Linear(d_model, d_model, bias=bias)
67        self.W_V = nn.Linear(d_model, d_model, bias=bias)
68        self.W_O = nn.Linear(d_model, d_model, bias=bias)
69
70        # Dropout
71        self.dropout = nn.Dropout(dropout)
72
73        # Initialize parameters
74        self._reset_parameters()
75
76    def _reset_parameters(self):
77        """Initialize parameters with Xavier uniform distribution."""
78        nn.init.xavier_uniform_(self.W_Q.weight)
79        nn.init.xavier_uniform_(self.W_K.weight)
80        nn.init.xavier_uniform_(self.W_V.weight)
81        nn.init.xavier_uniform_(self.W_O.weight)
82
83        if self.W_Q.bias is not None:
84            nn.init.zeros_(self.W_Q.bias)
85            nn.init.zeros_(self.W_K.bias)
86            nn.init.zeros_(self.W_V.bias)
87            nn.init.zeros_(self.W_O.bias)
88
89    def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
90        """
91        Split the last dimension into (num_heads, d_k).
92
93        Args:
94            x: [batch, seq_len, d_model]
95
96        Returns:
97            [batch, num_heads, seq_len, d_k]
98        """
99        batch_size, seq_len, _ = x.shape
100
101        # [batch, seq_len, d_model] -> [batch, seq_len, num_heads, d_k]
102        x = x.view(batch_size, seq_len, self.num_heads, self.d_k)
103
104        # [batch, seq_len, num_heads, d_k] -> [batch, num_heads, seq_len, d_k]
105        return x.transpose(1, 2)
106
107    def _combine_heads(self, x: torch.Tensor) -> torch.Tensor:
108        """
109        Combine head outputs back into single tensor.
110
111        Args:
112            x: [batch, num_heads, seq_len, d_k]
113
114        Returns:
115            [batch, seq_len, d_model]
116        """
117        batch_size, _, seq_len, _ = x.shape
118
119        # [batch, num_heads, seq_len, d_k] -> [batch, seq_len, num_heads, d_k]
120        x = x.transpose(1, 2)
121
122        # [batch, seq_len, num_heads, d_k] -> [batch, seq_len, d_model]
123        return x.contiguous().view(batch_size, seq_len, self.d_model)
124
125    def _scaled_dot_product_attention(
126        self,
127        Q: torch.Tensor,
128        K: torch.Tensor,
129        V: torch.Tensor,
130        mask: Optional[torch.Tensor] = None
131    ) -> Tuple[torch.Tensor, torch.Tensor]:
132        """
133        Compute scaled dot-product attention.
134
135        Args:
136            Q: [batch, num_heads, seq_len_q, d_k]
137            K: [batch, num_heads, seq_len_k, d_k]
138            V: [batch, num_heads, seq_len_k, d_k]
139            mask: [batch, 1, 1, seq_len_k] or [batch, 1, seq_len_q, seq_len_k]
140
141        Returns:
142            output: [batch, num_heads, seq_len_q, d_k]
143            attention_weights: [batch, num_heads, seq_len_q, seq_len_k]
144        """
145        # Compute attention scores
146        # Q @ K^T: [batch, heads, seq_q, d_k] @ [batch, heads, d_k, seq_k]
147        #        = [batch, heads, seq_q, seq_k]
148        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
149
150        # Apply mask
151        if mask is not None:
152            scores = scores.masked_fill(mask == 0, float('-inf'))
153
154        # Softmax over keys
155        attention_weights = F.softmax(scores, dim=-1)
156
157        # Handle all-masked rows (prevents NaN)
158        attention_weights = torch.nan_to_num(attention_weights, nan=0.0)
159
160        # Apply dropout to attention weights
161        attention_weights = self.dropout(attention_weights)
162
163        # Compute output
164        # weights @ V: [batch, heads, seq_q, seq_k] @ [batch, heads, seq_k, d_k]
165        #            = [batch, heads, seq_q, d_k]
166        output = torch.matmul(attention_weights, V)
167
168        return output, attention_weights
169
170    def forward(
171        self,
172        query: torch.Tensor,
173        key: torch.Tensor,
174        value: torch.Tensor,
175        mask: Optional[torch.Tensor] = None,
176        return_attention: bool = True
177    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
178        """
179        Forward pass for multi-head attention.
180
181        Args:
182            query: Query tensor [batch, seq_len_q, d_model]
183            key: Key tensor [batch, seq_len_k, d_model]
184            value: Value tensor [batch, seq_len_k, d_model]
185            mask: Optional attention mask [batch, 1, seq_len_q, seq_len_k]
186                  or [batch, 1, 1, seq_len_k]
187                  0 = masked (ignore), 1 = attend
188            return_attention: Whether to return attention weights
189
190        Returns:
191            output: [batch, seq_len_q, d_model]
192            attention_weights: [batch, num_heads, seq_len_q, seq_len_k] (if return_attention)
193        """
194        batch_size = query.size(0)
195
196        # 1. Linear projections
197        # [batch, seq_len, d_model] -> [batch, seq_len, d_model]
198        Q = self.W_Q(query)
199        K = self.W_K(key)
200        V = self.W_V(value)
201
202        # 2. Split into heads
203        # [batch, seq_len, d_model] -> [batch, num_heads, seq_len, d_k]
204        Q = self._split_heads(Q)
205        K = self._split_heads(K)
206        V = self._split_heads(V)
207
208        # 3. Scaled dot-product attention
209        # Q, K, V: [batch, num_heads, seq_len, d_k]
210        # attention_output: [batch, num_heads, seq_len_q, d_k]
211        # attention_weights: [batch, num_heads, seq_len_q, seq_len_k]
212        attention_output, attention_weights = self._scaled_dot_product_attention(
213            Q, K, V, mask
214        )
215
216        # 4. Combine heads
217        # [batch, num_heads, seq_len, d_k] -> [batch, seq_len, d_model]
218        combined = self._combine_heads(attention_output)
219
220        # 5. Output projection
221        # [batch, seq_len, d_model] -> [batch, seq_len, d_model]
222        output = self.W_O(combined)
223
224        # Apply dropout to output
225        output = self.dropout(output)
226
227        if return_attention:
228            return output, attention_weights
229        return output, None
230
231    def extra_repr(self) -> str:
232        """String representation for printing."""
233        return f'd_model={self.d_model}, num_heads={self.num_heads}, d_k={self.d_k}'

Testing the Implementation

Basic Functionality Tests

Comprehensive MultiHeadAttention Tests
🐍test_attention.py
99 lines without explanation
1def test_multi_head_attention():
2    """Comprehensive tests for MultiHeadAttention."""
3
4    print("Testing MultiHeadAttention...")
5    print("-" * 50)
6
7    # Configuration
8    d_model = 512
9    num_heads = 8
10    batch_size = 2
11    seq_len = 10
12
13    # Create module
14    mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads, dropout=0.0)
15
16    # Test 1: Self-attention shapes
17    print("Test 1: Self-attention shapes")
18    x = torch.randn(batch_size, seq_len, d_model)
19    output, weights = mha(x, x, x)
20
21    assert output.shape == (batch_size, seq_len, d_model), \
22        f"Expected output shape {(batch_size, seq_len, d_model)}, got {output.shape}"
23    assert weights.shape == (batch_size, num_heads, seq_len, seq_len), \
24        f"Expected weights shape {(batch_size, num_heads, seq_len, seq_len)}, got {weights.shape}"
25    print(f"  Output shape: {output.shape} ✓")
26    print(f"  Weights shape: {weights.shape} ✓")
27
28    # Test 2: Cross-attention shapes
29    print("\nTest 2: Cross-attention shapes")
30    seq_len_q = 8
31    seq_len_k = 12
32    query = torch.randn(batch_size, seq_len_q, d_model)
33    key = torch.randn(batch_size, seq_len_k, d_model)
34    value = torch.randn(batch_size, seq_len_k, d_model)
35
36    output, weights = mha(query, key, value)
37
38    assert output.shape == (batch_size, seq_len_q, d_model), \
39        f"Expected output shape {(batch_size, seq_len_q, d_model)}, got {output.shape}"
40    assert weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k), \
41        f"Expected weights shape {(batch_size, num_heads, seq_len_q, seq_len_k)}, got {weights.shape}"
42    print(f"  Query shape: {query.shape}")
43    print(f"  Key shape: {key.shape}")
44    print(f"  Output shape: {output.shape} ✓")
45    print(f"  Weights shape: {weights.shape} ✓")
46
47    # Test 3: Attention weights sum to 1
48    print("\nTest 3: Attention weights sum to 1")
49    x = torch.randn(batch_size, seq_len, d_model)
50    _, weights = mha(x, x, x)
51
52    weight_sums = weights.sum(dim=-1)
53    assert torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5), \
54        "Attention weights don't sum to 1"
55    print(f"  Weight sums: all ≈ 1.0 ✓")
56
57    # Test 4: Masking
58    print("\nTest 4: Masking")
59    x = torch.randn(batch_size, seq_len, d_model)
60
61    # Create causal mask
62    mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0).unsqueeze(0)
63    # [1, 1, seq_len, seq_len]
64
65    _, weights = mha(x, x, x, mask=mask)
66
67    # Check upper triangle is zero (masked)
68    for b in range(batch_size):
69        for h in range(num_heads):
70            for i in range(seq_len):
71                for j in range(i + 1, seq_len):
72                    assert weights[b, h, i, j] < 1e-5, \
73                        f"Masked position ({i},{j}) has non-zero weight: {weights[b, h, i, j]}"
74    print(f"  Causal mask applied correctly ✓")
75
76    # Test 5: Gradient flow
77    print("\nTest 5: Gradient flow")
78    x = torch.randn(batch_size, seq_len, d_model, requires_grad=True)
79    output, _ = mha(x, x, x)
80    loss = output.sum()
81    loss.backward()
82
83    assert x.grad is not None, "Gradients not computed"
84    assert not torch.isnan(x.grad).any(), "NaN in gradients"
85    print(f"  Gradients flow correctly ✓")
86
87    # Test 6: Parameter count
88    print("\nTest 6: Parameter count")
89    total_params = sum(p.numel() for p in mha.parameters())
90    expected_params = 4 * d_model * d_model + 4 * d_model  # 4 projections with bias
91    print(f"  Total parameters: {total_params:,}")
92    print(f"  Expected (with bias): {expected_params:,}")
93
94    print("\n" + "-" * 50)
95    print("All tests passed! ✓")
96
97
98# Run tests
99test_multi_head_attention()

Usage Examples

Self-Attention (Encoder)

Self-Attention Example
🐍usage_examples.py
12 lines without explanation
1# Encoder self-attention: query, key, value are all the same
2mha = MultiHeadAttention(d_model=512, num_heads=8)
3
4# Input sequence
5x = torch.randn(2, 20, 512)  # [batch, seq_len, d_model]
6
7# Self-attention
8output, weights = mha(x, x, x)
9
10print(f"Input: {x.shape}")
11print(f"Output: {output.shape}")
12print(f"Weights: {weights.shape}")

Cross-Attention (Decoder)

Cross-Attention Example
🐍usage_examples.py
20 lines without explanation
1# Cross-attention: query from decoder, key/value from encoder
2mha = MultiHeadAttention(d_model=512, num_heads=8)
3
4# Encoder output
5encoder_output = torch.randn(2, 30, 512)  # [batch, src_len, d_model]
6
7# Decoder state
8decoder_state = torch.randn(2, 20, 512)  # [batch, tgt_len, d_model]
9
10# Cross-attention
11output, weights = mha(
12    query=decoder_state,
13    key=encoder_output,
14    value=encoder_output
15)
16
17print(f"Encoder output: {encoder_output.shape}")
18print(f"Decoder state: {decoder_state.shape}")
19print(f"Cross-attention output: {output.shape}")
20print(f"Cross-attention weights: {weights.shape}")  # [batch, heads, tgt_len, src_len]

With Masking

Masked Self-Attention Example
🐍usage_examples.py
18 lines without explanation
1def create_causal_mask(seq_len: int) -> torch.Tensor:
2    """Create causal mask for decoder self-attention."""
3    mask = torch.tril(torch.ones(seq_len, seq_len))
4    return mask.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, seq_len]
5
6
7# Masked self-attention
8mha = MultiHeadAttention(d_model=512, num_heads=8)
9x = torch.randn(2, 10, 512)
10
11mask = create_causal_mask(10)
12output, weights = mha(x, x, x, mask=mask)
13
14# Verify masking
15print("Attention weights for position 0:")
16print(weights[0, 0, 0, :])  # Should only attend to position 0
17print("\nAttention weights for position 5:")
18print(weights[0, 0, 5, :])  # Should only attend to positions 0-5

Integration with PyTorch's Built-in

Comparison

Comparing with PyTorch Built-in
🐍pytorch_comparison.py
34 lines without explanation
1def compare_with_pytorch():
2    """Compare our implementation with PyTorch's nn.MultiheadAttention."""
3
4    d_model = 512
5    num_heads = 8
6    batch_size = 2
7    seq_len = 10
8
9    # Our implementation
10    our_mha = MultiHeadAttention(d_model, num_heads, dropout=0.0, bias=True)
11
12    # PyTorch's implementation
13    # Note: PyTorch expects [seq_len, batch, d_model] by default
14    pytorch_mha = nn.MultiheadAttention(d_model, num_heads, dropout=0.0, batch_first=True)
15
16    # Input
17    x = torch.randn(batch_size, seq_len, d_model)
18
19    # Our output
20    our_output, our_weights = our_mha(x, x, x)
21
22    # PyTorch output
23    pytorch_output, pytorch_weights = pytorch_mha(x, x, x)
24
25    print(f"Our output shape: {our_output.shape}")
26    print(f"PyTorch output shape: {pytorch_output.shape}")
27    print(f"Our weights shape: {our_weights.shape}")
28    print(f"PyTorch weights shape: {pytorch_weights.shape}")
29
30    # Note: Outputs will differ due to different random initialization
31    # But shapes should match
32
33
34compare_with_pytorch()

Performance Optimization

Fused Projection

Efficient MultiHeadAttention with Fused QKV
🐍efficient_attention.py
41 lines without explanation
1class EfficientMultiHeadAttention(nn.Module):
2    """
3    Efficient multi-head attention with fused QKV projection.
4    """
5
6    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.0):
7        super().__init__()
8
9        self.d_model = d_model
10        self.num_heads = num_heads
11        self.d_k = d_model // num_heads
12        self.scale = math.sqrt(self.d_k)
13
14        # Fused QKV projection (more efficient)
15        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
16        self.out_proj = nn.Linear(d_model, d_model)
17        self.dropout = nn.Dropout(dropout)
18
19    def forward(self, x, mask=None):
20        """Forward pass for self-attention only."""
21        batch_size, seq_len, _ = x.shape
22
23        # Fused QKV projection
24        qkv = self.qkv_proj(x)  # [batch, seq, 3*d_model]
25        qkv = qkv.view(batch_size, seq_len, 3, self.num_heads, self.d_k)
26        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, batch, heads, seq, d_k]
27        Q, K, V = qkv[0], qkv[1], qkv[2]
28
29        # Attention
30        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
31        if mask is not None:
32            scores = scores.masked_fill(mask == 0, float('-inf'))
33        weights = F.softmax(scores, dim=-1)
34        weights = self.dropout(weights)
35
36        # Output
37        out = torch.matmul(weights, V)
38        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
39        out = self.out_proj(out)
40
41        return out, weights

Summary

Module Structure

ComponentShape Transformation
Input[batch, seq, d_model]
W_Q/W_K/W_V[batch, seq, d_model] → [batch, seq, d_model]
split_heads[batch, seq, d_model] → [batch, heads, seq, d_k]
attention[batch, heads, seq, d_k] → [batch, heads, seq, d_k]
combine_heads[batch, heads, seq, d_k] → [batch, seq, d_model]
W_O[batch, seq, d_model] → [batch, seq, d_model]

Key Implementation Points

  1. Validate dimensions: d_model must be divisible by num_heads
  2. Initialize properly: Xavier uniform for attention weights
  3. Handle masks: Apply before softmax, use -inf for masked positions
  4. Memory layout: Use contiguous() before view after transpose
  5. Return weights: Useful for visualization and debugging

Exercises

Implementation Exercises

  1. Add a need_weights parameter that skips storing attention weights when False (for memory efficiency).
  2. Implement rotary position embeddings (RoPE) inside the attention computation.
  3. Add flash attention support using PyTorch 2.0's F.scaled_dot_product_attention.

Testing Exercises

  1. Write a test that verifies the module is equivariant to permutation of heads.
  2. Benchmark memory usage as sequence length increases from 128 to 4096.
  3. Test numerical stability with very long sequences and FP16.