Explain why a single attention head creates a representational bottleneck and why multiple heads running in parallel overcome it.
Derive the multi-head attention formula MultiHead(Q,K,V)=Concat(head1,…,headH)WO from first principles, understanding every symbol and matrix dimension.
Compute per-head attention weights and outputs by hand for our shared 5-token sentence using H=2 heads.
Compare how different heads learn different attention patterns from the same input — syntactic relationships in one head, semantic proximity in another.
Implement a complete, runnable multi-head attention class in both NumPy and PyTorch that you can use to simulate any number of heads on any input.
Connect multi-head attention to Flash Attention, KV-cache, positional encodings, and transformer scaling laws.
Where this appears: Multi-head attention is the standard attention module used in virtually every transformer — GPT-4, Claude, Gemini, LLaMA, ViT, DALL-E, AlphaFold, and Codex all use multi-head attention as their core computational building block. A single GPT-4 layer uses 96 or more heads. Understanding how and why heads split the representation is essential for understanding every modern architecture.
The Real Problem
The Single-Head Bottleneck
In Chapter 1, we computed scaled dot-product attention using a single set of Q, K, V matrices. The result was a single 5×5 attention weight matrix — one pattern describing which tokens attend to which.
But a single pattern is a severe bottleneck. Consider the word “cat” in “The cat sat on mat”. A human instantly recognizes that “cat” is:
Syntactically the subject of the verb “sat” (subject-verb relationship)
Grammatically a noun that follows the determiner “The” (determiner-noun agreement)
Semantically an animate entity related to the location “mat” (entity-location relation)
Positionally close to “The” and “sat” (local context)
A single attention head computes one weighted combination of all keys for each query. In our Chapter 1 computation, “cat” assigned weight 0.4026 to “The”, 0.2442 to “sat”, and only 0.0898 to itself. That is one compromise pattern — it cannot simultaneously attend heavily to “The” (for grammar) and to “sat” (for semantics) through separate, specialized lenses. The softmax forces a single probability distribution over all keys.
Vaswani et al. (2017) articulated this limitation precisely in the original transformer paper: “Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. With a single attention head, averaging inhibits this.”
The Committee Analogy
Think of multi-head attention as a committee of experts reviewing a document. A single reviewer reads the entire document through one lens and produces one summary. A committee of H experts — a grammarian, a semanticist, a coreference resolver, a positional tracker — each reads the same document but focuses on their specialty. Their reports are concatenated into a richer, more nuanced understanding than any single reviewer could produce.
Crucially, the experts do not talk to each other while reading. Each head operates independently and in parallel on a different subspace of the representation. Only after all heads produce their outputs are the results combined.
From Intuition to Mathematics
The Projection Idea
The core mathematical insight is that a dmodel-dimensional vector space can be decomposed into H lower-dimensional subspaces. If dmodel=512 and H=8, then each head operates in a dk=512/8=64-dimensional subspace.
Each head i gets its own set of learned projection matrices: WiQ∈Rdmodel×dk, WiK∈Rdmodel×dk, WiV∈Rdmodel×dv. These projections are learned during training — the model discovers which subspaces are useful for which kinds of relationships.
Split vs. Learned Projection
In our teaching example, we split the dimensions directly: Head 1 gets dims [0,1], Head 2 gets dims [2,3]. This is equivalent to using identity projections where W1Q selects columns 0-1 and W2Q selects columns 2-3. In real transformers, the projection matrices are dense learned parameters that can mix dimensions in arbitrary ways, discovering subspaces not aligned with any coordinate axis.
Why the split is mathematically equivalent
A general WiQ∈Rdmodel×dk can be any linear map. Setting it to a selection matrix (zeros and ones) that picks certain columns is a special case. The splitting trick lets us do multi-head on our tiny example without introducing new parameters, but the mechanism is identical.
The Mathematical Definition
The multi-head attention formula from Vaswani et al. (2017):
MultiHead(Q,K,V)=Concat(head1,…,headH)WO
where headi=Attention(QWiQ,KWiK,VWiV)
Symbol-by-Symbol Breakdown
Symbol
Shape
Meaning
Q,K,V
(N,dmodel)
Input query, key, value matrices (same as Chapter 1)
H
scalar
Number of parallel attention heads
dk=dmodel/H
scalar
Per-head key/query dimension
WiQ,WiK
(dmodel,dk)
Per-head query/key projection (learned)
WiV
(dmodel,dv)
Per-head value projection (learned)
headi
(N,dv)
Output of the i-th attention head
Concat
(N,H⋅dv)
Column-wise concatenation of all head outputs
WO
(H⋅dv,dmodel)
Output projection matrix (mixes information across heads)
What the Formula Says in Plain English
Project the input into H different subspaces using learned weight matrices WiQ,WiK,WiV.
Attend independently in each subspace using the scaled dot-product attention from Chapter 1: softmax(QhKh⊤/dk)Vh.
Concatenate the H output vectors into a single vector of the original dimensionality.
Project through WO to allow the model to mix and recombine information from different heads.
Why Multiple Heads Work
The Representational Subspace Argument
Each head operates on a dk-dimensional subspace. In this subspace, the dot product QhKh⊤ measures a different notion of similarity than the full-space dot product. A token pair that has a high score in Head 1's subspace (capturing syntactic proximity) may have a low score in Head 2's subspace (capturing semantic type).
In our example, consider the “on” token. In Head 1 (dims 0-1), Qon=[0.0,0.0] — the zero vector. Every dot product is zero, so Head 1 assigns uniform attention (0.2 to each token). But in Head 2 (dims 2-3), Qon=[1.0,1.0] actively attends to “on” itself (0.3673) where the key strongly matches. Two completely different behaviors from the same token, enabled by subspace specialization.
Capacity vs. Specialization Trade-off
There is a fundamental trade-off: more heads means more specialization, but each head has fewer dimensions to work with. With H=1, the single head has dk=4 dimensions for computing similarity. With H=4, each head has only dk=1 — a scalar dot product with very limited discriminative power. The sweet spot depends on the task and model size. In practice, dk=64 or dk=128 per head is typical.
Configuration
H
d_k
Patterns
Per-Head Capacity
Our Chapter 1
1
4
1 pattern
Full 4D space
Our Chapter 2
2
2
2 patterns
2D subspace each
Extreme split
4
1
4 patterns
Scalar dot product
GPT-3 (d=12288)
96
128
96 patterns
128D subspace each
LLaMA-2 70B
64
128
64 patterns
128D subspace each
Interactive: Head Splitting Pipeline
Step through the multi-head attention pipeline below. Click each step to see how the full matrices are split into per-head subspaces, processed independently, and concatenated back.
Loading head splitting visualizer...
Step-by-Step Calculation
We now compute multi-head attention with H=2 heads on our shared example. Head 1 operates on dims [0,1] (so dk=2) and Head 2 operates on dims [2,3]. The scaling factor for each head is dk=2≈1.4142.
Token: “The” (row 0)
Head 1 — Qh1[0]=[1.0,0.0], operating on Kh1 (columns 0-1 of K):
Key Token
Dot Product
Scaled (/√2)
After Softmax
The
1.0×0.0+0.0×1.0=0.0
0.0000
0.1237
cat
1.0×1.0+0.0×0.0=1.0
0.7071
0.2509
sat
1.0×1.0+0.0×1.0=1.0
0.7071
0.2509
on
1.0×0.0+0.0×0.0=0.0
0.0000
0.1237
mat
1.0×1.0+0.0×0.0=1.0
0.7071
0.2509
Head 1 output for “The”: Oh1[0]=[0.2491,0.3763]
Head 2 — Qh2[0]=[1.0,0.0], operating on Kh2 (columns 2-3 of K):
Key Token
Dot Product
Scaled (/√2)
After Softmax
The
1.0×0.0+0.0×1.0=0.0
0.0000
0.1337
cat
1.0×1.0+0.0×0.0=1.0
0.7071
0.2711
sat
1.0×0.0+0.0×0.0=0.0
0.0000
0.1337
on
1.0×1.0+0.0×1.0=1.0
0.7071
0.2711
mat
1.0×0.5+0.0×0.5=0.5
0.3536
0.1904
Head 2 output for “The”: Oh2[0]=[0.2289,0.3663]
Key observation: Head 1 distributes attention equally across “cat”, “sat”, and “mat” (all 0.2509). Head 2 focuses on “cat” and “on” (both 0.2711) while giving “sat” much less weight (0.1337). These are genuinely different attention patterns extracted from different dimensions of the same representation.
Head 1:Qh1[1]=[0.0,2.0] — this query strongly activates dim 1, matching keys with high values in dim 1.
Key Token
Dot Product
Scaled
After Softmax
The
0.0×0.0+2.0×1.0=2.0
1.4142
0.3664
cat
0.0×1.0+2.0×0.0=0.0
0.0000
0.0891
sat
0.0×1.0+2.0×1.0=2.0
1.4142
0.3664
on
0.0×0.0+2.0×0.0=0.0
0.0000
0.0891
mat
0.0×1.0+2.0×0.0=0.0
0.0000
0.0891
Head 1: “cat” strongly attends to “The” (0.3664) and “sat” (0.3664) — its grammatical neighbours. Output: Oh1[1]=[0.4109,0.1336].
Head 2:Qh2[1]=[0.0,1.0] — activates dim 3 (the second dim in Head 2's subspace).
Key Token
Dot Product
Scaled
After Softmax
The
0.0×0.0+1.0×1.0=1.0
0.7071
0.2711
cat
0.0×1.0+1.0×0.0=0.0
0.0000
0.1337
sat
0.0×0.0+1.0×0.0=0.0
0.0000
0.1337
on
0.0×1.0+1.0×1.0=1.0
0.7071
0.2711
mat
0.0×0.5+1.0×0.5=0.5
0.3536
0.1904
Head 2: “cat” attends to “The” (0.2711) and “on” (0.2711) — different from Head 1. Output: Oh2[1]=[0.2289,0.3663].
Concatenation
For each token, concatenate the two head outputs to recover the original dmodel=4 dimensionality:
Token
Head 1 Output
Head 2 Output
Concatenated
The
[0.2491, 0.3763]
[0.2289, 0.3663]
[0.2491, 0.3763, 0.2289, 0.3663]
cat
[0.4109, 0.1336]
[0.2289, 0.3663]
[0.4109, 0.1336, 0.2289, 0.3663]
sat
[0.2717, 0.2717]
[0.2289, 0.3663]
[0.2717, 0.2717, 0.2289, 0.3663]
on
[0.3000, 0.3000]
[0.1799, 0.4579]
[0.3000, 0.3000, 0.1799, 0.4579]
mat
[0.2491, 0.3763]
[0.2289, 0.3663]
[0.2491, 0.3763, 0.2289, 0.3663]
Where is W^O?
In our simplified example, we skip the final WO projection. In production transformers, WO∈Rdmodel×dmodel is a learned matrix that mixes information across heads. Without it, each dimension of the output would only carry information from one head. The WO projection is what allows the final representation to integrate insights from all heads.
Full Attention Weights and Output
Head 1 Attention Weights (5×5)
The
cat
sat
on
mat
The
0.1237
0.2509
0.2509
0.1237
0.2509
cat
0.3664
0.0891
0.3664
0.0891
0.0891
sat
0.1811
0.1811
0.3673
0.0893
0.1811
on
0.2000
0.2000
0.2000
0.2000
0.2000
mat
0.1237
0.2509
0.2509
0.1237
0.2509
Head 1 patterns: “cat” strongly attends to “The” and “sat” (both 0.3664). “sat” attends most to itself (0.3673). “on” is perfectly uniform (0.2000) because Qon=[0,0] in this subspace.
Head 2 Attention Weights (5×5)
The
cat
sat
on
mat
The
0.1337
0.2711
0.1337
0.2711
0.1904
cat
0.2711
0.1337
0.1337
0.2711
0.1904
sat
0.1337
0.2711
0.1337
0.2711
0.1904
on
0.1811
0.1811
0.0893
0.3673
0.1811
mat
0.2711
0.1337
0.1337
0.2711
0.1904
Head 2 patterns: “on” now strongly attends to itself (0.3673) — completely different from its uniform pattern in Head 1. “sat” attends to “cat” and “on” (both 0.2711), not to itself.
Averaged Attention Weights (for comparison)
The
cat
sat
on
mat
The
0.1287
0.2610
0.1923
0.1974
0.2206
cat
0.3188
0.1114
0.2500
0.1801
0.1397
sat
0.1574
0.2261
0.2505
0.1802
0.1858
on
0.1906
0.1906
0.1447
0.2837
0.1906
mat
0.1974
0.1923
0.1923
0.1974
0.2206
Output Matrix — Multi-Head (5×4)
dim-0
dim-1
dim-2
dim-3
The
0.2491
0.3763
0.2289
0.3663
cat
0.4109
0.1336
0.2289
0.3663
sat
0.2717
0.2717
0.2289
0.3663
on
0.3000
0.3000
0.1799
0.4579
mat
0.2491
0.3763
0.2289
0.3663
Compare with the single-head output from Chapter 1. The multi-head output has dims 0-1 shaped by Head 1's attention pattern and dims 2-3 shaped by Head 2's pattern. Each pair of dimensions carries contextual information computed through a different lens.
Interactive: Head Comparison Heatmap
Hover any cell to compare attention weights across Head 1, Head 2, single-head, and the difference between heads. Switch between side-by-side and individual views to see how each head specializes.
Loading head comparison heatmap...
Interactive: Head Count Explorer
Change the number of heads to see how splitting affects attention patterns. With H=1, you get one 4D pattern (Chapter 1). With H=4, each head has only dk=1 — observe how limited scalar dot products produce repetitive patterns.
Loading head count explorer...
Applications Across Domains
Natural Language Processing
In large language models, different heads specialize for different linguistic phenomena. Clark et al. (2019) found that in BERT: certain heads track subject-verb agreement (head 8-10 in layer 6), others resolve coreference (“she” → “Mary”), and others attend to the previous token (positional heads). Voita et al. (2019) showed that pruning task-irrelevant heads has minimal impact on accuracy, confirming that heads truly specialize.
Computer Vision
In Vision Transformers (ViT), image patches become tokens. Dosovitskiy et al. (2021) observed that some heads attend to nearby patches (local texture), others attend to distant patches with similar color (global color coherence), and still others track edges and contours. This mirrors how CNNs learn different filters — but the multi-head mechanism learns these patterns from data rather than requiring architectural inductive biases.
Code Generation
In code models like Codex and CodeLLaMA, heads specialize for code-specific patterns: variable-definition tracking (where was x defined?), bracket matching (which opening bracket does this closing bracket match?), type flow (what type does this expression evaluate to?), and import resolution. The parallel nature of multi-head attention lets the model track all of these simultaneously.
Scientific Modeling
In AlphaFold's structure module, multi-head attention over amino acid residues lets different heads capture: spatial proximity in 3D (nearby in folded structure), sequence proximity (nearby in primary sequence), evolutionary co-variation (residues that mutate together), and chemical interactions (hydrogen bonds, hydrophobic contacts). The 48-head architecture captures fundamentally different notions of “relevance” between residue pairs.
Connection to Modern Systems
Flash Attention
Flash Attention (Dao et al., 2022) does not change the mathematics of multi-head attention — it produces bit-identical results. What it changes is the memory access pattern. Standard multi-head attention materializes the full (N×N) score matrix for each head, requiring O(H⋅N2) memory. Flash Attention tiles the computation so that each SRAM block computes a partial softmax, never materializing the full matrix. For H=96 heads and N=8192, this reduces peak memory from ~6 GB to ~50 MB.
KV-Cache Optimization
During autoregressive generation, the KV-cache stores previously computed key and value tensors. With multi-head attention, each head has its own KV-cache of shape (seq_len,dk). The total cache size is 2×H×L×N×dk (2 for K and V, L layers, N tokens). For GPT-3 with 96 heads, 96 layers, and 2048 tokens, the KV-cache alone requires ~3 GB in fp16. This memory pressure is what motivates Multi-Query Attention (Chapter 3) and Grouped-Query Attention (Chapter 4), which share keys/values across heads.
Positional Encodings (RoPE)
Rotary Position Embeddings (Su et al., 2021) apply per-head position-dependent rotations to Q and K before the dot product. Because each head operates in its owndk-dimensional subspace, RoPE rotates pairs of dimensions within each head at different frequencies. Lower-frequency rotations in some heads capture long-range position, while higher-frequency rotations in others capture local position. The multi-head structure is what makes this frequency decomposition possible.
Transformer Scaling
As model size grows, both dmodel and H increase. GPT-3 (175B) uses dmodel=12288 with H=96, giving dk=128. LLaMA-2 70B uses dmodel=8192 with H=64 heads, also dk=128. The scaling laws (Kaplan et al., 2020) show that increasing both width and heads together improves sample efficiency, but the per-head dimension dk rarely drops below 64 in practice to maintain per-head expressiveness.
Complexity Analysis
Metric
Single-Head (Ch. 1)
Multi-Head (Ch. 2)
Time complexity
O(N2⋅dmodel)
O(N2⋅dmodel) (same total)
Per-head time
O(N2⋅dmodel)
O(N2⋅dk) (per head)
Memory for scores
O(N2)
O(H⋅N2)
Parallelism
1 operation
H independent operations
Parameters (projections)
3⋅dmodel2
3⋅dmodel2+dmodel2 (includes WO)
The key insight: multi-head attention has the same asymptotic time complexity as single-head attention because H⋅dk=dmodel. What changes is that the H per-head computations can run in parallel on GPUs, and the model gains H independent attention patterns instead of one.
Python Implementation
The full Python class below implements multi-head attention with the same structure as Chapter 1. Click any line to see its execution trace with actual values.
Multi-Head Attention — NumPy Implementation
🐍multi_head_attention.py
Explanation(49)
Code(141)
1import numpy as np
NumPy provides vectorized matrix operations. All matrix multiplications (Q @ K.T) execute as optimized C code, not Python loops.
2import math
Python standard library. We use math.sqrt() to precompute the scaling factor for each head.
4class MultiHeadAttention
Wraps multi-head attention in a reusable class. Compared to Chapter 1 (single-head), this class adds head splitting, per-head attention, and concatenation. Every chapter follows this same class structure so you can compare implementations side by side.
16def __init__(self, d_model, num_heads)
Constructor. Takes the total model dimension (d_model=4) and number of heads (H=2). Computes per-head dimension d_k = d_model / H and precomputes the scaling factor.
EXECUTION STATE
⬇ input: d_model = 4
⬇ input: num_heads = 2
22self.d_model = d_model
Store the total model dimension. This is the full width of Q, K, V before splitting.
EXECUTION STATE
self.d_model = 4
23self.num_heads = num_heads
Store H, the number of parallel attention heads.
EXECUTION STATE
self.num_heads = 2
24self.d_k = d_model // num_heads
Per-head dimension. Each head operates on d_k dimensions of the representation. d_model must be divisible by H.
EXECUTION STATE
d_model // num_heads = 4 // 2 = 2
self.d_k = 2
25self.scale = math.sqrt(self.d_k)
Precompute sqrt(d_k) once. Each head divides its dot products by this value. With d_k=2, the scale is smaller than Chapter 1 (sqrt(4)=2.0), meaning less compression per head.
EXECUTION STATE
math.sqrt(2) = 1.4142
self.scale = 1.4142
27def _softmax(self, x) -> np.ndarray
Numerically stable softmax. Takes a matrix of scaled scores and returns probabilities where each row sums to 1.0. Identical to the Chapter 1 implementation.
EXECUTION STATE
⬇ input: x = shape (5, 5) — scaled scores for one head
29x_shifted = x - np.max(x, axis=-1, keepdims=True)
Subtract row-wise max for numerical stability. exp(500) overflows, but exp(0) = 1.0. The subtraction does not change the softmax result because the constant cancels in the ratio.
EXECUTION STATE
axis=-1 = operate along the LAST axis — find max of each row independently, not the global max.
keepdims=True = keep shape (5,1) so broadcasting x(5×5) - max(5×1) works — each row subtracts its own max.
30exp_x = np.exp(x_shifted)
Exponentiate every element. The largest value per row becomes exp(0)=1.0 — no overflow.
Divide each element by its row sum to normalize into probabilities. Each row sums to 1.0.
EXECUTION STATE
axis=-1 = sum along last axis — sum each row independently.
keepdims=True = sum returns (5,1) so broadcasting exp_x(5×5) / sum(5×1) works correctly.
33def split_heads(self, M) -> list
Split a (N, d_model) matrix into H matrices of (N, d_k) each. This is the core mechanical step that enables multi-head attention — each head gets its own subspace slice.
EXECUTION STATE
⬇ input: M (5×4) = any of Q, K, or V — full d_model=4 width
⬆ returns = list of 2 matrices, each (5, 2)
35heads = []
Accumulator list. Will hold H=2 sub-matrices after the loop.
EXECUTION STATE
heads = [] (empty list)
36for h in range(self.num_heads):
Loop over H=2 heads. Each iteration computes the column slice for one head.
LOOP TRACE · 2 iterations
h=0 (Head 1)
start = 0 * 2 = 0
end = 0 + 2 = 2
M[:, 0:2] = columns 0 and 1 — first subspace
h=1 (Head 2)
start = 1 * 2 = 2
end = 2 + 2 = 4
M[:, 2:4] = columns 2 and 3 — second subspace
37start = h * self.d_k
Starting column index for this head. Head 0 starts at 0, Head 1 at 2.
38end = start + self.d_k
Ending column index (exclusive). Head 0 takes [0:2], Head 1 takes [2:4].
39heads.append(M[:, start:end])
Slice all rows, columns start:end. This is a view (no copy) in NumPy, so it is O(1).
EXECUTION STATE
[:, start:end] = NumPy slice — all rows (:), columns start to end-1. Returns a (5, d_k) sub-matrix.
40return heads
Return the list of H sub-matrices.
EXECUTION STATE
⬆ return: heads = list of 2 arrays, each shape (5, 2)
42def attention(self, Qh, Kh, Vh)
Single-head scaled dot-product attention. Identical to the full Chapter 1 computation, but now operating on a d_k=2 subspace instead of d_model=4.
EXECUTION STATE
⬇ input: Qh (5×2) = query sub-matrix for one head
⬇ input: Kh (5×2) = key sub-matrix for one head
⬇ input: Vh (5×2) = value sub-matrix for one head
⬆ returns = (weights, output) — shapes (5,5) and (5,2)
44scores = Qh @ Kh.T
Matrix multiply Qh (5×2) with Kh transposed (2×5). Each entry (i,j) is the dot product of query_i with key_j in this head's subspace.
EXECUTION STATE
Qh @ Kh.T = (5×2) @ (2×5) → (5×5)
── Head 1 scores ── =
scores (Head 1) =
The cat sat on mat
The 0.00 1.00 1.00 0.00 1.00
cat 2.00 0.00 2.00 0.00 0.00
sat 1.00 1.00 2.00 0.00 1.00
on 0.00 0.00 0.00 0.00 0.00
mat 0.00 1.00 1.00 0.00 1.00
── Head 2 scores ── =
scores (Head 2) =
The cat sat on mat
The 0.00 1.00 0.00 1.00 0.50
cat 1.00 0.00 0.00 1.00 0.50
sat 0.00 1.00 0.00 1.00 0.50
on 1.00 1.00 0.00 2.00 1.00
mat 1.00 0.00 0.00 1.00 0.50
45scaled = scores / self.scale
Divide all scores by sqrt(d_k) = sqrt(2) = 1.4142. Compared to Chapter 1 where we divided by sqrt(4) = 2.0, here we divide by a smaller number — less compression because d_k is smaller.
EXECUTION STATE
self.scale = 1.4142 (√2, not √4)
── Head 1 scaled ── =
scaled (Head 1) =
The cat sat on mat
The 0.0000 0.7071 0.7071 0.0000 0.7071
cat 1.4142 0.0000 1.4142 0.0000 0.0000
sat 0.7071 0.7071 1.4142 0.0000 0.7071
on 0.0000 0.0000 0.0000 0.0000 0.0000
mat 0.0000 0.7071 0.7071 0.0000 0.7071
── Head 2 scaled ── =
scaled (Head 2) =
The cat sat on mat
The 0.0000 0.7071 0.0000 0.7071 0.3536
cat 0.7071 0.0000 0.0000 0.7071 0.3536
sat 0.0000 0.7071 0.0000 0.7071 0.3536
on 0.7071 0.7071 0.0000 1.4142 0.7071
mat 0.7071 0.0000 0.0000 0.7071 0.3536
46weights = self._softmax(scaled)
Apply softmax row-wise to get attention probabilities. Each row sums to 1.0. These are the two different attention patterns — the whole point of multi-head.
EXECUTION STATE
── Head 1 weights ── =
weights (Head 1) =
The cat sat on mat
The 0.1237 0.2509 0.2509 0.1237 0.2509
cat 0.3664 0.0891 0.3664 0.0891 0.0891
sat 0.1811 0.1811 0.3673 0.0893 0.1811
on 0.2000 0.2000 0.2000 0.2000 0.2000
mat 0.1237 0.2509 0.2509 0.1237 0.2509
── Head 2 weights ── =
weights (Head 2) =
The cat sat on mat
The 0.1337 0.2711 0.1337 0.2711 0.1904
cat 0.2711 0.1337 0.1337 0.2711 0.1904
sat 0.1337 0.2711 0.1337 0.2711 0.1904
on 0.1811 0.1811 0.0893 0.3673 0.1811
mat 0.2711 0.1337 0.1337 0.2711 0.1904
47output = weights @ Vh
Weighted sum of value vectors. Each output row is a blend of all 5 value vectors (in this head's 2D subspace), weighted by attention probabilities.
EXECUTION STATE
── Head 1 output ── =
output (Head 1) =
d0 d1
The 0.2491 0.3763
cat 0.4109 0.1336
sat 0.2717 0.2717
on 0.3000 0.3000
mat 0.2491 0.3763
── Head 2 output ── =
output (Head 2) =
d2 d3
The 0.2289 0.3663
cat 0.2289 0.3663
sat 0.2289 0.3663
on 0.1799 0.4579
mat 0.2289 0.3663
48return weights, output
Return the attention weight matrix (5×5) for visualization and the context-enriched output (5×2) for concatenation.
Full multi-head forward pass. Splits Q, K, V into heads, runs attention on each, then concatenates outputs. This is the main entry point.
EXECUTION STATE
⬇ input: Q (5×4) =
d0 d1 d2 d3
The 1.0 0.0 1.0 0.0
cat 0.0 2.0 0.0 1.0
sat 1.0 1.0 1.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.0 1.0
⬇ input: K (5×4) =
d0 d1 d2 d3
The 0.0 1.0 0.0 1.0
cat 1.0 0.0 1.0 0.0
sat 1.0 1.0 0.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.5 0.5
⬇ input: V (5×4) =
d0 d1 d2 d3
The 1.0 0.0 0.0 0.0
cat 0.0 1.0 0.0 0.0
sat 0.0 0.0 1.0 0.0
on 0.0 0.0 0.0 1.0
mat 0.5 0.5 0.5 0.5
⬆ returns = (all_weights, output) — list of H weight matrices + concatenated output (5,4)
62Q_heads = self.split_heads(Q)
Split Q (5×4) into [Q_h1 (5×2), Q_h2 (5×2)]. Head 1 gets dims [0:2], Head 2 gets dims [2:4].
EXECUTION STATE
Q_heads[0] (5×2) =
d0 d1
The 1.0 0.0
cat 0.0 2.0
sat 1.0 1.0
on 0.0 0.0
mat 1.0 0.0
Q_heads[1] (5×2) =
d2 d3
The 1.0 0.0
cat 0.0 1.0
sat 1.0 0.0
on 1.0 1.0
mat 0.0 1.0
63K_heads = self.split_heads(K)
Split K (5×4) into [K_h1 (5×2), K_h2 (5×2)].
EXECUTION STATE
K_heads[0] (5×2) =
d0 d1
The 0.0 1.0
cat 1.0 0.0
sat 1.0 1.0
on 0.0 0.0
mat 1.0 0.0
K_heads[1] (5×2) =
d2 d3
The 0.0 1.0
cat 1.0 0.0
sat 0.0 0.0
on 1.0 1.0
mat 0.5 0.5
64V_heads = self.split_heads(V)
Split V (5×4) into [V_h1 (5×2), V_h2 (5×2)].
EXECUTION STATE
V_heads[0] (5×2) =
d0 d1
The 1.0 0.0
cat 0.0 1.0
sat 0.0 0.0
on 0.0 0.0
mat 0.5 0.5
V_heads[1] (5×2) =
d2 d3
The 0.0 0.0
cat 0.0 0.0
sat 1.0 0.0
on 0.0 1.0
mat 0.5 0.5
69for h in range(self.num_heads):
Loop over H=2 heads. Each iteration computes independent attention on one subspace.
LOOP TRACE · 2 iterations
h=0 (Head 1)
Q_heads[0] = dims [0:2] — shape (5, 2)
K_heads[0] = dims [0:2] — shape (5, 2)
V_heads[0] = dims [0:2] — shape (5, 2)
w (Head 1 weights) = shape (5, 5) — cat→The is 0.3664 (highest)
o (Head 1 output) = shape (5, 2)
h=1 (Head 2)
Q_heads[1] = dims [2:4] — shape (5, 2)
K_heads[1] = dims [2:4] — shape (5, 2)
V_heads[1] = dims [2:4] — shape (5, 2)
w (Head 2 weights) = shape (5, 5) — on→on is 0.3673 (highest)
o (Head 2 output) = shape (5, 2)
70w, o = self.attention(Q_heads[h], K_heads[h], V_heads[h])
Call single-head attention for this head. Returns weights (5×5) and output (5×2). Each head sees a completely different slice of the representation.
71all_weights.append(w)
Collect this head's attention weight matrix for later visualization/analysis.
72all_outputs.append(o)
Collect this head's (5×2) output for concatenation.
74output = np.hstack(all_outputs)
Concatenate Head 1 output (5×2) with Head 2 output (5×2) along the column axis to get (5×4). This is the Concat() operation in the formula. The result has the same shape as the original input.
Instantiate with 4 dimensions and 2 heads. This gives d_k = 4/2 = 2 per head, scale = √2 ≈ 1.4142.
EXECUTION STATE
mha.d_model = 4
mha.num_heads = 2
mha.d_k = 2
mha.scale = 1.4142
127all_weights, output = mha.forward(Q, K, V)
Run the full multi-head pipeline: split → per-head attention × 2 → concatenate.
EXECUTION STATE
all_weights[0] (Head 1) =
The cat sat on mat
The 0.1237 0.2509 0.2509 0.1237 0.2509
cat 0.3664 0.0891 0.3664 0.0891 0.0891
sat 0.1811 0.1811 0.3673 0.0893 0.1811
on 0.2000 0.2000 0.2000 0.2000 0.2000
mat 0.1237 0.2509 0.2509 0.1237 0.2509
all_weights[1] (Head 2) =
The cat sat on mat
The 0.1337 0.2711 0.1337 0.2711 0.1904
cat 0.2711 0.1337 0.1337 0.2711 0.1904
sat 0.1337 0.2711 0.1337 0.2711 0.1904
on 0.1811 0.1811 0.0893 0.3673 0.1811
mat 0.2711 0.1337 0.1337 0.2711 0.1904
output (5×4) =
d0 d1 d2 d3
The 0.2491 0.3763 0.2289 0.3663
cat 0.4109 0.1336 0.2289 0.3663
sat 0.2717 0.2717 0.2289 0.3663
on 0.3000 0.3000 0.1799 0.4579
mat 0.2491 0.3763 0.2289 0.3663
136mha.explain(Q, K, V, tokens, query_idx=0)
Print detailed trace for 'The' (token 0). Shows per-head dot products and weights. Hover the explain loop (line 86) above to see each iteration.
EXECUTION STATE
query_idx = 0 → tracing 'The'
92 lines without explanation
1import numpy as np
2import math
34classMultiHeadAttention:5"""
6 Multi-Head Attention (Vaswani et al., 2017)
78 Splits Q, K, V into H heads, runs scaled dot-product attention
9 on each head independently, then concatenates the results.
1011 MultiHead(Q, K, V) = Concat(head_1, ..., head_H) @ W_O
12 head_i = Attention(Q @ W_i^Q, K @ W_i^K, V @ W_i^V)
13 """1415def__init__(self, d_model:int, num_heads:int):16"""
17 Args:
18 d_model: Total model dimension (4 in our example)
19 num_heads: Number of attention heads H (2 in our example)
20 """21 self.d_model = d_model
22 self.num_heads = num_heads
23 self.d_k = d_model // num_heads
24 self.scale = math.sqrt(self.d_k)2526def_softmax(self, x: np.ndarray)-> np.ndarray:27"""Numerically stable softmax along last axis."""28 x_shifted = x - np.max(x, axis=-1, keepdims=True)29 exp_x = np.exp(x_shifted)30return exp_x / np.sum(exp_x, axis=-1, keepdims=True)3132defsplit_heads(self, M: np.ndarray)->list:33"""Split matrix M (N, d_model) into H matrices of (N, d_k)."""34 heads =[]35for h inrange(self.num_heads):36 start = h * self.d_k
37 end = start + self.d_k
38 heads.append(M[:, start:end])39return heads
4041defattention(self, Qh: np.ndarray, Kh: np.ndarray, Vh: np.ndarray):42"""Single-head scaled dot-product attention."""43 scores = Qh @ Kh.T
44 scaled = scores / self.scale
45 weights = self._softmax(scaled)46 output = weights @ Vh
47return weights, output
4849defforward(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray):50"""
51 Full multi-head forward pass.
5253 Args:
54 Q: Query matrix (N, d_model)
55 K: Key matrix (N, d_model)
56 V: Value matrix (N, d_model)
5758 Returns:
59 all_weights: List of H attention matrices, each (N, N)
60 output: Concatenated output (N, d_model)
61 """62 Q_heads = self.split_heads(Q)63 K_heads = self.split_heads(K)64 V_heads = self.split_heads(V)6566 all_weights =[]67 all_outputs =[]6869for h inrange(self.num_heads):70 w, o = self.attention(Q_heads[h], K_heads[h], V_heads[h])71 all_weights.append(w)72 all_outputs.append(o)7374 output = np.hstack(all_outputs)75return all_weights, output
7677defexplain(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray,78 tokens:list, query_idx:int=0):79"""Print a detailed trace for a specific query token."""80 all_weights, output = self.forward(Q, K, V)81 Q_heads = self.split_heads(Q)82 K_heads = self.split_heads(K)8384 token = tokens[query_idx]85print(f"\n=== Multi-Head trace for '{token}' (row {query_idx}) ===")8687for h inrange(self.num_heads):88 start = h * self.d_k
89 end = start + self.d_k
90print(f"\n--- Head {h+1} (dims {start}:{end}, d_k={self.d_k}) ---")91print(f" Q_h{h+1}[{token}] = {Q_heads[h][query_idx]}")92for j, t inenumerate(tokens):93 dot = Q_heads[h][query_idx] @ K_heads[h][j]94print(f" Q[{token}] . K[{t}] = {dot:.4f} -> /{self.scale:.2f} = {dot/self.scale:.4f}")95print(f" weights = {np.round(all_weights[h][query_idx],4)}")9697print(f"\nConcatenated output[{token}] = {np.round(output[query_idx],4)}")9899100# ── Shared Example (used in every chapter) ──101tokens =["The","cat","sat","on","mat"]102103Q = np.array([104[1.0,0.0,1.0,0.0],# The105[0.0,2.0,0.0,1.0],# cat106[1.0,1.0,1.0,0.0],# sat107[0.0,0.0,1.0,1.0],# on108[1.0,0.0,0.0,1.0],# mat109])110111K = np.array([112[0.0,1.0,0.0,1.0],# The113[1.0,0.0,1.0,0.0],# cat114[1.0,1.0,0.0,0.0],# sat115[0.0,0.0,1.0,1.0],# on116[1.0,0.0,0.5,0.5],# mat117])118119V = np.array([120[1.0,0.0,0.0,0.0],# The121[0.0,1.0,0.0,0.0],# cat122[0.0,0.0,1.0,0.0],# sat123[0.0,0.0,0.0,1.0],# on124[0.5,0.5,0.5,0.5],# mat125])126127# ── Run ──128mha = MultiHeadAttention(d_model=4, num_heads=2)129all_weights, output = mha.forward(Q, K, V)130131print("Head 1 Attention Weights (5x5):")132print(np.round(all_weights[0],4))133134print("\nHead 2 Attention Weights (5x5):")135print(np.round(all_weights[1],4))136137print("\nConcatenated Output (5x4):")138print(np.round(output,4))139140# Detailed trace for "The" (token 0)141mha.explain(Q, K, V, tokens, query_idx=0)
PyTorch Implementation
The PyTorch version adds GPU support, automatic differentiation, batched inputs, and includes the learned projection matrices (WQ,WK,WV,WO) that real transformers use. We set use_projections=False to verify results match the NumPy version exactly.
Multi-Head Attention — PyTorch Implementation
🐍multi_head_attention_torch.py
Explanation(22)
Code(127)
1Import PyTorch
torch is the core tensor library. nn provides neural network modules. F provides stateless ops like softmax. math provides sqrt.
6class MultiHeadAttention(nn.Module)
By subclassing nn.Module, we get parameter tracking, .cuda() for GPU, autograd for gradients, and integration with training loops. The NumPy version is a plain class — this is a trainable component.
16def __init__(self, d_model, num_heads)
Constructor. Sets up model dimensions, head count, and learnable projection matrices. super().__init__() registers this as an nn.Module.
Learnable query projection matrix W^Q of shape (4, 4). In production, this transforms input embeddings into queries. We skip it in our example (use_projections=False) to compare directly with the NumPy version.
EXECUTION STATE
nn.Linear(4, 4, bias=False) = creates a 4×4 weight matrix (no bias term). Equivalent to Q_proj = X @ W_Q.T in math.
Output projection W^O. Applied after concatenation to mix information across heads. This is the final linear layer in the multi-head formula.
29def split_heads(self, x) -> torch.Tensor
Reshape (B, N, d_model) into (B, H, N, d_k). This is the PyTorch-idiomatic way to split heads — using view + transpose instead of slicing, so it works with batched inputs.
31x.view(B, N, self.num_heads, self.d_k).transpose(1, 2)
Two-step reshape: view changes (1,5,4) to (1,5,2,2) by splitting the last dimension into (H, d_k). Then transpose swaps dims 1 and 2: (1,5,2,2) → (1,2,5,2). Now the head dimension is before the sequence dimension, which is what matmul expects.
EXECUTION STATE
.view(B, N, H, d_k) = (1,5,4) → (1,5,2,2) — split d_model=4 into H=2 × d_k=2
.transpose(1, 2) = (1,5,2,2) → (1,2,5,2) — swap sequence and head dims. Now each head's (N, d_k) is contiguous for efficient matmul.
33def forward(self, Q, K, V, mask, use_projections)
Main entry point. Handles optional batching, optional projections, head splitting, per-head attention, and concatenation. The mask parameter enables causal or padding masking.
EXECUTION STATE
⬇ input: Q = torch.Size([5, 4]) — unbatched
⬇ input: K = torch.Size([5, 4])
⬇ input: V = torch.Size([5, 4])
⬇ input: mask = None (no masking)
⬇ input: use_projections = False (skip W_Q/K/V/O for direct comparison)
55Q = Q.unsqueeze(0) — handle unbatched
If input is 2D (N, d_model), add a batch dimension to make it 3D (1, N, d_model). This allows the same code to handle both batched and unbatched inputs.
EXECUTION STATE
.unsqueeze(0) = add dim at position 0. (5,4) → (1,5,4)
Q.shape after = torch.Size([1, 5, 4])
64Q_h = self.split_heads(Q)
Split Q (1,5,4) into (1,2,5,2). Each of the 2 heads now has its own 2D query subspace.
Compute scaled dot-product scores for ALL heads simultaneously via batched matmul. K_h.transpose(-2,-1) swaps the last two dims: (1,2,5,2) → (1,2,2,5). Then matmul: (1,2,5,2) @ (1,2,2,5) → (1,2,5,5). Both heads computed in one operation — this is why the reshape approach is faster than looping.
EXECUTION STATE
K_h.transpose(-2, -1) = swap dims -2 and -1. (1,2,5,2) → (1,2,2,5). The batch and head dims are untouched.
Softmax along the last dimension (key positions). Operates independently on each (batch, head, query) row. Numerically stable by default.
EXECUTION STATE
dim=-1 = normalize along key positions. Each query row in each head independently sums to 1.0.
weights.shape = torch.Size([1, 2, 5, 5])
75head_out = torch.matmul(weights, V_h)
Weighted sum of value vectors per head. (1,2,5,5) @ (1,2,5,2) → (1,2,5,2). Each head produces its own contextual output.
EXECUTION STATE
head_out.shape = torch.Size([1, 2, 5, 2])
78output = head_out.transpose(1, 2).contiguous().view(B, N, H * d_k)
Reverse the split: transpose heads back (1,2,5,2) → (1,5,2,2), then view merges the last two dims: (1,5,2,2) → (1,5,4). This is the Concat() operation. contiguous() is needed because transpose creates a non-contiguous view.
EXECUTION STATE
.transpose(1, 2) = (1,2,5,2) → (1,5,2,2) — swap head and seq dims back
.contiguous() = make memory layout contiguous after transpose (required before view)
.view(B, N, H*d_k) = (1,5,2,2) → (1,5,4) — merge head and d_k dims
output.shape = torch.Size([1, 5, 4])
83return weights.squeeze(0), output.squeeze(0)
Remove the batch dimension we added for unbatched input. Returns weights (2,5,5) and output (5,4).
Call the module (triggers forward()). use_projections=False skips the learned W matrices so output matches the NumPy version exactly.
EXECUTION STATE
weights[0] (Head 1) =
The cat sat on mat
The 0.1237 0.2509 0.2509 0.1237 0.2509
cat 0.3664 0.0891 0.3664 0.0891 0.0891
sat 0.1811 0.1811 0.3673 0.0893 0.1811
on 0.2000 0.2000 0.2000 0.2000 0.2000
mat 0.1237 0.2509 0.2509 0.1237 0.2509
weights[1] (Head 2) =
The cat sat on mat
The 0.1337 0.2711 0.1337 0.2711 0.1904
cat 0.2711 0.1337 0.1337 0.2711 0.1904
sat 0.1337 0.2711 0.1337 0.2711 0.1904
on 0.1811 0.1811 0.0893 0.3673 0.1811
mat 0.2711 0.1337 0.1337 0.2711 0.1904
output (5×4) =
d0 d1 d2 d3
The 0.2491 0.3763 0.2289 0.3663
cat 0.4109 0.1336 0.2289 0.3663
sat 0.2717 0.2717 0.2289 0.3663
on 0.3000 0.3000 0.1799 0.4579
mat 0.2491 0.3763 0.2289 0.3663
105 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
56classMultiHeadAttention(nn.Module):7"""
8 Multi-Head Attention (Vaswani et al., 2017) — PyTorch
910 Splits Q, K, V into H heads, runs scaled dot-product attention
11 on each, then concatenates and projects with W_O.
1213 Supports GPU, automatic differentiation, and batched inputs.
14 """1516def__init__(self, d_model:int, num_heads:int):17super().__init__()18 self.d_model = d_model
19 self.num_heads = num_heads
20 self.d_k = d_model // num_heads
21 self.scale = math.sqrt(self.d_k)2223# Learned projection matrices (not used in this example,24# but included for production completeness)25 self.W_Q = nn.Linear(d_model, d_model, bias=False)26 self.W_K = nn.Linear(d_model, d_model, bias=False)27 self.W_V = nn.Linear(d_model, d_model, bias=False)28 self.W_O = nn.Linear(d_model, d_model, bias=False)2930defsplit_heads(self, x: torch.Tensor)-> torch.Tensor:31"""Reshape (B, N, d_model) -> (B, H, N, d_k)."""32 B, N, _ = x.shape
33return x.view(B, N, self.num_heads, self.d_k).transpose(1,2)3435defforward(36 self,37 Q: torch.Tensor,38 K: torch.Tensor,39 V: torch.Tensor,40 mask: torch.Tensor |None=None,41 use_projections:bool=False,42)->tuple[torch.Tensor, torch.Tensor]:43"""
44 Args:
45 Q: (B, N, d_model) or (N, d_model)
46 K: (B, N, d_model) or (N, d_model)
47 V: (B, N, d_model) or (N, d_model)
48 mask: Optional (B, 1, N, N) or (N, N) boolean
49 use_projections: If True, apply W_Q/W_K/W_V/W_O
5051 Returns:
52 all_weights: (B, H, N, N) attention weights
53 output: (B, N, d_model) multi-head output
54 """55# Handle unbatched input56if Q.dim()==2:57 Q = Q.unsqueeze(0)58 K = K.unsqueeze(0)59 V = V.unsqueeze(0)6061# Optional learned projections62if use_projections:63 Q = self.W_Q(Q)64 K = self.W_K(K)65 V = self.W_V(V)6667# Split into heads: (B, N, d_model) -> (B, H, N, d_k)68 Q_h = self.split_heads(Q)69 K_h = self.split_heads(K)70 V_h = self.split_heads(V)7172# Scaled dot-product per head73 scores = torch.matmul(Q_h, K_h.transpose(-2,-1))/ self.scale
7475if mask isnotNone:76 scores = scores.masked_fill(mask,float("-inf"))7778 weights = F.softmax(scores, dim=-1)# (B, H, N, N)79 head_out = torch.matmul(weights, V_h)# (B, H, N, d_k)8081# Concatenate heads: (B, H, N, d_k) -> (B, N, d_model)82 B, H, N, d_k = head_out.shape
83 output = head_out.transpose(1,2).contiguous().view(B, N, H * d_k)8485if use_projections:86 output = self.W_O(output)8788return weights.squeeze(0), output.squeeze(0)899091# ── Shared Example ──92tokens =["The","cat","sat","on","mat"]9394Q = torch.tensor([95[1.0,0.0,1.0,0.0],96[0.0,2.0,0.0,1.0],97[1.0,1.0,1.0,0.0],98[0.0,0.0,1.0,1.0],99[1.0,0.0,0.0,1.0],100])101K = torch.tensor([102[0.0,1.0,0.0,1.0],103[1.0,0.0,1.0,0.0],104[1.0,1.0,0.0,0.0],105[0.0,0.0,1.0,1.0],106[1.0,0.0,0.5,0.5],107])108V = torch.tensor([109[1.0,0.0,0.0,0.0],110[0.0,1.0,0.0,0.0],111[0.0,0.0,1.0,0.0],112[0.0,0.0,0.0,1.0],113[0.5,0.5,0.5,0.5],114])115116# ── Run (without learned projections for exact comparison) ──117mha = MultiHeadAttention(d_model=4, num_heads=2)118weights, output = mha(Q, K, V, use_projections=False)119120print("Head 1 weights (5x5):")121print(weights[0].round(decimals=4))122123print("\nHead 2 weights (5x5):")124print(weights[1].round(decimals=4))125126print("\nOutput (5x4):")127print(output.round(decimals=4))
Key Takeaways
The problem: A single attention head produces one attention pattern — a single compromise that cannot simultaneously specialize for syntax, semantics, and position.
The solution: Run H independent heads in parallel, each operating on a dk-dimensional subspace of the representation.
Same cost: Total computation is identical to single-head because H⋅dk=dmodel. The work is just distributed across heads.
Concatenation recovers dimensionality: Each head outputs (N,dk); concatenating H heads gives (N,dmodel).
WO is essential: The output projection mixes information across heads. Without it, each output dimension carries information from only one head.
Heads specialize: In trained models, different heads learn to capture syntactic dependencies, semantic similarity, positional proximity, and coreference — empirically verified by Clark et al. (2019) and Voita et al. (2019).
Exercises
Exercise 1: Compute for “sat”
Compute the multi-head attention output for the token “sat” (row 2) by hand. Show the per-head dot products, scaled scores, softmax weights, and per-head outputs. Verify that Head 1 attends most to “sat” itself (0.3673) while Head 2 distributes attention to “cat” and “on” (both 0.2711). Concatenate the outputs and check against the output matrix.
Exercise 2: H=4 Heads
Repeat the computation with H=4 heads, each with dk=1. What happens to the attention patterns when each head computes only a scalar dot product? Compare the expressiveness with H=2. Use the Head Count Explorer to verify your answers.
Exercise 3: Identical Subspaces
What happens if you assign both heads the same columns (e.g., both use dims [0,1])? Compute the attention weights for both heads and verify they are identical. What does this tell you about the role of diverse subspaces?
Exercise 4: WO Projection
Create a 4×4 output projection matrix WO (e.g., a random orthogonal matrix or a permutation matrix) and apply it to the concatenated output. How does the final output change? Why is this step necessary for mixing information across heads?
Exercise 5: Head Pruning
Set one head's output to zero (simulating pruning). How much does the final output change? Compute the L2 norm of the difference. Voita et al. (2019) found that many heads can be pruned with minimal accuracy loss — does this match your observation?
References
Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A.N., Kaiser, L., & Polosukhin, I. (2017). Attention Is All You Need.Advances in Neural Information Processing Systems, 30. The paper that introduced multi-head attention as part of the transformer architecture.
Clark, K., Khandelwal, U., Levy, O., & Manning, C.D. (2019). What Does BERT Look At? An Analysis of BERT's Attention.BlackboxNLP Workshop at ACL 2019. Empirical analysis showing that different BERT heads specialize for different linguistic phenomena.
Voita, E., Talbot, D., Moiseev, F., Sennrich, R., & Titov, I. (2019). Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned.ACL 2019. Showed that a small number of heads are important for specific tasks and the rest can be pruned.
Dao, T., Fu, D.Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.NeurIPS 2022. IO-aware algorithm for exact attention computation without materializing the full score matrix.
Su, J., Lu, Y., Pan, S., Murtadha, A., Wen, B., & Liu, Y. (2021). RoFormer: Enhanced Transformer with Rotary Position Embedding. Per-head rotary embeddings that encode position through rotation in each head's subspace.
Dosovitskiy, A., Beyer, L., Kolesnikov, A., Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M., Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J., & Houlsby, N. (2021). An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale.ICLR 2021. Vision Transformer showing multi-head attention on image patches.
Kaplan, J., McCandlish, S., Henighan, T., Brown, T.B., Chess, B., Child, R., Gray, S., Radford, A., Wu, J., & Amodei, D. (2020). Scaling Laws for Neural Language Models. Empirical scaling laws showing how performance improves with model size, dataset size, and compute.
Michel, P., Levy, O., & Neubig, G. (2019). Are Sixteen Heads Really Better than One?NeurIPS 2019. Systematic study of head importance, showing that many heads can be removed at test time with negligible accuracy loss.