Chapter 5
12 min read
Section 6 of 17

Multi-Query Attention (MQA)

Multi-Query Attention (MQA)

Noam Shazeer, "Fast Transformer Decoding: One Write-Head is All You Need", 2019


Learning Objectives

By the end of this chapter, you will be able to:

  1. Explain why the KV-cache becomes the dominant memory bottleneck during autoregressive inference and why standard multi-head attention (MHA) makes it worse.
  2. Describe the core insight of Multi-Query Attention: all heads can share a single set of Keys and Values while retaining per-head Query projections.
  3. Derive the KV-cache savings formula and calculate the exact memory reduction for real-world models like GPT-3 and LLaMA.
  4. Compute MQA attention weights and outputs by hand for “The cat sat on mat” and compare them to the MHA results from Chapter 2.
  5. Implement a complete MQA class in both NumPy and PyTorch that you can run on any input.
  6. Connect MQA to its generalization Grouped-Query Attention (GQA), and understand why most modern LLMs use GQA rather than pure MQA.
Where this appears: MQA was introduced at Google and used in PaLM (540B parameters). Its direct descendant GQA is the standard in LLaMA 2/3, Mistral, Gemma, and most production LLMs. Understanding MQA is essential for understanding how modern language models achieve fast inference at scale — it explains why a 70-billion-parameter model can generate tokens in real time on a single GPU.

The Real Problem

Chapters 1 through 4 established the core attention mechanisms: scaled dot-product, multi-head, causal masking, and cross-attention. These are mathematically elegant and produce excellent results during training. But there is a brutal engineering reality that these mechanisms ignore: inference is fundamentally different from training, and multi-head attention has a hidden cost that only reveals itself at inference time.

The KV-Cache Explosion

During autoregressive generation (producing text token by token), the model generates one new token at each step. To produce token tt, the model must compute attention over all previous tokens 1,2,,t11, 2, \ldots, t{-}1. Without caching, this means recomputing the Key and Value projections for every previous token at every step — an O(t2)O(t^2) total cost for generating a sequence of length tt.

The standard solution is the KV-cache: store the Key and Value tensors for all previously generated tokens, so they only need to be computed once. The new token computes its Query, looks up the cached Keys and Values, and produces the next output. This reduces per-step computation to O(t)O(t) — but at the cost of memory.

In standard multi-head attention (MHA, Chapter 2), every head maintains its own independent Key and Value matrices. The KV-cache for HH heads, sequence length NN, and per-head dimension dkd_k requires:

MHA KV-cache=2×H×N×dk floats\text{MHA KV-cache} = 2 \times H \times N \times d_k \text{ floats}

The factor of 2 accounts for both K and V. For a model like LLaMA 2 70B with H=64H = 64 heads and dk=128d_k = 128 at sequence length N=4096N = 4096 using float16 (2 bytes per float):

2×64×4096×128×2 bytes=128 MB per layer2 \times 64 \times 4096 \times 128 \times 2 \text{ bytes} = 128 \text{ MB per layer}

With 80 transformer layers, the total KV-cache is ~10 GB — for a single request. Serving 32 concurrent users at sequence length 8192 would require over 640 GB of KV-cache alone, exceeding the memory of even an 8×A100 node. The model weights themselves (~140 GB in float16) become thesmaller memory consumer.

The Inference Bottleneck

The KV-cache creates two problems simultaneously:

  1. Memory capacity. The cache must fit in GPU memory alongside model weights and activations. Larger caches mean fewer concurrent users, directly reducing serving throughput.
  2. Memory bandwidth. At each decoding step, the model must read the entire KV-cache from GPU memory to compute attention. Modern GPUs (like A100) have ~2 TB/s memory bandwidth. Reading a 10 GB cache takes ~5 ms — and this happens at every layer, every step. The attention computation itself (the dot products and softmax) is cheap by comparison. Decoding is memory-bandwidth bound, not compute bound.

Shazeer's insight was that the per-head K and V projections are the root cause. If all heads could share a single set of Keys and Values, the KV-cache would shrink by a factor of HH — from 10 GB to ~160 MB for LLaMA 2 70B. The question was: would the quality survive?

Training vs Inference

During training, all tokens are processed in parallel and no KV-cache exists. MQA and MHA have nearly identical training costs. The entire motivation for MQA is inference speed — reducing the memory footprint of the KV-cache so that more users can be served simultaneously and each token is generated faster.

From Intuition to Mathematics

The Shared Dictionary Insight

Recall the library metaphor from Chapter 1. In multi-head attention, imagine HH researchers in a library, each with their own question (query). In MHA, each researcher also has their own private catalogue of book labels (keys) and their own private set of books (values). This means the library must maintain HH separate catalogues and HH separate book collections — an enormous amount of storage.

Shazeer's insight: the researchers need to ask different questions, but they can all look up answers in the same catalogue and read from the same books. Each researcher still has their own unique perspective (their own Query projection), but the dictionary they consult (the Keys) and the content they read (the Values) are shared. The information is the same; only the way each head searches for it differs.

This is not as restrictive as it sounds. In practice, most of the representational diversity in multi-head attention comes from the Query projections, not the Key/Value projections. Different heads learn to “ask different questions” about the same content. Shazeer's experiments on translation benchmarks showed that sharing K and V caused surprisingly small quality drops — often less than 0.5 BLEU — while cutting inference memory and latency dramatically.

What Changes, What Stays

ComponentMHA (Chapter 2)MQA
QQ projectionsHH separate WiQW_i^Q matricesHH separate WiQW_i^Q matrices (unchanged)
KK projectionsHH separate WiKW_i^K matrices1 shared WKW^K matrix
VV projectionsHH separate WiVW_i^V matrices1 shared WVW^V matrix
KV-cache per layer2×H×N×dk2 \times H \times N \times d_k2×1×N×dk2 \times 1 \times N \times d_k (H×H\times smaller)

The Mathematical Definition

For each head i{1,,H}i \in \{1, \ldots, H\}:

headi=softmax ⁣(QWiQ(KWsharedK)dk)VWsharedV\text{head}_i = \text{softmax}\!\left(\frac{Q \cdot W_i^Q \cdot (K \cdot W_{\text{shared}}^K)^\top}{\sqrt{d_k}}\right) \cdot V \cdot W_{\text{shared}}^V

The final output is the concatenation of all heads followed by a linear projection:

MQA(Q,K,V)=Concat(head1,,headH)WO\text{MQA}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_H) \cdot W^O

Symbol-by-Symbol Breakdown

SymbolShapeMeaning
WiQW_i^Qdmodel ⁣× ⁣dkd_{\text{model}} \!\times\! d_kPer-head query projection. Each head has its own WiQW_i^Q. This is where heads learn to “ask different questions.”
WsharedKW_{\text{shared}}^Kdmodel ⁣× ⁣dkd_{\text{model}} \!\times\! d_kSingle shared key projection. All heads use the same WKW^K. This is the MQA innovation.
WsharedVW_{\text{shared}}^Vdmodel ⁣× ⁣dkd_{\text{model}} \!\times\! d_kSingle shared value projection. All heads read from the same value space.
KsharedK_{\text{shared}}N ⁣× ⁣dkN \!\times\! d_kThe single key matrix: KWsharedKK \cdot W_{\text{shared}}^K. In our example: 5 ⁣× ⁣25 \!\times\! 2.
VsharedV_{\text{shared}}N ⁣× ⁣dkN \!\times\! d_kThe single value matrix: VWsharedVV \cdot W_{\text{shared}}^V. In our example: 5 ⁣× ⁣25 \!\times\! 2.
HHscalarNumber of query heads. In our example: 2. In LLaMA 2 70B: 64.
WOW^Odmodel ⁣× ⁣dmodeld_{\text{model}} \!\times\! d_{\text{model}}Output projection (same as MHA). Maps concatenated head outputs back to model dimension.

KV-Cache Size Formula

MechanismFormulaOur Example (H=2, N=5, d_k=2)LLaMA 2 70B (H=64, N=4096, d_k=128)
MHA2×H×N×dk2 \times H \times N \times d_k40 floats128 MB (float16)
MQA2×1×N×dk2 \times 1 \times N \times d_k20 floats (2×2\times smaller)2 MB (64×64\times smaller)

Interactive: MHA vs MQA Architecture

The visualization below shows the structural difference between MHA and MQA. In MHA (left), each head has its own K and V projection matrices — separate dictionaries for each researcher. In MQA (right), all heads share a single K and V — one dictionary consulted by everyone. Hover over a head to highlight its connections. Use the slider to see how the savings scale with more heads.

Loading architecture visualizer...

Interactive: KV-Cache Memory Explorer

Use the explorer below to see how KV-cache memory scales for different models, sequence lengths, and batch sizes. The reduction factor is always equal to HH (the number of heads), regardless of other parameters.

Loading cache explorer...

Step-by-Step Calculation

Setting Up the Shared K and V

In our example with H=2H = 2 heads and dk=2d_k = 2 per head, MQA uses only the first dk=2d_k = 2 dimensions of KK and VV as the shared Key and Value matrices. Both heads use these same matrices.

KsharedK_{\text{shared}} (5×2) — used by ALL heads

d0d1
The0.01.0
cat1.00.0
sat1.01.0
on0.00.0
mat1.00.0

VsharedV_{\text{shared}} (5×2) — used by ALL heads

d0d1
The1.00.0
cat0.01.0
sat0.00.0
on0.00.0
mat0.50.5

The queries remain per-head: Qh1=Q[:,0:2]Q_{h1} = Q[:, 0{:}2] and Qh2=Q[:,2:4]Q_{h2} = Q[:, 2{:}4].

Step-by-Step for "The" (row 0)

Head 1: Qh1[The]=[1.0,  0.0]Q_{h1}[\text{The}] = [1.0,\; 0.0]

Dot products with each shared key, divided by 21.414\sqrt{2} \approx 1.414:

TokenDot ProductScaled Score
The1.0×0.0+0.0×1.0=0.01.0 \times 0.0 + 0.0 \times 1.0 = 0.00.0000
cat1.0×1.0+0.0×0.0=1.01.0 \times 1.0 + 0.0 \times 0.0 = 1.00.7071
sat1.0×1.0+0.0×1.0=1.01.0 \times 1.0 + 0.0 \times 1.0 = 1.00.7071
on1.0×0.0+0.0×0.0=0.01.0 \times 0.0 + 0.0 \times 0.0 = 0.00.0000
mat1.0×1.0+0.0×0.0=1.01.0 \times 1.0 + 0.0 \times 0.0 = 1.00.7071

Softmax: [0.1237,  0.2509,  0.2509,  0.1237,  0.2509][0.1237,\; 0.2509,\; 0.2509,\; 0.1237,\; 0.2509]

Output: Oh1[The]=[0.2491,  0.3763]O_{h1}[\text{The}] = [0.2491,\; 0.3763]

Head 2: Qh2[The]=[1.0,  0.0]Q_{h2}[\text{The}] = [1.0,\; 0.0]

Interestingly, Qh1[The]Q_{h1}[\text{The}] and Qh2[The]Q_{h2}[\text{The}] happen to be identical ([1,0][1, 0]), so both heads produce the same attention pattern for “The.” This is a coincidence of our fixed example — in practice, learned projections make them different.

Head 2 softmax: [0.1237,  0.2509,  0.2509,  0.1237,  0.2509][0.1237,\; 0.2509,\; 0.2509,\; 0.1237,\; 0.2509] (identical to Head 1)

Concatenated output for “The”: O[The]=[0.2491,  0.3763,  0.2491,  0.3763]O[\text{The}] = [0.2491,\; 0.3763,\; 0.2491,\; 0.3763]

Step-by-Step for "cat" (row 1)

Head 1: Qh1[cat]=[0.0,  2.0]Q_{h1}[\text{cat}] = [0.0,\; 2.0]

TokenDot ProductScaled Score
The0.0×0.0+2.0×1.0=2.00.0 \times 0.0 + 2.0 \times 1.0 = 2.01.4142
cat0.0×1.0+2.0×0.0=0.00.0 \times 1.0 + 2.0 \times 0.0 = 0.00.0000
sat0.0×1.0+2.0×1.0=2.00.0 \times 1.0 + 2.0 \times 1.0 = 2.01.4142
on0.0×0.0+2.0×0.0=0.00.0 \times 0.0 + 2.0 \times 0.0 = 0.00.0000
mat0.0×1.0+2.0×0.0=0.00.0 \times 1.0 + 2.0 \times 0.0 = 0.00.0000

Softmax: [0.3664,  0.0891,  0.3664,  0.0891,  0.0891][0.3664,\; 0.0891,\; 0.3664,\; 0.0891,\; 0.0891]

Head 2: Qh2[cat]=[0.0,  1.0]Q_{h2}[\text{cat}] = [0.0,\; 1.0]

Here the heads do diverge. Head 2's query [0,1][0, 1] is proportional to Head 1's [0,2][0, 2], so the attention weights differ:

Softmax: [0.2874,  0.1417,  0.2874,  0.1417,  0.1417][0.2874,\; 0.1417,\; 0.2874,\; 0.1417,\; 0.1417]

Head 1 concentrates strongly on “The” and “sat” (36.6% each), while Head 2 gives them only 28.7% each. This is MQA working as designed — different queries, same dictionary, different lookup patterns.

Concatenated output for “cat”: O[cat]=[0.4109,  0.1336,  0.3583,  0.2126]O[\text{cat}] = [0.4109,\; 0.1336,\; 0.3583,\; 0.2126]


Full Attention Weights and Output

Head 1 Attention Weights (5×5)

Thecatsatonmat
The0.12370.25090.25090.12370.2509
cat0.36640.08910.36640.08910.0891
sat0.18110.18110.36730.08930.1811
on0.20000.20000.20000.20000.2000
mat0.12370.25090.25090.12370.2509

Head 2 Attention Weights (5×5)

Thecatsatonmat
The0.12370.25090.25090.12370.2509
cat0.28740.14170.28740.14170.1417
sat0.12370.25090.25090.12370.2509
on0.18110.18110.36730.08930.1811
mat0.28740.14170.28740.14170.1417

Interesting pattern for "on"

In Head 1, “on” has Qh1=[0,0]Q_{h1} = [0, 0] — a zero query. Since every dot product with KsharedK_{\text{shared}} is zero, softmax produces a uniform distribution: all five tokens get exactly 0.2000. In contrast, Head 2 has Qh2[on]=[1,1]Q_{h2}[\text{on}] = [1, 1], which strongly favors “sat” (score 2.0) whose key [1,1][1, 1] perfectly matches.

MQA Output Matrix (5×4, heads concatenated)

d0d1d2d3
The0.24910.37630.24910.3763
cat0.41090.13360.35830.2126
sat0.27170.27170.24910.3763
on0.30000.30000.27170.2717
mat0.24910.37630.35830.2126

MQA vs MHA: Side-by-Side Comparison

Compare the MQA output with the MHA output from Chapter 2 (both using H=2H = 2 heads):

TokenMHA OutputMQA OutputDifference (MQA − MHA)
The[0.2491, 0.3763, 0.2289, 0.3663][0.2491, 0.3763, 0.2491, 0.3763][0.000, 0.000, +0.020, +0.010]
cat[0.4109, 0.1336, 0.2289, 0.3663][0.4109, 0.1336, 0.3583, 0.2126][0.000, 0.000, +0.129, −0.154]
sat[0.2717, 0.2717, 0.2289, 0.3663][0.2717, 0.2717, 0.2491, 0.3763][0.000, 0.000, +0.020, +0.010]
on[0.3000, 0.3000, 0.1799, 0.4579][0.3000, 0.3000, 0.2717, 0.2717][0.000, 0.000, +0.092, −0.186]
mat[0.2491, 0.3763, 0.2289, 0.3663][0.2491, 0.3763, 0.3583, 0.2126][0.000, 0.000, +0.129, −0.154]

Interpreting the Differences

Head 1 outputs (columns d0, d1) are identical between MHA and MQA because Head 1 uses K[:,0:2]K[:, 0{:}2] and V[:,0:2]V[:, 0{:}2] in both cases — these are the same matrices.

Head 2 outputs (columns d2, d3) differ because MHA's Head 2 uses K[:,2:4]K[:, 2{:}4] (its own keys), while MQA's Head 2 uses K[:,0:2]K[:, 0{:}2] (the shared keys). Different keys produce different attention patterns and therefore different outputs. The maximum absolute difference is 0.186 (for “on,” d3) — noticeable but not dramatic.

Quality vs efficiency tradeoff

In our tiny 2-head example, the differences are visible because Head 2 loses its unique key space entirely. In production models with 64+ heads, the quality impact is diluted across many heads, and the shared K,V captures a general-purpose representation that works well for all heads. Shazeer's experiments showed <0.5 BLEU drop on WMT translation tasks with MQA.

Applications Across Domains

Large Language Model Inference

MQA's primary application is accelerating autoregressive text generation. Google's PaLM (540B parameters, Chowdhery et al., 2022) used MQA across all 118 layers. With 48 heads and dk=256d_k = 256, the KV-cache reduction was 48×48\times, enabling significantly higher serving throughput. For a context length of 2048 tokens, PaLM's MQA KV-cache per layer is ~2 KB (float16) compared to ~96 KB for equivalent MHA — a difference that compounds across 118 layers to save over 10 MB per request.

Code Generation

Code generation models like StarCoder and CodeGen face especially long contexts (full source files with thousands of tokens). MQA allows these models to process 8,000+ token contexts while maintaining interactive response times. The KV-cache savings are proportionally larger because code sequences tend to be much longer than natural language prompts.

Vision Transformers

While vision transformers (ViT) typically use MHA during training, MQA variants are used for efficient inference in image generation models. When a vision transformer processes a 1024×1024 image as 16×16 patches (4096 tokens), the KV-cache can become substantial — especially in diffusion model denoising loops that run 50+ inference steps.

Scientific Computing

Protein structure prediction models (like ESMFold) and molecular dynamics transformers process sequences of amino acids or atomic coordinates that can reach thousands of tokens. MQA enables these models to handle full protein sequences without running into GPU memory limits during inference, making real-time structure prediction feasible on consumer hardware.


Connection to Modern Systems

Grouped-Query Attention (GQA) — Chapter 6

MQA uses 1 KV group. MHA uses HH KV groups. GQA (Ainslie et al., 2023) interpolates between these extremes with GG groups (1<G<H1 < G < H). LLaMA 2 70B uses G=8G = 8 groups for its 64 heads — each group serves 8 heads. This achieves 8×8\times cache reduction (vs MQA's 64×64\times) while preserving near-MHA quality. GQA has become the standard because it offers the best quality-efficiency tradeoff for models above ~30B parameters.

MechanismKV GroupsCache SavingsQuality
MHAG=HG = H1× (baseline)Best
GQA1<G<H1 < G < HH/GH/G× (e.g., 8×)Near-MHA
MQAG=1G = 1HH× (e.g., 64×)Slightly lower

Flash Attention — Chapter 13

Flash Attention (Dao et al., 2022) optimizes the computation of attention by tiling the score matrix to avoid materializing the full N×NN \times N matrix in GPU HBM. MQA optimizes the storage of keys and values. These are orthogonal improvements: you can use MQA to shrink the KV-cache and Flash Attention to speed up the attention computation itself. Most production systems (LLaMA 3, Gemma) combine both.

Multi-Head Latent Attention (MLA) — Chapter 15

DeepSeek-V2 (2024) takes a different approach to KV-cache compression. Instead of sharing raw K,V across heads, MLA compresses them into a low-rank latent representation and decompresses per-head at attention time. This achieves even greater compression than MQA while maintaining MHA-level quality. MLA can be seen as a learned generalization of MQA: instead of simply sharing K,V, it learns an optimal shared representation from which per-head K,V can be reconstructed.


Complexity Analysis

ResourceMHAMQASavings Factor
KV-cache memoryO(HNdk)O(H \cdot N \cdot d_k)O(Ndk)O(N \cdot d_k)H×H\times
KV projection params2Hdmodeldk2H \cdot d_{\text{model}} \cdot d_k2dmodeldk2 \cdot d_{\text{model}} \cdot d_kH×H\times
Attention FLOPs (per step)O(HNdk)O(H \cdot N \cdot d_k)O(HNdk)O(H \cdot N \cdot d_k)1× (same)
Memory bandwidthO(HNdk)O(H \cdot N \cdot d_k)O(Ndk)O(N \cdot d_k)H×H\times

The attention FLOPs are unchanged because each head still computes its own attention scores — it just does so against the shared K,V instead of per-head K,V. The savings come entirely from reduced memory and bandwidth for the KV-cache.


Python Implementation

The class below implements MQA from scratch. The critical method is compute_shared_kv(), which extracts a single K,V pair shared by all heads. Compare this with the MHA class from Chapter 2, where each head computes its own K,V slice.

Multi-Query Attention — NumPy Implementation
🐍multi_query_attention.py
1import numpy as np

NumPy provides vectorized matrix operations. Qh @ K_shared.T runs as optimized C code, not Python loops.

2import math

Python standard library. We use math.sqrt() to precompute the scaling factor √d_k.

4class MultiQueryAttention

Wraps MQA in a reusable class. The key difference from MHA: compute_shared_kv() extracts ONE K,V pair used by ALL heads, instead of per-head K,V slices.

16def __init__(self, d_model, n_heads)

Constructor. Takes the full model dimension and number of heads. Computes per-head dimension d_k = d_model / n_heads and precomputes the scaling factor √d_k.

EXECUTION STATE
⬇ input: d_model = 4
⬇ input: n_heads = 2
17self.d_model = d_model

Store full model dimension (4 in our example).

EXECUTION STATE
self.d_model = 4
18self.n_heads = n_heads

Store number of query heads. In MQA, all n_heads share one K,V — this is the core idea.

EXECUTION STATE
self.n_heads = 2
19self.d_k = d_model // n_heads

Per-head query dimension. Each head sees d_k=2 dimensions of the query vector. The shared K,V also have d_k=2 columns.

EXECUTION STATE
d_model // n_heads = 4 // 2 = 2
self.d_k = 2
20self.scale = math.sqrt(self.d_k)

Precompute √d_k once. Divides every dot product to control variance (same as Chapter 1).

EXECUTION STATE
math.sqrt(2) = 1.4142
self.scale = 1.4142
22def _softmax(self, x) → np.ndarray

Numerically stable softmax. Takes a score matrix (5×5) and returns probabilities per row. Each row sums to 1.0.

EXECUTION STATE
⬇ input: x = shape (5, 5) — scaled score matrix for one head
⬆ returns = np.ndarray (5, 5) — softmax probabilities per row
24x_shifted = x - np.max(x, axis=-1, keepdims=True)

Subtract row-wise max for numerical stability. exp(x - max) prevents overflow while preserving the softmax result.

EXECUTION STATE
axis=-1 = find max along last axis — each row gets its own max, not a global max
keepdims=True = result has shape (5,1) instead of (5,) so broadcasting x(5×5) - max(5×1) works — each row subtracts its own max
25exp_x = np.exp(x_shifted)

Exponentiate every element. The largest value per row is exp(0)=1.0 — no overflow possible.

26return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

Divide each element by its row sum to normalize into probabilities. Each row sums to exactly 1.0.

EXECUTION STATE
axis=-1 = sum along last axis — sum each row independently
keepdims=True = sum returns shape (5,1) so broadcasting works correctly
28def compute_shared_kv(self, K, V)

THE KEY MQA OPERATION: extract a single K,V pair from the first d_k=2 dimensions. All heads will use these same matrices. In MHA, each head gets its own slice — here, everyone shares one.

EXECUTION STATE
⬇ input: K (5×4) =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
⬇ input: V (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
self.d_k = 2 — only first 2 columns used
⬆ returns = (K_shared (5×2), V_shared (5×2))
30return K[:, :self.d_k], V[:, :self.d_k]

Slice first 2 columns from K and V. This is the shared representation that all heads look up from.

EXECUTION STATE
K[:, :2] =
     d0   d1
The  0.0  1.0
cat  1.0  0.0
sat  1.0  1.0
on   0.0  0.0
mat  1.0  0.0
V[:, :2] =
     d0   d1
The  1.0  0.0
cat  0.0  1.0
sat  0.0  0.0
on   0.0  0.0
mat  0.5  0.5
⬆ return: (K_shared, V_shared) = both shape (5, 2)
32def compute_head_query(self, Q, head_idx)

Extract the query slice for a specific head. Head 0 gets Q[:, 0:2], Head 1 gets Q[:, 2:4]. Each head asks DIFFERENT questions but looks them up in the SAME dictionary.

EXECUTION STATE
⬇ input: Q (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
⬇ input: head_idx = 0 or 1
⬆ returns = np.ndarray (5, 2) — query slice for this head
34start = head_idx * self.d_k

Compute start column index for this head's query slice.

EXECUTION STATE
head_idx=0 → start = 0 × 2 = 0
head_idx=1 → start = 1 × 2 = 2
35end = start + self.d_k

Compute end column index.

EXECUTION STATE
head_idx=0 → end = 0 + 2 = 2
head_idx=1 → end = 2 + 2 = 4
36return Q[:, start:end]

Return the 2-column query slice for this head.

EXECUTION STATE
⬆ return: Q_h1 (head 0) =
     d0   d1
The  1.0  0.0
cat  0.0  2.0
sat  1.0  1.0
on   0.0  0.0
mat  1.0  0.0
⬆ return: Q_h2 (head 1) =
     d2   d3
The  1.0  0.0
cat  0.0  1.0
sat  1.0  0.0
on   1.0  1.0
mat  0.0  1.0
38def compute_scores(self, Qh, K_shared) → np.ndarray

Raw dot-product scores: Qh (5×2) @ K_shared.T (2×5) → (5×5). Each entry measures how much a query aligns with a shared key.

EXECUTION STATE
⬇ input: Qh = shape (5, 2) — one head's queries
⬇ input: K_shared = shape (5, 2) — shared keys (same for all heads)
⬆ returns = np.ndarray (5, 5) — raw scores
40return Qh @ K_shared.T

Matrix multiply Qh (5×2) with K_shared transposed (2×5). Result is 5×5: entry (i,j) = dot product of query_i with shared_key_j.

EXECUTION STATE
.T = transpose — K_shared (5×2) becomes (2×5)
⬆ return (Head 1) =
     The  cat  sat   on  mat
The  0.0  1.0  1.0  0.0  1.0
cat  2.0  0.0  2.0  0.0  0.0
sat  1.0  1.0  2.0  0.0  1.0
on   0.0  0.0  0.0  0.0  0.0
mat  0.0  1.0  1.0  0.0  1.0
⬆ return (Head 2) =
     The  cat  sat   on  mat
The  0.0  1.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0  0.0
sat  0.0  1.0  1.0  0.0  1.0
on   1.0  1.0  2.0  0.0  1.0
mat  1.0  0.0  1.0  0.0  0.0
42def scale_scores(self, scores) → np.ndarray

Divide every score by √d_k = √2 ≈ 1.4142 to prevent softmax saturation.

EXECUTION STATE
⬇ input: scores = shape (5, 5) — raw dot products
self.scale = 1.4142 (√2)
⬆ returns = np.ndarray (5, 5) — scores ÷ 1.4142
44return scores / self.scale

Elementwise division. Every score is divided by 1.4142.

EXECUTION STATE
⬆ return (Head 1) =
       The     cat     sat      on     mat
The  0.0000  0.7071  0.7071  0.0000  0.7071
cat  1.4142  0.0000  1.4142  0.0000  0.0000
sat  0.7071  0.7071  1.4142  0.0000  0.7071
on   0.0000  0.0000  0.0000  0.0000  0.0000
mat  0.0000  0.7071  0.7071  0.0000  0.7071
46def compute_weights(self, scaled_scores) → np.ndarray

Apply softmax row-wise to get attention probabilities. Each row sums to 1.0.

EXECUTION STATE
⬇ input: scaled_scores = shape (5, 5)
⬆ returns = np.ndarray (5, 5) — each row sums to 1.0
48return self._softmax(scaled_scores)

Calls _softmax. All 5 rows become probability distributions.

EXECUTION STATE
⬆ return (Head 1) =
       The      cat      sat       on      mat
The  0.1237   0.2509   0.2509   0.1237   0.2509
cat  0.3664   0.0891   0.3664   0.0891   0.0891
sat  0.1811   0.1811   0.3673   0.0893   0.1811
on   0.2000   0.2000   0.2000   0.2000   0.2000
mat  0.1237   0.2509   0.2509   0.1237   0.2509
⬆ return (Head 2) =
       The      cat      sat       on      mat
The  0.1237   0.2509   0.2509   0.1237   0.2509
cat  0.2874   0.1417   0.2874   0.1417   0.1417
sat  0.1237   0.2509   0.2509   0.1237   0.2509
on   0.1811   0.1811   0.3673   0.0893   0.1811
mat  0.2874   0.1417   0.2874   0.1417   0.1417
50def compute_output(self, weights, V_shared)

Weighted sum of SHARED value vectors. weights (5×5) @ V_shared (5×2) → (5×2). This is the 2-dim output for one head.

EXECUTION STATE
⬇ input: weights = shape (5, 5) — attention probabilities
⬇ input: V_shared (5×2) =
     d0   d1
The  1.0  0.0
cat  0.0  1.0
sat  0.0  0.0
on   0.0  0.0
mat  0.5  0.5
⬆ returns = np.ndarray (5, 2) — weighted sum for one head
52return weights @ V_shared

Matrix multiply weights (5×5) with V_shared (5×2). Each output row is a blend of all 5 value vectors, weighted by attention.

EXECUTION STATE
⬆ return (Head 1) =
       d0       d1
The  0.2491   0.3763
cat  0.4109   0.1336
sat  0.2717   0.2717
on   0.3000   0.3000
mat  0.2491   0.3763
⬆ return (Head 2) =
       d0       d1
The  0.2491   0.3763
cat  0.3583   0.2126
sat  0.2491   0.3763
on   0.2717   0.2717
mat  0.3583   0.2126
54def forward(self, Q, K, V)

Main entry point. Extracts shared K,V once, then loops over heads computing per-head Q projections. Each head uses the SAME K_shared and V_shared. Returns concatenated outputs and per-head weights.

EXECUTION STATE
⬇ input: Q (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
⬇ input: K (5×4) =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
⬇ input: V (5×4) =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
⬆ returns = (output (5,4), all_weights list of 2 matrices)
67K_shared, V_shared = self.compute_shared_kv(K, V)

Extract the single shared K,V pair. This is done ONCE, not per-head — the core MQA savings.

EXECUTION STATE
K_shared (5×2) =
     d0   d1
The  0.0  1.0
cat  1.0  0.0
sat  1.0  1.0
on   0.0  0.0
mat  1.0  0.0
V_shared (5×2) =
     d0   d1
The  1.0  0.0
cat  0.0  1.0
sat  0.0  0.0
on   0.0  0.0
mat  0.5  0.5
68head_outputs = []

Will collect each head's (5×2) output to concatenate at the end.

EXECUTION STATE
head_outputs = [] (empty, will hold 2 matrices)
69all_weights = []

Will collect each head's (5×5) attention weight matrix.

EXECUTION STATE
all_weights = [] (empty, will hold 2 matrices)
71for h in range(self.n_heads):

Loop over all heads (h=0 and h=1). Each head gets its own Q slice but uses the SAME K_shared and V_shared.

LOOP TRACE · 2 iterations
h=0 (Head 1)
Qh = Q[:, 0:2] = d0 d1 The 1.0 0.0 cat 0.0 2.0 sat 1.0 1.0 on 0.0 0.0 mat 1.0 0.0
K_shared = SAME as above (shared!)
output_h1 = d0 d1 The 0.2491 0.3763 cat 0.4109 0.1336 sat 0.2717 0.2717 on 0.3000 0.3000 mat 0.2491 0.3763
h=1 (Head 2)
Qh = Q[:, 2:4] = d2 d3 The 1.0 0.0 cat 0.0 1.0 sat 1.0 0.0 on 1.0 1.0 mat 0.0 1.0
K_shared = SAME as above (shared!)
output_h2 = d0 d1 The 0.2491 0.3763 cat 0.3583 0.2126 sat 0.2491 0.3763 on 0.2717 0.2717 mat 0.3583 0.2126
72Qh = self.compute_head_query(Q, h)

Get head-specific query. Head 0 → Q[:, 0:2], Head 1 → Q[:, 2:4].

73raw_scores = self.compute_scores(Qh, K_shared)

Dot products of head-specific Q with SHARED K.

74scaled_scores = self.scale_scores(raw_scores)

Divide by √2 ≈ 1.4142.

75weights = self.compute_weights(scaled_scores)

Softmax row-wise → probabilities.

76output = self.compute_output(weights, V_shared)

Weighted sum of SHARED V.

77head_outputs.append(output)

Store this head's (5×2) output.

78all_weights.append(weights)

Store this head's (5×5) weight matrix.

80return np.hstack(head_outputs), all_weights

Concatenate all head outputs horizontally: [(5×2), (5×2)] → (5×4). Returns the full output and per-head weights.

EXECUTION STATE
np.hstack() = horizontal stack — concatenates along columns. Two (5×2) matrices become one (5×4) matrix.
⬆ return: output (5×4) =
        d0       d1       d2       d3
The  0.2491   0.3763   0.2491   0.3763
cat  0.4109   0.1336   0.3583   0.2126
sat  0.2717   0.2717   0.2491   0.3763
on   0.3000   0.3000   0.2717   0.2717
mat  0.2491   0.3763   0.3583   0.2126
103tokens = [...]

The 5 tokens used in every chapter.

EXECUTION STATE
tokens = ['The', 'cat', 'sat', 'on', 'mat']
105Q = np.array([...])

Query matrix — same as Chapter 1. Each row encodes what that token looks for.

EXECUTION STATE
Q =
      d0   d1   d2   d3
The  1.0  0.0  1.0  0.0
cat  0.0  2.0  0.0  1.0
sat  1.0  1.0  1.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.0  1.0
113K = np.array([...])

Key matrix — same as Chapter 1. MQA only uses the first 2 columns (K_shared).

EXECUTION STATE
K =
      d0   d1   d2   d3
The  0.0  1.0  0.0  1.0
cat  1.0  0.0  1.0  0.0
sat  1.0  1.0  0.0  0.0
on   0.0  0.0  1.0  1.0
mat  1.0  0.0  0.5  0.5
121V = np.array([...])

Value matrix — same as Chapter 1. MQA only uses the first 2 columns (V_shared).

EXECUTION STATE
V =
      d0   d1   d2   d3
The  1.0  0.0  0.0  0.0
cat  0.0  1.0  0.0  0.0
sat  0.0  0.0  1.0  0.0
on   0.0  0.0  0.0  1.0
mat  0.5  0.5  0.5  0.5
131mqa = MultiQueryAttention(d_model=4, n_heads=2)

Instantiate MQA with 4 model dims and 2 heads. Sets d_k=2 and scale=√2.

EXECUTION STATE
mqa.d_k = 2
mqa.scale = 1.4142
132output, all_weights = mqa.forward(Q, K, V)

Run the full MQA pipeline. Extracts shared K,V once, then runs 2 heads with different Q slices.

EXECUTION STATE
output (5×4) =
        d0       d1       d2       d3
The  0.2491   0.3763   0.2491   0.3763
cat  0.4109   0.1336   0.3583   0.2126
sat  0.2717   0.2717   0.2491   0.3763
on   0.3000   0.3000   0.2717   0.2717
mat  0.2491   0.3763   0.3583   0.2126
len(all_weights) = 2 (one 5×5 matrix per head)
141mqa.explain(Q, K, V, tokens, query_idx=0)

Print detailed trace for 'The' (token 0). Shows shared K,V then per-head Q, scores, and output.

EXECUTION STATE
query_idx = 0 → tracing 'The'
142mqa.explain(Q, K, V, tokens, query_idx=1)

Print detailed trace for 'cat' (token 1). Heads produce DIFFERENT attention patterns because their Q slices differ.

EXECUTION STATE
query_idx = 1 → tracing 'cat'
103 lines without explanation
1import numpy as np
2import math
3
4class MultiQueryAttention:
5    """
6    Multi-Query Attention (Shazeer, 2019)
7
8    All H query heads share a SINGLE set of Keys and Values.
9    Only the Query projections remain per-head.
10
11    KV-cache savings: H× reduction (e.g., 64× for LLaMA 2 70B)
12    """
13
14    def __init__(self, d_model: int, n_heads: int):
15        self.d_model = d_model
16        self.n_heads = n_heads
17        self.d_k = d_model // n_heads
18        self.scale = math.sqrt(self.d_k)
19
20    def _softmax(self, x: np.ndarray) -> np.ndarray:
21        """Numerically stable softmax along last axis."""
22        x_shifted = x - np.max(x, axis=-1, keepdims=True)
23        exp_x = np.exp(x_shifted)
24        return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
25
26    def compute_shared_kv(self, K: np.ndarray, V: np.ndarray):
27        """Extract the single shared K, V (first d_k dimensions)."""
28        return K[:, :self.d_k], V[:, :self.d_k]
29
30    def compute_head_query(self, Q: np.ndarray, head_idx: int):
31        """Extract query slice for a specific head."""
32        start = head_idx * self.d_k
33        end = start + self.d_k
34        return Q[:, start:end]
35
36    def compute_scores(self, Qh: np.ndarray, K_shared: np.ndarray):
37        """Raw dot-product scores: Qh @ K_shared^T."""
38        return Qh @ K_shared.T
39
40    def scale_scores(self, scores: np.ndarray) -> np.ndarray:
41        """Divide by sqrt(d_k) to control variance."""
42        return scores / self.scale
43
44    def compute_weights(self, scaled_scores: np.ndarray) -> np.ndarray:
45        """Apply softmax to get attention weights."""
46        return self._softmax(scaled_scores)
47
48    def compute_output(self, weights: np.ndarray, V_shared: np.ndarray):
49        """Weighted sum of shared value vectors."""
50        return weights @ V_shared
51
52    def forward(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray):
53        """
54        Full forward pass.
55
56        Args:
57            Q: Query matrix  (N, d_model)
58            K: Key matrix    (N, d_model)
59            V: Value matrix  (N, d_model)
60
61        Returns:
62            output:      Concatenated head outputs  (N, d_model)
63            all_weights: List of weight matrices per head
64        """
65        K_shared, V_shared = self.compute_shared_kv(K, V)
66        head_outputs = []
67        all_weights = []
68
69        for h in range(self.n_heads):
70            Qh = self.compute_head_query(Q, h)
71            raw_scores = self.compute_scores(Qh, K_shared)
72            scaled_scores = self.scale_scores(raw_scores)
73            weights = self.compute_weights(scaled_scores)
74            output = self.compute_output(weights, V_shared)
75            head_outputs.append(output)
76            all_weights.append(weights)
77
78        return np.hstack(head_outputs), all_weights
79
80    def explain(self, Q: np.ndarray, K: np.ndarray, V: np.ndarray,
81                tokens: list, query_idx: int = 0):
82        """Print a detailed trace for a specific query token."""
83        K_shared, V_shared = self.compute_shared_kv(K, V)
84        token = tokens[query_idx]
85        print(f"\n=== MQA trace for '{token}' (row {query_idx}) ===")
86        print(f"K_shared (ALL heads use this):")
87        for i, t in enumerate(tokens):
88            print(f"  {t}: {K_shared[i]}")
89        print(f"V_shared (ALL heads use this):")
90        for i, t in enumerate(tokens):
91            print(f"  {t}: {V_shared[i]}")
92
93        for h in range(self.n_heads):
94            Qh = self.compute_head_query(Q, h)
95            raw = self.compute_scores(Qh, K_shared)
96            scaled = self.scale_scores(raw)
97            w = self.compute_weights(scaled)
98            out = self.compute_output(w, V_shared)
99            print(f"\n--- Head {h+1} ---")
100            print(f"Q_h{h+1}[{token}] = {Qh[query_idx]}")
101            for j, t in enumerate(tokens):
102                print(f"  score[{t}] = {raw[query_idx,j]:.4f}"
103                      f" -> scaled = {scaled[query_idx,j]:.4f}")
104            print(f"  softmax = {np.round(w[query_idx], 4)}")
105            print(f"  output  = {np.round(out[query_idx], 4)}")
106
107
108# ── Shared Example (used in every chapter) ──
109tokens = ["The", "cat", "sat", "on", "mat"]
110
111Q = np.array([
112    [1.0, 0.0, 1.0, 0.0],   # The
113    [0.0, 2.0, 0.0, 1.0],   # cat
114    [1.0, 1.0, 1.0, 0.0],   # sat
115    [0.0, 0.0, 1.0, 1.0],   # on
116    [1.0, 0.0, 0.0, 1.0],   # mat
117])
118
119K = np.array([
120    [0.0, 1.0, 0.0, 1.0],   # The
121    [1.0, 0.0, 1.0, 0.0],   # cat
122    [1.0, 1.0, 0.0, 0.0],   # sat
123    [0.0, 0.0, 1.0, 1.0],   # on
124    [1.0, 0.0, 0.5, 0.5],   # mat
125])
126
127V = np.array([
128    [1.0, 0.0, 0.0, 0.0],   # The
129    [0.0, 1.0, 0.0, 0.0],   # cat
130    [0.0, 0.0, 1.0, 0.0],   # sat
131    [0.0, 0.0, 0.0, 1.0],   # on
132    [0.5, 0.5, 0.5, 0.5],   # mat
133])
134
135# ── Run ──
136mqa = MultiQueryAttention(d_model=4, n_heads=2)
137output, all_weights = mqa.forward(Q, K, V)
138
139print("MQA Output (5x4):")
140print(np.round(output, 4))
141
142print("\nHead 1 Weights (5x5):")
143print(np.round(all_weights[0], 4))
144
145print("\nHead 2 Weights (5x5):")
146print(np.round(all_weights[1], 4))
147
148# Detailed trace for "The" and "cat"
149mqa.explain(Q, K, V, tokens, query_idx=0)
150mqa.explain(Q, K, V, tokens, query_idx=1)

PyTorch Implementation

The PyTorch version supports GPU acceleration and automatic differentiation. The core logic is identical: extract shared K,V once, loop over heads with per-head Q. In production, the projection matrices (WQ,WK,WVW^Q, W^K, W^V) would be learnable nn.Linear layers.

Multi-Query Attention — PyTorch Implementation
🐍multi_query_attention_torch.py
1Import PyTorch

torch is the core tensor library. torch.nn provides neural network building blocks (nn.Module). torch.nn.functional provides stateless operations like softmax and matmul. math is used for sqrt.

6class MultiQueryAttention(nn.Module)

nn.Module subclass. Unlike the NumPy version, this supports GPU acceleration via .cuda(), automatic gradient computation via autograd, and integration with PyTorch training loops.

15def __init__(self, d_model, n_heads)

Constructor. Computes d_k = d_model / n_heads and scale = √d_k. In production, this would also define the learned projection matrices (shown in comments).

EXECUTION STATE
⬇ input: d_model = 4
⬇ input: n_heads = 2
self.d_k = 4 // 2 = 2
self.scale = √2 = 1.4142
22Production projection matrices (commented)

In a real model, W_q projects to d_model (all heads), while W_k and W_v project to only d_k (one shared set). This is where the H× parameter savings come from: MHA has H separate K,V projections; MQA has just 1.

EXECUTION STATE
W_q params (MHA) = d_model × d_model = 4 × 4 = 16
W_k params (MQA) = d_model × d_k = 4 × 2 = 8 (vs 4 × 4 = 16 in MHA)
W_v params (MQA) = d_model × d_k = 4 × 2 = 8 (vs 4 × 4 = 16 in MHA)
28def forward(self, Q, K, V, mask)

Main forward pass. Extracts shared K_shared, V_shared from first d_k columns, then loops over heads. Each head has its own Q slice but shares K,V.

EXECUTION STATE
⬇ input: Q = torch.Size([5, 4]) — full query tensor
⬇ input: K = torch.Size([5, 4]) — only first 2 cols used
⬇ input: V = torch.Size([5, 4]) — only first 2 cols used
⬇ input: mask = None (no masking)
⬆ returns = (output (5,4), list of 2 weight tensors)
46K_shared = K[:, :self.d_k]

Extract shared keys — first d_k=2 columns of K. All heads use this same K_shared. This is the defining operation of MQA.

EXECUTION STATE
K[:, :2] =
     d0   d1
The  0.0  1.0
cat  1.0  0.0
sat  1.0  1.0
on   0.0  0.0
mat  1.0  0.0
47V_shared = V[:, :self.d_k]

Extract shared values — first d_k=2 columns of V. Same shared-once principle.

EXECUTION STATE
V[:, :2] =
     d0   d1
The  1.0  0.0
cat  0.0  1.0
sat  0.0  0.0
on   0.0  0.0
mat  0.5  0.5
53Qh = Q[:, h * self.d_k : (h + 1) * self.d_k]

Extract head-specific query. This is the ONLY thing that differs per head in MQA. Head 0 → Q[:, 0:2], Head 1 → Q[:, 2:4].

EXECUTION STATE
h=0 → Qh = Q[:, 0:2] — shape (5, 2)
h=1 → Qh = Q[:, 2:4] — shape (5, 2)
56scores = torch.matmul(Qh, K_shared.transpose(-2, -1))

Compute dot products between head-specific Q and SHARED K. torch.matmul supports batched inputs; .transpose(-2, -1) swaps last two dims.

EXECUTION STATE
K_shared.transpose(-2, -1) = transpose (5,2) → (2,5). -2 = second-to-last, -1 = last.
scores.shape = torch.Size([5, 5])
57scores = scores / self.scale

Divide by √2 ≈ 1.4142. Same scaling as all other chapters.

59if mask is not None: scores.masked_fill(mask, -inf)

Optional masking. For causal MQA (used in autoregressive LLMs), this would zero out future tokens. Skipped here (mask is None).

EXECUTION STATE
masked_fill(mask, val) = replace masked positions with -inf → exp(-inf) = 0 after softmax
mask = None → skipped
61weights = F.softmax(scores, dim=-1)

PyTorch's built-in numerically stable softmax. dim=-1 normalizes each row independently.

EXECUTION STATE
dim=-1 = normalize along last dimension — same as axis=-1 in NumPy. Each row becomes a probability distribution.
62output = torch.matmul(weights, V_shared)

Weighted sum of shared V. weights (5×5) @ V_shared (5×2) → (5×2) output for this head.

67return torch.cat(head_outputs, dim=-1), all_weights

Concatenate head outputs along the last dimension. torch.cat with dim=-1 is equivalent to np.hstack for 2D tensors.

EXECUTION STATE
torch.cat(dim=-1) = concatenate along last dim. Two (5,2) tensors → one (5,4) tensor.
⬆ return: output =
        d0       d1       d2       d3
The  0.2491   0.3763   0.2491   0.3763
cat  0.4109   0.1336   0.3583   0.2126
sat  0.2717   0.2717   0.2491   0.3763
on   0.3000   0.3000   0.2717   0.2717
mat  0.2491   0.3763   0.3583   0.2126
89mqa = MultiQueryAttention(d_model=4, n_heads=2)

Instantiate MQA module. attn(Q, K, V) calls forward() with PyTorch's hook system. Never call .forward() directly.

EXECUTION STATE
mqa.d_k = 2
mqa.scale = 1.4142
90output, all_weights = mqa(Q, K, V)

Run MQA. Calling mqa(Q, K, V) invokes forward() through nn.Module.__call__ which adds gradient tracking.

EXECUTION STATE
output.shape = torch.Size([5, 4])
output =
        d0       d1       d2       d3
The  0.2491   0.3763   0.2491   0.3763
cat  0.4109   0.1336   0.3583   0.2126
sat  0.2717   0.2717   0.2491   0.3763
on   0.3000   0.3000   0.2717   0.2717
mat  0.2491   0.3763   0.3583   0.2126
96Heads match for 'The'?

Both heads produce identical weights for 'The' because Q_h1[The]=[1,0] and Q_h2[The]=[1,0] happen to be the same. In practice, learned projections make them different.

EXECUTION STATE
all_weights[0][0] = [0.1237, 0.2509, 0.2509, 0.1237, 0.2509]
all_weights[1][0] = [0.1237, 0.2509, 0.2509, 0.1237, 0.2509]
allclose? = True ✓
99GPU acceleration

Moving to GPU is one line: .cuda(). The same MQA code runs on CPU or GPU without changes. At production scale (H=64, seq_len=8192), GPU gives orders of magnitude speedup.

EXECUTION STATE
torch.cuda.is_available() = True/False (depends on hardware)
98 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5
6class MultiQueryAttention(nn.Module):
7    """
8    Multi-Query Attention (Shazeer, 2019) — PyTorch
9
10    All H query heads share a SINGLE set of Keys and Values.
11    Supports GPU, automatic differentiation, and batched inputs.
12    """
13
14    def __init__(self, d_model: int, n_heads: int):
15        super().__init__()
16        self.d_model = d_model
17        self.n_heads = n_heads
18        self.d_k = d_model // n_heads
19        self.scale = math.sqrt(self.d_k)
20
21        # In production: learned projection matrices
22        # H separate Q projections, but only 1 K and 1 V projection
23        # self.W_q = nn.Linear(d_model, d_model)     # H heads
24        # self.W_k = nn.Linear(d_model, self.d_k)     # 1 shared
25        # self.W_v = nn.Linear(d_model, self.d_k)     # 1 shared
26        # self.W_o = nn.Linear(d_model, d_model)
27
28    def forward(
29        self,
30        Q: torch.Tensor,
31        K: torch.Tensor,
32        V: torch.Tensor,
33        mask: torch.Tensor | None = None,
34    ) -> tuple[torch.Tensor, list[torch.Tensor]]:
35        """
36        Args:
37            Q: (N, d_model)  — full query matrix
38            K: (N, d_model)  — full key matrix (only first d_k used)
39            V: (N, d_model)  — full value matrix (only first d_k used)
40            mask: Optional (N, N) boolean mask
41
42        Returns:
43            output:      (N, d_model) — concatenated heads
44            all_weights: list of (N, N) weight tensors per head
45        """
46        # Extract shared K, V (first d_k dimensions only)
47        K_shared = K[:, :self.d_k]    # (N, d_k)
48        V_shared = V[:, :self.d_k]    # (N, d_k)
49
50        head_outputs = []
51        all_weights = []
52
53        for h in range(self.n_heads):
54            # Head-specific query
55            Qh = Q[:, h * self.d_k : (h + 1) * self.d_k]  # (N, d_k)
56
57            # Scaled dot-product with shared K
58            scores = torch.matmul(Qh, K_shared.transpose(-2, -1))
59            scores = scores / self.scale
60
61            if mask is not None:
62                scores = scores.masked_fill(mask, float("-inf"))
63
64            weights = F.softmax(scores, dim=-1)
65            output = torch.matmul(weights, V_shared)
66
67            head_outputs.append(output)
68            all_weights.append(weights)
69
70        return torch.cat(head_outputs, dim=-1), all_weights
71
72
73# ── Shared Example ──
74tokens = ["The", "cat", "sat", "on", "mat"]
75
76Q = torch.tensor([
77    [1.0, 0.0, 1.0, 0.0],
78    [0.0, 2.0, 0.0, 1.0],
79    [1.0, 1.0, 1.0, 0.0],
80    [0.0, 0.0, 1.0, 1.0],
81    [1.0, 0.0, 0.0, 1.0],
82])
83
84K = torch.tensor([
85    [0.0, 1.0, 0.0, 1.0],
86    [1.0, 0.0, 1.0, 0.0],
87    [1.0, 1.0, 0.0, 0.0],
88    [0.0, 0.0, 1.0, 1.0],
89    [1.0, 0.0, 0.5, 0.5],
90])
91
92V = torch.tensor([
93    [1.0, 0.0, 0.0, 0.0],
94    [0.0, 1.0, 0.0, 0.0],
95    [0.0, 0.0, 1.0, 0.0],
96    [0.0, 0.0, 0.0, 1.0],
97    [0.5, 0.5, 0.5, 0.5],
98])
99
100# ── Run ──
101mqa = MultiQueryAttention(d_model=4, n_heads=2)
102output, all_weights = mqa(Q, K, V)
103
104print("MQA Output (5x4):")
105print(output.round(decimals=4))
106
107print("\nHead 1 weights match Head 2 for 'The'?",
108      torch.allclose(all_weights[0][0], all_weights[1][0]))
109
110# ── GPU: just move tensors ──
111if torch.cuda.is_available():
112    mqa_gpu = mqa.cuda()
113    Q_gpu, K_gpu, V_gpu = Q.cuda(), K.cuda(), V.cuda()
114    out_gpu, _ = mqa_gpu(Q_gpu, K_gpu, V_gpu)
115    print("GPU matches CPU?",
116          torch.allclose(output, out_gpu.cpu(), atol=1e-4))

Key Takeaways

  1. MQA shares K and V across all heads while keeping per-head Q projections. This reduces the KV-cache by a factor of HH.
  2. The bottleneck is memory, not compute. During autoregressive decoding, reading the KV-cache from GPU memory dominates inference latency. Shrinking the cache directly reduces per-token latency and increases serving throughput.
  3. Quality degrades gracefully. Shazeer showed <0.5 BLEU degradation on translation tasks. The diversity in attention comes primarily from the Query projections, which remain separate.
  4. MQA is the extreme case of GQA with G=1G = 1. Most production models now use GQA with G>1G > 1 for a better quality-efficiency tradeoff, but MQA remains the theoretical foundation.
  5. MQA composes with other optimizations. Flash Attention (computation), KV-cache quantization (precision), and MQA/GQA (structure) are orthogonal improvements that multiply together.

Exercises

Exercise 1: Compute for "sat"

Using the shared KsharedK_{\text{shared}} and VsharedV_{\text{shared}} matrices from this chapter, compute the MQA attention weights and output for “sat” (row 2) for both heads. Verify that Head 1 gives [0.1811,  0.1811,  0.3673,  0.0893,  0.1811][0.1811,\; 0.1811,\; 0.3673,\; 0.0893,\; 0.1811] and Head 2 gives [0.1237,  0.2509,  0.2509,  0.1237,  0.2509][0.1237,\; 0.2509,\; 0.2509,\; 0.1237,\; 0.2509].

Exercise 2: Why "on" is uniform in Head 1

Explain mathematically why the attention weights for “on” in Head 1 are exactly uniform (all 0.2000). What property of Qh1[on]Q_{h1}[\text{on}] causes this? What would happen if Qh1[on]=[1,0]Q_{h1}[\text{on}] = [1, 0] instead?

Exercise 3: Memory at scale

For a model with H=96H = 96 heads, dk=128d_k = 128, sequence length N=32,768N = 32{,}768, and 96 layers using float16: compute the total KV-cache for MHA vs MQA. At what batch size does the MHA cache exceed 80 GB (a single A100)?

Exercise 4: Converting MHA to MQA

Suppose you have a trained MHA model and want to convert it to MQA without retraining from scratch (as described by Shazeer). Propose a strategy: which head's K,V weights should you keep as the shared set? What about averaging all heads' K,V weights? Would fine-tuning be needed?

Exercise 5: When MQA = MHA

Prove that when H=1H = 1 (single head), MQA is identical to standard attention. Then show that MQA's output approaches MHA's output as the learned K,V projections converge across heads (i.e., when all heads' key projections are similar).


References

  1. Shazeer, N. (2019). “Fast Transformer Decoding: One Write-Head is All You Need.”arXiv:1911.02150.
  2. Chowdhery, A., et al. (2022). “PaLM: Scaling Language Modeling with Pathways.”arXiv:2204.02311.
  3. Ainslie, J., et al. (2023). “GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints.” arXiv:2305.13245.
  4. Vaswani, A., et al. (2017). “Attention Is All You Need.”Advances in Neural Information Processing Systems (NeurIPS), 30.
  5. Dao, T., et al. (2022). “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” arXiv:2205.14135.
  6. DeepSeek-AI (2024). “DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model.” arXiv:2405.04434.
Loading comments...