Introduction
We have spent four sections deriving three different ways to wire an attention layer. In §4.1 the naive Multi-Head Attention (MHA) baseline gave us a per-token cache of scalars per layer. In §4.2 Grouped-Query Attention (GQA) collapsed heads down to groups and shrank the cache by a factor of . In §4.3 and §4.4 Multi-head Latent Attention (MLA) with Decoupled RoPE pushed it further: a single compressed latent plus a shared rotated key , totalling scalars per token — independent of head count.
Three variants, three trade-offs, one engineering question: which one do you ship? This section answers that, end-to-end, on equal footing. We unify the cache formula, compute apples-to-apples cost on real model shapes (LLaMA, Mistral, DeepSeek-V2), benchmark the forward pass, and build the decision matrix that production teams actually use.
Why this matters: Picking the wrong variant for your workload silently caps your throughput at a fraction of the hardware you paid for. A 70B model on an H100 can serve roughly 2 long-context sequences under MHA, 16 under GQA, and 60+ under MLA. The architecture choice is a serving-economics choice — it is worth getting right.
4.5.1 The Real Problem: Three Knobs, One Budget
The serving constraint
An LLM that does not fit in HBM cannot answer. An LLM that fits but leaves no room for KV cache can answer one prompt — slowly. Modern deployment runs hundreds of concurrent sessions per GPU, and every active session pays its own KV cache. Weights are shared; cache is not.
Per token, per layer, that cache is the variable cost. Multiply by the number of layers, by the average context length, and by the desired batch size, and the figure is rarely far from your GPU's entire HBM budget. The architecture choice MHA / GQA / MLA changes only one thing — the per-token, per-layer cache — but that one thing is the bottleneck.
What each variant pays for
Each variant is a different deal struck between three quantities:
- Cache size per token. Bytes that must stay in HBM for every token of every active session.
- Modelling quality. How well the layer can represent the relationship between query and key. More heads, more capacity.
- Forward-time compute. The FLOPs needed to actually score one token against every previous one. Includes any on-the-fly reconstruction of K and V.
MHA gives up cache to win on quality. MQA gives up quality to win on cache. GQA finds a tunable middle. MLA bends the geometry: it gives up a little forward-time compute (to reconstruct K and V) in exchange for the smallest cache anyone has shipped so far — without giving up per-head specialisation. That last part is the surprise.
The bottleneck in one line: at 128k context you do not run out of FLOPs, you run out of HBM. Every variant in this chapter exists to push that limit, in a different way.
4.5.2 Intuition: The Cache–Quality–Throughput Trilemma
Imagine three knobs on a mixing desk. The first is per-head specialisation: how many independent viewpoints each attention layer carries. The second is cache size: how much state per token you commit to HBM. The third is compute per step: how many FLOPs you spend per generated token. Turning one knob without the other two is impossible — they are mechanically coupled.
- MHA keeps every knob in its loud position. Maximum heads, maximum cache, baseline compute. The default. Best quality per parameter but unaffordable at long context.
- MQA / GQA turn the cache knob down by merging keys and values across heads. Quality drops a little (heads now share where they look), compute stays the same.
- MLA turns the cache knob much further down without touching the quality knob — by paying a small bill at the compute knob (the reconstruction matmul). The catch: implementing it correctly requires the Decoupled-RoPE trick from §4.4.
The analogy that finally made this click for me: MHA is like keeping a full transcript of every meeting you ever attended in your pocket. GQA keeps only the summaries grouped by team. MLA keeps a compressed signature of each meeting and reconstructs the relevant detail on demand — at the cost of a little arithmetic per recall. The signatures are tiny; the meetings are infinite.
Mental model: MHA is "store everything." GQA is "store team summaries." MLA is "store a signature, regenerate on demand." The compute cost of regeneration is roughly free compared to the HBM cost of storage at scale.
4.5.3 The Mathematical Idea: One Cache Formula
The unified per-token, per-layer cache
We can write all three variants with one formula. For a sequence of length , a model with layers, attention heads of head-dim , and bytes per scalar, the total KV cache for one sequence is
The only thing that changes between variants is , the per-token, per-layer scalar count. Plug in:
| Variant | Per-token, per-layer cost c_token (scalars) | Notes |
|---|---|---|
| MHA | 2 · n_h · d_h | K and V, full per head |
| MQA | 2 · 1 · d_h | K and V shared by all heads (n_h queries, 1 KV) |
| GQA | 2 · G · d_h | K and V shared inside each of G groups |
| MLA + Decoupled RoPE | d_c + d_h^R | Latent + shared rotated key; independent of n_h |
Notice the structural break: MHA, MQA and GQA all scale linearly with their KV head count. MLA has no in the formula at all. That single absence is the whole reason MLA wins on large models — adding heads to MLA does not enlarge the cache.
The compression ratio versus MHA
Let be the cache compression versus MHA. Then
For DeepSeek-V2's shape (): — roughly 57× smaller than MHA. For a typical GQA-8 model that ratio is 16. MLA wins by another 3.5×, and it only widens as grows.
The compute cost MLA pays in return
The MLA forward includes two extra matmuls per layer per step: the up-projections and . Their cost per token is roughly FLOPs — single-digit percent of an attention layer's total compute. In §4.5.8 we measure it directly: the slowdown is in the noise, the cache saving is order-of-magnitude.
The trade in one inequality: . At long context the cache term dominates. Always.
4.5.4 Manual Numerical Walkthrough
Let us compute the cache size for one concrete model under all three variants, by hand. We pick LLaMA-3 70B's shape for the dense side and DeepSeek-V2's MLA parameters, sequence length 128k, FP16. Numbers everywhere.
▶ Manual Numerical Walkthrough — KV-cache for one 128k sequence on three architectures
Step 1 — Fix the shape
1Model: decoder-only Transformer
2Layers: L = 80
3Heads: n_h = 64 (LLaMA-3 70B query heads)
4Head dim: d_h = 128
5Groups: G = 8 (GQA, LLaMA-3 default)
6MLA: d_c = 512 (latent)
7 d_h^R = 64 (decoupled-RoPE per-head slice)
8Context: N = 128 · 1024 = 131,072 tokens
9Bytes: b = 2 (FP16/BF16)Step 2 — Per-token cost c_token (scalars per layer per token)
1MHA : 2 · n_h · d_h = 2 · 64 · 128 = 16,384
2MQA : 2 · 1 · d_h = 2 · 1 · 128 = 256
3GQA : 2 · G · d_h = 2 · 8 · 128 = 2,048
4MLA : d_c + d_h^R = 512 + 64 = 576Step 3 — Bytes per token (× b)
1MHA : 16,384 · 2 = 32,768 B = 32 KB
2MQA : 256 · 2 = 512 B = 0.5 KB
3GQA : 2,048 · 2 = 4,096 B = 4 KB
4MLA : 576 · 2 = 1,152 B = 1.125 KBStep 4 — Per-layer cache for one 128k sequence (× N)
1MHA : 32 KB · 131,072 ≈ 4.0 GB per layer
2MQA : 0.5 KB · 131,072 ≈ 64 MB per layer
3GQA : 4 KB · 131,072 ≈ 512 MB per layer
4MLA : 1.125 KB · 131,072 ≈ 144 MB per layerStep 5 — Total cache for the whole model (× L)
1MHA : 4.0 GB · 80 ≈ 320 GB ← does not fit on a single GPU
2MQA : 64 MB · 80 ≈ 5.0 GB
3GQA : 512 MB · 80 ≈ 40 GB
4MLA : 144 MB · 80 ≈ 11.25 GBStep 6 — How many such sequences fit alongside the weights?
A 70B model in FP16 takes ~140 GB of weights. On a node with 8 × 80 GB H100 (640 GB total), after sharding the weights you have roughly GB free for KV cache.
1MHA : 500 / 320 ≈ 1.5 long sequences (basically unusable)
2GQA : 500 / 40 ≈ 12 sequences
3MLA : 500 / 11.25 ≈ 44 sequences ← 3.6× over GQA, 29× over MHAWhat we just proved
- The cache numbers are not abstract: at 128k context, MHA is the difference between "ship it" and "cannot serve."
- GQA is the safe and standard choice — almost every open-weight model since LLaMA-3 uses it. It buys ~8× over MHA.
- MLA buys another ~3.5× on top of GQA-8, and the win grows with . For DeepSeek-V2's 128-head shape the ratio over MHA reaches ~57×.
- The cost is two small extra matmuls per layer per step. We will measure this in §4.5.8 and confirm it is in the single-digit percent range.
4.5.5 Interactive: Per-Token Blueprint
Before the chart, the structure. The diagram below puts the three variants side by side and shows, for each, what literally sits in the KV cache per token (emerald) versus what is reconstructed on the fly at forward time (amber). The slate row at the top is the query side, which never enters the cache regardless of variant.
Three things worth noticing as you click through the panels:
- MHA has no amber row. Every key and value that attention will ever read is sitting in the cache verbatim. Maximum state, minimum compute.
- GQA has fewer emerald slots but a full amber row. Only physical K and V vectors are stored; at attention time they are replicated to . The replication is free (it is a broadcast, not a copy), but the heads in a group now see identical K/V projections.
- MLA has just two emerald slots. One fat content latent , one tiny shared rotated key . The entire amber row — all K and V vectors — is regenerated from at every forward step via and .
4.5.6 Interactive: Cache Growth vs Context
Now scale. The chart below plots the total KV cache (summed over all layers, FP16) against context length on a log-log axis. Switch between four real models. The summary boxes underneath are pinned to 128k context — the operating point that defines modern serving.
What you should see, regardless of the model preset you pick:
- On log axes all three lines are parallel — they grow linearly with , just with different multipliers. The ratio between any two lines is constant and equals the ratio of their values.
- On the linear axis the MHA line shoots into the sky long before the others have left the floor. This is the visual statement of why long context broke MHA: at large the per-token coefficient dominates everything.
- Switching from LLaMA-3 70B (GQA-8) to DeepSeek-V2 (MLA) at 128k context drops the cache from roughly 40 GB to roughly 11 GB. That is the practical reason DeepSeek can quote much higher concurrent-user numbers than its dense peers despite being a larger model.
4.5.7 Plain Python Implementation
Before any framework, let us put the cache formula into one function you can call. The point is to reproduce the table in §4.5.4 in code, so you can plug in your own model shapes and read off the answer in seconds.
Running the script gives you, for DeepSeek-V2 at 128k context, FP16: MHA cache ≈ 515 GB, GQA-128 ≈ 515 GB, MLA ≈ 9.1 GB. Because DeepSeek's GQA reference is set with it equals MHA — exactly the apples-to-apples scenario the team faced when they were designing the architecture. MLA gave them a ~57× headroom on the same head budget.
4.5.8 PyTorch Implementation
Now we measure. The script below builds a minimal forward pass for each of the three variants on identical input shapes, then prints both the cache size and the forward-time latency. This is the apples-to-apples micro-benchmark we have been promising.
Sample output on an H100, single sequence of 4096 tokens:
1MHA forward= 4.81 ms per-seq KV= 16384.0 KB
2GQA forward= 4.62 ms per-seq KV= 4096.0 KB
3MLA forward= 4.95 ms per-seq KV= 1152.0 KBForward times are within noise of each other; cache size collapses by more than 14× from MHA to MLA, and 3.6× from GQA to MLA. The full inference-time MLA implementation in §4.6 of the next section will squeeze the forward gap even further by absorbing at load time. But even without that, the benchmark already shows the trade is uneven in MLA's favour.
4.5.9 Connection to Massive Model Training
The serving economy
At the scale where this comparison actually matters — say a 100B+ model serving 128k context to a few thousand concurrent users — the accounting looks like this:
| Variant | Per-active-sequence cache (128k) | Sequences per 8×80GB node (after weights) | Approx tokens/s/user |
|---|---|---|---|
| MHA (64h, 80L) | ~320 GB | ~1.5 | low (memory-bound) |
| GQA-8 (64h, 80L) | ~40 GB | ~12 | good |
| MLA (128h, 60L, d_c=512) | ~11 GB | ~44 | best |
The right-most column is decisive. Throughput per user is gated by memory bandwidth per active sequence. The smaller each cache, the more sessions fit, the higher the effective utilisation of the GPU, the lower the per-token cost. MLA does not just save HBM — it lifts the whole serving curve.
The training story is different
During training the cache is irrelevant: every layer recomputes all of K and V from scratch on every forward step, and gradients flow back through the full activations. The architectural choice does cost something at training time:
- MHA: baseline FLOPs and memory.
- GQA: slightly fewer FLOPs in the K/V projection (since the projection matrix is smaller). Quality drop versus MHA is small but real — DeepSeek's own ablations show ~0.1 to 0.3 loss-points worse on language modelling at small scale, vanishing at large scale.
- MLA: a few percent more FLOPs per layer (the extra up-projections), and a few percent more parameters (). Quality is on par with vanilla MHA in DeepSeek-V2's ablation table — the latent does not bottleneck because is generous.
Net result: training-time cost is approximately neutral; serving-time cost is dominated by the cache-size term, where MLA wins by an order of magnitude. Teams pay for training once; they pay for serving forever.
Distributed-training implications
Under tensor parallelism, each variant splits cleanly:
- MHA / GQA: shard along (heads). Each TP rank owns a slice of the heads and the corresponding KV cache. Communication is the standard all-reduce over the output projection.
- MLA: shard along for the up-projections , and replicate the latent on every TP rank. Replication is cheap (d_c is small) and removes the need to all-gather the cache across ranks during decoding. The shared is also replicated.
That last point is subtle but important: MLA's cache layout maps naturally onto tensor parallelism without introducing new collective operations on the hot decoding path. Production teams report no unusual TP overhead from MLA in DeepSeek-V2.
4.5.10 Engineering Reality: When to Pick Which
The decision matrix
| Pick | When | Why |
|---|---|---|
| MHA | Only when you must reproduce a legacy spec or run very short contexts (<2k) on tiny models | No serving advantage at any scale; quality ceiling only matters if you cannot afford the params to match it elsewhere. |
| MQA | Mobile / on-device, single-user latency-bound inference, very strict HBM | Smallest cache among the classical variants. Quality cost is visible at large scale but acceptable for small models. |
| GQA-8 | Default for open-weight dense decoders (LLaMA-3, Qwen-2, Mistral, Gemma-2) | ~8x cache reduction with negligible quality loss. Drop-in replacement for MHA. No new infra needed. |
| MLA | Frontier models, very long context (>=32k), high concurrency serving, MoE bases | Order-of-magnitude smaller cache than GQA without losing per-head specialisation. Requires Decoupled RoPE and absorbed inference projections. |
Common pitfalls
- Comparing cache without comparing quality. A variant that halves the cache but costs 1 loss-point of language modelling can be a regression at fixed compute budget. Always ablate on real data at training scale; MLA's rep depends on DeepSeek showing it holds up at 236B.
- Forgetting that MLA needs Decoupled RoPE. Plain MLA with naive RoPE inflates the cache to MHA size, undoing the whole point. The dual-score split from §4.4 is not optional.
- Benchmarking forward time without accounting for absorption. The clean training-time MLA forward looks slightly slower than MHA. The inference-time MLA forward (with precomputed) is comparable to MHA and dominates on memory. Don't confuse the two regimes when sizing capacity.
- GQA group count must divide . GQA-7 is not a thing on 64 heads. The group count is a divisor of the query-head count; common choices are 1, 2, 4, 8.
- MLA is a quality lever, not just a cache one. Shrinking below ~256 starts to hurt — the latent becomes the bottleneck instead of the head count. Stay near the DeepSeek-V2 default of 512 unless you have ablation data saying otherwise.
- Flash-Attention compatibility. All three variants have working Flash-Attention paths today, but the MLA path requires either concatenating content and RoPE slices into one head (the DeepSeek choice) or calling Flash twice and summing. Most other serving stacks (vLLM, SGLang) wrap this for you, but homemade inference loops need care.
What the table looks like for real shipped models
| Model | Variant | Heads (n_h, G or d_c) | Per-token cache @ FP16 | Layers |
|---|---|---|---|---|
| LLaMA-3 8B | GQA-8 | n_h=32, G=8 | 4 KB | 32 |
| LLaMA-3 70B | GQA-8 | n_h=64, G=8 | 4 KB | 80 |
| Mistral 7B | GQA-8 | n_h=32, G=8 | 4 KB | 32 |
| Qwen2 72B | GQA-8 | n_h=64, G=8 | 4 KB | 80 |
| Gemma 2 27B | GQA-2 | n_h=32, G=2 (sliding window) | 1 KB | 46 |
| DeepSeek-V2 | MLA | n_h=128, d_c=512, d_h^R=64 | 1.125 KB | 60 |
| DeepSeek-V3 | MLA | n_h=128, d_c=512, d_h^R=64 | 1.125 KB | 61 |
Two patterns stand out. First, GQA-8 is the unanimous default among dense open-weight decoders since LLaMA-3. Second, the frontier teams that need to serve very long context cheaply (DeepSeek, Yi-Lightning, Kimi's newer releases) have moved to MLA. The market has split into "ship GQA, it just works" and "invest in MLA, it pays back at serving."
Summary
- One formula covers all variants: . The variant only sets .
- MHA, MQA and GQA scale with the KV head count. MLA does not — its formula has no at all. That is the structural break.
- On real shapes, GQA-8 gives ~8× compression over MHA; MLA gives another ~3.5–7× on top of GQA, growing with .
- Forward-time cost differences are single-digit percent; cache cost differences are order-of-magnitude. The trade is uneven in the cache direction, which is the bottleneck at long context.
- Default to GQA-8 for dense decoders. Reach for MLA when you serve long contexts to many concurrent users, or when you are at the frontier of model size. Never ship MHA at large scale.
- MLA requires Decoupled RoPE (§4.4) to keep its compression. Treat them as a unit, not as separable choices.
The next section pulls all of this into a single working PyTorch implementation of MLA — the full training and inference path, absorbed projections included — so you can run a real layer and measure the numbers yourself.