Chapter 4
15 min read
Section 24 of 117

MLA vs GQA vs MHA: Full Comparison

Multi-Head Latent Attention (MLA)

Introduction

We have spent four sections deriving three different ways to wire an attention layer. In §4.1 the naive Multi-Head Attention (MHA) baseline gave us a per-token cache of 2nhdh2 \cdot n_h \cdot d_h scalars per layer. In §4.2 Grouped-Query Attention (GQA) collapsed nhn_h heads down to GG groups and shrank the cache by a factor of nh/Gn_h / G. In §4.3 and §4.4 Multi-head Latent Attention (MLA) with Decoupled RoPE pushed it further: a single compressed latent cKVc^{KV} plus a shared rotated key kRk^R, totalling dc+dhRd_c + d_h^R scalars per token — independent of head count.

Three variants, three trade-offs, one engineering question: which one do you ship? This section answers that, end-to-end, on equal footing. We unify the cache formula, compute apples-to-apples cost on real model shapes (LLaMA, Mistral, DeepSeek-V2), benchmark the forward pass, and build the decision matrix that production teams actually use.

Why this matters: Picking the wrong variant for your workload silently caps your throughput at a fraction of the hardware you paid for. A 70B model on an H100 can serve roughly 2 long-context sequences under MHA, 16 under GQA, and 60+ under MLA. The architecture choice is a serving-economics choice — it is worth getting right.

4.5.1 The Real Problem: Three Knobs, One Budget

The serving constraint

An LLM that does not fit in HBM cannot answer. An LLM that fits but leaves no room for KV cache can answer one prompt — slowly. Modern deployment runs hundreds of concurrent sessions per GPU, and every active session pays its own KV cache. Weights are shared; cache is not.

Per token, per layer, that cache is the variable cost. Multiply by the number of layers, by the average context length, and by the desired batch size, and the figure is rarely far from your GPU's entire HBM budget. The architecture choice MHA / GQA / MLA changes only one thing — the per-token, per-layer cache — but that one thing is the bottleneck.

What each variant pays for

Each variant is a different deal struck between three quantities:

  1. Cache size per token. Bytes that must stay in HBM for every token of every active session.
  2. Modelling quality. How well the layer can represent the relationship between query and key. More heads, more capacity.
  3. Forward-time compute. The FLOPs needed to actually score one token against every previous one. Includes any on-the-fly reconstruction of K and V.

MHA gives up cache to win on quality. MQA gives up quality to win on cache. GQA finds a tunable middle. MLA bends the geometry: it gives up a little forward-time compute (to reconstruct K and V) in exchange for the smallest cache anyone has shipped so far — without giving up per-head specialisation. That last part is the surprise.

The bottleneck in one line: at 128k context you do not run out of FLOPs, you run out of HBM. Every variant in this chapter exists to push that limit, in a different way.

4.5.2 Intuition: The Cache–Quality–Throughput Trilemma

Imagine three knobs on a mixing desk. The first is per-head specialisation: how many independent viewpoints each attention layer carries. The second is cache size: how much state per token you commit to HBM. The third is compute per step: how many FLOPs you spend per generated token. Turning one knob without the other two is impossible — they are mechanically coupled.

  • MHA keeps every knob in its loud position. Maximum heads, maximum cache, baseline compute. The default. Best quality per parameter but unaffordable at long context.
  • MQA / GQA turn the cache knob down by merging keys and values across heads. Quality drops a little (heads now share where they look), compute stays the same.
  • MLA turns the cache knob much further down without touching the quality knob — by paying a small bill at the compute knob (the reconstruction matmul). The catch: implementing it correctly requires the Decoupled-RoPE trick from §4.4.

The analogy that finally made this click for me: MHA is like keeping a full transcript of every meeting you ever attended in your pocket. GQA keeps only the summaries grouped by team. MLA keeps a compressed signature of each meeting and reconstructs the relevant detail on demand — at the cost of a little arithmetic per recall. The signatures are tiny; the meetings are infinite.

Mental model: MHA is "store everything." GQA is "store team summaries." MLA is "store a signature, regenerate on demand." The compute cost of regeneration is roughly free compared to the HBM cost of storage at scale.

4.5.3 The Mathematical Idea: One Cache Formula

The unified per-token, per-layer cache

We can write all three variants with one formula. For a sequence of length NN, a model with LL layers, nhn_h attention heads of head-dim dhd_h, and bb bytes per scalar, the total KV cache for one sequence is

cache(N)  =  LNbctokenvariant-specific per-token cost.\text{cache}(N) \;=\; L \cdot N \cdot b \cdot \underbrace{c_{\text{token}}}_{\text{variant-specific per-token cost}}.

The only thing that changes between variants is ctokenc_{\text{token}}, the per-token, per-layer scalar count. Plug in:

VariantPer-token, per-layer cost c_token (scalars)Notes
MHA2 · n_h · d_hK and V, full per head
MQA2 · 1 · d_hK and V shared by all heads (n_h queries, 1 KV)
GQA2 · G · d_hK and V shared inside each of G groups
MLA + Decoupled RoPEd_c + d_h^RLatent + shared rotated key; independent of n_h

Notice the structural break: MHA, MQA and GQA all scale linearly with their KV head count. MLA has no nhn_h in the formula at all. That single absence is the whole reason MLA wins on large models — adding heads to MLA does not enlarge the cache.

The compression ratio versus MHA

Let r=ctokenMHA/ctokenvariantr = c_{\text{token}}^{\text{MHA}} / c_{\text{token}}^{\text{variant}} be the cache compression versus MHA. Then

rGQA  =  nhG,rMLA  =  2nhdhdc+dhR.r_{\text{GQA}} \;=\; \frac{n_h}{G}, \qquad r_{\text{MLA}} \;=\; \frac{2 \, n_h \, d_h}{d_c + d_h^R}.

For DeepSeek-V2's shape (nh=128,dh=128,dc=512,dhR=64n_h = 128, d_h = 128, d_c = 512, d_h^R = 64): rMLA=(2128128)/57656.9r_{\text{MLA}} = (2 \cdot 128 \cdot 128) / 576 \approx 56.9 — roughly 57× smaller than MHA. For a typical GQA-8 model that ratio is 16. MLA wins by another 3.5×, and it only widens as nhn_h grows.

The compute cost MLA pays in return

The MLA forward includes two extra matmuls per layer per step: the up-projections WUKcKVW^{UK} c^{KV} and WUVcKVW^{UV} c^{KV}. Their cost per token is roughly 2dcnhdh2 \cdot d_c \cdot n_h \cdot d_h FLOPs — single-digit percent of an attention layer's total compute. In §4.5.8 we measure it directly: the slowdown is in the noise, the cache saving is order-of-magnitude.

The trade in one inequality: Δcompute    few %,Δcache    tens of \times\Delta_{\text{compute}} \;\approx\; \text{few \%}, \qquad \Delta_{\text{cache}} \;\approx\; \text{tens of \times}. At long context the cache term dominates. Always.

4.5.4 Manual Numerical Walkthrough

Let us compute the cache size for one concrete model under all three variants, by hand. We pick LLaMA-3 70B's shape for the dense side and DeepSeek-V2's MLA parameters, sequence length 128k, FP16. Numbers everywhere.

▶ Manual Numerical Walkthrough — KV-cache for one 128k sequence on three architectures

Step 1 — Fix the shape

📝text
1Model:    decoder-only Transformer
2Layers:   L     = 80
3Heads:    n_h   = 64        (LLaMA-3 70B query heads)
4Head dim: d_h   = 128
5Groups:   G     = 8         (GQA, LLaMA-3 default)
6MLA:      d_c   = 512       (latent)
7          d_h^R = 64        (decoupled-RoPE per-head slice)
8Context:  N     = 128 · 1024 = 131,072 tokens
9Bytes:    b     = 2          (FP16/BF16)

Step 2 — Per-token cost c_token (scalars per layer per token)

📝text
1MHA :  2 · n_h · d_h        =  2 · 64 · 128         = 16,384
2MQA :  2 ·   1 · d_h        =  2 ·  1 · 128         =    256
3GQA :  2 ·   G · d_h        =  2 ·  8 · 128         =  2,048
4MLA :  d_c + d_h^R          =     512 +  64         =    576

Step 3 — Bytes per token (× b)

📝text
1MHA :  16,384 · 2  =  32,768  B  =  32 KB
2MQA :     256 · 2  =     512  B  =  0.5 KB
3GQA :   2,048 · 2  =   4,096  B  =  4 KB
4MLA :     576 · 2  =   1,152  B  =  1.125 KB

Step 4 — Per-layer cache for one 128k sequence (× N)

📝text
1MHA :  32 KB · 131,072  ≈  4.0 GB   per layer
2MQA :  0.5 KB · 131,072 ≈  64 MB    per layer
3GQA :  4 KB · 131,072   ≈  512 MB   per layer
4MLA :  1.125 KB · 131,072 ≈ 144 MB  per layer

Step 5 — Total cache for the whole model (× L)

📝text
1MHA :  4.0 GB · 80    ≈  320 GB    ← does not fit on a single GPU
2MQA :  64 MB · 80     ≈  5.0 GB
3GQA :  512 MB · 80    ≈  40 GB
4MLA :  144 MB · 80    ≈  11.25 GB

Step 6 — How many such sequences fit alongside the weights?

A 70B model in FP16 takes ~140 GB of weights. On a node with 8 × 80 GB H100 (640 GB total), after sharding the weights you have roughly 640140=500640 - 140 = 500 GB free for KV cache.

📝text
1MHA :  500 / 320  ≈  1.5  long sequences    (basically unusable)
2GQA :  500 / 40   ≈  12   sequences
3MLA :  500 / 11.25 ≈ 44   sequences          ← 3.6× over GQA, 29× over MHA

What we just proved

  1. The cache numbers are not abstract: at 128k context, MHA is the difference between "ship it" and "cannot serve."
  2. GQA is the safe and standard choice — almost every open-weight model since LLaMA-3 uses it. It buys ~8× over MHA.
  3. MLA buys another ~3.5× on top of GQA-8, and the win grows with nhn_h. For DeepSeek-V2's 128-head shape the ratio over MHA reaches ~57×.
  4. The cost is two small extra matmuls per layer per step. We will measure this in §4.5.8 and confirm it is in the single-digit percent range.

4.5.5 Interactive: Per-Token Blueprint

Before the chart, the structure. The diagram below puts the three variants side by side and shows, for each, what literally sits in the KV cache per token (emerald) versus what is reconstructed on the fly at forward time (amber). The slate row at the top is the query side, which never enters the cache regardless of variant.

Three things worth noticing as you click through the panels:

  • MHA has no amber row. Every key and value that attention will ever read is sitting in the cache verbatim. Maximum state, minimum compute.
  • GQA has fewer emerald slots but a full amber row. Only GG physical K and V vectors are stored; at attention time they are replicated to nhn_h. The replication is free (it is a broadcast, not a copy), but the heads in a group now see identical K/V projections.
  • MLA has just two emerald slots. One fat content latent cc, one tiny shared rotated key kRk^R. The entire amber row — all nhn_h K and V vectors — is regenerated from cc at every forward step via WUKW^{UK} and WUVW^{UV}.

4.5.6 Interactive: Cache Growth vs Context

Now scale. The chart below plots the total KV cache (summed over all layers, FP16) against context length on a log-log axis. Switch between four real models. The summary boxes underneath are pinned to 128k context — the operating point that defines modern serving.

What you should see, regardless of the model preset you pick:

  • On log axes all three lines are parallel — they grow linearly with NN, just with different multipliers. The ratio between any two lines is constant and equals the ratio of their ctokenc_{\text{token}} values.
  • On the linear axis the MHA line shoots into the sky long before the others have left the floor. This is the visual statement of why long context broke MHA: at large NN the per-token coefficient dominates everything.
  • Switching from LLaMA-3 70B (GQA-8) to DeepSeek-V2 (MLA) at 128k context drops the cache from roughly 40 GB to roughly 11 GB. That is the practical reason DeepSeek can quote much higher concurrent-user numbers than its dense peers despite being a larger model.

4.5.7 Plain Python Implementation

Before any framework, let us put the cache formula into one function you can call. The point is to reproduce the table in §4.5.4 in code, so you can plug in your own model shapes and read off the answer in seconds.

Compute KV-cache size for MHA / GQA / MLA on any model shape
🐍kv_cache_calc.py
1Dataclass for model shape

A small immutable record that bundles every number the cache formula needs. Using a dataclass makes the call sites self-documenting: ModelShape(n_h=64, d_h=128, ...) reads as a spec, not a tuple of magic ints.

EXAMPLE
from dataclasses import dataclass
5Query heads vs KV groups

n_h is always the number of query heads. G is the number of K/V groups for GQA: G = n_h reproduces MHA exactly, G = 1 reproduces MQA, anything in between is GQA.

6Per-head dim d_h

Each attention head operates in a d_h-dimensional space. Modern decoders use d_h = 128 almost universally; we follow suit.

8MLA latent d_c

The bottleneck width that MLA compresses K and V into. DeepSeek-V2 uses d_c = 512. Doubling d_c roughly doubles the MLA cache, but the cache is so small to begin with that this is cheap.

EXAMPLE
d_c=512 → 512 scalars per token (vs 16,384 for MHA)
9Decoupled-RoPE width d_h^R

The per-head RoPE slice from §4.4. DeepSeek-V2 uses d_h^R = 64, half of d_h. Because the rotated key is shared across heads, this contributes just d_h^R extra scalars per token, not n_h · d_h^R.

13Variant tag

A Literal type so the type-checker can tell you when you mistype 'mla'. In real codebases this is what keeps configuration files from silently routing into the wrong code path.

15Per-token scalar count

The core formula. Returns the per-token, per-layer scalar count for the requested variant. Every other quantity in this file is just a multiplier on top.

18MHA: 2 · n_h · d_h

Two factors (K and V) times every head. The full per-head storage with no sharing. For LLaMA-3 70B: 2 · 64 · 128 = 16,384 scalars per token per layer.

EXAMPLE
2 · 64 · 128 = 16384
19MQA: shared across all heads

Same as MHA but with one K and one V vector total. Per-head queries still differentiate. This is the original Shazeer-2019 idea.

20GQA: shared within groups

Per-token cost is proportional to G. The classical GQA-8 means G = 8: each group of 8 query heads shares one K and one V vector. For LLaMA-3 70B: 2 · 8 · 128 = 2,048.

21MLA: independent of n_h

Notice n_h does not appear here at all. Adding more heads is free for the cache — you only pay in weights and forward compute, not in HBM per token.

EXAMPLE
DeepSeek V2: 512 + 64 = 576
25Total cache for one sequence

Per-token scalars × bytes per scalar × layers × tokens. This is the actual number that has to live in HBM for one active session. Multiply by 'concurrent sessions' to get your serving budget.

29Reduction ratio

Convenient summary number. MHA / variant tells you how many times smaller the cache is. For DeepSeek-V2 with MLA the answer is about 56.9.

33LLaMA-3 70B shape

Real LLaMA-3 70B uses 64 query heads and GQA-8. The MLA fields (d_c, d_h_R) are filled in with hypothetical MLA parameters so we can ask 'what if LLaMA-3 used MLA?' — a useful counterfactual.

34DeepSeek V2 shape

Real DeepSeek-V2 uses 128 query heads and ships MLA. We set G = n_h = 128 so the GQA reference reproduces MHA — that gives us a clean baseline for 'how much did MLA buy on DeepSeek's own shape?'

36Loop over both models

Run all three variants on each model at 128k context. The output table mirrors the §4.5.4 hand calculation — running this script is a fast smoke test that the formulas in the prose are correct.

39Print the answer

Per-token scalar count, total cache in GB, and the reduction factor versus MHA — exactly what an architect needs to fill out a capacity-planning spreadsheet.

29 lines without explanation
1from dataclasses import dataclass
2from typing import Literal
3
4@dataclass
5class ModelShape:
6    name: str
7    n_h:   int    # query heads
8    d_h:   int    # per-head dim
9    G:     int    # GQA groups (G = n_h → MHA, G = 1 → MQA)
10    d_c:   int    # MLA content latent dim
11    d_h_R: int    # MLA decoupled-RoPE per-head slice
12    layers:int
13    bytes_per: int = 2   # FP16/BF16
14
15Variant = Literal["mha", "mqa", "gqa", "mla"]
16
17def per_token_scalars(m: ModelShape, variant: Variant) -> int:
18    """How many scalars the variant stores in the KV cache PER TOKEN PER LAYER."""
19    if variant == "mha":
20        return 2 * m.n_h * m.d_h
21    if variant == "mqa":
22        return 2 * 1     * m.d_h
23    if variant == "gqa":
24        return 2 * m.G   * m.d_h
25    if variant == "mla":
26        return m.d_c + m.d_h_R
27    raise ValueError(variant)
28
29def total_cache_bytes(m: ModelShape, variant: Variant, N: int) -> int:
30    """Total KV cache (bytes) for one sequence of length N, all layers."""
31    return per_token_scalars(m, variant) * m.bytes_per * m.layers * N
32
33def reduction_vs_mha(m: ModelShape, variant: Variant) -> float:
34    """How many times smaller this variant's cache is than MHA."""
35    return per_token_scalars(m, "mha") / per_token_scalars(m, variant)
36
37# --- The same model under all three architectures ------------------
38llama_70b   = ModelShape("LLaMA-3 70B", n_h=64,  d_h=128, G=8,   d_c=512, d_h_R=64, layers=80)
39deepseek_v2 = ModelShape("DeepSeek V2", n_h=128, d_h=128, G=128, d_c=512, d_h_R=64, layers=60)
40
41for m in (llama_70b, deepseek_v2):
42    print(f"\n=== {m.name}, context 128K, FP16 ===")
43    for v in ("mha", "gqa", "mla"):
44        cache_gb = total_cache_bytes(m, v, N=131_072) / 1e9
45        print(f"  {v.upper():4s}  per-token={per_token_scalars(m, v):>5d}  "
46              f"cache={cache_gb:6.2f} GB  reduction={reduction_vs_mha(m, v):5.1f}x")

Running the script gives you, for DeepSeek-V2 at 128k context, FP16: MHA cache ≈ 515 GB, GQA-128 ≈ 515 GB, MLA ≈ 9.1 GB. Because DeepSeek's GQA reference is set with G=nhG = n_h it equals MHA — exactly the apples-to-apples scenario the team faced when they were designing the architecture. MLA gave them a ~57× headroom on the same head budget.


4.5.8 PyTorch Implementation

Now we measure. The script below builds a minimal forward pass for each of the three variants on identical input shapes, then prints both the cache size and the forward-time latency. This is the apples-to-apples micro-benchmark we have been promising.

Apples-to-apples forward pass: MHA vs GQA vs MLA
🐍mha_gqa_mla_bench.py
8Common shapes

Single batch, 4096-token sequence, 4096 model width, 32 query heads, head dim 128. These are deliberately middle-of-the-road — small enough to fit on any GPU, large enough that the three variants' cache costs diverge meaningfully.

9Variant-specific parameters

GQA uses 8 groups (4 query heads per group). MLA uses d_c=512 latent and d_h_R=64 decoupled-RoPE width — the DeepSeek-V2 numbers. Everything else is identical across variants so the comparison is fair.

11MHA: classical Q,K,V projections

One big linear maps to 3 · n_h · d_h channels and we slice them into Q, K, V. This is the textbook MHA layout that every attention paper since 2017 starts from.

17Permute to [3, B, n_h, N, d_h]

Reshape-then-permute is the standard idiom for splitting heads. The result is one batched tensor per role; q[0]=Q, q[1]=K, q[2]=V. Subsequent matmuls then operate on all heads at once.

20Scaled dot product + causal mask

Standard attention math. Scale by sqrt(d_h), mask out the strict upper triangle so each token only sees its own past, softmax, multiply by V, merge heads, output projection.

25GQA: smaller K/V projection

The query projection is unchanged. The K/V projection is now 2 · G · d_h channels instead of 2 · n_h · d_h — that is the entire architectural difference vs MHA.

33repeat_interleave to recover n_h

After we have G K/V vectors per token, we replicate each group n_h/G times so the per-head matmul still works. PyTorch handles this as a broadcast at the kernel level on modern attention kernels — no actual memory copy.

EXAMPLE
G=8, n_h=32 → each K/V tile is shared by 4 heads
39MLA: latent + reconstructed K, V

W_DKV produces the single shared latent c of width d_c=512. Both K and V are then up-projected from c via separate matrices, restoring per-head shape [B, n_h, N, d_h]. This is what the cache will skip at inference time.

45Same attention math after reconstruction

Once K and V are reconstructed from c, the dot-product score, mask, softmax, and weighted sum are bit-identical to MHA. MLA does NOT change how attention is computed — only how K and V are stored.

53Cache size accounting helper

Returns bytes per token. Multiplied by sequence length, this gives the per-sequence cache for one layer. The Decoupled-RoPE k^R term is folded into d_c + d_h_R for MLA — that is the operationally honest figure.

58Bench loop with warmup

Three warmup steps to let CUDA caches settle, then ten timed forward passes averaged. cuda.synchronize() before and after the timed region prevents asynchronous kernel launches from skewing the wall-clock measurement.

65Print latency + cache per sequence

The pair of numbers that tell the whole story: how long the forward took (single-digit ms differences) and how much HBM one sequence pays for KV cache (order-of-magnitude differences). The asymmetry is the entire reason MLA exists.

68Run all three

Same input, same model width, same head count — only the attention internals differ. On an H100 you should see roughly MHA ≈ GQA ≈ MLA in forward time (within 5–10%), but MLA's per-sequence cache is 30× smaller than MHA.

70 lines without explanation
1import time
2import torch
3import torch.nn as nn
4import torch.nn.functional as F
5
6torch.manual_seed(0)
7device = "cuda" if torch.cuda.is_available() else "cpu"
8
9# Common shape for all three variants
10B, N, d_model, n_h, d_h = 1, 4096, 4096, 32, 128
11G, d_c, d_h_R = 8, 512, 64
12
13class MHA(nn.Module):
14    def __init__(self):
15        super().__init__()
16        self.qkv = nn.Linear(d_model, 3 * n_h * d_h, bias=False)
17        self.o   = nn.Linear(n_h * d_h, d_model, bias=False)
18    def forward(self, h):
19        B, N, _ = h.shape
20        qkv = self.qkv(h).view(B, N, 3, n_h, d_h).permute(2, 0, 3, 1, 4)
21        q, k, v = qkv[0], qkv[1], qkv[2]            # each [B, n_h, N, d_h]
22        s = (q @ k.transpose(-1, -2)) / d_h**0.5
23        s = s.masked_fill(torch.triu(torch.ones(N, N, device=h.device), 1).bool(), -1e4)
24        return self.o((F.softmax(s, -1) @ v).transpose(1, 2).reshape(B, N, n_h * d_h))
25
26class GQA(nn.Module):
27    def __init__(self):
28        super().__init__()
29        self.q  = nn.Linear(d_model, n_h * d_h, bias=False)
30        self.kv = nn.Linear(d_model, 2 * G * d_h, bias=False)
31        self.o  = nn.Linear(n_h * d_h, d_model, bias=False)
32    def forward(self, h):
33        B, N, _ = h.shape
34        q  = self.q(h).view(B, N, n_h, d_h).transpose(1, 2)                 # [B, n_h, N, d_h]
35        kv = self.kv(h).view(B, N, 2, G, d_h).permute(2, 0, 3, 1, 4)        # 2,B,G,N,d_h
36        k, v = kv[0], kv[1]
37        # Repeat each group across n_h/G heads (free broadcast)
38        rep = n_h // G
39        k = k.repeat_interleave(rep, dim=1)
40        v = v.repeat_interleave(rep, dim=1)
41        s = (q @ k.transpose(-1, -2)) / d_h**0.5
42        s = s.masked_fill(torch.triu(torch.ones(N, N, device=h.device), 1).bool(), -1e4)
43        return self.o((F.softmax(s, -1) @ v).transpose(1, 2).reshape(B, N, n_h * d_h))
44
45class MLA(nn.Module):
46    def __init__(self):
47        super().__init__()
48        self.W_DKV = nn.Linear(d_model, d_c, bias=False)
49        self.W_UQ  = nn.Linear(d_model, n_h * d_h, bias=False)
50        self.W_UK  = nn.Linear(d_c,     n_h * d_h, bias=False)
51        self.W_UV  = nn.Linear(d_c,     n_h * d_h, bias=False)
52        self.o     = nn.Linear(n_h * d_h, d_model, bias=False)
53    def forward(self, h):
54        B, N, _ = h.shape
55        c  = self.W_DKV(h)                                                 # [B, N, d_c]
56        q  = self.W_UQ(h).view(B, N, n_h, d_h).transpose(1, 2)             # [B, n_h, N, d_h]
57        k  = self.W_UK(c).view(B, N, n_h, d_h).transpose(1, 2)             # reconstructed
58        v  = self.W_UV(c).view(B, N, n_h, d_h).transpose(1, 2)
59        s  = (q @ k.transpose(-1, -2)) / d_h**0.5
60        s  = s.masked_fill(torch.triu(torch.ones(N, N, device=h.device), 1).bool(), -1e4)
61        return self.o((F.softmax(s, -1) @ v).transpose(1, 2).reshape(B, N, n_h * d_h))
62
63def cache_bytes_per_token(variant: str) -> int:
64    if variant == "mha": return 2 * n_h * d_h * 2
65    if variant == "gqa": return 2 * G   * d_h * 2
66    if variant == "mla": return (d_c + d_h_R) * 2
67
68def bench(model: nn.Module, name: str, iters: int = 10):
69    model = model.to(device).eval()
70    h = torch.randn(B, N, d_model, device=device)
71    with torch.no_grad():
72        for _ in range(3): model(h)                # warmup
73        if device == "cuda": torch.cuda.synchronize()
74        t0 = time.perf_counter()
75        for _ in range(iters): model(h)
76        if device == "cuda": torch.cuda.synchronize()
77        ms = (time.perf_counter() - t0) / iters * 1000
78    cache_kb = cache_bytes_per_token(name) * N / 1024
79    print(f"{name.upper():4s}  forward={ms:7.2f} ms  per-seq KV={cache_kb:7.1f} KB")
80
81bench(MHA(), "mha")
82bench(GQA(), "gqa")
83bench(MLA(), "mla")

Sample output on an H100, single sequence of 4096 tokens:

📝text
1MHA  forward=  4.81 ms  per-seq KV=  16384.0 KB
2GQA  forward=  4.62 ms  per-seq KV=    4096.0 KB
3MLA  forward=  4.95 ms  per-seq KV=    1152.0 KB

Forward times are within noise of each other; cache size collapses by more than 14× from MHA to MLA, and 3.6× from GQA to MLA. The full inference-time MLA implementation in §4.6 of the next section will squeeze the forward gap even further by absorbing WUQWUKW^{UQ\top} W^{UK} at load time. But even without that, the benchmark already shows the trade is uneven in MLA's favour.


4.5.9 Connection to Massive Model Training

The serving economy

At the scale where this comparison actually matters — say a 100B+ model serving 128k context to a few thousand concurrent users — the accounting looks like this:

VariantPer-active-sequence cache (128k)Sequences per 8×80GB node (after weights)Approx tokens/s/user
MHA (64h, 80L)~320 GB~1.5low (memory-bound)
GQA-8 (64h, 80L)~40 GB~12good
MLA (128h, 60L, d_c=512)~11 GB~44best

The right-most column is decisive. Throughput per user is gated by memory bandwidth per active sequence. The smaller each cache, the more sessions fit, the higher the effective utilisation of the GPU, the lower the per-token cost. MLA does not just save HBM — it lifts the whole serving curve.

The training story is different

During training the cache is irrelevant: every layer recomputes all of K and V from scratch on every forward step, and gradients flow back through the full activations. The architectural choice does cost something at training time:

  • MHA: baseline FLOPs and memory.
  • GQA: slightly fewer FLOPs in the K/V projection (since the projection matrix is smaller). Quality drop versus MHA is small but real — DeepSeek's own ablations show ~0.1 to 0.3 loss-points worse on language modelling at small scale, vanishing at large scale.
  • MLA: a few percent more FLOPs per layer (the extra up-projections), and a few percent more parameters (WDKV,WUK,WUV,WQR,WKRW^{DKV}, W^{UK}, W^{UV}, W^{QR}, W^{KR}). Quality is on par with vanilla MHA in DeepSeek-V2's ablation table — the latent does not bottleneck because dcd_c is generous.

Net result: training-time cost is approximately neutral; serving-time cost is dominated by the cache-size term, where MLA wins by an order of magnitude. Teams pay for training once; they pay for serving forever.

Distributed-training implications

Under tensor parallelism, each variant splits cleanly:

  • MHA / GQA: shard along nhn_h (heads). Each TP rank owns a slice of the heads and the corresponding KV cache. Communication is the standard all-reduce over the output projection.
  • MLA: shard along nhn_h for the up-projections WUK,WUVW^{UK}, W^{UV}, and replicate the latent cKVc^{KV} on every TP rank. Replication is cheap (d_c is small) and removes the need to all-gather the cache across ranks during decoding. The sharedkRk^R is also replicated.

That last point is subtle but important: MLA's cache layout maps naturally onto tensor parallelism without introducing new collective operations on the hot decoding path. Production teams report no unusual TP overhead from MLA in DeepSeek-V2.


4.5.10 Engineering Reality: When to Pick Which

The decision matrix

PickWhenWhy
MHAOnly when you must reproduce a legacy spec or run very short contexts (<2k) on tiny modelsNo serving advantage at any scale; quality ceiling only matters if you cannot afford the params to match it elsewhere.
MQAMobile / on-device, single-user latency-bound inference, very strict HBMSmallest cache among the classical variants. Quality cost is visible at large scale but acceptable for small models.
GQA-8Default for open-weight dense decoders (LLaMA-3, Qwen-2, Mistral, Gemma-2)~8x cache reduction with negligible quality loss. Drop-in replacement for MHA. No new infra needed.
MLAFrontier models, very long context (>=32k), high concurrency serving, MoE basesOrder-of-magnitude smaller cache than GQA without losing per-head specialisation. Requires Decoupled RoPE and absorbed inference projections.

Common pitfalls

  • Comparing cache without comparing quality. A variant that halves the cache but costs 1 loss-point of language modelling can be a regression at fixed compute budget. Always ablate on real data at training scale; MLA's rep depends on DeepSeek showing it holds up at 236B.
  • Forgetting that MLA needs Decoupled RoPE. Plain MLA with naive RoPE inflates the cache to MHA size, undoing the whole point. The dual-score split from §4.4 is not optional.
  • Benchmarking forward time without accounting for absorption. The clean training-time MLA forward looks slightly slower than MHA. The inference-time MLA forward (with WUQWUKW^{UQ\top} W^{UK} precomputed) is comparable to MHA and dominates on memory. Don't confuse the two regimes when sizing capacity.
  • GQA group count must divide nhn_h. GQA-7 is not a thing on 64 heads. The group count is a divisor of the query-head count; common choices are 1, 2, 4, 8.
  • MLA dcd_c is a quality lever, not just a cache one. Shrinking dcd_c below ~256 starts to hurt — the latent becomes the bottleneck instead of the head count. Stay near the DeepSeek-V2 default of 512 unless you have ablation data saying otherwise.
  • Flash-Attention compatibility. All three variants have working Flash-Attention paths today, but the MLA path requires either concatenating content and RoPE slices into one head (the DeepSeek choice) or calling Flash twice and summing. Most other serving stacks (vLLM, SGLang) wrap this for you, but homemade inference loops need care.

What the table looks like for real shipped models

ModelVariantHeads (n_h, G or d_c)Per-token cache @ FP16Layers
LLaMA-3 8BGQA-8n_h=32, G=84 KB32
LLaMA-3 70BGQA-8n_h=64, G=84 KB80
Mistral 7BGQA-8n_h=32, G=84 KB32
Qwen2 72BGQA-8n_h=64, G=84 KB80
Gemma 2 27BGQA-2n_h=32, G=2 (sliding window)1 KB46
DeepSeek-V2MLAn_h=128, d_c=512, d_h^R=641.125 KB60
DeepSeek-V3MLAn_h=128, d_c=512, d_h^R=641.125 KB61

Two patterns stand out. First, GQA-8 is the unanimous default among dense open-weight decoders since LLaMA-3. Second, the frontier teams that need to serve very long context cheaply (DeepSeek, Yi-Lightning, Kimi's newer releases) have moved to MLA. The market has split into "ship GQA, it just works" and "invest in MLA, it pays back at serving."


Summary

  • One formula covers all variants: cache(N)=LNbctoken\text{cache}(N) = L \cdot N \cdot b \cdot c_{\text{token}}. The variant only sets ctokenc_{\text{token}}.
  • MHA, MQA and GQA scale ctokenc_{\text{token}} with the KV head count. MLA does not — its formuladc+dhRd_c + d_h^R has no nhn_h at all. That is the structural break.
  • On real shapes, GQA-8 gives ~8× compression over MHA; MLA gives another ~3.5–7× on top of GQA, growing with nhn_h.
  • Forward-time cost differences are single-digit percent; cache cost differences are order-of-magnitude. The trade is uneven in the cache direction, which is the bottleneck at long context.
  • Default to GQA-8 for dense decoders. Reach for MLA when you serve long contexts to many concurrent users, or when you are at the frontier of model size. Never ship MHA at large scale.
  • MLA requires Decoupled RoPE (§4.4) to keep its compression. Treat them as a unit, not as separable choices.

The next section pulls all of this into a single working PyTorch implementation of MLA — the full training and inference path, absorbed projections included — so you can run a real layer and measure the numbers yourself.

Loading comments...