Noam Shazeer, "Fast Transformer Decoding: One Write-Head is All You Need", 2019
Learning Objectives
By the end of this chapter, you will be able to:
Explain why the KV-cache becomes the dominant memory bottleneck during autoregressive inference and why standard multi-head attention (MHA) makes it worse.
Describe the core insight of Multi-Query Attention: all heads can share a single set of Keys and Values while retaining per-head Query projections.
Derive the KV-cache savings formula and calculate the exact memory reduction for real-world models like GPT-3 and LLaMA.
Compute MQA attention weights and outputs by hand for “The cat sat on mat” and compare them to the MHA results from Chapter 2.
Implement a complete MQA class in both NumPy and PyTorch that you can run on any input.
Connect MQA to its generalization Grouped-Query Attention (GQA), and understand why most modern LLMs use GQA rather than pure MQA.
Where this appears: MQA was introduced at Google and used in PaLM (540B parameters). Its direct descendant GQA is the standard in LLaMA 2/3, Mistral, Gemma, and most production LLMs. Understanding MQA is essential for understanding how modern language models achieve fast inference at scale — it explains why a 70-billion-parameter model can generate tokens in real time on a single GPU.
The Real Problem
Chapters 1 through 4 established the core attention mechanisms: scaled dot-product, multi-head, causal masking, and cross-attention. These are mathematically elegant and produce excellent results during training. But there is a brutal engineering reality that these mechanisms ignore: inference is fundamentally different from training, and multi-head attention has a hidden cost that only reveals itself at inference time.
The KV-Cache Explosion
During autoregressive generation (producing text token by token), the model generates one new token at each step. To produce token t, the model must compute attention over all previous tokens 1,2,…,t−1. Without caching, this means recomputing the Key and Value projections for every previous token at every step — an O(t2) total cost for generating a sequence of length t.
The standard solution is the KV-cache: store the Key and Value tensors for all previously generated tokens, so they only need to be computed once. The new token computes its Query, looks up the cached Keys and Values, and produces the next output. This reduces per-step computation to O(t) — but at the cost of memory.
In standard multi-head attention (MHA, Chapter 2), every head maintains its own independent Key and Value matrices. The KV-cache for H heads, sequence length N, and per-head dimension dk requires:
MHA KV-cache=2×H×N×dk floats
The factor of 2 accounts for both K and V. For a model like LLaMA 2 70B with H=64 heads and dk=128 at sequence length N=4096 using float16 (2 bytes per float):
2×64×4096×128×2 bytes=128 MB per layer
With 80 transformer layers, the total KV-cache is ~10 GB — for a single request. Serving 32 concurrent users at sequence length 8192 would require over 640 GB of KV-cache alone, exceeding the memory of even an 8×A100 node. The model weights themselves (~140 GB in float16) become thesmaller memory consumer.
The Inference Bottleneck
The KV-cache creates two problems simultaneously:
Memory capacity. The cache must fit in GPU memory alongside model weights and activations. Larger caches mean fewer concurrent users, directly reducing serving throughput.
Memory bandwidth. At each decoding step, the model must read the entire KV-cache from GPU memory to compute attention. Modern GPUs (like A100) have ~2 TB/s memory bandwidth. Reading a 10 GB cache takes ~5 ms — and this happens at every layer, every step. The attention computation itself (the dot products and softmax) is cheap by comparison. Decoding is memory-bandwidth bound, not compute bound.
Shazeer's insight was that the per-head K and V projections are the root cause. If all heads could share a single set of Keys and Values, the KV-cache would shrink by a factor of H — from 10 GB to ~160 MB for LLaMA 2 70B. The question was: would the quality survive?
Training vs Inference
During training, all tokens are processed in parallel and no KV-cache exists. MQA and MHA have nearly identical training costs. The entire motivation for MQA is inference speed — reducing the memory footprint of the KV-cache so that more users can be served simultaneously and each token is generated faster.
From Intuition to Mathematics
The Shared Dictionary Insight
Recall the library metaphor from Chapter 1. In multi-head attention, imagine H researchers in a library, each with their own question (query). In MHA, each researcher also has their own private catalogue of book labels (keys) and their own private set of books (values). This means the library must maintain H separate catalogues and H separate book collections — an enormous amount of storage.
Shazeer's insight: the researchers need to ask different questions, but they can all look up answers in the same catalogue and read from the same books. Each researcher still has their own unique perspective (their own Query projection), but the dictionary they consult (the Keys) and the content they read (the Values) are shared. The information is the same; only the way each head searches for it differs.
This is not as restrictive as it sounds. In practice, most of the representational diversity in multi-head attention comes from the Query projections, not the Key/Value projections. Different heads learn to “ask different questions” about the same content. Shazeer's experiments on translation benchmarks showed that sharing K and V caused surprisingly small quality drops — often less than 0.5 BLEU — while cutting inference memory and latency dramatically.
The final output is the concatenation of all heads followed by a linear projection:
MQA(Q,K,V)=Concat(head1,…,headH)⋅WO
Symbol-by-Symbol Breakdown
Symbol
Shape
Meaning
WiQ
dmodel×dk
Per-head query projection. Each head has its own WiQ. This is where heads learn to “ask different questions.”
WsharedK
dmodel×dk
Single shared key projection. All heads use the same WK. This is the MQA innovation.
WsharedV
dmodel×dk
Single shared value projection. All heads read from the same value space.
Kshared
N×dk
The single key matrix: K⋅WsharedK. In our example: 5×2.
Vshared
N×dk
The single value matrix: V⋅WsharedV. In our example: 5×2.
H
scalar
Number of query heads. In our example: 2. In LLaMA 2 70B: 64.
WO
dmodel×dmodel
Output projection (same as MHA). Maps concatenated head outputs back to model dimension.
KV-Cache Size Formula
Mechanism
Formula
Our Example (H=2, N=5, d_k=2)
LLaMA 2 70B (H=64, N=4096, d_k=128)
MHA
2×H×N×dk
40 floats
128 MB (float16)
MQA
2×1×N×dk
20 floats (2× smaller)
2 MB (64× smaller)
Interactive: MHA vs MQA Architecture
The visualization below shows the structural difference between MHA and MQA. In MHA (left), each head has its own K and V projection matrices — separate dictionaries for each researcher. In MQA (right), all heads share a single K and V — one dictionary consulted by everyone. Hover over a head to highlight its connections. Use the slider to see how the savings scale with more heads.
Loading architecture visualizer...
Interactive: KV-Cache Memory Explorer
Use the explorer below to see how KV-cache memory scales for different models, sequence lengths, and batch sizes. The reduction factor is always equal to H (the number of heads), regardless of other parameters.
Loading cache explorer...
Step-by-Step Calculation
Setting Up the Shared K and V
In our example with H=2 heads and dk=2 per head, MQA uses only the first dk=2 dimensions of K and V as the shared Key and Value matrices. Both heads use these same matrices.
Kshared (5×2) — used by ALL heads
d0
d1
The
0.0
1.0
cat
1.0
0.0
sat
1.0
1.0
on
0.0
0.0
mat
1.0
0.0
Vshared (5×2) — used by ALL heads
d0
d1
The
1.0
0.0
cat
0.0
1.0
sat
0.0
0.0
on
0.0
0.0
mat
0.5
0.5
The queries remain per-head: Qh1=Q[:,0:2] and Qh2=Q[:,2:4].
Step-by-Step for "The" (row 0)
Head 1:Qh1[The]=[1.0,0.0]
Dot products with each shared key, divided by 2≈1.414:
Token
Dot Product
Scaled Score
The
1.0×0.0+0.0×1.0=0.0
0.0000
cat
1.0×1.0+0.0×0.0=1.0
0.7071
sat
1.0×1.0+0.0×1.0=1.0
0.7071
on
1.0×0.0+0.0×0.0=0.0
0.0000
mat
1.0×1.0+0.0×0.0=1.0
0.7071
Softmax: [0.1237,0.2509,0.2509,0.1237,0.2509]
Output: Oh1[The]=[0.2491,0.3763]
Head 2:Qh2[The]=[1.0,0.0]
Interestingly, Qh1[The] and Qh2[The] happen to be identical ([1,0]), so both heads produce the same attention pattern for “The.” This is a coincidence of our fixed example — in practice, learned projections make them different.
Head 2 softmax: [0.1237,0.2509,0.2509,0.1237,0.2509] (identical to Head 1)
Concatenated output for “The”: O[The]=[0.2491,0.3763,0.2491,0.3763]
Step-by-Step for "cat" (row 1)
Head 1:Qh1[cat]=[0.0,2.0]
Token
Dot Product
Scaled Score
The
0.0×0.0+2.0×1.0=2.0
1.4142
cat
0.0×1.0+2.0×0.0=0.0
0.0000
sat
0.0×1.0+2.0×1.0=2.0
1.4142
on
0.0×0.0+2.0×0.0=0.0
0.0000
mat
0.0×1.0+2.0×0.0=0.0
0.0000
Softmax: [0.3664,0.0891,0.3664,0.0891,0.0891]
Head 2:Qh2[cat]=[0.0,1.0]
Here the heads do diverge. Head 2's query [0,1] is proportional to Head 1's [0,2], so the attention weights differ:
Softmax: [0.2874,0.1417,0.2874,0.1417,0.1417]
Head 1 concentrates strongly on “The” and “sat” (36.6% each), while Head 2 gives them only 28.7% each. This is MQA working as designed — different queries, same dictionary, different lookup patterns.
Concatenated output for “cat”: O[cat]=[0.4109,0.1336,0.3583,0.2126]
Full Attention Weights and Output
Head 1 Attention Weights (5×5)
The
cat
sat
on
mat
The
0.1237
0.2509
0.2509
0.1237
0.2509
cat
0.3664
0.0891
0.3664
0.0891
0.0891
sat
0.1811
0.1811
0.3673
0.0893
0.1811
on
0.2000
0.2000
0.2000
0.2000
0.2000
mat
0.1237
0.2509
0.2509
0.1237
0.2509
Head 2 Attention Weights (5×5)
The
cat
sat
on
mat
The
0.1237
0.2509
0.2509
0.1237
0.2509
cat
0.2874
0.1417
0.2874
0.1417
0.1417
sat
0.1237
0.2509
0.2509
0.1237
0.2509
on
0.1811
0.1811
0.3673
0.0893
0.1811
mat
0.2874
0.1417
0.2874
0.1417
0.1417
Interesting pattern for "on"
In Head 1, “on” has Qh1=[0,0] — a zero query. Since every dot product with Kshared is zero, softmax produces a uniform distribution: all five tokens get exactly 0.2000. In contrast, Head 2 has Qh2[on]=[1,1], which strongly favors “sat” (score 2.0) whose key [1,1] perfectly matches.
MQA Output Matrix (5×4, heads concatenated)
d0
d1
d2
d3
The
0.2491
0.3763
0.2491
0.3763
cat
0.4109
0.1336
0.3583
0.2126
sat
0.2717
0.2717
0.2491
0.3763
on
0.3000
0.3000
0.2717
0.2717
mat
0.2491
0.3763
0.3583
0.2126
MQA vs MHA: Side-by-Side Comparison
Compare the MQA output with the MHA output from Chapter 2 (both using H=2 heads):
Token
MHA Output
MQA Output
Difference (MQA − MHA)
The
[0.2491, 0.3763, 0.2289, 0.3663]
[0.2491, 0.3763, 0.2491, 0.3763]
[0.000, 0.000, +0.020, +0.010]
cat
[0.4109, 0.1336, 0.2289, 0.3663]
[0.4109, 0.1336, 0.3583, 0.2126]
[0.000, 0.000, +0.129, −0.154]
sat
[0.2717, 0.2717, 0.2289, 0.3663]
[0.2717, 0.2717, 0.2491, 0.3763]
[0.000, 0.000, +0.020, +0.010]
on
[0.3000, 0.3000, 0.1799, 0.4579]
[0.3000, 0.3000, 0.2717, 0.2717]
[0.000, 0.000, +0.092, −0.186]
mat
[0.2491, 0.3763, 0.2289, 0.3663]
[0.2491, 0.3763, 0.3583, 0.2126]
[0.000, 0.000, +0.129, −0.154]
Interpreting the Differences
Head 1 outputs (columns d0, d1) are identical between MHA and MQA because Head 1 uses K[:,0:2] and V[:,0:2] in both cases — these are the same matrices.
Head 2 outputs (columns d2, d3) differ because MHA's Head 2 uses K[:,2:4] (its own keys), while MQA's Head 2 uses K[:,0:2] (the shared keys). Different keys produce different attention patterns and therefore different outputs. The maximum absolute difference is 0.186 (for “on,” d3) — noticeable but not dramatic.
Quality vs efficiency tradeoff
In our tiny 2-head example, the differences are visible because Head 2 loses its unique key space entirely. In production models with 64+ heads, the quality impact is diluted across many heads, and the shared K,V captures a general-purpose representation that works well for all heads. Shazeer's experiments showed <0.5 BLEU drop on WMT translation tasks with MQA.
Applications Across Domains
Large Language Model Inference
MQA's primary application is accelerating autoregressive text generation. Google's PaLM (540B parameters, Chowdhery et al., 2022) used MQA across all 118 layers. With 48 heads and dk=256, the KV-cache reduction was 48×, enabling significantly higher serving throughput. For a context length of 2048 tokens, PaLM's MQA KV-cache per layer is ~2 KB (float16) compared to ~96 KB for equivalent MHA — a difference that compounds across 118 layers to save over 10 MB per request.
Code Generation
Code generation models like StarCoder and CodeGen face especially long contexts (full source files with thousands of tokens). MQA allows these models to process 8,000+ token contexts while maintaining interactive response times. The KV-cache savings are proportionally larger because code sequences tend to be much longer than natural language prompts.
Vision Transformers
While vision transformers (ViT) typically use MHA during training, MQA variants are used for efficient inference in image generation models. When a vision transformer processes a 1024×1024 image as 16×16 patches (4096 tokens), the KV-cache can become substantial — especially in diffusion model denoising loops that run 50+ inference steps.
Scientific Computing
Protein structure prediction models (like ESMFold) and molecular dynamics transformers process sequences of amino acids or atomic coordinates that can reach thousands of tokens. MQA enables these models to handle full protein sequences without running into GPU memory limits during inference, making real-time structure prediction feasible on consumer hardware.
Connection to Modern Systems
Grouped-Query Attention (GQA) — Chapter 6
MQA uses 1 KV group. MHA uses H KV groups. GQA (Ainslie et al., 2023) interpolates between these extremes with G groups (1<G<H). LLaMA 2 70B uses G=8 groups for its 64 heads — each group serves 8 heads. This achieves 8× cache reduction (vs MQA's 64×) while preserving near-MHA quality. GQA has become the standard because it offers the best quality-efficiency tradeoff for models above ~30B parameters.
Mechanism
KV Groups
Cache Savings
Quality
MHA
G=H
1× (baseline)
Best
GQA
1<G<H
H/G× (e.g., 8×)
Near-MHA
MQA
G=1
H× (e.g., 64×)
Slightly lower
Flash Attention — Chapter 13
Flash Attention (Dao et al., 2022) optimizes the computation of attention by tiling the score matrix to avoid materializing the full N×N matrix in GPU HBM. MQA optimizes the storage of keys and values. These are orthogonal improvements: you can use MQA to shrink the KV-cache and Flash Attention to speed up the attention computation itself. Most production systems (LLaMA 3, Gemma) combine both.
Multi-Head Latent Attention (MLA) — Chapter 15
DeepSeek-V2 (2024) takes a different approach to KV-cache compression. Instead of sharing raw K,V across heads, MLA compresses them into a low-rank latent representation and decompresses per-head at attention time. This achieves even greater compression than MQA while maintaining MHA-level quality. MLA can be seen as a learned generalization of MQA: instead of simply sharing K,V, it learns an optimal shared representation from which per-head K,V can be reconstructed.
Complexity Analysis
Resource
MHA
MQA
Savings Factor
KV-cache memory
O(H⋅N⋅dk)
O(N⋅dk)
H×
KV projection params
2H⋅dmodel⋅dk
2⋅dmodel⋅dk
H×
Attention FLOPs (per step)
O(H⋅N⋅dk)
O(H⋅N⋅dk)
1× (same)
Memory bandwidth
O(H⋅N⋅dk)
O(N⋅dk)
H×
The attention FLOPs are unchanged because each head still computes its own attention scores — it just does so against the shared K,V instead of per-head K,V. The savings come entirely from reduced memory and bandwidth for the KV-cache.
Python Implementation
The class below implements MQA from scratch. The critical method is compute_shared_kv(), which extracts a single K,V pair shared by all heads. Compare this with the MHA class from Chapter 2, where each head computes its own K,V slice.
Multi-Query Attention — NumPy Implementation
🐍multi_query_attention.py
Explanation(47)
Code(150)
1import numpy as np
NumPy provides vectorized matrix operations. Qh @ K_shared.T runs as optimized C code, not Python loops.
2import math
Python standard library. We use math.sqrt() to precompute the scaling factor √d_k.
4class MultiQueryAttention
Wraps MQA in a reusable class. The key difference from MHA: compute_shared_kv() extracts ONE K,V pair used by ALL heads, instead of per-head K,V slices.
16def __init__(self, d_model, n_heads)
Constructor. Takes the full model dimension and number of heads. Computes per-head dimension d_k = d_model / n_heads and precomputes the scaling factor √d_k.
EXECUTION STATE
⬇ input: d_model = 4
⬇ input: n_heads = 2
17self.d_model = d_model
Store full model dimension (4 in our example).
EXECUTION STATE
self.d_model = 4
18self.n_heads = n_heads
Store number of query heads. In MQA, all n_heads share one K,V — this is the core idea.
EXECUTION STATE
self.n_heads = 2
19self.d_k = d_model // n_heads
Per-head query dimension. Each head sees d_k=2 dimensions of the query vector. The shared K,V also have d_k=2 columns.
EXECUTION STATE
d_model // n_heads = 4 // 2 = 2
self.d_k = 2
20self.scale = math.sqrt(self.d_k)
Precompute √d_k once. Divides every dot product to control variance (same as Chapter 1).
EXECUTION STATE
math.sqrt(2) = 1.4142
self.scale = 1.4142
22def _softmax(self, x) → np.ndarray
Numerically stable softmax. Takes a score matrix (5×5) and returns probabilities per row. Each row sums to 1.0.
EXECUTION STATE
⬇ input: x = shape (5, 5) — scaled score matrix for one head
Divide each element by its row sum to normalize into probabilities. Each row sums to exactly 1.0.
EXECUTION STATE
axis=-1 = sum along last axis — sum each row independently
keepdims=True = sum returns shape (5,1) so broadcasting works correctly
28def compute_shared_kv(self, K, V)
THE KEY MQA OPERATION: extract a single K,V pair from the first d_k=2 dimensions. All heads will use these same matrices. In MHA, each head gets its own slice — here, everyone shares one.
EXECUTION STATE
⬇ input: K (5×4) =
d0 d1 d2 d3
The 0.0 1.0 0.0 1.0
cat 1.0 0.0 1.0 0.0
sat 1.0 1.0 0.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.5 0.5
⬇ input: V (5×4) =
d0 d1 d2 d3
The 1.0 0.0 0.0 0.0
cat 0.0 1.0 0.0 0.0
sat 0.0 0.0 1.0 0.0
on 0.0 0.0 0.0 1.0
mat 0.5 0.5 0.5 0.5
self.d_k = 2 — only first 2 columns used
⬆ returns = (K_shared (5×2), V_shared (5×2))
30return K[:, :self.d_k], V[:, :self.d_k]
Slice first 2 columns from K and V. This is the shared representation that all heads look up from.
EXECUTION STATE
K[:, :2] =
d0 d1
The 0.0 1.0
cat 1.0 0.0
sat 1.0 1.0
on 0.0 0.0
mat 1.0 0.0
V[:, :2] =
d0 d1
The 1.0 0.0
cat 0.0 1.0
sat 0.0 0.0
on 0.0 0.0
mat 0.5 0.5
⬆ return: (K_shared, V_shared) = both shape (5, 2)
32def compute_head_query(self, Q, head_idx)
Extract the query slice for a specific head. Head 0 gets Q[:, 0:2], Head 1 gets Q[:, 2:4]. Each head asks DIFFERENT questions but looks them up in the SAME dictionary.
EXECUTION STATE
⬇ input: Q (5×4) =
d0 d1 d2 d3
The 1.0 0.0 1.0 0.0
cat 0.0 2.0 0.0 1.0
sat 1.0 1.0 1.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.0 1.0
⬇ input: head_idx = 0 or 1
⬆ returns = np.ndarray (5, 2) — query slice for this head
34start = head_idx * self.d_k
Compute start column index for this head's query slice.
EXECUTION STATE
head_idx=0 → start = 0 × 2 = 0
head_idx=1 → start = 1 × 2 = 2
35end = start + self.d_k
Compute end column index.
EXECUTION STATE
head_idx=0 → end = 0 + 2 = 2
head_idx=1 → end = 2 + 2 = 4
36return Q[:, start:end]
Return the 2-column query slice for this head.
EXECUTION STATE
⬆ return: Q_h1 (head 0) =
d0 d1
The 1.0 0.0
cat 0.0 2.0
sat 1.0 1.0
on 0.0 0.0
mat 1.0 0.0
⬆ return: Q_h2 (head 1) =
d2 d3
The 1.0 0.0
cat 0.0 1.0
sat 1.0 0.0
on 1.0 1.0
mat 0.0 1.0
Raw dot-product scores: Qh (5×2) @ K_shared.T (2×5) → (5×5). Each entry measures how much a query aligns with a shared key.
EXECUTION STATE
⬇ input: Qh = shape (5, 2) — one head's queries
⬇ input: K_shared = shape (5, 2) — shared keys (same for all heads)
⬆ returns = np.ndarray (5, 5) — raw scores
40return Qh @ K_shared.T
Matrix multiply Qh (5×2) with K_shared transposed (2×5). Result is 5×5: entry (i,j) = dot product of query_i with shared_key_j.
EXECUTION STATE
.T = transpose — K_shared (5×2) becomes (2×5)
⬆ return (Head 1) =
The cat sat on mat
The 0.0 1.0 1.0 0.0 1.0
cat 2.0 0.0 2.0 0.0 0.0
sat 1.0 1.0 2.0 0.0 1.0
on 0.0 0.0 0.0 0.0 0.0
mat 0.0 1.0 1.0 0.0 1.0
⬆ return (Head 2) =
The cat sat on mat
The 0.0 1.0 1.0 0.0 1.0
cat 1.0 0.0 1.0 0.0 0.0
sat 0.0 1.0 1.0 0.0 1.0
on 1.0 1.0 2.0 0.0 1.0
mat 1.0 0.0 1.0 0.0 0.0
42def scale_scores(self, scores) → np.ndarray
Divide every score by √d_k = √2 ≈ 1.4142 to prevent softmax saturation.
EXECUTION STATE
⬇ input: scores = shape (5, 5) — raw dot products
self.scale = 1.4142 (√2)
⬆ returns = np.ndarray (5, 5) — scores ÷ 1.4142
44return scores / self.scale
Elementwise division. Every score is divided by 1.4142.
EXECUTION STATE
⬆ return (Head 1) =
The cat sat on mat
The 0.0000 0.7071 0.7071 0.0000 0.7071
cat 1.4142 0.0000 1.4142 0.0000 0.0000
sat 0.7071 0.7071 1.4142 0.0000 0.7071
on 0.0000 0.0000 0.0000 0.0000 0.0000
mat 0.0000 0.7071 0.7071 0.0000 0.7071
Apply softmax row-wise to get attention probabilities. Each row sums to 1.0.
EXECUTION STATE
⬇ input: scaled_scores = shape (5, 5)
⬆ returns = np.ndarray (5, 5) — each row sums to 1.0
48return self._softmax(scaled_scores)
Calls _softmax. All 5 rows become probability distributions.
EXECUTION STATE
⬆ return (Head 1) =
The cat sat on mat
The 0.1237 0.2509 0.2509 0.1237 0.2509
cat 0.3664 0.0891 0.3664 0.0891 0.0891
sat 0.1811 0.1811 0.3673 0.0893 0.1811
on 0.2000 0.2000 0.2000 0.2000 0.2000
mat 0.1237 0.2509 0.2509 0.1237 0.2509
⬆ return (Head 2) =
The cat sat on mat
The 0.1237 0.2509 0.2509 0.1237 0.2509
cat 0.2874 0.1417 0.2874 0.1417 0.1417
sat 0.1237 0.2509 0.2509 0.1237 0.2509
on 0.1811 0.1811 0.3673 0.0893 0.1811
mat 0.2874 0.1417 0.2874 0.1417 0.1417
50def compute_output(self, weights, V_shared)
Weighted sum of SHARED value vectors. weights (5×5) @ V_shared (5×2) → (5×2). This is the 2-dim output for one head.
d0 d1
The 1.0 0.0
cat 0.0 1.0
sat 0.0 0.0
on 0.0 0.0
mat 0.5 0.5
⬆ returns = np.ndarray (5, 2) — weighted sum for one head
52return weights @ V_shared
Matrix multiply weights (5×5) with V_shared (5×2). Each output row is a blend of all 5 value vectors, weighted by attention.
EXECUTION STATE
⬆ return (Head 1) =
d0 d1
The 0.2491 0.3763
cat 0.4109 0.1336
sat 0.2717 0.2717
on 0.3000 0.3000
mat 0.2491 0.3763
⬆ return (Head 2) =
d0 d1
The 0.2491 0.3763
cat 0.3583 0.2126
sat 0.2491 0.3763
on 0.2717 0.2717
mat 0.3583 0.2126
54def forward(self, Q, K, V)
Main entry point. Extracts shared K,V once, then loops over heads computing per-head Q projections. Each head uses the SAME K_shared and V_shared. Returns concatenated outputs and per-head weights.
EXECUTION STATE
⬇ input: Q (5×4) =
d0 d1 d2 d3
The 1.0 0.0 1.0 0.0
cat 0.0 2.0 0.0 1.0
sat 1.0 1.0 1.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.0 1.0
⬇ input: K (5×4) =
d0 d1 d2 d3
The 0.0 1.0 0.0 1.0
cat 1.0 0.0 1.0 0.0
sat 1.0 1.0 0.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.5 0.5
⬇ input: V (5×4) =
d0 d1 d2 d3
The 1.0 0.0 0.0 0.0
cat 0.0 1.0 0.0 0.0
sat 0.0 0.0 1.0 0.0
on 0.0 0.0 0.0 1.0
mat 0.5 0.5 0.5 0.5
⬆ returns = (output (5,4), all_weights list of 2 matrices)
Instantiate MQA with 4 model dims and 2 heads. Sets d_k=2 and scale=√2.
EXECUTION STATE
mqa.d_k = 2
mqa.scale = 1.4142
132output, all_weights = mqa.forward(Q, K, V)
Run the full MQA pipeline. Extracts shared K,V once, then runs 2 heads with different Q slices.
EXECUTION STATE
output (5×4) =
d0 d1 d2 d3
The 0.2491 0.3763 0.2491 0.3763
cat 0.4109 0.1336 0.3583 0.2126
sat 0.2717 0.2717 0.2491 0.3763
on 0.3000 0.3000 0.2717 0.2717
mat 0.2491 0.3763 0.3583 0.2126
len(all_weights) = 2 (one 5×5 matrix per head)
141mqa.explain(Q, K, V, tokens, query_idx=0)
Print detailed trace for 'The' (token 0). Shows shared K,V then per-head Q, scores, and output.
EXECUTION STATE
query_idx = 0 → tracing 'The'
142mqa.explain(Q, K, V, tokens, query_idx=1)
Print detailed trace for 'cat' (token 1). Heads produce DIFFERENT attention patterns because their Q slices differ.
EXECUTION STATE
query_idx = 1 → tracing 'cat'
103 lines without explanation
1import numpy as np
2import math
34classMultiQueryAttention:5"""
6 Multi-Query Attention (Shazeer, 2019)
78 All H query heads share a SINGLE set of Keys and Values.
9 Only the Query projections remain per-head.
1011 KV-cache savings: H× reduction (e.g., 64× for LLaMA 2 70B)
12 """1314def__init__(self, d_model:int, n_heads:int):15 self.d_model = d_model
16 self.n_heads = n_heads
17 self.d_k = d_model // n_heads
18 self.scale = math.sqrt(self.d_k)1920def_softmax(self, x: np.ndarray)-> np.ndarray:21"""Numerically stable softmax along last axis."""22 x_shifted = x - np.max(x, axis=-1, keepdims=True)23 exp_x = np.exp(x_shifted)24return exp_x / np.sum(exp_x, axis=-1, keepdims=True)2526defcompute_shared_kv(self, K: np.ndarray, V: np.ndarray):27"""Extract the single shared K, V (first d_k dimensions)."""28return K[:,:self.d_k], V[:,:self.d_k]2930defcompute_head_query(self, Q: np.ndarray, head_idx:int):31"""Extract query slice for a specific head."""32 start = head_idx * self.d_k
33 end = start + self.d_k
34return Q[:, start:end]3536defcompute_scores(self, Qh: np.ndarray, K_shared: np.ndarray):37"""Raw dot-product scores: Qh @ K_shared^T."""38return Qh @ K_shared.T
3940defscale_scores(self, scores: np.ndarray)-> np.ndarray:41"""Divide by sqrt(d_k) to control variance."""42return scores / self.scale
4344defcompute_weights(self, scaled_scores: np.ndarray)-> np.ndarray:45"""Apply softmax to get attention weights."""46return self._softmax(scaled_scores)4748defcompute_output(self, weights: np.ndarray, V_shared: np.ndarray):49"""Weighted sum of shared value vectors."""50return weights @ V_shared
5152defforward(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray):53"""
54 Full forward pass.
5556 Args:
57 Q: Query matrix (N, d_model)
58 K: Key matrix (N, d_model)
59 V: Value matrix (N, d_model)
6061 Returns:
62 output: Concatenated head outputs (N, d_model)
63 all_weights: List of weight matrices per head
64 """65 K_shared, V_shared = self.compute_shared_kv(K, V)66 head_outputs =[]67 all_weights =[]6869for h inrange(self.n_heads):70 Qh = self.compute_head_query(Q, h)71 raw_scores = self.compute_scores(Qh, K_shared)72 scaled_scores = self.scale_scores(raw_scores)73 weights = self.compute_weights(scaled_scores)74 output = self.compute_output(weights, V_shared)75 head_outputs.append(output)76 all_weights.append(weights)7778return np.hstack(head_outputs), all_weights
7980defexplain(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray,81 tokens:list, query_idx:int=0):82"""Print a detailed trace for a specific query token."""83 K_shared, V_shared = self.compute_shared_kv(K, V)84 token = tokens[query_idx]85print(f"\n=== MQA trace for '{token}' (row {query_idx}) ===")86print(f"K_shared (ALL heads use this):")87for i, t inenumerate(tokens):88print(f" {t}: {K_shared[i]}")89print(f"V_shared (ALL heads use this):")90for i, t inenumerate(tokens):91print(f" {t}: {V_shared[i]}")9293for h inrange(self.n_heads):94 Qh = self.compute_head_query(Q, h)95 raw = self.compute_scores(Qh, K_shared)96 scaled = self.scale_scores(raw)97 w = self.compute_weights(scaled)98 out = self.compute_output(w, V_shared)99print(f"\n--- Head {h+1} ---")100print(f"Q_h{h+1}[{token}] = {Qh[query_idx]}")101for j, t inenumerate(tokens):102print(f" score[{t}] = {raw[query_idx,j]:.4f}"103f" -> scaled = {scaled[query_idx,j]:.4f}")104print(f" softmax = {np.round(w[query_idx],4)}")105print(f" output = {np.round(out[query_idx],4)}")106107108# ── Shared Example (used in every chapter) ──109tokens =["The","cat","sat","on","mat"]110111Q = np.array([112[1.0,0.0,1.0,0.0],# The113[0.0,2.0,0.0,1.0],# cat114[1.0,1.0,1.0,0.0],# sat115[0.0,0.0,1.0,1.0],# on116[1.0,0.0,0.0,1.0],# mat117])118119K = np.array([120[0.0,1.0,0.0,1.0],# The121[1.0,0.0,1.0,0.0],# cat122[1.0,1.0,0.0,0.0],# sat123[0.0,0.0,1.0,1.0],# on124[1.0,0.0,0.5,0.5],# mat125])126127V = np.array([128[1.0,0.0,0.0,0.0],# The129[0.0,1.0,0.0,0.0],# cat130[0.0,0.0,1.0,0.0],# sat131[0.0,0.0,0.0,1.0],# on132[0.5,0.5,0.5,0.5],# mat133])134135# ── Run ──136mqa = MultiQueryAttention(d_model=4, n_heads=2)137output, all_weights = mqa.forward(Q, K, V)138139print("MQA Output (5x4):")140print(np.round(output,4))141142print("\nHead 1 Weights (5x5):")143print(np.round(all_weights[0],4))144145print("\nHead 2 Weights (5x5):")146print(np.round(all_weights[1],4))147148# Detailed trace for "The" and "cat"149mqa.explain(Q, K, V, tokens, query_idx=0)150mqa.explain(Q, K, V, tokens, query_idx=1)
PyTorch Implementation
The PyTorch version supports GPU acceleration and automatic differentiation. The core logic is identical: extract shared K,V once, loop over heads with per-head Q. In production, the projection matrices (WQ,WK,WV) would be learnable nn.Linear layers.
Multi-Query Attention — PyTorch Implementation
🐍multi_query_attention_torch.py
Explanation(18)
Code(116)
1Import PyTorch
torch is the core tensor library. torch.nn provides neural network building blocks (nn.Module). torch.nn.functional provides stateless operations like softmax and matmul. math is used for sqrt.
6class MultiQueryAttention(nn.Module)
nn.Module subclass. Unlike the NumPy version, this supports GPU acceleration via .cuda(), automatic gradient computation via autograd, and integration with PyTorch training loops.
15def __init__(self, d_model, n_heads)
Constructor. Computes d_k = d_model / n_heads and scale = √d_k. In production, this would also define the learned projection matrices (shown in comments).
EXECUTION STATE
⬇ input: d_model = 4
⬇ input: n_heads = 2
self.d_k = 4 // 2 = 2
self.scale = √2 = 1.4142
22Production projection matrices (commented)
In a real model, W_q projects to d_model (all heads), while W_k and W_v project to only d_k (one shared set). This is where the H× parameter savings come from: MHA has H separate K,V projections; MQA has just 1.
Concatenate head outputs along the last dimension. torch.cat with dim=-1 is equivalent to np.hstack for 2D tensors.
EXECUTION STATE
torch.cat(dim=-1) = concatenate along last dim. Two (5,2) tensors → one (5,4) tensor.
⬆ return: output =
d0 d1 d2 d3
The 0.2491 0.3763 0.2491 0.3763
cat 0.4109 0.1336 0.3583 0.2126
sat 0.2717 0.2717 0.2491 0.3763
on 0.3000 0.3000 0.2717 0.2717
mat 0.2491 0.3763 0.3583 0.2126
89mqa = MultiQueryAttention(d_model=4, n_heads=2)
Instantiate MQA module. attn(Q, K, V) calls forward() with PyTorch's hook system. Never call .forward() directly.
EXECUTION STATE
mqa.d_k = 2
mqa.scale = 1.4142
90output, all_weights = mqa(Q, K, V)
Run MQA. Calling mqa(Q, K, V) invokes forward() through nn.Module.__call__ which adds gradient tracking.
EXECUTION STATE
output.shape = torch.Size([5, 4])
output =
d0 d1 d2 d3
The 0.2491 0.3763 0.2491 0.3763
cat 0.4109 0.1336 0.3583 0.2126
sat 0.2717 0.2717 0.2491 0.3763
on 0.3000 0.3000 0.2717 0.2717
mat 0.2491 0.3763 0.3583 0.2126
96Heads match for 'The'?
Both heads produce identical weights for 'The' because Q_h1[The]=[1,0] and Q_h2[The]=[1,0] happen to be the same. In practice, learned projections make them different.
Moving to GPU is one line: .cuda(). The same MQA code runs on CPU or GPU without changes. At production scale (H=64, seq_len=8192), GPU gives orders of magnitude speedup.
EXECUTION STATE
torch.cuda.is_available() = True/False (depends on hardware)
98 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
56classMultiQueryAttention(nn.Module):7"""
8 Multi-Query Attention (Shazeer, 2019) — PyTorch
910 All H query heads share a SINGLE set of Keys and Values.
11 Supports GPU, automatic differentiation, and batched inputs.
12 """1314def__init__(self, d_model:int, n_heads:int):15super().__init__()16 self.d_model = d_model
17 self.n_heads = n_heads
18 self.d_k = d_model // n_heads
19 self.scale = math.sqrt(self.d_k)2021# In production: learned projection matrices22# H separate Q projections, but only 1 K and 1 V projection23# self.W_q = nn.Linear(d_model, d_model) # H heads24# self.W_k = nn.Linear(d_model, self.d_k) # 1 shared25# self.W_v = nn.Linear(d_model, self.d_k) # 1 shared26# self.W_o = nn.Linear(d_model, d_model)2728defforward(29 self,30 Q: torch.Tensor,31 K: torch.Tensor,32 V: torch.Tensor,33 mask: torch.Tensor |None=None,34)->tuple[torch.Tensor,list[torch.Tensor]]:35"""
36 Args:
37 Q: (N, d_model) — full query matrix
38 K: (N, d_model) — full key matrix (only first d_k used)
39 V: (N, d_model) — full value matrix (only first d_k used)
40 mask: Optional (N, N) boolean mask
4142 Returns:
43 output: (N, d_model) — concatenated heads
44 all_weights: list of (N, N) weight tensors per head
45 """46# Extract shared K, V (first d_k dimensions only)47 K_shared = K[:,:self.d_k]# (N, d_k)48 V_shared = V[:,:self.d_k]# (N, d_k)4950 head_outputs =[]51 all_weights =[]5253for h inrange(self.n_heads):54# Head-specific query55 Qh = Q[:, h * self.d_k :(h +1)* self.d_k]# (N, d_k)5657# Scaled dot-product with shared K58 scores = torch.matmul(Qh, K_shared.transpose(-2,-1))59 scores = scores / self.scale
6061if mask isnotNone:62 scores = scores.masked_fill(mask,float("-inf"))6364 weights = F.softmax(scores, dim=-1)65 output = torch.matmul(weights, V_shared)6667 head_outputs.append(output)68 all_weights.append(weights)6970return torch.cat(head_outputs, dim=-1), all_weights
717273# ── Shared Example ──74tokens =["The","cat","sat","on","mat"]7576Q = torch.tensor([77[1.0,0.0,1.0,0.0],78[0.0,2.0,0.0,1.0],79[1.0,1.0,1.0,0.0],80[0.0,0.0,1.0,1.0],81[1.0,0.0,0.0,1.0],82])8384K = torch.tensor([85[0.0,1.0,0.0,1.0],86[1.0,0.0,1.0,0.0],87[1.0,1.0,0.0,0.0],88[0.0,0.0,1.0,1.0],89[1.0,0.0,0.5,0.5],90])9192V = torch.tensor([93[1.0,0.0,0.0,0.0],94[0.0,1.0,0.0,0.0],95[0.0,0.0,1.0,0.0],96[0.0,0.0,0.0,1.0],97[0.5,0.5,0.5,0.5],98])99100# ── Run ──101mqa = MultiQueryAttention(d_model=4, n_heads=2)102output, all_weights = mqa(Q, K, V)103104print("MQA Output (5x4):")105print(output.round(decimals=4))106107print("\nHead 1 weights match Head 2 for 'The'?",108 torch.allclose(all_weights[0][0], all_weights[1][0]))109110# ── GPU: just move tensors ──111if torch.cuda.is_available():112 mqa_gpu = mqa.cuda()113 Q_gpu, K_gpu, V_gpu = Q.cuda(), K.cuda(), V.cuda()114 out_gpu, _ = mqa_gpu(Q_gpu, K_gpu, V_gpu)115print("GPU matches CPU?",116 torch.allclose(output, out_gpu.cpu(), atol=1e-4))
Key Takeaways
MQA shares K and V across all heads while keeping per-head Q projections. This reduces the KV-cache by a factor of H.
The bottleneck is memory, not compute. During autoregressive decoding, reading the KV-cache from GPU memory dominates inference latency. Shrinking the cache directly reduces per-token latency and increases serving throughput.
Quality degrades gracefully. Shazeer showed <0.5 BLEU degradation on translation tasks. The diversity in attention comes primarily from the Query projections, which remain separate.
MQA is the extreme case of GQA with G=1. Most production models now use GQA with G>1 for a better quality-efficiency tradeoff, but MQA remains the theoretical foundation.
MQA composes with other optimizations. Flash Attention (computation), KV-cache quantization (precision), and MQA/GQA (structure) are orthogonal improvements that multiply together.
Exercises
Exercise 1: Compute for "sat"
Using the shared Kshared and Vshared matrices from this chapter, compute the MQA attention weights and output for “sat” (row 2) for both heads. Verify that Head 1 gives [0.1811,0.1811,0.3673,0.0893,0.1811] and Head 2 gives [0.1237,0.2509,0.2509,0.1237,0.2509].
Exercise 2: Why "on" is uniform in Head 1
Explain mathematically why the attention weights for “on” in Head 1 are exactly uniform (all 0.2000). What property of Qh1[on] causes this? What would happen if Qh1[on]=[1,0] instead?
Exercise 3: Memory at scale
For a model with H=96 heads, dk=128, sequence length N=32,768, and 96 layers using float16: compute the total KV-cache for MHA vs MQA. At what batch size does the MHA cache exceed 80 GB (a single A100)?
Exercise 4: Converting MHA to MQA
Suppose you have a trained MHA model and want to convert it to MQA without retraining from scratch (as described by Shazeer). Propose a strategy: which head's K,V weights should you keep as the shared set? What about averaging all heads' K,V weights? Would fine-tuning be needed?
Exercise 5: When MQA = MHA
Prove that when H=1 (single head), MQA is identical to standard attention. Then show that MQA's output approaches MHA's output as the learned K,V projections converge across heads (i.e., when all heads' key projections are similar).
References
Shazeer, N. (2019). “Fast Transformer Decoding: One Write-Head is All You Need.”arXiv:1911.02150.
Chowdhery, A., et al. (2022). “PaLM: Scaling Language Modeling with Pathways.”arXiv:2204.02311.
Ainslie, J., et al. (2023). “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.” arXiv:2305.13245.
Vaswani, A., et al. (2017). “Attention Is All You Need.”Advances in Neural Information Processing Systems (NeurIPS), 30.
Dao, T., et al. (2022). “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” arXiv:2205.14135.
DeepSeek-AI (2024). “DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model.” arXiv:2405.04434.