Chapter 4
25 min read
Section 21 of 117

Grouped-Query and Multi-Query Attention

Multi-Head Latent Attention (MLA)

Where MHA Hurts at Inference

Section 1 left us with a sharp number. For a model like LLaMA-2 70B serving a single 4 096-token request, the KV cache alone occupies roughly 2HTdkL2 B  =  2644096128802    10.7GB2 \cdot H \cdot T \cdot d_k \cdot L \cdot 2 \text{ B} \;=\; 2 \cdot 64 \cdot 4096 \cdot 128 \cdot 80 \cdot 2 \;\approx\; 10.7\,\text{GB}. That is per request, before batch 1 of users shows up. On an 80 GB H100 with model weights already eating ~140 GB across two GPUs, the cache is what evicts you from memory long before your math runs out of FLOPs.

The structural cause is that vanilla Multi-Head Attention stores one independent (Kh,Vh)(K^h, V^h) pair per head. With H=64H = 64 heads, the cache scales linearly in HH. Worse, every autoregressive decoding step has to read the entire cache: for token tt, attention must gather all t1t - 1 previous keys and values from HBM. Modern decoding is memory-bandwidth-bound, not compute-bound — the GPU spends its time moving bytes, not multiplying numbers.

The bottleneck reframed. At training time the cost of MHA is its FLOPs. At inference time the cost of MHA is its cache size and its cache read bandwidth. Cutting those two numbers is what every attention variant in this chapter is really fighting for.

Grouped-Query Attention (GQA) and its limiting case Multi-Query Attention (MQA) attack this with a single, almost embarrassingly simple idea. The next section gives the intuition; the rest of the page makes it rigorous.


Intuition: One Reader, Many Lenses

Picture each attention head as a person reading the same library of books. In standard MHA, every reader brings their own private library — 64 readers, 64 libraries, all stored on the GPU. They all read the same text on the page (the token's residual stream), but each builds their own index of what to remember (KK) and what to retrieve (VV).

That feels wasteful. In a really long document, do 64 readers truly need 64 different memories? GQA's wager is no: most of the variation between heads lives on the query side — the lens through which a head looks at the past — not in what they want to remember about the past. So we keep all 64 query lenses, but share the library across small groups of readers.

Multi-Query Attention pushes this to the extreme: 64 readers, one library. Every head queries the same (K,V)(K, V) with its ownWQh\,W_Q^h. Cache shrinks by 64×64\times. The question — which we'll answer numerically further down — is how much quality you pay for that compression.

Keep two pictures in your head. Queries are what you ask this token to look for; queries are recomputed from scratch every step, so they cost FLOPs but no cache. Keys and values are what you wrote down about past tokens; they live in the cache for the rest of the sequence, so every byte you spend on K/V is paid back at every future decoding step.

The Mathematical Idea

Let xtRdmodelx_t \in \mathbb{R}^{d_\text{model}} be the residual-stream vector for token tt. Standard MHA defines, for each head h=1,,Hh = 1, \dots, H:

Qth=WQhxt,Kth=WKhxt,Vth=WVhxtQ^h_t = W_Q^h\, x_t, \qquad K^h_t = W_K^h\, x_t, \qquad V^h_t = W_V^h\, x_t

with WQh,WKh,WVhRdk×dmodelW_Q^h, W_K^h, W_V^h \in \mathbb{R}^{d_k \times d_\text{model}} and the attention output of head hh at position tt being

Oth  =  softmax ⁣(Qth(Kth)dk)VthO^h_t \;=\; \mathrm{softmax}\!\left(\frac{Q^h_t\, (K^h_{\le t})^{\top}}{\sqrt{d_k}}\right)\, V^h_{\le t}

The cache size, per layer, per token, is 2Hdk2 \cdot H \cdot d_k — two slabs (K and V), HH heads each of width dkd_k.

Now partition the HH query heads into GG non-overlapping groups of equal size H/GH / G. Write g(h){1,,G}g(h) \in \{1, \dots, G\} for the group of head hh. GQA replaces the per-head K and V projections with per-group ones:

Ktg=WKgxt,Vtg=WVgxtfor g=1,,GK^{g}_t = W_K^{g}\, x_t, \qquad V^{g}_t = W_V^{g}\, x_t \qquad \text{for } g = 1, \dots, G

Head hh then attends to its group's keys and values:

Oth  =  softmax ⁣(Qth(Ktg(h))dk)Vtg(h)O^h_t \;=\; \mathrm{softmax}\!\left(\frac{Q^h_t\, (K^{g(h)}_{\le t})^{\top}}{\sqrt{d_k}}\right)\, V^{g(h)}_{\le t}

Two specialisations recover the endpoints of the spectrum:

  • G=HG = H — every group has exactly one head — recovers vanilla MHA.
  • G=1G = 1 — one big group containing every head — is MQA.
  • 1<G<H1 < G < H with GHG \mid H — the intermediate regime — is GQA.

The cache size becomes 2Gdk2 \cdot G \cdot d_k per token per layer, i.e. a factor of H/GH / G smaller than MHA. The query projection is untouched, so the total parameter count drops only slightly, but the runtime cache — the thing that actually pins your serving cost — drops by exactly that factor.

Why each symbol matters. xtx_t is what came out of the previous block at position tt; WQhW_Q^h is one of HH learned lenses for asking questions; WKg,WVgW_K^{g}, W_V^{g} are GHG \le H learned indexers for building shared memory; g(h)g(h) is the fixed (not learned) routing of head hh to its group; and the cache size formula 2Gdk2 \cdot G \cdot d_k is the bottom line your serving infrastructure pays.

MHA, MQA, and the Spectrum Between

The cleanest way to internalise GQA is to see how it slides between the two endpoints. The visualisation below maps each of the HH query heads (top row) to a coloured KV group (bottom row). Drag GG down to 1 to watch the whole thing collapse into MQA; drag it up to HH to recover MHA.

Interactive: Head-to-Group Mapping

Loading head-to-group visualiser…

Production models live somewhere in the middle of this slider. LLaMA-2 7B uses H=32,  G=32H = 32, \; G = 32 (still pure MHA); LLaMA-2 70B and LLaMA-3 8B both use G=8G = 8 — an 8× cache shrink while preserving most of MHA's expressive power, as the GQA paper showed empirically. Mistral 7B follows the same recipe.

Interactive: MHA vs MQA Architecture

Slide the head count to compare the two architectures side-by-side. The left panel keeps per-head K,VK, V boxes; the right collapses them into a single shared pair that every query head reads from.

Loading MHA / MQA architecture diagram…

Manual Numerical Walkthrough

We'll work the smallest example that still distinguishes the three mechanisms: H=4H = 4 query heads, dk=2d_k = 2, T=3T = 3 tokens. We'll compute the keys for one token under each scheme and watch the cache size change.

Click to expand the by-hand walkthrough
Setup. Pick a single token vector x1=(1,0,1,2)x_1 = (1,\, 0,\, -1,\, 2) with dmodel=4d_\text{model} = 4.
Step 1 — MHA keys (4 heads, 4 weight matrices). Each WKhR2×4W_K^h \in \mathbb{R}^{2 \times 4} is a different matrix. Take, for instance:
WK1=[10000100],WK2=[00100001]W_K^1 = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \end{bmatrix}, \quad W_K^2 = \begin{bmatrix} 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix}
WK3=[11000011],WK4=[10100101]W_K^3 = \begin{bmatrix} 1 & 1 & 0 & 0 \\ 0 & 0 & 1 & 1 \end{bmatrix}, \quad W_K^4 = \begin{bmatrix} 1 & 0 & 1 & 0 \\ 0 & 1 & 0 & 1 \end{bmatrix}
Then K11=(1,0)K^1_1 = (1, 0), K12=(1,2)K^2_1 = (-1, 2), K13=(1,1)K^3_1 = (1, 1), K14=(0,2)K^4_1 = (0, 2).
Cache cost for this one token: 2Hdk=242=162 \cdot H \cdot d_k = 2 \cdot 4 \cdot 2 = 16 floats (K and V together).
Step 2 — GQA with G = 2. We now keep only two K weight matrices. Reuse WK1W_K^1 and WK2W_K^2 from above, and bind heads {1,2}\{1, 2\} to group 1, heads {3,4}\{3, 4\} to group 2.
K1g=1=WK1x1=(1,0)K^{g=1}_1 = W_K^1 x_1 = (1, 0), K1g=2=WK2x1=(1,2)K^{g=2}_1 = W_K^2 x_1 = (-1, 2).
Heads 1 and 2 will both query (1,0)(1, 0); heads 3 and 4 will both query (1,2)(-1, 2). The query lenses are still distinct, but the "memory" is shared.
Cache cost: 2Gdk=222=82 \cdot G \cdot d_k = 2 \cdot 2 \cdot 2 = 8 floats — exactly half.
Step 3 — MQA (G = 1). Keep only WK1W_K^1. Every head reads K1g=1=(1,0)K^{g=1}_1 = (1, 0).
Cache cost: 212=42 \cdot 1 \cdot 2 = 4 floats — a quarter of MHA, an H=4H = 4× compression.
Step 4 — What the queries see. Pick a query head h=1h = 1 with Q11=(1,1)Q^1_1 = (1, 1). Under MHA it dot-products against K11K^1_{\le 1}; under GQA-2 it dot-products against K1g=1K^{g=1}_{\le 1}; under MQA it dot-products against the single shared K1K_{\le 1}. Notice that the computation at the attention head is identical in shape (a Q × K dot-product divided by dk\sqrt{d_k}) — what differs is which K matrix gets fetched from HBM.
Sanity check. Plug H=4,G=2,dk=2,T=3H = 4, G = 2, d_k = 2, T = 3 into the Python script below. You should see the printout: MHA: 48  |  GQA-2: 24  |  MQA: 12. The ratio always equals H:G:1H : G : 1.

The KV-Cache Budget, Made Concrete

Toy numbers are convincing only up to a point. Let's plug the formula cache=2GTdkLbytes\text{cache} = 2 \cdot G \cdot T \cdot d_k \cdot L \cdot \text{bytes} into real configurations and see what serving infrastructure actually buys with each GG.

Interactive: KV-Cache Memory by G

Pick a model, pick a sequence length, and watch the bars for different group counts. The savings ratio is exactly G/HG / H.

Loading KV-cache memory comparison…

Interactive: MHA vs MQA Cache Explorer

For the same scenarios, this view focuses on the two endpoints. Notice how the gap widens with the head count: a 12-head GPT-2 saves 11/12; a 64-head LLaMA-2 70B saves 63/64. MQA's win is greatest precisely where MHA's pain is worst.

Loading MQA cache explorer…
The KV cache is a tax on context. Doubling context length doubles cache memory and doubles the bytes fetched per decoding step. Compression schemes like GQA give you a fixed multiplicative discount on that tax — useful, but they don't change the linear-in-T growth. The next section's MLA attacks that growth itself.

What Do We Lose? The Quality Trade-off

Compressing K/V cannot be free. Three things actually happen when GG drops:

  1. Less head diversity in memory. Two heads in the same group see the same (K,V)(K, V) and can only differentiate their behaviour through their QQ. The model has fewer degrees of freedom for specialising what to remember.
  2. Sharper trade-off at MQA. The Shazeer 2019 MQA paper and the 2023 GQA paper both report a measurable quality dip at G=1G = 1, but a recoverable one with light extra training. With G=8G = 8 the dip is essentially in the noise — this is why every modern open LLM picked GQA over MQA.
  3. Training-time effect. Models trained from scratch with GQA from token 0 lose almost nothing. Models retrofitted by averaging per-head K/V weights into per-group weights and fine-tuning lose more — though recoverable cheaply. DeepSeek V3 trains its attention scheme (MLA) end-to-end from scratch.
ModelH (query heads)G (KV groups)Cache shrinkNotes
GPT-3 175B (2020)96961× (MHA)MHA — pre-GQA era
LLaMA-2 7B32321× (MHA)Small model, cache not yet painful
LLaMA-2 70B648First widely-deployed GQA
LLaMA-3 8B / 70B32 / 648 / 84× / 8×GQA-8 became the de facto default
Mistral 7B328GQA + sliding window attention
PaLM (Google, 2022)1H× (MQA)Pure MQA at 540B scale
DeepSeek V3128see §4.3Switches to MLA — Chapter 4 main act

Plain Python: Three Mechanisms, Side by Side

Before we touch PyTorch, here is the bare math in NumPy. The point is to see exactly which tensor shrinks when we move from MHA → GQA → MQA, and where the routing table g(h)g(h) is encoded.

gqa_mqa_demo.py — MHA, GQA-2 and MQA in 70 lines of NumPy
🐍gqa_mqa_demo.py
11Toy shapes — chosen so we can read them out loud

We use T=3 tokens, d_model=4 features per token, d_k=2 features per head, and H=4 query heads. Every following shape is built from these four numbers, so when you see (4, 3, 2) you can immediately read it as (head, token, head-dim).

EXAMPLE
x.shape = (3, 4), W_Q.shape = (4, 4, 2)
15Input activations x

One row per token. In a real transformer this is the output of the previous block (or the embedding layer for layer 0). Three tokens × four features.

EXAMPLE
x = [[ 0.13, -1.32,  0.65,  0.50],
     [-0.96, -0.13,  1.41,  0.95],
     [-0.21,  0.49,  0.46, -0.69]]
18Query weights — always per-head

The number of query weight matrices never changes when we move from MHA → GQA → MQA. Each query head keeps its own d_model × d_k projection. Only K and V shrink.

EXAMPLE
W_Q[h] : d_model × d_k = 4 × 2
20project() — the K/V factory

Given a single weight tensor of shape (G, d_model, d_k), we contract over d_model with the input x to produce one (T, d_k) key (or value) slab per group. Setting G = H gives MHA; G = 1 gives MQA; anything in between is GQA.

EXAMPLE
einsum('td,gdk->gtk', x, W) → shape (G, T, d_k)
28attention() — group-aware lookup

Heart of GQA/MQA: each query head h does its own scaled dot-product softmax, but it reads from K[g] and V[g] where g = head_to_group[h]. Many heads can share the same (K, V) slab — that is the entire memory saving.

EXAMPLE
head_to_group for GQA-2 with 4 heads: [0, 0, 1, 1]
38scaled dot-product per head

Q[h] @ K[g].T gives a (T, T) score matrix. Dividing by √d_k controls the variance so softmax does not saturate. This step is identical to vanilla attention — GQA changes only WHICH K and V we plug in.

EXAMPLE
Q[h].shape (3, 2)  @  K[g].T (2, 3)  →  (3, 3)
40softmax, written by hand

Numerically stable softmax: subtract the row max, exponentiate, then divide by the row sum. Exactly the same routine you would use in MHA.

46MHA — one K and one V per head

Here W_K_mha has shape (H, d_model, d_k) = (4, 4, 2). After project() we get K_m of shape (4, 3, 2): four independent key slabs, one per head. head_to_group = arange(H) means head h reads from group h.

EXAMPLE
K_m.shape = (4, 3, 2)  → 24 floats of cache
53GQA-2 — half the K, V weight matrices

W_K_gqa has only G=2 slices. Heads [0, 1] share group 0; heads [2, 3] share group 1. The query side is unchanged, so the model still has 4 attention heads — they just SEE through a coarser K/V lens.

EXAMPLE
K_g.shape = (2, 3, 2)  → 12 floats of cache (50% smaller)
61MQA — a single K and a single V for ALL heads

G = 1. Now head_to_group is all zeros: every query head reads from the same (K, V). This is the most aggressive compression in this family — the KV cache shrinks by a factor of H.

EXAMPLE
K_q.shape = (1, 3, 2)  → 6 floats of cache (~75% smaller)
69Cache cost in floats

The factor 2 covers both K and V. Multiply by sequence length, head dim, layer count, batch size, and bytes-per-float to get total bytes. In a real 70B model these numbers are 4–5 orders of magnitude larger.

EXAMPLE
MHA: 2 · 4 · 3 · 2 = 48  |  MQA: 2 · 1 · 3 · 2 = 12
69 lines without explanation
1"""
2Three attention KV-projection schemes, side by side, in pure NumPy.
3
4We deliberately keep every tensor tiny so you can read shapes out loud:
5
6    sequence length  T = 3       (three tokens)
7    model dim        d_model = 4
8    head dim         d_k = 2     (so 4 / 2 = 2 head-slots in d_model)
9    query heads      H = 4
10    KV groups        G in {4, 2, 1}   (MHA, GQA-2, MQA)
11"""
12
13import numpy as np
14
15T, d_model, d_k, H = 3, 4, 2, 4
16rng = np.random.default_rng(0)
17
18# Input activations: one row per token.
19x = rng.standard_normal((T, d_model))         # shape (T, d_model) = (3, 4)
20
21# Query weights — always H of them, one per query head.
22W_Q = rng.standard_normal((H, d_model, d_k))  # shape (H, d_model, d_k)
23
24def project(x, W_KV_G):
25    """
26    Project x into G group-level K (and V) matrices.
27
28    W_KV_G has shape (G, d_model, d_k).
29    Returns a tensor of shape (G, T, d_k) — one (T, d_k) slab per group.
30    """
31    return np.einsum("td,gdk->gtk", x, W_KV_G)
32
33def attention(Q, K, V, head_to_group):
34    """
35    Q : (H, T, d_k)          per-head queries
36    K : (G, T, d_k)          per-GROUP keys
37    V : (G, T, d_k)          per-GROUP values
38    head_to_group : (H,)     which group each query head reads from
39
40    Each query head h looks up K[head_to_group[h]] and V[head_to_group[h]].
41    """
42    out = np.zeros_like(Q)                              # (H, T, d_k)
43    for h in range(Q.shape[0]):
44        g = head_to_group[h]
45        scores = Q[h] @ K[g].T / np.sqrt(d_k)           # (T, T)
46        probs  = np.exp(scores - scores.max(-1, keepdims=True))
47        probs /= probs.sum(-1, keepdims=True)
48        out[h] = probs @ V[g]                           # (T, d_k)
49    return out
50
51# ---------------- MHA: G = H = 4 ----------------
52W_K_mha = rng.standard_normal((H, d_model, d_k))
53W_V_mha = rng.standard_normal((H, d_model, d_k))
54Q   = np.einsum("td,hdk->htk", x, W_Q)
55K_m = project(x, W_K_mha)            # (4, 3, 2)
56V_m = project(x, W_V_mha)            # (4, 3, 2)
57mha_out = attention(Q, K_m, V_m, head_to_group=np.arange(H))
58
59# ---------------- GQA: G = 2 (groups of 2 heads share K, V) ----------------
60G = 2
61W_K_gqa = rng.standard_normal((G, d_model, d_k))
62W_V_gqa = rng.standard_normal((G, d_model, d_k))
63K_g = project(x, W_K_gqa)            # (2, 3, 2)   <-- HALF the cache
64V_g = project(x, W_V_gqa)            # (2, 3, 2)
65gqa_out = attention(Q, K_g, V_g, head_to_group=np.array([0, 0, 1, 1]))
66
67# ---------------- MQA: G = 1 (every head shares ONE K, V) ----------------
68W_K_mqa = rng.standard_normal((1, d_model, d_k))
69W_V_mqa = rng.standard_normal((1, d_model, d_k))
70K_q = project(x, W_K_mqa)            # (1, 3, 2)   <-- 1/H the cache
71V_q = project(x, W_V_mqa)            # (1, 3, 2)
72mqa_out = attention(Q, K_q, V_q, head_to_group=np.zeros(H, dtype=int))
73
74# Cache cost in floats (no batch, FP16 would be 2 bytes per float).
75def cache_floats(G_, T_, d_k_):
76    return 2 * G_ * T_ * d_k_        # factor 2 for K and V
77
78print("MHA  cache floats:", cache_floats(4, T, d_k))   # 48
79print("GQA2 cache floats:", cache_floats(2, T, d_k))   # 24  (50% saving)
80print("MQA  cache floats:", cache_floats(1, T, d_k))   # 12  (75% saving)
Run the file and you will see exactly 48 / 24 / 12 floats of cache for MHA / GQA-2 / MQA. The compression ratio is the only difference; the attention math is bit-for-bit identical once the right K and V are looked up.

PyTorch: GQA via repeat_interleave

Production GQA layers don't loop over heads; they let Flash-Attention or scaled_dot_product_attention\texttt{scaled\_dot\_product\_attention} do the work. The standard trick is to store K,VK, V at group granularity but expand them to head granularity at the last possible moment with repeat_interleave\texttt{repeat\_interleave}. That keeps the cache small while letting the attention kernel keep its (B,H,T,dk)(B, H, T, d_k) contract.

gqa_pytorch.py — production-shape GroupedQueryAttention module
🐍gqa_pytorch.py
15Constructor signature

Three knobs that uniquely identify a GQA layer: model width, number of query heads, and number of KV groups. Setting n_kv_groups = n_heads degenerates to MHA; setting n_kv_groups = 1 degenerates to MQA.

EXAMPLE
GQA(d_model=4096, n_heads=32, n_kv_groups=8)  # LLaMA-3 8B
17Divisibility check

Heads-per-group must be an integer. If it isn't, repeat_interleave won't line up cleanly and some heads would end up bound to a partial group. Crash early at construction time, not deep inside the forward pass.

22reps = heads per group

This is the replication factor. With n_heads=8 and n_kv_groups=2, each cached K/V slab gets duplicated 4 times so the kernel sees 8 head-aligned slabs.

EXAMPLE
reps = 4 means each K_g feeds 4 query heads
26W_q stays full width

The query projection still produces n_heads * d_k features per token — exactly like MHA. GQA does not save anything on the query side, only on the K/V side. This is by design: queries are recomputed every step, KV is *cached*.

EXAMPLE
Linear(4096, 32*128 = 4096)
27W_k is narrower than W_q

Here is the entire memory win. W_k projects into n_kv_groups * d_k instead of n_heads * d_k. With 32 query heads and 8 groups, that's 4× fewer key features per token — and 4× less cache memory at inference.

EXAMPLE
Linear(4096, 8*128 = 1024)  vs MHA's Linear(4096, 4096)
35Project, reshape, then transpose

After Linear, .view splits the last axis into (heads-or-groups, d_k), and .transpose(1, 2) puts heads in the second axis. After this line Q lives at (B, H, T, d_k) and K, V live at (B, G, T, d_k). The asymmetry is fully visible in the shapes.

EXAMPLE
Q.shape = (B, 8, T, 128)   K.shape = (B, 2, T, 128)
45repeat_interleave — the GQA trick in one line

We duplicate each group's K (and V) along the head axis `reps` times. After this line K has shape (B, n_heads, T, d_k), so the kernel can pretend it is just doing plain multi-head attention. The cache stays group-sized; the compute stays head-sized.

EXAMPLE
K_g of shape (B, 2, T, 128)  →  K of shape (B, 8, T, 128)
46Why repeat_interleave and not .repeat()

repeat_interleave groups copies together: [g0, g0, g0, g0, g1, g1, g1, g1]. That matches head_to_group = [0,0,0,0,1,1,1,1]. .repeat() would have given [g0, g1, g0, g1, …] — the wrong mapping.

49Plain SDPA — no GQA awareness needed

Once we've expanded K and V, scaled_dot_product_attention is a black box. It can be Flash-Attention, the math-only fallback, or any future kernel. This is exactly how production GQA layers work in nanoGPT, LLaMA-style stacks, and HuggingFace.

EXAMPLE
F.scaled_dot_product_attention(Q, K, V, is_causal=True)
53Merge heads and output projection

Transpose heads and tokens back so the last two axes are (heads, d_k); .contiguous() so .view can flatten them into one axis of width d_model. Then W_o projects back to d_model. Identical to plain MHA.

EXAMPLE
(B, 8, T, 128) → (B, T, 8, 128) → (B, T, 1024) → (B, T, d_model)
65Parameter count tells the story

W_q has 64 × 64 = 4096 parameters; W_k and W_v each have only 64 × 16 = 1024. That 4× ratio is exactly n_heads / n_kv_groups — the same factor by which the KV cache shrinks at inference.

61 lines without explanation
1"""
2GQA in PyTorch — production-flavoured implementation.
3
4Key idea: store K and V at GROUP granularity in the cache, then
5*replicate* them up to head granularity right before the matmul. This
6keeps the cache small but lets the SDPA kernel keep its (B, H, T, d_k)
7shape contract.
8"""
9
10import torch
11import torch.nn as nn
12import torch.nn.functional as F
13
14
15class GroupedQueryAttention(nn.Module):
16    def __init__(self, d_model: int, n_heads: int, n_kv_groups: int):
17        super().__init__()
18        assert n_heads % n_kv_groups == 0, \
19            "n_heads must be divisible by n_kv_groups"
20
21        self.n_heads     = n_heads
22        self.n_kv_groups = n_kv_groups
23        self.d_k         = d_model // n_heads
24        self.reps        = n_heads // n_kv_groups   # heads per group
25
26        # Notice the asymmetry: Q is n_heads * d_k wide,
27        # but K and V are only n_kv_groups * d_k wide.
28        self.W_q = nn.Linear(d_model, n_heads     * self.d_k, bias=False)
29        self.W_k = nn.Linear(d_model, n_kv_groups * self.d_k, bias=False)
30        self.W_v = nn.Linear(d_model, n_kv_groups * self.d_k, bias=False)
31        self.W_o = nn.Linear(n_heads * self.d_k, d_model,     bias=False)
32
33    def forward(self, x: torch.Tensor) -> torch.Tensor:
34        B, T, _ = x.shape
35
36        # 1. Project. Note the different fan-outs.
37        Q = self.W_q(x).view(B, T, self.n_heads,     self.d_k).transpose(1, 2)
38        K = self.W_k(x).view(B, T, self.n_kv_groups, self.d_k).transpose(1, 2)
39        V = self.W_v(x).view(B, T, self.n_kv_groups, self.d_k).transpose(1, 2)
40        # Q : (B, n_heads,     T, d_k)
41        # K : (B, n_kv_groups, T, d_k)   <-- this is what we cache!
42        # V : (B, n_kv_groups, T, d_k)
43
44        # 2. Replicate K, V along the head axis so the kernel
45        #    sees a clean (B, n_heads, T, d_k) on both sides.
46        K = K.repeat_interleave(self.reps, dim=1)
47        V = V.repeat_interleave(self.reps, dim=1)
48
49        # 3. Standard scaled dot-product attention — no GQA awareness needed
50        #    once K and V have been expanded.
51        out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
52
53        # 4. Merge heads and project back to d_model.
54        out = out.transpose(1, 2).contiguous().view(B, T, -1)
55        return self.W_o(out)
56
57
58if __name__ == "__main__":
59    B, T, d_model = 2, 5, 64
60    layer = GroupedQueryAttention(d_model=d_model, n_heads=8, n_kv_groups=2)
61    x = torch.randn(B, T, d_model)
62    y = layer(x)
63    print("output:", y.shape)            # (2, 5, 64)
64
65    # Confirm the cache asymmetry by counting parameters.
66    p_q = layer.W_q.weight.numel()
67    p_k = layer.W_k.weight.numel()
68    p_v = layer.W_v.weight.numel()
69    print(f"W_q params: {p_q}, W_k params: {p_k}, W_v params: {p_v}")
70    # W_q : 64 * 64  = 4096
71    # W_k : 64 * 16  = 1024   <-- 4x smaller than W_q
72    # W_v : 64 * 16  = 1024
What autograd sees. The replication step is a view-like op (followed by a copy when needed); its gradient simply sums back into each group's (K,V)(K, V). So every group's K/V weights are trained by accumulated gradient from H/GH/G query heads — they really do learn a consensus memory for their group.

What Changes at 70B Scale

At toy size the cache numbers above are amusing curiosities. At 70B parameters they decide whether your serving stack is even runnable.

  1. HBM budget per request. LLaMA-2 70B with MHA would burn ~10.7 GB of KV cache per 4 096-token request; with GQA-8 that drops to ~1.3 GB. The same H100 that could serve 2 concurrent users in MHA can serve 16 in GQA-8 — an 8× throughput win on identical hardware.
  2. Bytes fetched per decoding step. For each generated token, attention must read the entire cache from HBM. Cutting cache by 8× cuts those reads by 8×. Since decoding is memory-bandwidth-bound, latency drops nearly linearly with cache size.
  3. Tensor-parallel sharding. When the model is sharded across PP GPUs along the head axis, each GPU stores G/PG/P groups of K, V (instead of H/PH/P). For LLaMA-3 70B with G=8G = 8 and P=8P = 8, each GPU owns just one KV group — trivial to shard, no cross-GPU communication for attention.
  4. Speculative decoding friendliness. Speculative decoders run a small draft model and verify with the large model. The large-model verification step is the KV-cache-read-bound part — making the cache 8× smaller makes the verifier ~8× cheaper, which is multiplicative on top of the spec-decoding speedup.
  5. Batched serving. Engines like vLLM keep many users' KV caches resident simultaneously. Smaller per-user caches means more concurrent users on the same GPU — this is the single biggest unit-economics lever in LLM serving today.
Why DeepSeek went further. Even GQA-8 leaves the cache scaling linearly in sequence length. At 128K context, that linear factor still hurts. MLA — the rest of this chapter — compresses the K/V representation itself into a much smaller latent vector, attacking the constant and the per-token width simultaneously.

Engineering Reality and Common Pitfalls

  • Pick G so H / G is an integer. Almost every implementation in the wild uses G{1,2,4,8}G \in \{1, 2, 4, 8\}. Non-divisor group counts force ragged repeat patterns that break Flash kernels.
  • Don't materialise the replicated K/V if you can avoid it. Production kernels (FlashAttention-2, vLLM's PagedAttention, FlashInfer) accept the group-shaped K/V directly and do the index mapping inside the kernel. The repeat_interleave\texttt{repeat\_interleave} version in the PyTorch code above is correct but pessimistic — it makes a full-sized copy. The real serving stack skips that copy.
  • Sliding-window + GQA composes. Mistral combines GQA-8 with a 4K sliding window so the cache is bounded in sequence length too. Order of operations: window first (caps T), then GQA (caps per-token width).
  • Quantising the KV cache is the next lever. GQA shrinks the count of cached scalars; INT8 / FP8 KV cache shrinks each scalar. They multiply: GQA-8 + FP8-KV gives a 16× total cache shrink versus MHA-FP16. Modern serving stacks turn both knobs.
  • Retrofitting an MHA checkpoint. If you inherit an MHA-trained model and want GQA at serve time, the standard recipe is: average the per-head K (and V) weights inside each target group, then fine-tune for ~5% of pre-training compute. Quality recovers; this is how LLaMA-2 70B was produced.

Bridge to MLA

GQA is a beautifully simple idea: same query lenses, fewer memories. It buys you a constant-factor cache shrink, free of training-time wizardry, with negligible quality loss at sensible group counts. That is why every major open LLM released in 2024–2025 ships with GQA-8.

But notice what GQA does not do. It still stores a full dkd_k-dimensional K and V per group per token. As dkd_k grows (DeepSeek V3 uses dk=128d_k = 128) and as context grows (V3 trains at T=128000T = 128\,000), the cache balloons again. GQA gives you a constant divisor; it doesn't change the geometry of the cache.

Multi-Head Latent Attention — the subject of §4.3§4.3 — replaces the cached K and V with a low-rank latent vector ctRdcc_t \in \mathbb{R}^{d_c} with dcHdkd_c \ll H \cdot d_k, and absorbs the up-projections back into the attention math. The cache per token becomes a single small vector regardless of head count. GQA was the warm-up; MLA is the real surgery.


Summary

  • The problem. MHA's KV cache scales linearly in heads, and inference is memory-bandwidth-bound, so the cache is the dominant serving cost.
  • The fix. Share (K,V)(K, V) across groups of query heads. GG groups instead of HH: cache shrinks by H/GH/G, query side is untouched.
  • The spectrum. G=HG = H is MHA; G=1G = 1 is MQA; 1<G<H1 < G < H is GQA. Modern models land at G=8G = 8.
  • The code. One repeat_interleave\texttt{repeat\_interleave} line turns group-shaped K,VK, V into head-shaped K,VK, V so existing Flash kernels run unchanged.
  • The win. 8× cache shrink → 8× concurrent users on the same GPU at near-zero quality cost.
  • What it can't do. It doesn't break the linear-in-T growth of the cache, and it doesn't shrink the cache width below dkd_k. Both are what MLA is built to attack.
Loading comments...