Chapter 3
15 min read
Section 16 of 75

Reshaping and Transposing for Parallel Heads

Multi-Head Attention

Introduction

The trickiest part of multi-head attention isn't the mathβ€”it's the tensor reshaping. We need to transform our projected Q, K, V tensors to compute all heads in parallel, then reverse the transformation afterward.

This section provides a detailed, step-by-step guide to these tensor manipulations with explicit shape tracking at every step.


The Goal: Parallel Head Computation

What We Want

Instead of looping over heads:

Sequential Head Computation (Slow)
🐍attention.py
9 lines without explanation
1# SLOW: Sequential head computation
2outputs = []
3for head in range(num_heads):
4    q_head = Q[:, :, head*d_k:(head+1)*d_k]
5    k_head = K[:, :, head*d_k:(head+1)*d_k]
6    v_head = V[:, :, head*d_k:(head+1)*d_k]
7    output = attention(q_head, k_head, v_head)
8    outputs.append(output)
9result = torch.cat(outputs, dim=-1)

We want parallel computation:

Parallel Head Computation (Fast)
🐍attention.py
6 lines without explanation
1# FAST: All heads in parallel
2Q_heads = split_heads(Q)  # [batch, num_heads, seq_len, d_k]
3K_heads = split_heads(K)
4V_heads = split_heads(V)
5output_heads = attention(Q_heads, K_heads, V_heads)  # Batched over heads
6result = combine_heads(output_heads)  # [batch, seq_len, d_model]

The Reshape-Transpose Pattern

Step-by-Step Transformation

Starting shape: [batch, seq_len, d_model]

Target shape: [batch, num_heads, seq_len, d_k]

Step-by-Step Transformation
πŸ“text
5 lines without explanation
1Step 1: View/Reshape
2[batch, seq_len, d_model] β†’ [batch, seq_len, num_heads, d_k]
3
4Step 2: Transpose
5[batch, seq_len, num_heads, d_k] β†’ [batch, num_heads, seq_len, d_k]

Why This Order?

Why reshape then transpose (not transpose then reshape)?

Memory layout matters. After projection, data is laid out as:

Memory Layout Explanation
πŸ“text
8 lines without explanation
1For d_model=8, num_heads=2, d_k=4:
2
3Memory: [q0_h0, q0_h1, q1_h0, q1_h1, ..., qn_h0, qn_h1]
4        |--d_model=8--|
5
6Where:
7- q0_h0 = query for position 0, head 0 (4 values)
8- q0_h1 = query for position 0, head 1 (4 values)

Reshape groups by head: [batch, seq_len, num_heads, d_k]

Transpose puts head dimension before seq_len for batched matmul.


Implementation: split_heads

The Function

split_heads Function Implementation
🐍reshape_utils.py
45 lines without explanation
1import torch
2
3def split_heads(x: torch.Tensor, num_heads: int) -> torch.Tensor:
4    """
5    Split the last dimension into (num_heads, d_k).
6
7    Args:
8        x: Input tensor of shape [batch, seq_len, d_model]
9        num_heads: Number of attention heads
10
11    Returns:
12        Tensor of shape [batch, num_heads, seq_len, d_k]
13    """
14    batch_size, seq_len, d_model = x.shape
15    d_k = d_model // num_heads
16
17    # Step 1: Reshape [batch, seq_len, d_model] β†’ [batch, seq_len, num_heads, d_k]
18    x = x.view(batch_size, seq_len, num_heads, d_k)
19
20    # Step 2: Transpose [batch, seq_len, num_heads, d_k] β†’ [batch, num_heads, seq_len, d_k]
21    x = x.transpose(1, 2)
22
23    return x
24
25
26# Example with explicit shape tracking
27batch_size = 2
28seq_len = 4
29d_model = 8
30num_heads = 2
31d_k = d_model // num_heads  # 4
32
33x = torch.randn(batch_size, seq_len, d_model)
34print(f"Input shape: {x.shape}")  # [2, 4, 8]
35
36# Step by step
37x_reshaped = x.view(batch_size, seq_len, num_heads, d_k)
38print(f"After view: {x_reshaped.shape}")  # [2, 4, 2, 4]
39
40x_transposed = x_reshaped.transpose(1, 2)
41print(f"After transpose: {x_transposed.shape}")  # [2, 2, 4, 4]
42
43# Or in one function call
44x_split = split_heads(x, num_heads)
45print(f"split_heads output: {x_split.shape}")  # [2, 2, 4, 4]

Visual Representation

Visual Representation of split_heads
πŸ“text
22 lines without explanation
1Input: [batch=2, seq_len=4, d_model=8]
2
3Batch 0:
4β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
5β”‚ Pos0: [a0 a1 a2 a3 | a4 a5 a6 a7]  β”‚  ← d_model=8
6β”‚ Pos1: [b0 b1 b2 b3 | b4 b5 b6 b7]  β”‚
7β”‚ Pos2: [c0 c1 c2 c3 | c4 c5 c6 c7]  β”‚
8β”‚ Pos3: [d0 d1 d2 d3 | d4 d5 d6 d7]  β”‚
9β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
10                ↓ view(2, 4, 2, 4)
11β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
12β”‚ Pos0: [[a0 a1 a2 a3], [a4 a5 a6 a7]]β”‚  ← [num_heads=2, d_k=4]
13β”‚ Pos1: [[b0 b1 b2 b3], [b4 b5 b6 b7]]β”‚
14β”‚ ...                                  β”‚
15β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
16                ↓ transpose(1, 2)
17β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
18β”‚ Head0: [[a0..a3], [b0..b3], ...]   β”‚  ← All pos for head 0
19β”‚ Head1: [[a4..a7], [b4..b7], ...]   β”‚  ← All pos for head 1
20β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
21
22Output: [batch=2, num_heads=2, seq_len=4, d_k=4]

Implementation: combine_heads

The Reverse Operation

After attention, we need to reverse the transformation:

combine_heads Shape Transformation
πŸ“text
2 lines without explanation
1Input: [batch, num_heads, seq_len, d_k]
2Output: [batch, seq_len, d_model]

The Function

combine_heads Function Implementation
🐍reshape_utils.py
29 lines without explanation
1def combine_heads(x: torch.Tensor) -> torch.Tensor:
2    """
3    Combine head outputs back into single tensor.
4
5    Args:
6        x: Input tensor of shape [batch, num_heads, seq_len, d_k]
7
8    Returns:
9        Tensor of shape [batch, seq_len, d_model]
10    """
11    batch_size, num_heads, seq_len, d_k = x.shape
12    d_model = num_heads * d_k
13
14    # Step 1: Transpose [batch, num_heads, seq_len, d_k] β†’ [batch, seq_len, num_heads, d_k]
15    x = x.transpose(1, 2)
16
17    # Step 2: Reshape [batch, seq_len, num_heads, d_k] β†’ [batch, seq_len, d_model]
18    # Use contiguous() to ensure memory layout is correct for view
19    x = x.contiguous().view(batch_size, seq_len, d_model)
20
21    return x
22
23
24# Example
25x_heads = torch.randn(2, 2, 4, 4)  # [batch, heads, seq, d_k]
26print(f"Input shape: {x_heads.shape}")  # [2, 2, 4, 4]
27
28x_combined = combine_heads(x_heads)
29print(f"Output shape: {x_combined.shape}")  # [2, 4, 8]

Why contiguous()?

After transpose, the tensor may not have contiguous memory layout:

Understanding contiguous() in PyTorch
🐍contiguous_demo.py
11 lines without explanation
1x = torch.randn(2, 4, 2, 4)
2x_t = x.transpose(1, 2)
3print(x_t.is_contiguous())  # False!
4
5# view() requires contiguous memory
6# x_t.view(2, 4, 8)  # This would fail!
7
8# Solution: call contiguous() first
9x_t = x_t.contiguous()
10print(x_t.is_contiguous())  # True
11x_t.view(2, 4, 8)  # Now works!

view() vs reshape()

The Difference

view():

  • Requires contiguous memory
  • Returns a view (shares data with original)
  • Fails if tensor is not contiguous

reshape():

  • Works on non-contiguous tensors
  • May return a copy if needed
  • More flexible but potentially slower

Recommendation

Safe Reshape Approaches
🐍reshape_utils.py
5 lines without explanation
1# Safe approach: use reshape() or contiguous().view()
2x = x.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
3
4# Alternative: reshape() handles it automatically
5x = x.transpose(1, 2).reshape(batch, seq_len, d_model)

For performance-critical code, contiguous().view() is more explicit about what's happening.


Complete Example with Shape Tracking

Complete Multi-Head Attention Shape Demo
🐍shape_tracking_demo.py
91 lines without explanation
1import torch
2import torch.nn.functional as F
3import math
4
5def multi_head_attention_shapes_demo():
6    """
7    Complete demonstration of shape transformations in multi-head attention.
8    """
9    # Configuration
10    batch_size = 2
11    seq_len = 6
12    d_model = 512
13    num_heads = 8
14    d_k = d_model // num_heads  # 64
15
16    print("=" * 60)
17    print("MULTI-HEAD ATTENTION SHAPE TRANSFORMATIONS")
18    print("=" * 60)
19    print(f"\nConfiguration:")
20    print(f"  batch_size = {batch_size}")
21    print(f"  seq_len = {seq_len}")
22    print(f"  d_model = {d_model}")
23    print(f"  num_heads = {num_heads}")
24    print(f"  d_k = {d_k}")
25
26    # Input
27    x = torch.randn(batch_size, seq_len, d_model)
28    print(f"\n1. Input X: {x.shape}")
29    print(f"   [batch={batch_size}, seq_len={seq_len}, d_model={d_model}]")
30
31    # Projections (simulated)
32    Q = torch.randn(batch_size, seq_len, d_model)
33    K = torch.randn(batch_size, seq_len, d_model)
34    V = torch.randn(batch_size, seq_len, d_model)
35    print(f"\n2. After projection:")
36    print(f"   Q: {Q.shape}")
37    print(f"   K: {K.shape}")
38    print(f"   V: {V.shape}")
39
40    # Split heads
41    def split_heads(x):
42        b, s, d = x.shape
43        return x.view(b, s, num_heads, d_k).transpose(1, 2)
44
45    Q_heads = split_heads(Q)
46    K_heads = split_heads(K)
47    V_heads = split_heads(V)
48    print(f"\n3. After split_heads:")
49    print(f"   Q_heads: {Q_heads.shape}")
50    print(f"   K_heads: {K_heads.shape}")
51    print(f"   V_heads: {V_heads.shape}")
52    print(f"   [batch={batch_size}, num_heads={num_heads}, seq_len={seq_len}, d_k={d_k}]")
53
54    # Attention scores
55    scores = torch.matmul(Q_heads, K_heads.transpose(-2, -1)) / math.sqrt(d_k)
56    print(f"\n4. Attention scores (QK^T / sqrt(d_k)):")
57    print(f"   scores: {scores.shape}")
58    print(f"   [batch={batch_size}, num_heads={num_heads}, seq_len={seq_len}, seq_len={seq_len}]")
59
60    # Attention weights
61    attention_weights = F.softmax(scores, dim=-1)
62    print(f"\n5. Attention weights (softmax):")
63    print(f"   weights: {attention_weights.shape}")
64
65    # Attention output
66    attention_output = torch.matmul(attention_weights, V_heads)
67    print(f"\n6. Attention output (weights @ V):")
68    print(f"   output: {attention_output.shape}")
69    print(f"   [batch={batch_size}, num_heads={num_heads}, seq_len={seq_len}, d_k={d_k}]")
70
71    # Combine heads
72    def combine_heads(x):
73        b, h, s, dk = x.shape
74        return x.transpose(1, 2).contiguous().view(b, s, h * dk)
75
76    combined = combine_heads(attention_output)
77    print(f"\n7. After combine_heads:")
78    print(f"   combined: {combined.shape}")
79    print(f"   [batch={batch_size}, seq_len={seq_len}, d_model={d_model}]")
80
81    # Output projection (simulated)
82    output = torch.randn(batch_size, seq_len, d_model)  # W_O @ combined
83    print(f"\n8. After output projection (W_O):")
84    print(f"   output: {output.shape}")
85    print(f"   [batch={batch_size}, seq_len={seq_len}, d_model={d_model}]")
86
87    print("\n" + "=" * 60)
88    print("Shape transformation complete!")
89    print("=" * 60)
90
91multi_head_attention_shapes_demo()

Output:

Shape Tracking Demo Output
πŸ“text
47 lines without explanation
1============================================================
2MULTI-HEAD ATTENTION SHAPE TRANSFORMATIONS
3============================================================
4
5Configuration:
6  batch_size = 2
7  seq_len = 6
8  d_model = 512
9  num_heads = 8
10  d_k = 64
11
121. Input X: torch.Size([2, 6, 512])
13   [batch=2, seq_len=6, d_model=512]
14
152. After projection:
16   Q: torch.Size([2, 6, 512])
17   K: torch.Size([2, 6, 512])
18   V: torch.Size([2, 6, 512])
19
203. After split_heads:
21   Q_heads: torch.Size([2, 8, 6, 64])
22   K_heads: torch.Size([2, 8, 6, 64])
23   V_heads: torch.Size([2, 8, 6, 64])
24   [batch=2, num_heads=8, seq_len=6, d_k=64]
25
264. Attention scores (QK^T / sqrt(d_k)):
27   scores: torch.Size([2, 8, 6, 6])
28   [batch=2, num_heads=8, seq_len=6, seq_len=6]
29
305. Attention weights (softmax):
31   weights: torch.Size([2, 8, 6, 6])
32
336. Attention output (weights @ V):
34   output: torch.Size([2, 8, 6, 64])
35   [batch=2, num_heads=8, seq_len=6, d_k=64]
36
377. After combine_heads:
38   combined: torch.Size([2, 6, 512])
39   [batch=2, seq_len=6, d_model=512]
40
418. After output projection (W_O):
42   output: torch.Size([2, 6, 512])
43   [batch=2, seq_len=6, d_model=512]
44
45============================================================
46Shape transformation complete!
47============================================================

Common Mistakes and Fixes

Mistake 1: Wrong Transpose Dimensions

Wrong Transpose Dimensions
🐍common_mistakes.py
5 lines without explanation
1# ❌ WRONG: transpose(0, 1) swaps batch and seq_len
2x = x.view(batch, seq_len, num_heads, d_k).transpose(0, 1)
3
4# βœ… CORRECT: transpose(1, 2) swaps seq_len and num_heads
5x = x.view(batch, seq_len, num_heads, d_k).transpose(1, 2)

Mistake 2: Forgetting contiguous()

Forgetting contiguous()
🐍common_mistakes.py
5 lines without explanation
1# ❌ WRONG: view() on non-contiguous tensor
2x = x.transpose(1, 2).view(batch, seq_len, d_model)  # Error!
3
4# βœ… CORRECT: call contiguous() first
5x = x.transpose(1, 2).contiguous().view(batch, seq_len, d_model)

Mistake 3: Wrong Dimension in view()

Wrong Dimension Order in view()
🐍common_mistakes.py
5 lines without explanation
1# ❌ WRONG: wrong order of dimensions
2x = x.view(batch, num_heads, seq_len, d_k)  # Reorders data incorrectly!
3
4# βœ… CORRECT: reshape first, then transpose
5x = x.view(batch, seq_len, num_heads, d_k).transpose(1, 2)

Mistake 4: d_model Not Divisible by num_heads

Divisibility Check
🐍common_mistakes.py
5 lines without explanation
1d_model = 512
2num_heads = 6  # 512 / 6 = 85.33... not an integer!
3
4# βœ… Always check
5assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

Efficient Implementation

Using einops (Optional)

The einops library makes reshaping more readable:

Using einops for Readable Reshaping
🐍einops_example.py
8 lines without explanation
1# pip install einops
2from einops import rearrange
3
4# Split heads
5Q_heads = rearrange(Q, 'b s (h d) -> b h s d', h=num_heads)
6
7# Combine heads
8combined = rearrange(attention_output, 'b h s d -> b s (h d)')

Performance Tips

  1. Minimize reshapes: Combine operations when possible
  2. Avoid unnecessary copies: Use contiguous() only when needed
  3. Batch operations: Always include batch dimension in matmul

Summary

The Complete Pattern

Complete split_heads and combine_heads Pattern
🐍reshape_utils.py
10 lines without explanation
1# Split heads: [batch, seq, d_model] β†’ [batch, heads, seq, d_k]
2def split_heads(x, num_heads):
3    batch, seq_len, d_model = x.shape
4    d_k = d_model // num_heads
5    return x.view(batch, seq_len, num_heads, d_k).transpose(1, 2)
6
7# Combine heads: [batch, heads, seq, d_k] β†’ [batch, seq, d_model]
8def combine_heads(x):
9    batch, heads, seq_len, d_k = x.shape
10    return x.transpose(1, 2).contiguous().view(batch, seq_len, heads * d_k)

Shape Flow Summary

Shape Flow Summary
πŸ“text
6 lines without explanation
1Input:          [batch, seq_len, d_model]
2After proj:     [batch, seq_len, d_model]
3After split:    [batch, num_heads, seq_len, d_k]
4After attn:     [batch, num_heads, seq_len, d_k]
5After combine:  [batch, seq_len, d_model]
6After W_O:      [batch, seq_len, d_model]

Exercises

Implementation Exercises

  1. Implement split_heads and combine_heads using torch.reshape() instead of view().
  2. Write a test that verifies combine_heads(split_heads(x)) == x (up to floating point precision).
  3. Implement using einops.rearrange() and compare readability.

Debugging Exercises

  1. What error do you get if d_model is not divisible by num_heads? Write code to catch this.
  2. Create a tensor, transpose it, then try to view it without contiguous(). What's the error message?
  3. Print intermediate shapes at every step of multi-head attention to verify your understanding.