Introduction The trickiest part of multi-head attention isn't the mathβit's the tensor reshaping . We need to transform our projected Q, K, V tensors to compute all heads in parallel, then reverse the transformation afterward.
This section provides a detailed, step-by-step guide to these tensor manipulations with explicit shape tracking at every step.
The Goal: Parallel Head Computation What We Want Instead of looping over heads:
Sequential Head Computation (Slow)
We want parallel computation:
Parallel Head Computation (Fast)
The Reshape-Transpose Pattern Step-by-Step Transformation Starting shape: [batch, seq_len, d_model]
Target shape: [batch, num_heads, seq_len, d_k]
Step-by-Step Transformation
Why This Order? Why reshape then transpose (not transpose then reshape)?
Memory layout matters. After projection, data is laid out as:
Memory Layout Explanation
Reshape groups by head: [batch, seq_len, num_heads, d_k]
Transpose puts head dimension before seq_len for batched matmul.
Implementation: split_heads The Function split_heads Function Implementation
1 import torch
2
3 def split_heads ( x : torch . Tensor , num_heads : int ) - > torch . Tensor :
4 """
5 Split the last dimension into (num_heads, d_k).
6
7 Args:
8 x: Input tensor of shape [batch, seq_len, d_model]
9 num_heads: Number of attention heads
10
11 Returns:
12 Tensor of shape [batch, num_heads, seq_len, d_k]
13 """
14 batch_size , seq_len , d_model = x . shape
15 d_k = d_model // num_heads
16
17 # Step 1: Reshape [batch, seq_len, d_model] β [batch, seq_len, num_heads, d_k]
18 x = x . view ( batch_size , seq_len , num_heads , d_k )
19
20 # Step 2: Transpose [batch, seq_len, num_heads, d_k] β [batch, num_heads, seq_len, d_k]
21 x = x . transpose ( 1 , 2 )
22
23 return x
24
25
26 # Example with explicit shape tracking
27 batch_size = 2
28 seq_len = 4
29 d_model = 8
30 num_heads = 2
31 d_k = d_model // num_heads # 4
32
33 x = torch . randn ( batch_size , seq_len , d_model )
34 print ( f"Input shape: { x . shape } " ) # [2, 4, 8]
35
36 # Step by step
37 x_reshaped = x . view ( batch_size , seq_len , num_heads , d_k )
38 print ( f"After view: { x_reshaped . shape } " ) # [2, 4, 2, 4]
39
40 x_transposed = x_reshaped . transpose ( 1 , 2 )
41 print ( f"After transpose: { x_transposed . shape } " ) # [2, 2, 4, 4]
42
43 # Or in one function call
44 x_split = split_heads ( x , num_heads )
45 print ( f"split_heads output: { x_split . shape } " ) # [2, 2, 4, 4]
Visual Representation Visual Representation of split_heads
Implementation: combine_heads The Reverse Operation After attention, we need to reverse the transformation:
combine_heads Shape Transformation
The Function combine_heads Function Implementation
Why contiguous()? After transpose, the tensor may not have contiguous memory layout:
Understanding contiguous() in PyTorch
view() vs reshape() The Difference view() :
Requires contiguous memory Returns a view (shares data with original) Fails if tensor is not contiguous reshape() :
Works on non-contiguous tensors May return a copy if needed More flexible but potentially slower Recommendation For performance-critical code, contiguous().view() is more explicit about what's happening.
Complete Example with Shape Tracking Complete Multi-Head Attention Shape Demo
π shape_tracking_demo.py
1 import torch
2 import torch . nn . functional as F
3 import math
4
5 def multi_head_attention_shapes_demo ( ) :
6 """
7 Complete demonstration of shape transformations in multi-head attention.
8 """
9 # Configuration
10 batch_size = 2
11 seq_len = 6
12 d_model = 512
13 num_heads = 8
14 d_k = d_model // num_heads # 64
15
16 print ( "=" * 60 )
17 print ( "MULTI-HEAD ATTENTION SHAPE TRANSFORMATIONS" )
18 print ( "=" * 60 )
19 print ( f"\nConfiguration:" )
20 print ( f" batch_size = { batch_size } " )
21 print ( f" seq_len = { seq_len } " )
22 print ( f" d_model = { d_model } " )
23 print ( f" num_heads = { num_heads } " )
24 print ( f" d_k = { d_k } " )
25
26 # Input
27 x = torch . randn ( batch_size , seq_len , d_model )
28 print ( f"\n1. Input X: { x . shape } " )
29 print ( f" [batch= { batch_size } , seq_len= { seq_len } , d_model= { d_model } ]" )
30
31 # Projections (simulated)
32 Q = torch . randn ( batch_size , seq_len , d_model )
33 K = torch . randn ( batch_size , seq_len , d_model )
34 V = torch . randn ( batch_size , seq_len , d_model )
35 print ( f"\n2. After projection:" )
36 print ( f" Q: { Q . shape } " )
37 print ( f" K: { K . shape } " )
38 print ( f" V: { V . shape } " )
39
40 # Split heads
41 def split_heads ( x ) :
42 b , s , d = x . shape
43 return x . view ( b , s , num_heads , d_k ) . transpose ( 1 , 2 )
44
45 Q_heads = split_heads ( Q )
46 K_heads = split_heads ( K )
47 V_heads = split_heads ( V )
48 print ( f"\n3. After split_heads:" )
49 print ( f" Q_heads: { Q_heads . shape } " )
50 print ( f" K_heads: { K_heads . shape } " )
51 print ( f" V_heads: { V_heads . shape } " )
52 print ( f" [batch= { batch_size } , num_heads= { num_heads } , seq_len= { seq_len } , d_k= { d_k } ]" )
53
54 # Attention scores
55 scores = torch . matmul ( Q_heads , K_heads . transpose ( - 2 , - 1 ) ) / math . sqrt ( d_k )
56 print ( f"\n4. Attention scores (QK^T / sqrt(d_k)):" )
57 print ( f" scores: { scores . shape } " )
58 print ( f" [batch= { batch_size } , num_heads= { num_heads } , seq_len= { seq_len } , seq_len= { seq_len } ]" )
59
60 # Attention weights
61 attention_weights = F . softmax ( scores , dim = - 1 )
62 print ( f"\n5. Attention weights (softmax):" )
63 print ( f" weights: { attention_weights . shape } " )
64
65 # Attention output
66 attention_output = torch . matmul ( attention_weights , V_heads )
67 print ( f"\n6. Attention output (weights @ V):" )
68 print ( f" output: { attention_output . shape } " )
69 print ( f" [batch= { batch_size } , num_heads= { num_heads } , seq_len= { seq_len } , d_k= { d_k } ]" )
70
71 # Combine heads
72 def combine_heads ( x ) :
73 b , h , s , dk = x . shape
74 return x . transpose ( 1 , 2 ) . contiguous ( ) . view ( b , s , h * dk )
75
76 combined = combine_heads ( attention_output )
77 print ( f"\n7. After combine_heads:" )
78 print ( f" combined: { combined . shape } " )
79 print ( f" [batch= { batch_size } , seq_len= { seq_len } , d_model= { d_model } ]" )
80
81 # Output projection (simulated)
82 output = torch . randn ( batch_size , seq_len , d_model ) # W_O @ combined
83 print ( f"\n8. After output projection (W_O):" )
84 print ( f" output: { output . shape } " )
85 print ( f" [batch= { batch_size } , seq_len= { seq_len } , d_model= { d_model } ]" )
86
87 print ( "\n" + "=" * 60 )
88 print ( "Shape transformation complete!" )
89 print ( "=" * 60 )
90
91 multi_head_attention_shapes_demo ( )
Output:
Shape Tracking Demo Output
1 ============================================================
2 MULTI-HEAD ATTENTION SHAPE TRANSFORMATIONS
3 ============================================================
4
5 Configuration:
6 batch_size = 2
7 seq_len = 6
8 d_model = 512
9 num_heads = 8
10 d_k = 64
11
12 1. Input X: torch.Size([2, 6, 512])
13 [batch=2, seq_len=6, d_model=512]
14
15 2. After projection:
16 Q: torch.Size([2, 6, 512])
17 K: torch.Size([2, 6, 512])
18 V: torch.Size([2, 6, 512])
19
20 3. After split_heads:
21 Q_heads: torch.Size([2, 8, 6, 64])
22 K_heads: torch.Size([2, 8, 6, 64])
23 V_heads: torch.Size([2, 8, 6, 64])
24 [batch=2, num_heads=8, seq_len=6, d_k=64]
25
26 4. Attention scores (QK^T / sqrt(d_k)):
27 scores: torch.Size([2, 8, 6, 6])
28 [batch=2, num_heads=8, seq_len=6, seq_len=6]
29
30 5. Attention weights (softmax):
31 weights: torch.Size([2, 8, 6, 6])
32
33 6. Attention output (weights @ V):
34 output: torch.Size([2, 8, 6, 64])
35 [batch=2, num_heads=8, seq_len=6, d_k=64]
36
37 7. After combine_heads:
38 combined: torch.Size([2, 6, 512])
39 [batch=2, seq_len=6, d_model=512]
40
41 8. After output projection (W_O):
42 output: torch.Size([2, 6, 512])
43 [batch=2, seq_len=6, d_model=512]
44
45 ============================================================
46 Shape transformation complete!
47 ============================================================
Common Mistakes and Fixes Mistake 1: Wrong Transpose Dimensions Wrong Transpose Dimensions
Mistake 2: Forgetting contiguous() Mistake 3: Wrong Dimension in view() Wrong Dimension Order in view()
Mistake 4: d_model Not Divisible by num_heads Efficient Implementation Using einops (Optional) The einops library makes reshaping more readable:
Using einops for Readable Reshaping
Performance Tips Minimize reshapes : Combine operations when possibleAvoid unnecessary copies : Use contiguous() only when neededBatch operations : Always include batch dimension in matmulSummary The Complete Pattern Complete split_heads and combine_heads Pattern
Shape Flow Summary Exercises Implementation Exercises Implement split_heads and combine_heads using torch.reshape() instead of view(). Write a test that verifies combine_heads(split_heads(x)) == x (up to floating point precision). Implement using einops.rearrange() and compare readability. Debugging Exercises What error do you get if d_model is not divisible by num_heads? Write code to catch this. Create a tensor, transpose it, then try to view it without contiguous(). What's the error message? Print intermediate shapes at every step of multi-head attention to verify your understanding.