Section 1 left us with a sharp number. For a model like LLaMA-2 70B serving a single 4 096-token request, the KV cache alone occupies roughly 2⋅H⋅T⋅dk⋅L⋅2 B=2⋅64⋅4096⋅128⋅80⋅2≈10.7GB. That is per request, before batch 1 of users shows up. On an 80 GB H100 with model weights already eating ~140 GB across two GPUs, the cache is what evicts you from memory long before your math runs out of FLOPs.
The structural cause is that vanilla Multi-Head Attention stores one independent (Kh,Vh) pair per head. With H=64 heads, the cache scales linearly in H. Worse, every autoregressive decoding step has to read the entire cache: for token t, attention must gather all t−1 previous keys and values from HBM. Modern decoding is memory-bandwidth-bound, not compute-bound — the GPU spends its time moving bytes, not multiplying numbers.
The bottleneck reframed. At training time the cost of MHA is its FLOPs. At inference time the cost of MHA is its cache size and its cache read bandwidth. Cutting those two numbers is what every attention variant in this chapter is really fighting for.
Grouped-Query Attention (GQA) and its limiting case Multi-Query Attention (MQA) attack this with a single, almost embarrassingly simple idea. The next section gives the intuition; the rest of the page makes it rigorous.
Intuition: One Reader, Many Lenses
Picture each attention head as a person reading the same library of books. In standard MHA, every reader brings their own private library — 64 readers, 64 libraries, all stored on the GPU. They all read the same text on the page (the token's residual stream), but each builds their own index of what to remember (K) and what to retrieve (V).
That feels wasteful. In a really long document, do 64 readers truly need 64 different memories? GQA's wager is no: most of the variation between heads lives on the query side — the lens through which a head looks at the past — not in what they want to remember about the past. So we keep all 64 query lenses, but share the library across small groups of readers.
Multi-Query Attention pushes this to the extreme: 64 readers, one library. Every head queries the same (K,V) with its ownWQh. Cache shrinks by 64×. The question — which we'll answer numerically further down — is how much quality you pay for that compression.
Keep two pictures in your head. Queries are what you ask this token to look for; queries are recomputed from scratch every step, so they cost FLOPs but no cache. Keys and values are what you wrote down about past tokens; they live in the cache for the rest of the sequence, so every byte you spend on K/V is paid back at every future decoding step.
The Mathematical Idea
Let xt∈Rdmodel be the residual-stream vector for token t. Standard MHA defines, for each head h=1,…,H:
Qth=WQhxt,Kth=WKhxt,Vth=WVhxt
with WQh,WKh,WVh∈Rdk×dmodel and the attention output of head h at position t being
Oth=softmax(dkQth(K≤th)⊤)V≤th
The cache size, per layer, per token, is 2⋅H⋅dk — two slabs (K and V), H heads each of width dk.
Now partition the H query heads into G non-overlapping groups of equal size H/G. Write g(h)∈{1,…,G} for the group of head h. GQA replaces the per-head K and V projections with per-group ones:
Ktg=WKgxt,Vtg=WVgxtfor g=1,…,G
Head h then attends to its group's keys and values:
Oth=softmax(dkQth(K≤tg(h))⊤)V≤tg(h)
Two specialisations recover the endpoints of the spectrum:
G=H — every group has exactly one head — recovers vanilla MHA.
G=1 — one big group containing every head — is MQA.
1<G<H with G∣H — the intermediate regime — is GQA.
The cache size becomes 2⋅G⋅dk per token per layer, i.e. a factor of H/G smaller than MHA. The query projection is untouched, so the total parameter count drops only slightly, but the runtime cache — the thing that actually pins your serving cost — drops by exactly that factor.
Why each symbol matters.xt is what came out of the previous block at position t; WQh is one of H learned lenses for asking questions; WKg,WVg are G≤H learned indexers for building shared memory; g(h) is the fixed (not learned) routing of head h to its group; and the cache size formula 2⋅G⋅dk is the bottom line your serving infrastructure pays.
MHA, MQA, and the Spectrum Between
The cleanest way to internalise GQA is to see how it slides between the two endpoints. The visualisation below maps each of the H query heads (top row) to a coloured KV group (bottom row). Drag G down to 1 to watch the whole thing collapse into MQA; drag it up to H to recover MHA.
Interactive: Head-to-Group Mapping
Loading head-to-group visualiser…
Production models live somewhere in the middle of this slider. LLaMA-2 7B uses H=32,G=32 (still pure MHA); LLaMA-2 70B and LLaMA-3 8B both use G=8 — an 8× cache shrink while preserving most of MHA's expressive power, as the GQA paper showed empirically. Mistral 7B follows the same recipe.
Interactive: MHA vs MQA Architecture
Slide the head count to compare the two architectures side-by-side. The left panel keeps per-head K,V boxes; the right collapses them into a single shared pair that every query head reads from.
Loading MHA / MQA architecture diagram…
Manual Numerical Walkthrough
We'll work the smallest example that still distinguishes the three mechanisms: H=4 query heads, dk=2, T=3 tokens. We'll compute the keys for one token under each scheme and watch the cache size change.
Click to expand the by-hand walkthrough
Setup. Pick a single token vector x1=(1,0,−1,2) with dmodel=4.
Step 1 — MHA keys (4 heads, 4 weight matrices). Each WKh∈R2×4 is a different matrix. Take, for instance:
WK1=[10010000],WK2=[00001001]
WK3=[10100101],WK4=[10011001]
Then K11=(1,0), K12=(−1,2), K13=(1,1), K14=(0,2). Cache cost for this one token: 2⋅H⋅dk=2⋅4⋅2=16 floats (K and V together).
Step 2 — GQA with G = 2. We now keep only two K weight matrices. Reuse WK1 and WK2 from above, and bind heads {1,2} to group 1, heads {3,4} to group 2. K1g=1=WK1x1=(1,0), K1g=2=WK2x1=(−1,2). Heads 1 and 2 will both query (1,0); heads 3 and 4 will both query (−1,2). The query lenses are still distinct, but the "memory" is shared. Cache cost:2⋅G⋅dk=2⋅2⋅2=8 floats — exactly half.
Step 3 — MQA (G = 1). Keep only WK1. Every head reads K1g=1=(1,0). Cache cost:2⋅1⋅2=4 floats — a quarter of MHA, an H=4× compression.
Step 4 — What the queries see. Pick a query head h=1 with Q11=(1,1). Under MHA it dot-products against K≤11; under GQA-2 it dot-products against K≤1g=1; under MQA it dot-products against the single shared K≤1. Notice that the computation at the attention head is identical in shape (a Q × K dot-product divided by dk) — what differs is which K matrix gets fetched from HBM.
Sanity check. Plug H=4,G=2,dk=2,T=3 into the Python script below. You should see the printout: MHA: 48 | GQA-2: 24 | MQA: 12. The ratio always equals H:G:1.
The KV-Cache Budget, Made Concrete
Toy numbers are convincing only up to a point. Let's plug the formula cache=2⋅G⋅T⋅dk⋅L⋅bytes into real configurations and see what serving infrastructure actually buys with each G.
Interactive: KV-Cache Memory by G
Pick a model, pick a sequence length, and watch the bars for different group counts. The savings ratio is exactly G/H.
Loading KV-cache memory comparison…
Interactive: MHA vs MQA Cache Explorer
For the same scenarios, this view focuses on the two endpoints. Notice how the gap widens with the head count: a 12-head GPT-2 saves 11/12; a 64-head LLaMA-2 70B saves 63/64. MQA's win is greatest precisely where MHA's pain is worst.
Loading MQA cache explorer…
The KV cache is a tax on context. Doubling context length doubles cache memory and doubles the bytes fetched per decoding step. Compression schemes like GQA give you a fixed multiplicative discount on that tax — useful, but they don't change the linear-in-T growth. The next section's MLA attacks that growth itself.
What Do We Lose? The Quality Trade-off
Compressing K/V cannot be free. Three things actually happen when G drops:
Less head diversity in memory. Two heads in the same group see the same (K,V) and can only differentiate their behaviour through their Q. The model has fewer degrees of freedom for specialising what to remember.
Sharper trade-off at MQA. The Shazeer 2019 MQA paper and the 2023 GQA paper both report a measurable quality dip at G=1, but a recoverable one with light extra training. With G=8 the dip is essentially in the noise — this is why every modern open LLM picked GQA over MQA.
Training-time effect. Models trained from scratch with GQA from token 0 lose almost nothing. Models retrofitted by averaging per-head K/V weights into per-group weights and fine-tuning lose more — though recoverable cheaply. DeepSeek V3 trains its attention scheme (MLA) end-to-end from scratch.
Model
H (query heads)
G (KV groups)
Cache shrink
Notes
GPT-3 175B (2020)
96
96
1× (MHA)
MHA — pre-GQA era
LLaMA-2 7B
32
32
1× (MHA)
Small model, cache not yet painful
LLaMA-2 70B
64
8
8×
First widely-deployed GQA
LLaMA-3 8B / 70B
32 / 64
8 / 8
4× / 8×
GQA-8 became the de facto default
Mistral 7B
32
8
4×
GQA + sliding window attention
PaLM (Google, 2022)
—
1
H× (MQA)
Pure MQA at 540B scale
DeepSeek V3
128
—
see §4.3
Switches to MLA — Chapter 4 main act
Plain Python: Three Mechanisms, Side by Side
Before we touch PyTorch, here is the bare math in NumPy. The point is to see exactly which tensor shrinks when we move from MHA → GQA → MQA, and where the routing table g(h) is encoded.
gqa_mqa_demo.py — MHA, GQA-2 and MQA in 70 lines of NumPy
🐍gqa_mqa_demo.py
Explanation(11)
Code(80)
11Toy shapes — chosen so we can read them out loud
We use T=3 tokens, d_model=4 features per token, d_k=2 features per head, and H=4 query heads. Every following shape is built from these four numbers, so when you see (4, 3, 2) you can immediately read it as (head, token, head-dim).
EXAMPLE
x.shape = (3, 4), W_Q.shape = (4, 4, 2)
15Input activations x
One row per token. In a real transformer this is the output of the previous block (or the embedding layer for layer 0). Three tokens × four features.
The number of query weight matrices never changes when we move from MHA → GQA → MQA. Each query head keeps its own d_model × d_k projection. Only K and V shrink.
EXAMPLE
W_Q[h] : d_model × d_k = 4 × 2
20project() — the K/V factory
Given a single weight tensor of shape (G, d_model, d_k), we contract over d_model with the input x to produce one (T, d_k) key (or value) slab per group. Setting G = H gives MHA; G = 1 gives MQA; anything in between is GQA.
EXAMPLE
einsum('td,gdk->gtk', x, W) → shape (G, T, d_k)
28attention() — group-aware lookup
Heart of GQA/MQA: each query head h does its own scaled dot-product softmax, but it reads from K[g] and V[g] where g = head_to_group[h]. Many heads can share the same (K, V) slab — that is the entire memory saving.
EXAMPLE
head_to_group for GQA-2 with 4 heads: [0, 0, 1, 1]
38scaled dot-product per head
Q[h] @ K[g].T gives a (T, T) score matrix. Dividing by √d_k controls the variance so softmax does not saturate. This step is identical to vanilla attention — GQA changes only WHICH K and V we plug in.
EXAMPLE
Q[h].shape (3, 2) @ K[g].T (2, 3) → (3, 3)
40softmax, written by hand
Numerically stable softmax: subtract the row max, exponentiate, then divide by the row sum. Exactly the same routine you would use in MHA.
46MHA — one K and one V per head
Here W_K_mha has shape (H, d_model, d_k) = (4, 4, 2). After project() we get K_m of shape (4, 3, 2): four independent key slabs, one per head. head_to_group = arange(H) means head h reads from group h.
EXAMPLE
K_m.shape = (4, 3, 2) → 24 floats of cache
53GQA-2 — half the K, V weight matrices
W_K_gqa has only G=2 slices. Heads [0, 1] share group 0; heads [2, 3] share group 1. The query side is unchanged, so the model still has 4 attention heads — they just SEE through a coarser K/V lens.
G = 1. Now head_to_group is all zeros: every query head reads from the same (K, V). This is the most aggressive compression in this family — the KV cache shrinks by a factor of H.
The factor 2 covers both K and V. Multiply by sequence length, head dim, layer count, batch size, and bytes-per-float to get total bytes. In a real 70B model these numbers are 4–5 orders of magnitude larger.
EXAMPLE
MHA: 2 · 4 · 3 · 2 = 48 | MQA: 2 · 1 · 3 · 2 = 12
69 lines without explanation
1"""
2Three attention KV-projection schemes, side by side, in pure NumPy.
34We deliberately keep every tensor tiny so you can read shapes out loud:
56 sequence length T = 3 (three tokens)
7 model dim d_model = 4
8 head dim d_k = 2 (so 4 / 2 = 2 head-slots in d_model)
9 query heads H = 4
10 KV groups G in {4, 2, 1} (MHA, GQA-2, MQA)
11"""1213import numpy as np
1415T, d_model, d_k, H =3,4,2,416rng = np.random.default_rng(0)1718# Input activations: one row per token.19x = rng.standard_normal((T, d_model))# shape (T, d_model) = (3, 4)2021# Query weights — always H of them, one per query head.22W_Q = rng.standard_normal((H, d_model, d_k))# shape (H, d_model, d_k)2324defproject(x, W_KV_G):25"""
26 Project x into G group-level K (and V) matrices.
2728 W_KV_G has shape (G, d_model, d_k).
29 Returns a tensor of shape (G, T, d_k) — one (T, d_k) slab per group.
30 """31return np.einsum("td,gdk->gtk", x, W_KV_G)3233defattention(Q, K, V, head_to_group):34"""
35 Q : (H, T, d_k) per-head queries
36 K : (G, T, d_k) per-GROUP keys
37 V : (G, T, d_k) per-GROUP values
38 head_to_group : (H,) which group each query head reads from
3940 Each query head h looks up K[head_to_group[h]] and V[head_to_group[h]].
41 """42 out = np.zeros_like(Q)# (H, T, d_k)43for h inrange(Q.shape[0]):44 g = head_to_group[h]45 scores = Q[h] @ K[g].T / np.sqrt(d_k)# (T, T)46 probs = np.exp(scores - scores.max(-1, keepdims=True))47 probs /= probs.sum(-1, keepdims=True)48 out[h]= probs @ V[g]# (T, d_k)49return out
5051# ---------------- MHA: G = H = 4 ----------------52W_K_mha = rng.standard_normal((H, d_model, d_k))53W_V_mha = rng.standard_normal((H, d_model, d_k))54Q = np.einsum("td,hdk->htk", x, W_Q)55K_m = project(x, W_K_mha)# (4, 3, 2)56V_m = project(x, W_V_mha)# (4, 3, 2)57mha_out = attention(Q, K_m, V_m, head_to_group=np.arange(H))5859# ---------------- GQA: G = 2 (groups of 2 heads share K, V) ----------------60G =261W_K_gqa = rng.standard_normal((G, d_model, d_k))62W_V_gqa = rng.standard_normal((G, d_model, d_k))63K_g = project(x, W_K_gqa)# (2, 3, 2) <-- HALF the cache64V_g = project(x, W_V_gqa)# (2, 3, 2)65gqa_out = attention(Q, K_g, V_g, head_to_group=np.array([0,0,1,1]))6667# ---------------- MQA: G = 1 (every head shares ONE K, V) ----------------68W_K_mqa = rng.standard_normal((1, d_model, d_k))69W_V_mqa = rng.standard_normal((1, d_model, d_k))70K_q = project(x, W_K_mqa)# (1, 3, 2) <-- 1/H the cache71V_q = project(x, W_V_mqa)# (1, 3, 2)72mqa_out = attention(Q, K_q, V_q, head_to_group=np.zeros(H, dtype=int))7374# Cache cost in floats (no batch, FP16 would be 2 bytes per float).75defcache_floats(G_, T_, d_k_):76return2* G_ * T_ * d_k_ # factor 2 for K and V7778print("MHA cache floats:", cache_floats(4, T, d_k))# 4879print("GQA2 cache floats:", cache_floats(2, T, d_k))# 24 (50% saving)80print("MQA cache floats:", cache_floats(1, T, d_k))# 12 (75% saving)
Run the file and you will see exactly 48 / 24 / 12 floats of cache for MHA / GQA-2 / MQA. The compression ratio is the only difference; the attention math is bit-for-bit identical once the right K and V are looked up.
PyTorch: GQA via repeat_interleave
Production GQA layers don't loop over heads; they let Flash-Attention or scaled_dot_product_attention do the work. The standard trick is to storeK,V at group granularity but expand them to head granularity at the last possible moment with repeat_interleave. That keeps the cache small while letting the attention kernel keep its (B,H,T,dk) contract.
Three knobs that uniquely identify a GQA layer: model width, number of query heads, and number of KV groups. Setting n_kv_groups = n_heads degenerates to MHA; setting n_kv_groups = 1 degenerates to MQA.
Heads-per-group must be an integer. If it isn't, repeat_interleave won't line up cleanly and some heads would end up bound to a partial group. Crash early at construction time, not deep inside the forward pass.
22reps = heads per group
This is the replication factor. With n_heads=8 and n_kv_groups=2, each cached K/V slab gets duplicated 4 times so the kernel sees 8 head-aligned slabs.
EXAMPLE
reps = 4 means each K_g feeds 4 query heads
26W_q stays full width
The query projection still produces n_heads * d_k features per token — exactly like MHA. GQA does not save anything on the query side, only on the K/V side. This is by design: queries are recomputed every step, KV is *cached*.
EXAMPLE
Linear(4096, 32*128 = 4096)
27W_k is narrower than W_q
Here is the entire memory win. W_k projects into n_kv_groups * d_k instead of n_heads * d_k. With 32 query heads and 8 groups, that's 4× fewer key features per token — and 4× less cache memory at inference.
EXAMPLE
Linear(4096, 8*128 = 1024) vs MHA's Linear(4096, 4096)
35Project, reshape, then transpose
After Linear, .view splits the last axis into (heads-or-groups, d_k), and .transpose(1, 2) puts heads in the second axis. After this line Q lives at (B, H, T, d_k) and K, V live at (B, G, T, d_k). The asymmetry is fully visible in the shapes.
EXAMPLE
Q.shape = (B, 8, T, 128) K.shape = (B, 2, T, 128)
45repeat_interleave — the GQA trick in one line
We duplicate each group's K (and V) along the head axis `reps` times. After this line K has shape (B, n_heads, T, d_k), so the kernel can pretend it is just doing plain multi-head attention. The cache stays group-sized; the compute stays head-sized.
EXAMPLE
K_g of shape (B, 2, T, 128) → K of shape (B, 8, T, 128)
46Why repeat_interleave and not .repeat()
repeat_interleave groups copies together: [g0, g0, g0, g0, g1, g1, g1, g1]. That matches head_to_group = [0,0,0,0,1,1,1,1]. .repeat() would have given [g0, g1, g0, g1, …] — the wrong mapping.
49Plain SDPA — no GQA awareness needed
Once we've expanded K and V, scaled_dot_product_attention is a black box. It can be Flash-Attention, the math-only fallback, or any future kernel. This is exactly how production GQA layers work in nanoGPT, LLaMA-style stacks, and HuggingFace.
Transpose heads and tokens back so the last two axes are (heads, d_k); .contiguous() so .view can flatten them into one axis of width d_model. Then W_o projects back to d_model. Identical to plain MHA.
W_q has 64 × 64 = 4096 parameters; W_k and W_v each have only 64 × 16 = 1024. That 4× ratio is exactly n_heads / n_kv_groups — the same factor by which the KV cache shrinks at inference.
61 lines without explanation
1"""
2GQA in PyTorch — production-flavoured implementation.
34Key idea: store K and V at GROUP granularity in the cache, then
5*replicate* them up to head granularity right before the matmul. This
6keeps the cache small but lets the SDPA kernel keep its (B, H, T, d_k)
7shape contract.
8"""910import torch
11import torch.nn as nn
12import torch.nn.functional as F
131415classGroupedQueryAttention(nn.Module):16def__init__(self, d_model:int, n_heads:int, n_kv_groups:int):17super().__init__()18assert n_heads % n_kv_groups ==0, \
19"n_heads must be divisible by n_kv_groups"2021 self.n_heads = n_heads
22 self.n_kv_groups = n_kv_groups
23 self.d_k = d_model // n_heads
24 self.reps = n_heads // n_kv_groups # heads per group2526# Notice the asymmetry: Q is n_heads * d_k wide,27# but K and V are only n_kv_groups * d_k wide.28 self.W_q = nn.Linear(d_model, n_heads * self.d_k, bias=False)29 self.W_k = nn.Linear(d_model, n_kv_groups * self.d_k, bias=False)30 self.W_v = nn.Linear(d_model, n_kv_groups * self.d_k, bias=False)31 self.W_o = nn.Linear(n_heads * self.d_k, d_model, bias=False)3233defforward(self, x: torch.Tensor)-> torch.Tensor:34 B, T, _ = x.shape
3536# 1. Project. Note the different fan-outs.37 Q = self.W_q(x).view(B, T, self.n_heads, self.d_k).transpose(1,2)38 K = self.W_k(x).view(B, T, self.n_kv_groups, self.d_k).transpose(1,2)39 V = self.W_v(x).view(B, T, self.n_kv_groups, self.d_k).transpose(1,2)40# Q : (B, n_heads, T, d_k)41# K : (B, n_kv_groups, T, d_k) <-- this is what we cache!42# V : (B, n_kv_groups, T, d_k)4344# 2. Replicate K, V along the head axis so the kernel45# sees a clean (B, n_heads, T, d_k) on both sides.46 K = K.repeat_interleave(self.reps, dim=1)47 V = V.repeat_interleave(self.reps, dim=1)4849# 3. Standard scaled dot-product attention — no GQA awareness needed50# once K and V have been expanded.51 out = F.scaled_dot_product_attention(Q, K, V, is_causal=True)5253# 4. Merge heads and project back to d_model.54 out = out.transpose(1,2).contiguous().view(B, T,-1)55return self.W_o(out)565758if __name__ =="__main__":59 B, T, d_model =2,5,6460 layer = GroupedQueryAttention(d_model=d_model, n_heads=8, n_kv_groups=2)61 x = torch.randn(B, T, d_model)62 y = layer(x)63print("output:", y.shape)# (2, 5, 64)6465# Confirm the cache asymmetry by counting parameters.66 p_q = layer.W_q.weight.numel()67 p_k = layer.W_k.weight.numel()68 p_v = layer.W_v.weight.numel()69print(f"W_q params: {p_q}, W_k params: {p_k}, W_v params: {p_v}")70# W_q : 64 * 64 = 409671# W_k : 64 * 16 = 1024 <-- 4x smaller than W_q72# W_v : 64 * 16 = 1024
What autograd sees. The replication step is a view-like op (followed by a copy when needed); its gradient simply sums back into each group's (K,V). So every group's K/V weights are trained by accumulated gradient from H/G query heads — they really do learn a consensus memory for their group.
What Changes at 70B Scale
At toy size the cache numbers above are amusing curiosities. At 70B parameters they decide whether your serving stack is even runnable.
HBM budget per request. LLaMA-2 70B with MHA would burn ~10.7 GB of KV cache per 4 096-token request; with GQA-8 that drops to ~1.3 GB. The same H100 that could serve 2 concurrent users in MHA can serve 16 in GQA-8 — an 8× throughput win on identical hardware.
Bytes fetched per decoding step. For each generated token, attention must read the entire cache from HBM. Cutting cache by 8× cuts those reads by 8×. Since decoding is memory-bandwidth-bound, latency drops nearly linearly with cache size.
Tensor-parallel sharding. When the model is sharded across P GPUs along the head axis, each GPU stores G/P groups of K, V (instead of H/P). For LLaMA-3 70B with G=8 and P=8, each GPU owns just one KV group — trivial to shard, no cross-GPU communication for attention.
Speculative decoding friendliness. Speculative decoders run a small draft model and verify with the large model. The large-model verification step is the KV-cache-read-bound part — making the cache 8× smaller makes the verifier ~8× cheaper, which is multiplicative on top of the spec-decoding speedup.
Batched serving. Engines like vLLM keep many users' KV caches resident simultaneously. Smaller per-user caches means more concurrent users on the same GPU — this is the single biggest unit-economics lever in LLM serving today.
Why DeepSeek went further. Even GQA-8 leaves the cache scaling linearly in sequence length. At 128K context, that linear factor still hurts. MLA — the rest of this chapter — compresses the K/V representation itself into a much smaller latent vector, attacking the constant and the per-token width simultaneously.
Engineering Reality and Common Pitfalls
Pick G so H / G is an integer. Almost every implementation in the wild uses G∈{1,2,4,8}. Non-divisor group counts force ragged repeat patterns that break Flash kernels.
Don't materialise the replicated K/V if you can avoid it. Production kernels (FlashAttention-2, vLLM's PagedAttention, FlashInfer) accept the group-shaped K/V directly and do the index mapping inside the kernel. The repeat_interleave version in the PyTorch code above is correct but pessimistic — it makes a full-sized copy. The real serving stack skips that copy.
Sliding-window + GQA composes. Mistral combines GQA-8 with a 4K sliding window so the cache is bounded in sequence length too. Order of operations: window first (caps T), then GQA (caps per-token width).
Quantising the KV cache is the next lever. GQA shrinks the count of cached scalars; INT8 / FP8 KV cache shrinks each scalar. They multiply: GQA-8 + FP8-KV gives a 16× total cache shrink versus MHA-FP16. Modern serving stacks turn both knobs.
Retrofitting an MHA checkpoint. If you inherit an MHA-trained model and want GQA at serve time, the standard recipe is: average the per-head K (and V) weights inside each target group, then fine-tune for ~5% of pre-training compute. Quality recovers; this is how LLaMA-2 70B was produced.
Bridge to MLA
GQA is a beautifully simple idea: same query lenses, fewer memories. It buys you a constant-factor cache shrink, free of training-time wizardry, with negligible quality loss at sensible group counts. That is why every major open LLM released in 2024–2025 ships with GQA-8.
But notice what GQA does not do. It still stores a full dk-dimensional K and V per group per token. As dk grows (DeepSeek V3 uses dk=128) and as context grows (V3 trains at T=128000), the cache balloons again. GQA gives you a constant divisor; it doesn't change the geometry of the cache.
Multi-Head Latent Attention — the subject of §4.3 — replaces the cached K and V with a low-rank latent vector ct∈Rdc with dc≪H⋅dk, and absorbs the up-projections back into the attention math. The cache per token becomes a single small vector regardless of head count. GQA was the warm-up; MLA is the real surgery.
Summary
The problem. MHA's KV cache scales linearly in heads, and inference is memory-bandwidth-bound, so the cache is the dominant serving cost.
The fix. Share (K,V) across groups of query heads. G groups instead of H: cache shrinks by H/G, query side is untouched.
The spectrum.G=H is MHA; G=1 is MQA; 1<G<H is GQA. Modern models land at G=8.
The code. One repeat_interleave line turns group-shaped K,V into head-shaped K,V so existing Flash kernels run unchanged.
The win. 8× cache shrink → 8× concurrent users on the same GPU at near-zero quality cost.
What it can't do. It doesn't break the linear-in-T growth of the cache, and it doesn't shrink the cache width below dk. Both are what MLA is built to attack.