Chapter 2
25 min read
Section 3 of 17

Multi-Head Attention

Multi-Head Attention

Learning Objectives

By the end of this chapter, you will be able to:

  1. Explain why a single attention head creates a representational bottleneck and why multiple heads running in parallel overcome it.
  2. Derive the multi-head attention formula MultiHead(Q,K,V)=Concat(head1,,headH)WO\text{MultiHead}(Q,K,V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_H) W^O from first principles, understanding every symbol and matrix dimension.
  3. Compute per-head attention weights and outputs by hand for our shared 5-token sentence using H=2H = 2 heads.
  4. Compare how different heads learn different attention patterns from the same input — syntactic relationships in one head, semantic proximity in another.
  5. Implement a complete, runnable multi-head attention class in both NumPy and PyTorch that you can use to simulate any number of heads on any input.
  6. Connect multi-head attention to Flash Attention, KV-cache, positional encodings, and transformer scaling laws.
Where this appears: Multi-head attention is the standard attention module used in virtually every transformer — GPT-4, Claude, Gemini, LLaMA, ViT, DALL-E, AlphaFold, and Codex all use multi-head attention as their core computational building block. A single GPT-4 layer uses 96 or more heads. Understanding how and why heads split the representation is essential for understanding every modern architecture.

The Real Problem

The Single-Head Bottleneck

In Chapter 1, we computed scaled dot-product attention using a single set of QQ, KK, VV matrices. The result was a single 5×55 \times 5 attention weight matrix — one pattern describing which tokens attend to which.

But a single pattern is a severe bottleneck. Consider the word “cat” in “The cat sat on mat”. A human instantly recognizes that “cat” is:

  1. Syntactically the subject of the verb “sat” (subject-verb relationship)
  2. Grammatically a noun that follows the determiner “The” (determiner-noun agreement)
  3. Semantically an animate entity related to the location “mat” (entity-location relation)
  4. Positionally close to “The” and “sat” (local context)

A single attention head computes one weighted combination of all keys for each query. In our Chapter 1 computation, “cat” assigned weight 0.4026 to “The”, 0.2442 to “sat”, and only 0.0898 to itself. That is one compromise pattern — it cannot simultaneously attend heavily to “The” (for grammar) and to “sat” (for semantics) through separate, specialized lenses. The softmax forces a single probability distribution over all keys.

Vaswani et al. (2017) articulated this limitation precisely in the original transformer paper: “Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.”

The Committee Analogy

Think of multi-head attention as a committee of experts reviewing a document. A single reviewer reads the entire document through one lens and produces one summary. A committee of HH experts — a grammarian, a semanticist, a coreference resolver, a positional tracker — each reads the same document but focuses on their specialty. Their reports are concatenated into a richer, more nuanced understanding than any single reviewer could produce.

Crucially, the experts do not talk to each other while reading. Each head operates independently and in parallel on a different subspace of the representation. Only after all heads produce their outputs are the results combined.


From Intuition to Mathematics

The Projection Idea

The core mathematical insight is that a dmodeld_{\text{model}}-dimensional vector space can be decomposed into HH lower-dimensional subspaces. If dmodel=512d_{\text{model}} = 512 and H=8H = 8, then each head operates in a dk=512/8=64d_k = 512/8 = 64-dimensional subspace.

Each head ii gets its own set of learned projection matrices: WiQRdmodel×dkW_i^Q \in \mathbb{R}^{d_{\text{model}} \times d_k}, WiKRdmodel×dkW_i^K \in \mathbb{R}^{d_{\text{model}} \times d_k}, WiVRdmodel×dvW_i^V \in \mathbb{R}^{d_{\text{model}} \times d_v}. These projections are learned during training — the model discovers which subspaces are useful for which kinds of relationships.

Split vs. Learned Projection

In our teaching example, we split the dimensions directly: Head 1 gets dims [0,1][0,1], Head 2 gets dims [2,3][2,3]. This is equivalent to using identity projections where W1QW_1^Q selects columns 0-1 and W2QW_2^Q selects columns 2-3. In real transformers, the projection matrices are dense learned parameters that can mix dimensions in arbitrary ways, discovering subspaces not aligned with any coordinate axis.

Why the split is mathematically equivalent

A general WiQRdmodel×dkW_i^Q \in \mathbb{R}^{d_{\text{model}} \times d_k} can be any linear map. Setting it to a selection matrix (zeros and ones) that picks certain columns is a special case. The splitting trick lets us do multi-head on our tiny example without introducing new parameters, but the mechanism is identical.

The Mathematical Definition

The multi-head attention formula from Vaswani et al. (2017):

MultiHead(Q,K,V)=Concat(head1,,headH)WO\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_H) \, W^O

where headi=Attention(QWiQ,  KWiK,  VWiV)\text{head}_i = \text{Attention}(Q W_i^Q, \; K W_i^K, \; V W_i^V)

Symbol-by-Symbol Breakdown

SymbolShapeMeaning
Q,K,VQ, K, V(N,dmodel)(N, d_{\text{model}})Input query, key, value matrices (same as Chapter 1)
HHscalarNumber of parallel attention heads
dk=dmodel/Hd_k = d_{\text{model}} / HscalarPer-head key/query dimension
WiQ,WiKW_i^Q, W_i^K(dmodel,dk)(d_{\text{model}}, d_k)Per-head query/key projection (learned)
WiVW_i^V(dmodel,dv)(d_{\text{model}}, d_v)Per-head value projection (learned)
headi\text{head}_i(N,dv)(N, d_v)Output of the i-th attention head
Concat(N,Hdv)(N, H \cdot d_v)Column-wise concatenation of all head outputs
WOW^O(Hdv,dmodel)(H \cdot d_v, d_{\text{model}})Output projection matrix (mixes information across heads)

What the Formula Says in Plain English

  1. Project the input into HH different subspaces using learned weight matrices WiQ,WiK,WiVW_i^Q, W_i^K, W_i^V.
  2. Attend independently in each subspace using the scaled dot-product attention from Chapter 1: softmax(QhKh/dk)Vh\text{softmax}(Q_h K_h^\top / \sqrt{d_k}) \, V_h.
  3. Concatenate the HH output vectors into a single vector of the original dimensionality.
  4. Project through WOW^O to allow the model to mix and recombine information from different heads.

Why Multiple Heads Work

The Representational Subspace Argument

Each head operates on a dkd_k-dimensional subspace. In this subspace, the dot product QhKhQ_h K_h^\top measures a different notion of similarity than the full-space dot product. A token pair that has a high score in Head 1's subspace (capturing syntactic proximity) may have a low score in Head 2's subspace (capturing semantic type).

In our example, consider the “on” token. In Head 1 (dims 0-1), Qon=[0.0,0.0]Q_{\text{on}} = [0.0, 0.0] — the zero vector. Every dot product is zero, so Head 1 assigns uniform attention (0.2 to each token). But in Head 2 (dims 2-3), Qon=[1.0,1.0]Q_{\text{on}} = [1.0, 1.0] actively attends to “on” itself (0.3673) where the key strongly matches. Two completely different behaviors from the same token, enabled by subspace specialization.

Capacity vs. Specialization Trade-off

There is a fundamental trade-off: more heads means more specialization, but each head has fewer dimensions to work with. With H=1H = 1, the single head has dk=4d_k = 4 dimensions for computing similarity. With H=4H = 4, each head has only dk=1d_k = 1 — a scalar dot product with very limited discriminative power. The sweet spot depends on the task and model size. In practice, dk=64d_k = 64 or dk=128d_k = 128 per head is typical.

ConfigurationHd_kPatternsPer-Head Capacity
Our Chapter 1141 patternFull 4D space
Our Chapter 2222 patterns2D subspace each
Extreme split414 patternsScalar dot product
GPT-3 (d=12288)9612896 patterns128D subspace each
LLaMA-2 70B6412864 patterns128D subspace each

Interactive: Head Splitting Pipeline

Step through the multi-head attention pipeline below. Click each step to see how the full matrices are split into per-head subspaces, processed independently, and concatenated back.

Loading head splitting visualizer...

Step-by-Step Calculation

We now compute multi-head attention with H=2H = 2 heads on our shared example. Head 1 operates on dims [0,1][0, 1] (so dk=2d_k = 2) and Head 2 operates on dims [2,3][2, 3]. The scaling factor for each head is dk=21.4142\sqrt{d_k} = \sqrt{2} \approx 1.4142.

Token: “The” (row 0)

Head 1Qh1[0]=[1.0,0.0]Q_{h1}[0] = [1.0, 0.0], operating on Kh1K_{h1} (columns 0-1 of K):

Key TokenDot ProductScaled (/√2)After Softmax
The1.0×0.0+0.0×1.0=0.01.0 \times 0.0 + 0.0 \times 1.0 = 0.00.00000.1237
cat1.0×1.0+0.0×0.0=1.01.0 \times 1.0 + 0.0 \times 0.0 = 1.00.70710.2509
sat1.0×1.0+0.0×1.0=1.01.0 \times 1.0 + 0.0 \times 1.0 = 1.00.70710.2509
on1.0×0.0+0.0×0.0=0.01.0 \times 0.0 + 0.0 \times 0.0 = 0.00.00000.1237
mat1.0×1.0+0.0×0.0=1.01.0 \times 1.0 + 0.0 \times 0.0 = 1.00.70710.2509

Head 1 output for “The”: Oh1[0]=[0.2491,0.3763]O_{h1}[0] = [0.2491, 0.3763]

Head 2Qh2[0]=[1.0,0.0]Q_{h2}[0] = [1.0, 0.0], operating on Kh2K_{h2} (columns 2-3 of K):

Key TokenDot ProductScaled (/√2)After Softmax
The1.0×0.0+0.0×1.0=0.01.0 \times 0.0 + 0.0 \times 1.0 = 0.00.00000.1337
cat1.0×1.0+0.0×0.0=1.01.0 \times 1.0 + 0.0 \times 0.0 = 1.00.70710.2711
sat1.0×0.0+0.0×0.0=0.01.0 \times 0.0 + 0.0 \times 0.0 = 0.00.00000.1337
on1.0×1.0+0.0×1.0=1.01.0 \times 1.0 + 0.0 \times 1.0 = 1.00.70710.2711
mat1.0×0.5+0.0×0.5=0.51.0 \times 0.5 + 0.0 \times 0.5 = 0.50.35360.1904

Head 2 output for “The”: Oh2[0]=[0.2289,0.3663]O_{h2}[0] = [0.2289, 0.3663]

Key observation: Head 1 distributes attention equally across “cat”, “sat”, and “mat” (all 0.2509). Head 2 focuses on “cat” and “on” (both 0.2711) while giving “sat” much less weight (0.1337). These are genuinely different attention patterns extracted from different dimensions of the same representation.

Concatenation: O[0]=[Oh1[0],Oh2[0]]=[0.2491,0.3763,0.2289,0.3663]O[0] = [O_{h1}[0], O_{h2}[0]] = [0.2491, 0.3763, 0.2289, 0.3663]

Token: “cat” (row 1)

Head 1: Qh1[1]=[0.0,2.0]Q_{h1}[1] = [0.0, 2.0] — this query strongly activates dim 1, matching keys with high values in dim 1.

Key TokenDot ProductScaledAfter Softmax
The0.0×0.0+2.0×1.0=2.00.0 \times 0.0 + 2.0 \times 1.0 = 2.01.41420.3664
cat0.0×1.0+2.0×0.0=0.00.0 \times 1.0 + 2.0 \times 0.0 = 0.00.00000.0891
sat0.0×1.0+2.0×1.0=2.00.0 \times 1.0 + 2.0 \times 1.0 = 2.01.41420.3664
on0.0×0.0+2.0×0.0=0.00.0 \times 0.0 + 2.0 \times 0.0 = 0.00.00000.0891
mat0.0×1.0+2.0×0.0=0.00.0 \times 1.0 + 2.0 \times 0.0 = 0.00.00000.0891

Head 1: “cat” strongly attends to “The” (0.3664) and “sat” (0.3664) — its grammatical neighbours. Output: Oh1[1]=[0.4109,0.1336]O_{h1}[1] = [0.4109, 0.1336].

Head 2: Qh2[1]=[0.0,1.0]Q_{h2}[1] = [0.0, 1.0] — activates dim 3 (the second dim in Head 2's subspace).

Key TokenDot ProductScaledAfter Softmax
The0.0×0.0+1.0×1.0=1.00.0 \times 0.0 + 1.0 \times 1.0 = 1.00.70710.2711
cat0.0×1.0+1.0×0.0=0.00.0 \times 1.0 + 1.0 \times 0.0 = 0.00.00000.1337
sat0.0×0.0+1.0×0.0=0.00.0 \times 0.0 + 1.0 \times 0.0 = 0.00.00000.1337
on0.0×1.0+1.0×1.0=1.00.0 \times 1.0 + 1.0 \times 1.0 = 1.00.70710.2711
mat0.0×0.5+1.0×0.5=0.50.0 \times 0.5 + 1.0 \times 0.5 = 0.50.35360.1904

Head 2: “cat” attends to “The” (0.2711) and “on” (0.2711) — different from Head 1. Output: Oh2[1]=[0.2289,0.3663]O_{h2}[1] = [0.2289, 0.3663].

Concatenation

For each token, concatenate the two head outputs to recover the original dmodel=4d_{\text{model}} = 4 dimensionality:

TokenHead 1 OutputHead 2 OutputConcatenated
The[0.2491, 0.3763][0.2289, 0.3663][0.2491, 0.3763, 0.2289, 0.3663]
cat[0.4109, 0.1336][0.2289, 0.3663][0.4109, 0.1336, 0.2289, 0.3663]
sat[0.2717, 0.2717][0.2289, 0.3663][0.2717, 0.2717, 0.2289, 0.3663]
on[0.3000, 0.3000][0.1799, 0.4579][0.3000, 0.3000, 0.1799, 0.4579]
mat[0.2491, 0.3763][0.2289, 0.3663][0.2491, 0.3763, 0.2289, 0.3663]

Where is W^O?

In our simplified example, we skip the final WOW^O projection. In production transformers, WORdmodel×dmodelW^O \in \mathbb{R}^{d_{\text{model}} \times d_{\text{model}}} is a learned matrix that mixes information across heads. Without it, each dimension of the output would only carry information from one head. The WOW^O projection is what allows the final representation to integrate insights from all heads.

Full Attention Weights and Output

Head 1 Attention Weights (5×55 \times 5)

Thecatsatonmat
The0.12370.25090.25090.12370.2509
cat0.36640.08910.36640.08910.0891
sat0.18110.18110.36730.08930.1811
on0.20000.20000.20000.20000.2000
mat0.12370.25090.25090.12370.2509

Head 1 patterns: “cat” strongly attends to “The” and “sat” (both 0.3664). “sat” attends most to itself (0.3673). “on” is perfectly uniform (0.2000) because Qon=[0,0]Q_{\text{on}} = [0, 0] in this subspace.

Head 2 Attention Weights (5×55 \times 5)

Thecatsatonmat
The0.13370.27110.13370.27110.1904
cat0.27110.13370.13370.27110.1904
sat0.13370.27110.13370.27110.1904
on0.18110.18110.08930.36730.1811
mat0.27110.13370.13370.27110.1904

Head 2 patterns: “on” now strongly attends to itself (0.3673) — completely different from its uniform pattern in Head 1. “sat” attends to “cat” and “on” (both 0.2711), not to itself.

Averaged Attention Weights (for comparison)

Thecatsatonmat
The0.12870.26100.19230.19740.2206
cat0.31880.11140.25000.18010.1397
sat0.15740.22610.25050.18020.1858
on0.19060.19060.14470.28370.1906
mat0.19740.19230.19230.19740.2206

Output Matrix — Multi-Head (5×45 \times 4)

dim-0dim-1dim-2dim-3
The0.24910.37630.22890.3663
cat0.41090.13360.22890.3663
sat0.27170.27170.22890.3663
on0.30000.30000.17990.4579
mat0.24910.37630.22890.3663

Compare with the single-head output from Chapter 1. The multi-head output has dims 0-1 shaped by Head 1's attention pattern and dims 2-3 shaped by Head 2's pattern. Each pair of dimensions carries contextual information computed through a different lens.


Interactive: Head Comparison Heatmap

Hover any cell to compare attention weights across Head 1, Head 2, single-head, and the difference between heads. Switch between side-by-side and individual views to see how each head specializes.

Loading head comparison heatmap...

Interactive: Head Count Explorer

Change the number of heads to see how splitting affects attention patterns. With H=1H = 1, you get one 4D pattern (Chapter 1). With H=4H = 4, each head has only dk=1d_k = 1 — observe how limited scalar dot products produce repetitive patterns.

Loading head count explorer...

Applications Across Domains

Natural Language Processing

In large language models, different heads specialize for different linguistic phenomena. Clark et al. (2019) found that in BERT: certain heads track subject-verb agreement (head 8-10 in layer 6), others resolve coreference (“she” → “Mary”), and others attend to the previous token (positional heads). Voita et al. (2019) showed that pruning task-irrelevant heads has minimal impact on accuracy, confirming that heads truly specialize.

Computer Vision

In Vision Transformers (ViT), image patches become tokens. Dosovitskiy et al. (2021) observed that some heads attend to nearby patches (local texture), others attend to distant patches with similar color (global color coherence), and still others track edges and contours. This mirrors how CNNs learn different filters — but the multi-head mechanism learns these patterns from data rather than requiring architectural inductive biases.

Code Generation

In code models like Codex and CodeLLaMA, heads specialize for code-specific patterns: variable-definition tracking (where was xx defined?), bracket matching (which opening bracket does this closing bracket match?), type flow (what type does this expression evaluate to?), and import resolution. The parallel nature of multi-head attention lets the model track all of these simultaneously.

Scientific Modeling

In AlphaFold's structure module, multi-head attention over amino acid residues lets different heads capture: spatial proximity in 3D (nearby in folded structure), sequence proximity (nearby in primary sequence), evolutionary co-variation (residues that mutate together), and chemical interactions (hydrogen bonds, hydrophobic contacts). The 48-head architecture captures fundamentally different notions of “relevance” between residue pairs.


Connection to Modern Systems

Flash Attention

Flash Attention (Dao et al., 2022) does not change the mathematics of multi-head attention — it produces bit-identical results. What it changes is the memory access pattern. Standard multi-head attention materializes the full (N×N)(N \times N) score matrix for each head, requiring O(HN2)O(H \cdot N^2) memory. Flash Attention tiles the computation so that each SRAM block computes a partial softmax, never materializing the full matrix. For H=96H = 96 heads and N=8192N = 8192, this reduces peak memory from ~6 GB to ~50 MB.

KV-Cache Optimization

During autoregressive generation, the KV-cache stores previously computed key and value tensors. With multi-head attention, each head has its own KV-cache of shape (seq_len,dk)(\text{seq\_len}, d_k). The total cache size is 2×H×L×N×dk2 \times H \times L \times N \times d_k (2 for K and V, LL layers, NN tokens). For GPT-3 with 96 heads, 96 layers, and 2048 tokens, the KV-cache alone requires ~3 GB in fp16. This memory pressure is what motivates Multi-Query Attention (Chapter 3) and Grouped-Query Attention (Chapter 4), which share keys/values across heads.

Positional Encodings (RoPE)

Rotary Position Embeddings (Su et al., 2021) apply per-head position-dependent rotations to Q and K before the dot product. Because each head operates in its owndkd_k-dimensional subspace, RoPE rotates pairs of dimensions within each head at different frequencies. Lower-frequency rotations in some heads capture long-range position, while higher-frequency rotations in others capture local position. The multi-head structure is what makes this frequency decomposition possible.

Transformer Scaling

As model size grows, both dmodeld_{\text{model}} and HH increase. GPT-3 (175B) uses dmodel=12288d_{\text{model}} = 12288 with H=96H = 96, giving dk=128d_k = 128. LLaMA-2 70B uses dmodel=8192d_{\text{model}} = 8192 with H=64H = 64 heads, also dk=128d_k = 128. The scaling laws (Kaplan et al., 2020) show that increasing both width and heads together improves sample efficiency, but the per-head dimension dkd_k rarely drops below 64 in practice to maintain per-head expressiveness.


Complexity Analysis

MetricSingle-Head (Ch. 1)Multi-Head (Ch. 2)
Time complexityO(N2dmodel)O(N^2 \cdot d_{\text{model}})O(N2dmodel)O(N^2 \cdot d_{\text{model}}) (same total)
Per-head timeO(N2dmodel)O(N^2 \cdot d_{\text{model}})O(N2dk)O(N^2 \cdot d_k) (per head)
Memory for scoresO(N2)O(N^2)O(HN2)O(H \cdot N^2)
Parallelism1 operationHH independent operations
Parameters (projections)3dmodel23 \cdot d_{\text{model}}^23dmodel2+dmodel23 \cdot d_{\text{model}}^2 + d_{\text{model}}^2 (includes WOW^O)

The key insight: multi-head attention has the same asymptotic time complexity as single-head attention because Hdk=dmodelH \cdot d_k = d_{\text{model}}. What changes is that the HH per-head computations can run in parallel on GPUs, and the model gains HH independent attention patterns instead of one.


Python Implementation

The full Python class below implements multi-head attention with the same structure as Chapter 1. Click any line to see its execution trace with actual values.

Multi-Head Attention — NumPy Implementation
🐍multi_head_attention.py
1import numpy as np

NumPy provides vectorized matrix operations. All matrix multiplications (Q @ K.T) execute as optimized C code, not Python loops.

2import math

Python standard library. We use math.sqrt() to precompute the scaling factor for each head.

4class MultiHeadAttention

Wraps multi-head attention in a reusable class. Compared to Chapter 1 (single-head), this class adds head splitting, per-head attention, and concatenation. Every chapter follows this same class structure so you can compare implementations side by side.

16def __init__(self, d_model, num_heads)

Constructor. Takes the total model dimension (d_model=4) and number of heads (H=2). Computes per-head dimension d_k = d_model / H and precomputes the scaling factor.

EXECUTION STATE
⬇ input: d_model = 4
⬇ input: num_heads = 2
22self.d_model = d_model

Store the total model dimension. This is the full width of Q, K, V before splitting.

EXECUTION STATE
self.d_model = 4
23self.num_heads = num_heads

Store H, the number of parallel attention heads.

EXECUTION STATE
self.num_heads = 2
24self.d_k = d_model // num_heads

Per-head dimension. Each head operates on d_k dimensions of the representation. d_model must be divisible by H.

EXECUTION STATE
d_model // num_heads = 4 // 2 = 2
self.d_k = 2
25self.scale = math.sqrt(self.d_k)

Precompute sqrt(d_k) once. Each head divides its dot products by this value. With d_k=2, the scale is smaller than Chapter 1 (sqrt(4)=2.0), meaning less compression per head.

EXECUTION STATE
math.sqrt(2) = 1.4142
self.scale = 1.4142
27def _softmax(self, x) -> np.ndarray

Numerically stable softmax. Takes a matrix of scaled scores and returns probabilities where each row sums to 1.0. Identical to the Chapter 1 implementation.

EXECUTION STATE
⬇ input: x = shape (5, 5) — scaled scores for one head
⬆ returns = np.ndarray (5, 5) — softmax probabilities per row
29x_shifted = x - np.max(x, axis=-1, keepdims=True)

Subtract row-wise max for numerical stability. exp(500) overflows, but exp(0) = 1.0. The subtraction does not change the softmax result because the constant cancels in the ratio.

EXECUTION STATE
axis=-1 = operate along the LAST axis — find max of each row independently, not the global max.
keepdims=True = keep shape (5,1) so broadcasting x(5×5) - max(5×1) works — each row subtracts its own max.
30exp_x = np.exp(x_shifted)

Exponentiate every element. The largest value per row becomes exp(0)=1.0 — no overflow.

31return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

Divide each element by its row sum to normalize into probabilities. Each row sums to 1.0.

EXECUTION STATE
axis=-1 = sum along last axis — sum each row independently.
keepdims=True = sum returns (5,1) so broadcasting exp_x(5×5) / sum(5×1) works correctly.
33def split_heads(self, M) -> list

Split a (N, d_model) matrix into H matrices of (N, d_k) each. This is the core mechanical step that enables multi-head attention — each head gets its own subspace slice.

EXECUTION STATE
⬇ input: M (5×4) = any of Q, K, or V — full d_model=4 width
⬆ returns = list of 2 matrices, each (5, 2)
35heads = []

Accumulator list. Will hold H=2 sub-matrices after the loop.

EXECUTION STATE
heads = [] (empty list)
36for h in range(self.num_heads):

Loop over H=2 heads. Each iteration computes the column slice for one head.

LOOP TRACE · 2 iterations
h=0 (Head 1)
start = 0 * 2 = 0
end = 0 + 2 = 2
M[:, 0:2] = columns 0 and 1 — first subspace
h=1 (Head 2)
start = 1 * 2 = 2
end = 2 + 2 = 4
M[:, 2:4] = columns 2 and 3 — second subspace
37start = h * self.d_k

Starting column index for this head. Head 0 starts at 0, Head 1 at 2.

38end = start + self.d_k

Ending column index (exclusive). Head 0 takes [0:2], Head 1 takes [2:4].

39heads.append(M[:, start:end])

Slice all rows, columns start:end. This is a view (no copy) in NumPy, so it is O(1).

EXECUTION STATE
[:, start:end] = NumPy slice — all rows (:), columns start to end-1. Returns a (5, d_k) sub-matrix.
40return heads

Return the list of H sub-matrices.

EXECUTION STATE
⬆ return: heads = list of 2 arrays, each shape (5, 2)
42def attention(self, Qh, Kh, Vh)

Single-head scaled dot-product attention. Identical to the full Chapter 1 computation, but now operating on a d_k=2 subspace instead of d_model=4.

EXECUTION STATE
⬇ input: Qh (5×2) = query sub-matrix for one head
⬇ input: Kh (5×2) = key sub-matrix for one head
⬇ input: Vh (5×2) = value sub-matrix for one head
⬆ returns = (weights, output) — shapes (5,5) and (5,2)
44scores = Qh @ Kh.T

Matrix multiply Qh (5×2) with Kh transposed (2×5). Each entry (i,j) is the dot product of query_i with key_j in this head's subspace.

EXECUTION STATE
Qh @ Kh.T = (5×2) @ (2×5) → (5×5)
── Head 1 scores ── =
scores (Head 1) =
      The   cat   sat    on   mat
The  0.00  1.00  1.00  0.00  1.00
cat  2.00  0.00  2.00  0.00  0.00
sat  1.00  1.00  2.00  0.00  1.00
on   0.00  0.00  0.00  0.00  0.00
mat  0.00  1.00  1.00  0.00  1.00
── Head 2 scores ── =
scores (Head 2) =
      The   cat   sat    on   mat
The  0.00  1.00  0.00  1.00  0.50
cat  1.00  0.00  0.00  1.00  0.50
sat  0.00  1.00  0.00  1.00  0.50
on   1.00  1.00  0.00  2.00  1.00
mat  1.00  0.00  0.00  1.00  0.50
45scaled = scores / self.scale

Divide all scores by sqrt(d_k) = sqrt(2) = 1.4142. Compared to Chapter 1 where we divided by sqrt(4) = 2.0, here we divide by a smaller number — less compression because d_k is smaller.

EXECUTION STATE
self.scale = 1.4142 (√2, not √4)
── Head 1 scaled ── =
scaled (Head 1) =
       The     cat     sat      on     mat
The  0.0000  0.7071  0.7071  0.0000  0.7071
cat  1.4142  0.0000  1.4142  0.0000  0.0000
sat  0.7071  0.7071  1.4142  0.0000  0.7071
on   0.0000  0.0000  0.0000  0.0000  0.0000
mat  0.0000  0.7071  0.7071  0.0000  0.7071
── Head 2 scaled ── =
scaled (Head 2) =
       The     cat     sat      on     mat
The  0.0000  0.7071  0.0000  0.7071  0.3536
cat  0.7071  0.0000  0.0000  0.7071  0.3536
sat  0.0000  0.7071  0.0000  0.7071  0.3536
on   0.7071  0.7071  0.0000  1.4142  0.7071
mat  0.7071  0.0000  0.0000  0.7071  0.3536
46weights = self._softmax(scaled)

Apply softmax row-wise to get attention probabilities. Each row sums to 1.0. These are the two different attention patterns — the whole point of multi-head.

EXECUTION STATE
── Head 1 weights ── =
weights (Head 1) =
       The      cat      sat       on      mat
The  0.1237   0.2509   0.2509   0.1237   0.2509
cat  0.3664   0.0891   0.3664   0.0891   0.0891
sat  0.1811   0.1811   0.3673   0.0893   0.1811
on   0.2000   0.2000   0.2000   0.2000   0.2000
mat  0.1237   0.2509   0.2509   0.1237   0.2509
── Head 2 weights ── =
weights (Head 2) =
       The      cat      sat       on      mat
The  0.1337   0.2711   0.1337   0.2711   0.1904
cat  0.2711   0.1337   0.1337   0.2711   0.1904
sat  0.1337   0.2711   0.1337   0.2711   0.1904
on   0.1811   0.1811   0.0893   0.3673   0.1811
mat  0.2711   0.1337   0.1337   0.2711   0.1904
47output = weights @ Vh

Weighted sum of value vectors. Each output row is a blend of all 5 value vectors (in this head's 2D subspace), weighted by attention probabilities.

EXECUTION STATE
── Head 1 output ── =
output (Head 1) =
       d0       d1
The  0.2491   0.3763
cat  0.4109   0.1336
sat  0.2717   0.2717
on   0.3000   0.3000
mat  0.2491   0.3763
── Head 2 output ── =
output (Head 2) =
       d2       d3
The  0.2289   0.3663
cat  0.2289   0.3663
sat  0.2289   0.3663
on   0.1799   0.4579
mat  0.2289   0.3663
48return weights, output

Return the attention weight matrix (5×5) for visualization and the context-enriched output (5×2) for concatenation.

EXECUTION STATE
⬆ return: weights = shape (5, 5) — attention probabilities
⬆ return: output = shape (5, 2) — context-enriched per-head output
50def forward(self, Q, K, V)

Full multi-head forward pass. Splits Q, K, V into heads, runs attention on each, then concatenates outputs. This is the main entry point.

EXECUTION STATE
⬇ input: Q (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
⬇ input: K (5×4) =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
⬇ input: V (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
⬆ returns = (all_weights, output) — list of H weight matrices + concatenated output (5,4)
62Q_heads = self.split_heads(Q)

Split Q (5×4) into [Q_h1 (5×2), Q_h2 (5×2)]. Head 1 gets dims [0:2], Head 2 gets dims [2:4].

EXECUTION STATE
Q_heads[0] (5×2) =
      d0   d1
The  1.0  0.0
cat  0.0  2.0
sat  1.0  1.0
on   0.0  0.0
mat  1.0  0.0
Q_heads[1] (5×2) =
      d2   d3
The  1.0  0.0
cat  0.0  1.0
sat  1.0  0.0
on   1.0  1.0
mat  0.0  1.0
63K_heads = self.split_heads(K)

Split K (5×4) into [K_h1 (5×2), K_h2 (5×2)].

EXECUTION STATE
K_heads[0] (5×2) =
      d0   d1
The  0.0  1.0
cat  1.0  0.0
sat  1.0  1.0
on   0.0  0.0
mat  1.0  0.0
K_heads[1] (5×2) =
      d2   d3
The  0.0  1.0
cat  1.0  0.0
sat  0.0  0.0
on   1.0  1.0
mat  0.5  0.5
64V_heads = self.split_heads(V)

Split V (5×4) into [V_h1 (5×2), V_h2 (5×2)].

EXECUTION STATE
V_heads[0] (5×2) =
      d0   d1
The  1.0  0.0
cat  0.0  1.0
sat  0.0  0.0
on   0.0  0.0
mat  0.5  0.5
V_heads[1] (5×2) =
      d2   d3
The  0.0  0.0
cat  0.0  0.0
sat  1.0  0.0
on   0.0  1.0
mat  0.5  0.5
69for h in range(self.num_heads):

Loop over H=2 heads. Each iteration computes independent attention on one subspace.

LOOP TRACE · 2 iterations
h=0 (Head 1)
Q_heads[0] = dims [0:2] — shape (5, 2)
K_heads[0] = dims [0:2] — shape (5, 2)
V_heads[0] = dims [0:2] — shape (5, 2)
w (Head 1 weights) = shape (5, 5) — cat→The is 0.3664 (highest)
o (Head 1 output) = shape (5, 2)
h=1 (Head 2)
Q_heads[1] = dims [2:4] — shape (5, 2)
K_heads[1] = dims [2:4] — shape (5, 2)
V_heads[1] = dims [2:4] — shape (5, 2)
w (Head 2 weights) = shape (5, 5) — on→on is 0.3673 (highest)
o (Head 2 output) = shape (5, 2)
70w, o = self.attention(Q_heads[h], K_heads[h], V_heads[h])

Call single-head attention for this head. Returns weights (5×5) and output (5×2). Each head sees a completely different slice of the representation.

71all_weights.append(w)

Collect this head's attention weight matrix for later visualization/analysis.

72all_outputs.append(o)

Collect this head's (5×2) output for concatenation.

74output = np.hstack(all_outputs)

Concatenate Head 1 output (5×2) with Head 2 output (5×2) along the column axis to get (5×4). This is the Concat() operation in the formula. The result has the same shape as the original input.

EXECUTION STATE
np.hstack() = horizontal stack — concatenate arrays column-wise. [shape(5,2), shape(5,2)] → shape(5,4).
⬆ return: output (5×4) =
        d0       d1       d2       d3
The  0.2491   0.3763   0.2289   0.3663
cat  0.4109   0.1336   0.2289   0.3663
sat  0.2717   0.2717   0.2289   0.3663
on   0.3000   0.3000   0.1799   0.4579
mat  0.2491   0.3763   0.2289   0.3663
75return all_weights, output

Return all per-head weight matrices (for visualization) and the concatenated output (for the next layer).

EXECUTION STATE
⬆ return: all_weights = list of 2 matrices, each (5, 5)
⬆ return: output = shape (5, 4) — concatenated multi-head output
77def explain(self, Q, K, V, tokens, query_idx=0)

Diagnostic function. Recomputes multi-head attention and prints a detailed per-head trace for one query token. Returns nothing — prints to stdout.

EXECUTION STATE
⬇ input: Q, K, V = same shared example matrices (5×4 each)
⬇ input: tokens = ['The', 'cat', 'sat', 'on', 'mat']
⬇ input: query_idx = 0 → traces token 'The'
79all_weights, output = self.forward(Q, K, V)

Run the full multi-head pipeline. Needed because explain() is self-contained.

EXECUTION STATE
all_weights = list of 2 weight matrices (5×5 each)
output = shape (5, 4) — concatenated
80Q_heads = self.split_heads(Q)

Re-split to access per-head Q vectors for printing.

81K_heads = self.split_heads(K)

Re-split to access per-head K vectors for printing.

83token = tokens[query_idx]

Look up the token name for the index we are tracing.

EXECUTION STATE
token = 'The'
86for h in range(self.num_heads): — trace loop

Loop over each head and print its dot products and weights for the query token.

LOOP TRACE · 2 iterations
h=0 (Head 1)
Q_h1[The] = [1.0, 0.0]
Q[The]·K[The] = 1.0×0.0 + 0.0×1.0 = 0.00 → /1.41 = 0.0000
Q[The]·K[cat] = 1.0×1.0 + 0.0×0.0 = 1.00 → /1.41 = 0.7071
Q[The]·K[sat] = 1.0×1.0 + 0.0×1.0 = 1.00 → /1.41 = 0.7071
Q[The]·K[on] = 1.0×0.0 + 0.0×0.0 = 0.00 → /1.41 = 0.0000
Q[The]·K[mat] = 1.0×1.0 + 0.0×0.0 = 1.00 → /1.41 = 0.7071
weights[The] = [0.1237, 0.2509, 0.2509, 0.1237, 0.2509]
h=1 (Head 2)
Q_h2[The] = [1.0, 0.0]
Q[The]·K[The] = 1.0×0.0 + 0.0×1.0 = 0.00 → /1.41 = 0.0000
Q[The]·K[cat] = 1.0×1.0 + 0.0×0.0 = 1.00 → /1.41 = 0.7071
Q[The]·K[sat] = 1.0×0.0 + 0.0×0.0 = 0.00 → /1.41 = 0.0000
Q[The]·K[on] = 1.0×1.0 + 0.0×1.0 = 1.00 → /1.41 = 0.7071
Q[The]·K[mat] = 1.0×0.5 + 0.0×0.5 = 0.50 → /1.41 = 0.3536
weights[The] = [0.1337, 0.2711, 0.1337, 0.2711, 0.1904]
95print concatenated output

Print the final 4-dimensional output vector for the traced token — the combination of both heads' views.

EXECUTION STATE
output[The] = [0.2491, 0.3763, 0.2289, 0.3663]
99tokens = [...]

The 5 tokens used in every chapter. Identical to Chapter 1.

EXECUTION STATE
tokens = ['The', 'cat', 'sat', 'on', 'mat']
101Q = np.array([...])

Query matrix — identical to Chapter 1. Each row is what that token 'looks for'.

EXECUTION STATE
Q =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
109K = np.array([...])

Key matrix — identical to Chapter 1. Each row is what that token 'advertises'.

EXECUTION STATE
K =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
117V = np.array([...])

Value matrix — identical to Chapter 1. The actual content retrieved when a token is attended to.

EXECUTION STATE
V =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
126mha = MultiHeadAttention(d_model=4, num_heads=2)

Instantiate with 4 dimensions and 2 heads. This gives d_k = 4/2 = 2 per head, scale = √2 ≈ 1.4142.

EXECUTION STATE
mha.d_model = 4
mha.num_heads = 2
mha.d_k = 2
mha.scale = 1.4142
127all_weights, output = mha.forward(Q, K, V)

Run the full multi-head pipeline: split → per-head attention × 2 → concatenate.

EXECUTION STATE
all_weights[0] (Head 1) =
       The      cat      sat       on      mat
The  0.1237   0.2509   0.2509   0.1237   0.2509
cat  0.3664   0.0891   0.3664   0.0891   0.0891
sat  0.1811   0.1811   0.3673   0.0893   0.1811
on   0.2000   0.2000   0.2000   0.2000   0.2000
mat  0.1237   0.2509   0.2509   0.1237   0.2509
all_weights[1] (Head 2) =
       The      cat      sat       on      mat
The  0.1337   0.2711   0.1337   0.2711   0.1904
cat  0.2711   0.1337   0.1337   0.2711   0.1904
sat  0.1337   0.2711   0.1337   0.2711   0.1904
on   0.1811   0.1811   0.0893   0.3673   0.1811
mat  0.2711   0.1337   0.1337   0.2711   0.1904
output (5×4) =
        d0       d1       d2       d3
The  0.2491   0.3763   0.2289   0.3663
cat  0.4109   0.1336   0.2289   0.3663
sat  0.2717   0.2717   0.2289   0.3663
on   0.3000   0.3000   0.1799   0.4579
mat  0.2491   0.3763   0.2289   0.3663
136mha.explain(Q, K, V, tokens, query_idx=0)

Print detailed trace for 'The' (token 0). Shows per-head dot products and weights. Hover the explain loop (line 86) above to see each iteration.

EXECUTION STATE
query_idx = 0 → tracing 'The'
92 lines without explanation
1import numpy as np
2import math
3
4class MultiHeadAttention:
5    """
6    Multi-Head Attention (Vaswani et al., 2017)
7
8    Splits Q, K, V into H heads, runs scaled dot-product attention
9    on each head independently, then concatenates the results.
10
11    MultiHead(Q, K, V) = Concat(head_1, ..., head_H) @ W_O
12    head_i = Attention(Q @ W_i^Q, K @ W_i^K, V @ W_i^V)
13    """
14
15    def __init__(self, d_model: int, num_heads: int):
16        """
17        Args:
18            d_model:   Total model dimension (4 in our example)
19            num_heads: Number of attention heads H (2 in our example)
20        """
21        self.d_model = d_model
22        self.num_heads = num_heads
23        self.d_k = d_model // num_heads
24        self.scale = math.sqrt(self.d_k)
25
26    def _softmax(self, x: np.ndarray) -> np.ndarray:
27        """Numerically stable softmax along last axis."""
28        x_shifted = x - np.max(x, axis=-1, keepdims=True)
29        exp_x = np.exp(x_shifted)
30        return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
31
32    def split_heads(self, M: np.ndarray) -> list:
33        """Split matrix M (N, d_model) into H matrices of (N, d_k)."""
34        heads = []
35        for h in range(self.num_heads):
36            start = h * self.d_k
37            end = start + self.d_k
38            heads.append(M[:, start:end])
39        return heads
40
41    def attention(self, Qh: np.ndarray, Kh: np.ndarray, Vh: np.ndarray):
42        """Single-head scaled dot-product attention."""
43        scores = Qh @ Kh.T
44        scaled = scores / self.scale
45        weights = self._softmax(scaled)
46        output = weights @ Vh
47        return weights, output
48
49    def forward(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray):
50        """
51        Full multi-head forward pass.
52
53        Args:
54            Q: Query matrix  (N, d_model)
55            K: Key matrix    (N, d_model)
56            V: Value matrix  (N, d_model)
57
58        Returns:
59            all_weights: List of H attention matrices, each (N, N)
60            output:      Concatenated output (N, d_model)
61        """
62        Q_heads = self.split_heads(Q)
63        K_heads = self.split_heads(K)
64        V_heads = self.split_heads(V)
65
66        all_weights = []
67        all_outputs = []
68
69        for h in range(self.num_heads):
70            w, o = self.attention(Q_heads[h], K_heads[h], V_heads[h])
71            all_weights.append(w)
72            all_outputs.append(o)
73
74        output = np.hstack(all_outputs)
75        return all_weights, output
76
77    def explain(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray,
78                tokens: list, query_idx: int = 0):
79        """Print a detailed trace for a specific query token."""
80        all_weights, output = self.forward(Q, K, V)
81        Q_heads = self.split_heads(Q)
82        K_heads = self.split_heads(K)
83
84        token = tokens[query_idx]
85        print(f"\n=== Multi-Head trace for '{token}' (row {query_idx}) ===")
86
87        for h in range(self.num_heads):
88            start = h * self.d_k
89            end = start + self.d_k
90            print(f"\n--- Head {h+1} (dims {start}:{end}, d_k={self.d_k}) ---")
91            print(f"  Q_h{h+1}[{token}] = {Q_heads[h][query_idx]}")
92            for j, t in enumerate(tokens):
93                dot = Q_heads[h][query_idx] @ K_heads[h][j]
94                print(f"  Q[{token}] . K[{t}] = {dot:.4f} -> /{self.scale:.2f} = {dot/self.scale:.4f}")
95            print(f"  weights = {np.round(all_weights[h][query_idx], 4)}")
96
97        print(f"\nConcatenated output[{token}] = {np.round(output[query_idx], 4)}")
98
99
100# ── Shared Example (used in every chapter) ──
101tokens = ["The", "cat", "sat", "on", "mat"]
102
103Q = np.array([
104    [1.0, 0.0, 1.0, 0.0],   # The
105    [0.0, 2.0, 0.0, 1.0],   # cat
106    [1.0, 1.0, 1.0, 0.0],   # sat
107    [0.0, 0.0, 1.0, 1.0],   # on
108    [1.0, 0.0, 0.0, 1.0],   # mat
109])
110
111K = np.array([
112    [0.0, 1.0, 0.0, 1.0],   # The
113    [1.0, 0.0, 1.0, 0.0],   # cat
114    [1.0, 1.0, 0.0, 0.0],   # sat
115    [0.0, 0.0, 1.0, 1.0],   # on
116    [1.0, 0.0, 0.5, 0.5],   # mat
117])
118
119V = np.array([
120    [1.0, 0.0, 0.0, 0.0],   # The
121    [0.0, 1.0, 0.0, 0.0],   # cat
122    [0.0, 0.0, 1.0, 0.0],   # sat
123    [0.0, 0.0, 0.0, 1.0],   # on
124    [0.5, 0.5, 0.5, 0.5],   # mat
125])
126
127# ── Run ──
128mha = MultiHeadAttention(d_model=4, num_heads=2)
129all_weights, output = mha.forward(Q, K, V)
130
131print("Head 1 Attention Weights (5x5):")
132print(np.round(all_weights[0], 4))
133
134print("\nHead 2 Attention Weights (5x5):")
135print(np.round(all_weights[1], 4))
136
137print("\nConcatenated Output (5x4):")
138print(np.round(output, 4))
139
140# Detailed trace for "The" (token 0)
141mha.explain(Q, K, V, tokens, query_idx=0)

PyTorch Implementation

The PyTorch version adds GPU support, automatic differentiation, batched inputs, and includes the learned projection matrices (WQ,WK,WV,WOW^Q, W^K, W^V, W^O) that real transformers use. We set use_projections=False\texttt{use\_projections=False} to verify results match the NumPy version exactly.

Multi-Head Attention — PyTorch Implementation
🐍multi_head_attention_torch.py
1Import PyTorch

torch is the core tensor library. nn provides neural network modules. F provides stateless ops like softmax. math provides sqrt.

6class MultiHeadAttention(nn.Module)

By subclassing nn.Module, we get parameter tracking, .cuda() for GPU, autograd for gradients, and integration with training loops. The NumPy version is a plain class — this is a trainable component.

16def __init__(self, d_model, num_heads)

Constructor. Sets up model dimensions, head count, and learnable projection matrices. super().__init__() registers this as an nn.Module.

EXECUTION STATE
⬇ input: d_model = 4
⬇ input: num_heads = 2
self.d_k = 4 // 2 = 2
self.scale = √2 = 1.4142
24self.W_Q = nn.Linear(d_model, d_model, bias=False)

Learnable query projection matrix W^Q of shape (4, 4). In production, this transforms input embeddings into queries. We skip it in our example (use_projections=False) to compare directly with the NumPy version.

EXECUTION STATE
nn.Linear(4, 4, bias=False) = creates a 4×4 weight matrix (no bias term). Equivalent to Q_proj = X @ W_Q.T in math.
25self.W_K = nn.Linear(d_model, d_model, bias=False)

Learnable key projection. Same shape as W_Q.

26self.W_V = nn.Linear(d_model, d_model, bias=False)

Learnable value projection.

27self.W_O = nn.Linear(d_model, d_model, bias=False)

Output projection W^O. Applied after concatenation to mix information across heads. This is the final linear layer in the multi-head formula.

29def split_heads(self, x) -> torch.Tensor

Reshape (B, N, d_model) into (B, H, N, d_k). This is the PyTorch-idiomatic way to split heads — using view + transpose instead of slicing, so it works with batched inputs.

EXECUTION STATE
⬇ input: x = shape (1, 5, 4) — batch=1, tokens=5, dims=4
⬆ returns = shape (1, 2, 5, 2) — batch=1, heads=2, tokens=5, d_k=2
31x.view(B, N, self.num_heads, self.d_k).transpose(1, 2)

Two-step reshape: view changes (1,5,4) to (1,5,2,2) by splitting the last dimension into (H, d_k). Then transpose swaps dims 1 and 2: (1,5,2,2) → (1,2,5,2). Now the head dimension is before the sequence dimension, which is what matmul expects.

EXECUTION STATE
.view(B, N, H, d_k) = (1,5,4) → (1,5,2,2) — split d_model=4 into H=2 × d_k=2
.transpose(1, 2) = (1,5,2,2) → (1,2,5,2) — swap sequence and head dims. Now each head's (N, d_k) is contiguous for efficient matmul.
33def forward(self, Q, K, V, mask, use_projections)

Main entry point. Handles optional batching, optional projections, head splitting, per-head attention, and concatenation. The mask parameter enables causal or padding masking.

EXECUTION STATE
⬇ input: Q = torch.Size([5, 4]) — unbatched
⬇ input: K = torch.Size([5, 4])
⬇ input: V = torch.Size([5, 4])
⬇ input: mask = None (no masking)
⬇ input: use_projections = False (skip W_Q/K/V/O for direct comparison)
55Q = Q.unsqueeze(0) — handle unbatched

If input is 2D (N, d_model), add a batch dimension to make it 3D (1, N, d_model). This allows the same code to handle both batched and unbatched inputs.

EXECUTION STATE
.unsqueeze(0) = add dim at position 0. (5,4) → (1,5,4)
Q.shape after = torch.Size([1, 5, 4])
64Q_h = self.split_heads(Q)

Split Q (1,5,4) into (1,2,5,2). Each of the 2 heads now has its own 2D query subspace.

EXECUTION STATE
Q_h.shape = torch.Size([1, 2, 5, 2])
65K_h = self.split_heads(K)

Split K into (1,2,5,2).

EXECUTION STATE
K_h.shape = torch.Size([1, 2, 5, 2])
66V_h = self.split_heads(V)

Split V into (1,2,5,2).

EXECUTION STATE
V_h.shape = torch.Size([1, 2, 5, 2])
69scores = torch.matmul(Q_h, K_h.transpose(-2, -1)) / self.scale

Compute scaled dot-product scores for ALL heads simultaneously via batched matmul. K_h.transpose(-2,-1) swaps the last two dims: (1,2,5,2) → (1,2,2,5). Then matmul: (1,2,5,2) @ (1,2,2,5) → (1,2,5,5). Both heads computed in one operation — this is why the reshape approach is faster than looping.

EXECUTION STATE
K_h.transpose(-2, -1) = swap dims -2 and -1. (1,2,5,2) → (1,2,2,5). The batch and head dims are untouched.
/ self.scale = / 1.4142 — same as NumPy version
scores.shape = torch.Size([1, 2, 5, 5]) — batch=1, heads=2, N=5, N=5
74weights = F.softmax(scores, dim=-1)

Softmax along the last dimension (key positions). Operates independently on each (batch, head, query) row. Numerically stable by default.

EXECUTION STATE
dim=-1 = normalize along key positions. Each query row in each head independently sums to 1.0.
weights.shape = torch.Size([1, 2, 5, 5])
75head_out = torch.matmul(weights, V_h)

Weighted sum of value vectors per head. (1,2,5,5) @ (1,2,5,2) → (1,2,5,2). Each head produces its own contextual output.

EXECUTION STATE
head_out.shape = torch.Size([1, 2, 5, 2])
78output = head_out.transpose(1, 2).contiguous().view(B, N, H * d_k)

Reverse the split: transpose heads back (1,2,5,2) → (1,5,2,2), then view merges the last two dims: (1,5,2,2) → (1,5,4). This is the Concat() operation. contiguous() is needed because transpose creates a non-contiguous view.

EXECUTION STATE
.transpose(1, 2) = (1,2,5,2) → (1,5,2,2) — swap head and seq dims back
.contiguous() = make memory layout contiguous after transpose (required before view)
.view(B, N, H*d_k) = (1,5,2,2) → (1,5,4) — merge head and d_k dims
output.shape = torch.Size([1, 5, 4])
83return weights.squeeze(0), output.squeeze(0)

Remove the batch dimension we added for unbatched input. Returns weights (2,5,5) and output (5,4).

EXECUTION STATE
⬆ return: weights = torch.Size([2, 5, 5]) — per-head attention matrices
⬆ return: output = torch.Size([5, 4]) — concatenated multi-head output
88Q = torch.tensor([...])

Same shared example as NumPy version. torch.tensor creates a PyTorch tensor that supports autograd and GPU.

EXECUTION STATE
Q.shape = torch.Size([5, 4])
111mha = MultiHeadAttention(d_model=4, num_heads=2)

Instantiate the module. W_Q, W_K, W_V, W_O are randomly initialized but we skip them with use_projections=False.

EXECUTION STATE
mha.d_k = 2
mha.scale = 1.4142
112weights, output = mha(Q, K, V, use_projections=False)

Call the module (triggers forward()). use_projections=False skips the learned W matrices so output matches the NumPy version exactly.

EXECUTION STATE
weights[0] (Head 1) =
       The      cat      sat       on      mat
The  0.1237   0.2509   0.2509   0.1237   0.2509
cat  0.3664   0.0891   0.3664   0.0891   0.0891
sat  0.1811   0.1811   0.3673   0.0893   0.1811
on   0.2000   0.2000   0.2000   0.2000   0.2000
mat  0.1237   0.2509   0.2509   0.1237   0.2509
weights[1] (Head 2) =
       The      cat      sat       on      mat
The  0.1337   0.2711   0.1337   0.2711   0.1904
cat  0.2711   0.1337   0.1337   0.2711   0.1904
sat  0.1337   0.2711   0.1337   0.2711   0.1904
on   0.1811   0.1811   0.0893   0.3673   0.1811
mat  0.2711   0.1337   0.1337   0.2711   0.1904
output (5×4) =
        d0       d1       d2       d3
The  0.2491   0.3763   0.2289   0.3663
cat  0.4109   0.1336   0.2289   0.3663
sat  0.2717   0.2717   0.2289   0.3663
on   0.3000   0.3000   0.1799   0.4579
mat  0.2491   0.3763   0.2289   0.3663
105 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5
6class MultiHeadAttention(nn.Module):
7    """
8    Multi-Head Attention (Vaswani et al., 2017) — PyTorch
9
10    Splits Q, K, V into H heads, runs scaled dot-product attention
11    on each, then concatenates and projects with W_O.
12
13    Supports GPU, automatic differentiation, and batched inputs.
14    """
15
16    def __init__(self, d_model: int, num_heads: int):
17        super().__init__()
18        self.d_model = d_model
19        self.num_heads = num_heads
20        self.d_k = d_model // num_heads
21        self.scale = math.sqrt(self.d_k)
22
23        # Learned projection matrices (not used in this example,
24        # but included for production completeness)
25        self.W_Q = nn.Linear(d_model, d_model, bias=False)
26        self.W_K = nn.Linear(d_model, d_model, bias=False)
27        self.W_V = nn.Linear(d_model, d_model, bias=False)
28        self.W_O = nn.Linear(d_model, d_model, bias=False)
29
30    def split_heads(self, x: torch.Tensor) -> torch.Tensor:
31        """Reshape (B, N, d_model) -> (B, H, N, d_k)."""
32        B, N, _ = x.shape
33        return x.view(B, N, self.num_heads, self.d_k).transpose(1, 2)
34
35    def forward(
36        self,
37        Q: torch.Tensor,
38        K: torch.Tensor,
39        V: torch.Tensor,
40        mask: torch.Tensor | None = None,
41        use_projections: bool = False,
42    ) -> tuple[torch.Tensor, torch.Tensor]:
43        """
44        Args:
45            Q: (B, N, d_model) or (N, d_model)
46            K: (B, N, d_model) or (N, d_model)
47            V: (B, N, d_model) or (N, d_model)
48            mask: Optional (B, 1, N, N) or (N, N) boolean
49            use_projections: If True, apply W_Q/W_K/W_V/W_O
50
51        Returns:
52            all_weights: (B, H, N, N) attention weights
53            output: (B, N, d_model) multi-head output
54        """
55        # Handle unbatched input
56        if Q.dim() == 2:
57            Q = Q.unsqueeze(0)
58            K = K.unsqueeze(0)
59            V = V.unsqueeze(0)
60
61        # Optional learned projections
62        if use_projections:
63            Q = self.W_Q(Q)
64            K = self.W_K(K)
65            V = self.W_V(V)
66
67        # Split into heads: (B, N, d_model) -> (B, H, N, d_k)
68        Q_h = self.split_heads(Q)
69        K_h = self.split_heads(K)
70        V_h = self.split_heads(V)
71
72        # Scaled dot-product per head
73        scores = torch.matmul(Q_h, K_h.transpose(-2, -1)) / self.scale
74
75        if mask is not None:
76            scores = scores.masked_fill(mask, float("-inf"))
77
78        weights = F.softmax(scores, dim=-1)  # (B, H, N, N)
79        head_out = torch.matmul(weights, V_h)  # (B, H, N, d_k)
80
81        # Concatenate heads: (B, H, N, d_k) -> (B, N, d_model)
82        B, H, N, d_k = head_out.shape
83        output = head_out.transpose(1, 2).contiguous().view(B, N, H * d_k)
84
85        if use_projections:
86            output = self.W_O(output)
87
88        return weights.squeeze(0), output.squeeze(0)
89
90
91# ── Shared Example ──
92tokens = ["The", "cat", "sat", "on", "mat"]
93
94Q = torch.tensor([
95    [1.0, 0.0, 1.0, 0.0],
96    [0.0, 2.0, 0.0, 1.0],
97    [1.0, 1.0, 1.0, 0.0],
98    [0.0, 0.0, 1.0, 1.0],
99    [1.0, 0.0, 0.0, 1.0],
100])
101K = torch.tensor([
102    [0.0, 1.0, 0.0, 1.0],
103    [1.0, 0.0, 1.0, 0.0],
104    [1.0, 1.0, 0.0, 0.0],
105    [0.0, 0.0, 1.0, 1.0],
106    [1.0, 0.0, 0.5, 0.5],
107])
108V = torch.tensor([
109    [1.0, 0.0, 0.0, 0.0],
110    [0.0, 1.0, 0.0, 0.0],
111    [0.0, 0.0, 1.0, 0.0],
112    [0.0, 0.0, 0.0, 1.0],
113    [0.5, 0.5, 0.5, 0.5],
114])
115
116# ── Run (without learned projections for exact comparison) ──
117mha = MultiHeadAttention(d_model=4, num_heads=2)
118weights, output = mha(Q, K, V, use_projections=False)
119
120print("Head 1 weights (5x5):")
121print(weights[0].round(decimals=4))
122
123print("\nHead 2 weights (5x5):")
124print(weights[1].round(decimals=4))
125
126print("\nOutput (5x4):")
127print(output.round(decimals=4))

Key Takeaways

  1. The problem: A single attention head produces one attention pattern — a single compromise that cannot simultaneously specialize for syntax, semantics, and position.
  2. The solution: Run HH independent heads in parallel, each operating on a dkd_k-dimensional subspace of the representation.
  3. Same cost: Total computation is identical to single-head because Hdk=dmodelH \cdot d_k = d_{\text{model}}. The work is just distributed across heads.
  4. Concatenation recovers dimensionality: Each head outputs (N,dk)(N, d_k); concatenating HH heads gives (N,dmodel)(N, d_{\text{model}}).
  5. WOW^O is essential: The output projection mixes information across heads. Without it, each output dimension carries information from only one head.
  6. Heads specialize: In trained models, different heads learn to capture syntactic dependencies, semantic similarity, positional proximity, and coreference — empirically verified by Clark et al. (2019) and Voita et al. (2019).

Exercises

Exercise 1: Compute for “sat”

Compute the multi-head attention output for the token “sat” (row 2) by hand. Show the per-head dot products, scaled scores, softmax weights, and per-head outputs. Verify that Head 1 attends most to “sat” itself (0.3673) while Head 2 distributes attention to “cat” and “on” (both 0.2711). Concatenate the outputs and check against the output matrix.

Exercise 2: H=4H = 4 Heads

Repeat the computation with H=4H = 4 heads, each with dk=1d_k = 1. What happens to the attention patterns when each head computes only a scalar dot product? Compare the expressiveness with H=2H = 2. Use the Head Count Explorer to verify your answers.

Exercise 3: Identical Subspaces

What happens if you assign both heads the same columns (e.g., both use dims [0,1][0, 1])? Compute the attention weights for both heads and verify they are identical. What does this tell you about the role of diverse subspaces?

Exercise 4: WOW^O Projection

Create a 4×44 \times 4 output projection matrix WOW^O (e.g., a random orthogonal matrix or a permutation matrix) and apply it to the concatenated output. How does the final output change? Why is this step necessary for mixing information across heads?

Exercise 5: Head Pruning

Set one head's output to zero (simulating pruning). How much does the final output change? Compute the L2 norm of the difference. Voita et al. (2019) found that many heads can be pruned with minimal accuracy loss — does this match your observation?


References

  1. Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., & Polosukhin, I. (2017). Attention Is All You Need. Advances in Neural Information Processing Systems, 30. The paper that introduced multi-head attention as part of the transformer architecture.
  2. Clark, K., Khandelwal, U., Levy, O., & Manning, C.D. (2019). What Does BERT Look At? An Analysis of BERT's Attention. BlackboxNLP Workshop at ACL 2019. Empirical analysis showing that different BERT heads specialize for different linguistic phenomena.
  3. Voita, E., Talbot, D., Moiseev, F., Sennrich, R., & Titov, I. (2019). Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned. ACL 2019. Showed that a small number of heads are important for specific tasks and the rest can be pruned.
  4. Dao, T., Fu, D.Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022. IO-aware algorithm for exact attention computation without materializing the full score matrix.
  5. Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B., & Liu, Y. (2021). RoFormer: Enhanced Transformer with Rotary Position Embedding. Per-head rotary embeddings that encode position through rotation in each head's subspace.
  6. Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., & Houlsby, N. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. ICLR 2021. Vision Transformer showing multi-head attention on image patches.
  7. Kaplan, J., McCandlish, S., Henighan, T., Brown, T.B., Chess, B., Child, R., Gray, S., Radford, A., Wu, J., & Amodei, D. (2020). Scaling Laws for Neural Language Models. Empirical scaling laws showing how performance improves with model size, dataset size, and compute.
  8. Michel, P., Levy, O., & Neubig, G. (2019). Are Sixteen Heads Really Better than One? NeurIPS 2019. Systematic study of head importance, showing that many heads can be removed at test time with negligible accuracy loss.
Loading comments...