Chapter 4
25 min read
Section 25 of 117

Implementing MLA in PyTorch

Multi-Head Latent Attention (MLA)

From Theory to Code

Sections 4.3 and 4.4 derived MLA on paper — joint low-rank compression of the KV cache, decoupled RoPE for position, and the algebraic identity that lets us absorb the up-projection into the query matrix at inference. This section turns that derivation into code you can train and serve.

We are going to build the layer in three passes. First a minimal NumPy version where every intermediate is a real matrix the reader can inspect. Then a complete nn.Module with decoupled RoPE wired in, written the way you would actually ship it. Finally an inference-only fused path that exploits the absorption trick to delete two of the largest tensors in the decode hot loop.

What you should already know

The MLA latent cKV=hWDKVc_{KV} = h W^{DKV}, the two up-projections WUK,WUVW^{UK}, W^{UV}, and the decoupled-RoPE head split into a content part of dim dhcd_h^c and a shared rope part of dim dhRd_h^R. If those symbols look unfamiliar, read sections 4.3 and 4.4 first — this one assumes the math.


The Shape Contract

Every implementation bug in attention code is, at heart, a shape bug. Before any keystroke, pin down the contract for every tensor that crosses a function boundary. If you cannot fill this table from memory for your own layer, you will spend tomorrow debugging einsums.

TensorShapeLives whereNotes
h (hidden state input)(B, T_new, d_model)per-step inputT_new = 1 in decode, T in prefill
c_KV (latent)(B, T, d_c)KV cache (persistent)the only large cache write
k_rope (shared rope key)(B, T, d_rope)KV cache (persistent)shared across heads — not per-head
Q_content(B, T_new, n_heads, d_h_content)transientrebuilt every step from h
Q_rope(B, T_new, n_heads, d_rope)transientrebuilt every step from h
K_content (decompressed)(B, T, n_heads, d_h_content)transient — training pathFUSED AWAY at inference
V (decompressed)(B, T, n_heads, d_h_content)transient — training pathFUSED AWAY at inference
attn weights(B, n_heads, T_new, T)transientsoftmax over last dim
y (output)(B, T_new, d_model)passed to next layerfeeds the residual
Read the third column carefully. Only cKVc_{KV} and kRk^R are persistent. Everything else is a transient that lives for the duration of one forward pass. The whole MLA story is moving as much of the work into the "transient" column as possible.

Minimal MLA in Plain NumPy

The smallest possible MLA forward pass — no RoPE, no cache, no batching tricks — fits on one screen of NumPy. Build it first, run it, print the shapes. Then go look at the explanation panel to see what each line is actually doing.

Plain NumPy MLA forward pass
🐍mla_numpy.py
6Why these dimensions

B=1 keeps the trace readable. T=4 gives us enough tokens to see the causal mask do something. d_model=8 splits cleanly into 2 heads of 4. We pick small numbers because every intermediate is going to be a real matrix the reader could verify by hand.

8The compression target — d_c

The naive cache holds K and V — that is 2·n_heads·d_h = 16 scalars per token. MLA caches a single d_c-wide latent — 4 scalars per token. The compression ratio here is exactly 16/4 = 4×. DeepSeek-V3 pushes this to ~56×.

EXAMPLE
naive: 2*n_heads*d_h scalars  |  MLA: d_c scalars
11W_DKV — the compressor

Shape (d_model, d_c). This is the single most important matrix in MLA. Every hidden state h_t passes through it once to produce the cached latent c_KV[t]. Without it there is no compression.

12W_UK — the K decompressor

Shape (d_c, n_heads * d_h). At forward time we need full-width keys for the dot product. W_UK and W_UV pull K and V back out of the shared latent — that is what 'joint compression' means.

14W_Q stays full-width

Queries are computed fresh from the current hidden state every step. There is nothing to cache for Q, so the down-projection trick does not apply here. In the full DeepSeek architecture a query-side latent c_Q exists, but only to reduce activation memory during training, not to shrink any cache.

21c_KV — THE cache

This single line is the entire point of MLA. We replace 'store K_t and V_t per token' with 'store c_KV[t]'. Every other line below recreates the standard attention computation; only this one is the compression.

EXAMPLE
c_KV.shape == (B, T, d_c) == (1, 4, 4)
22Up-project on the fly

K is rebuilt from the latent at forward time. Notice that K is NOT stored — only c_KV is. The reshape splits the trailing dim of size n_heads*d_h into (n_heads, d_h) so each head gets its own 4-D slice.

EXAMPLE
(B, T, 8) -reshape-> (B, T, 2, 4)
28Why we transpose

Putting head in axis 1 lets us treat each (T, d_h) slice as an independent 2-D matrix. The matmul on line 33 then becomes a clean batched (T, d_h) @ (d_h, T) per head.

33Scaled dot-product scores

Same arithmetic as standard MHA. The √d_h scale keeps the softmax inputs in a sensible range so gradients do not vanish. MLA changes how K is produced — not how attention is scored.

34Causal mask

Upper triangle (strictly above diagonal) is masked — a token at position i cannot attend to positions > i. We set those scores to -inf so softmax assigns them zero weight without any branching.

38Numerically-stable softmax

We subtract the row max before exponentiating to keep exp() from overflowing. This is the same trick PyTorch's F.softmax does internally.

EXAMPLE
exp(scores - scores.max()) keeps values in [0, 1]
44Output projection

Concatenate heads (reshape) then project back to d_model. W_O is the only reason heads are not completely independent — it mixes information across them at the end of every layer.

34 lines without explanation
1import numpy as np
2
3np.random.seed(0)
4
5# ------- Shapes (toy, but real proportions) ---------------------------------
6B, T, d_model = 1, 4, 8        # batch, tokens, hidden
7n_heads, d_h = 2, 4            # 2 heads of 4 dims each (n_heads * d_h = 8)
8d_c = 4                        # MLA latent dim — what we cache per token
9
10# ------- Parameters (one-time) ----------------------------------------------
11W_DKV = np.random.randn(d_model, d_c) * 0.1            # down-proj  (8 -> 4)
12W_UK  = np.random.randn(d_c, n_heads * d_h) * 0.1      # K up-proj  (4 -> 8)
13W_UV  = np.random.randn(d_c, n_heads * d_h) * 0.1      # V up-proj  (4 -> 8)
14W_Q   = np.random.randn(d_model, n_heads * d_h) * 0.1  # Q up-proj  (8 -> 8)
15W_O   = np.random.randn(n_heads * d_h, d_model) * 0.1  # output     (8 -> 8)
16
17# ------- Input ---------------------------------------------------------------
18H = np.random.randn(B, T, d_model)          # hidden states from prior layer
19
20# ------- MLA forward pass ----------------------------------------------------
21c_KV = H @ W_DKV                            # (B, T, d_c)   <- CACHE THIS
22K = (c_KV @ W_UK).reshape(B, T, n_heads, d_h)
23V = (c_KV @ W_UV).reshape(B, T, n_heads, d_h)
24Q = (H    @ W_Q ).reshape(B, T, n_heads, d_h)
25
26# Reorder to (B, head, T, d_h) so each head is a clean 2-D matmul
27Q = Q.transpose(0, 2, 1, 3)
28K = K.transpose(0, 2, 1, 3)
29V = V.transpose(0, 2, 1, 3)
30
31# Scores & causal mask
32scores = (Q @ K.transpose(0, 1, 3, 2)) / np.sqrt(d_h)        # (B, H, T, T)
33mask = np.triu(np.ones((T, T), dtype=bool), k=1)
34scores = np.where(mask, -np.inf, scores)
35
36# Softmax over last dim, then weighted sum of values
37attn = np.exp(scores - scores.max(-1, keepdims=True))
38attn = attn / attn.sum(-1, keepdims=True)
39out  = attn @ V                                              # (B, H, T, d_h)
40
41# Concat heads -> output projection
42out = out.transpose(0, 2, 1, 3).reshape(B, T, n_heads * d_h)
43y   = out @ W_O                                              # (B, T, d_model)
44
45print("c_KV (the only thing we cache):", c_KV.shape)        # (1, 4, 4)
46print("y (layer output):              ", y.shape)           # (1, 4, 8)

Notice what is missing. There is no position information yet — cKVc_{KV} is the same whether the token is at position 1 or position 10,000. That is the gap decoupled RoPE will fill. And we recompute KK and VV from cKVc_{KV} every step — fine for training, wasteful at decode. The absorption trick further down deletes both of those rebuilds.

What you can do with this code

  1. Set dc=nhdhd_c = n_h \cdot d_h (here: dc=8d_c = 8) and verify the output matches a plain MHA you write next to it — at full rank, MLA degenerates into MHA up to a basis change.
  2. Shrink dcd_c to 1 and watch the attention collapse onto a single direction: every key now lives on the same line through the origin, so all scores become co-linear and the softmax flattens.
  3. Print (KKreconstructed)(K - K_{\text{reconstructed}}) to confirm there is no reconstruction error at full rank — the loss appears only when dc<nhdhd_c < n_h \cdot d_h.

The pipeline, visualized

The pipeline below is the same arithmetic you just read, run on a slightly larger toy. Pick a token, slide dcd_c, and watch the reconstruction quality of KK' degrade as the latent gets squeezed.

Loading MLA pipeline explorer…

The reconstruction never has to be perfect. Training co-adapts WDKVW^{DKV} and the two up-projections to the down-stream attention pattern — the model learns to compress in directions that matter to softmax, not in L2 of the keys. That is why DeepSeek-V3 can push dc=512d_c = 512 against an MHA-equivalent of 128128=16,384128 \cdot 128 = 16{,}384 without measurable quality loss.


PyTorch: The Module Skeleton

Translating the NumPy version into PyTorch is mostly bookkeeping. Wrap the matrices in nn.Linear so they are tracked by nn.Module and autograd, hand off the softmax to F.softmax, and let einsum or @ do the per-head batching.

We jump straight to the production version below — but you should understand what the skeleton looks like before all the RoPE and cache plumbing arrives. In short:

🐍python
1class MLA_Skeleton(nn.Module):
2    def __init__(self, d_model, n_heads, d_h, d_c):
3        super().__init__()
4        self.n_heads, self.d_h, self.d_c = n_heads, d_h, d_c
5        self.W_DKV = nn.Linear(d_model, d_c, bias=False)
6        self.W_UK  = nn.Linear(d_c, n_heads * d_h, bias=False)
7        self.W_UV  = nn.Linear(d_c, n_heads * d_h, bias=False)
8        self.W_Q   = nn.Linear(d_model, n_heads * d_h, bias=False)
9        self.W_O   = nn.Linear(n_heads * d_h, d_model, bias=False)
10
11    def forward(self, h):
12        c_KV = self.W_DKV(h)                       # <-- the cache write
13        K = self.W_UK(c_KV).view(*h.shape[:-1], self.n_heads, self.d_h)
14        V = self.W_UV(c_KV).view(*h.shape[:-1], self.n_heads, self.d_h)
15        Q = self.W_Q(h).view(*h.shape[:-1], self.n_heads, self.d_h)
16        # ... scaled dot-product attention with a causal mask ...
17        return y, c_KV

That is the entire scaffold. Every additional line below exists for exactly one of three reasons: handling the KV cache across multiple forward passes, integrating decoupled RoPE without breaking the absorption trick, or wiring up the inference fast path.


RoPE: The Implementation Problem

Section 4.4 explained why RoPE has to be decoupled from the latent. This is what it looks like as an implementation constraint, before we know the fix.

The naive thing to do is rotate KK right after decompressing it from the latent:

🐍python
1# WRONG — naive attempt to add RoPE on top of vanilla MLA.
2# This silently breaks the cache invariant.
3
4# Step 1: build full K from the latent (as in plain MLA)
5K = self.W_UK(c_KV).view(B, T, self.n_heads, self.d_h)
6
7# Step 2: rotate K by absolute position
8K = apply_rope(K, cos_all, sin_all)                # <-- the bug lives here
9
10# Why this is wrong:
11# RoPE inserts position into K. If we cache c_KV but reconstruct K every
12# step and re-rotate, the rotation is fine. But now we can NO LONGER absorb
13# W_UK into W_Q — because  Q @ rotate(K)  is NOT  Q @ rotate(c_KV @ W_UK)
14# unless we can commute rotation with W_UK, which we cannot. So we lose the
15# absorption trick AND still pay d_h_content per head per token of rotation
16# work. We get the worst of both worlds.

This compiles and runs. It even trains. What it quietly destroys is the absorption identity (hWQ)(cKVWUK)=h(WQWUK)cKV(h W^Q)(c_{KV} W^{UK})^\top = h \big(W^Q W^{UK\top}\big) c_{KV}^\top — because once a rotation matrix lives between WQW^Q and WUKW^{UK}, you can no longer fuse them. At inference you are then forced to materialize the full KcontentK_{\text{content}} every decode step — which is exactly the cost MLA exists to avoid.

The structural fix

Decoupled RoPE splits every head into two subspaces:

  1. A content subspace of dim dhcd_h^c that is fed by the latent and stays rotation-free — so absorption still works on it.
  2. A RoPE subspace of dim dhRd_h^R with its own dedicated projection from hh, rotated by absolute position, and shared across all heads on the key side so the per-token cache footprint stays dc+dhRd_c + d_h^R instead of dc+nhdhRd_c + n_h \cdot d_h^R.

The visual below shows the head split, the rotation on the small RoPE subspace, and how the per-token cache changes as you slide the dials. Drop dhRd_h^R to zero and watch the rotation panel disappear; raise dcd_c and watch the cache grow linearly.

Loading decoupled-RoPE visualizer…

The Complete MLA Layer

Here is the full PyTorch layer. It is roughly 100 lines, handles prefill and decode in the same forward(), applies decoupled RoPE on the small shared subspace, and returns a clean cache tuple you can thread through your generation loop.

MLA with decoupled RoPE — full layer
🐍mla.py
6Why a separate apply_rope helper

RoPE pairs adjacent dims (2i, 2i+1) and rotates each pair. Pulling this into a function keeps the main forward() readable and lets us apply the same rotation to Q (per head) and K (shared across heads) without duplicating code.

13Interleaved rotation (even/odd split)

x1 takes even dims, x2 takes odd dims. We rotate them as a single complex number: (a+ib)·(cos+i·sin) = (a·cos − b·sin) + i·(a·sin + b·cos). This is exactly how RoPE is defined — a position-dependent 2-D rotation on every pair.

EXAMPLE
[a, b, c, d] -> rotate (a,b), rotate (c,d)
17Re-interleave with stack + flatten

torch.stack(..., dim=-1) builds (..., d_rope/2, 2), then flatten(-2) re-interleaves the rotated even/odd values back into shape (..., d_rope). This is the fastest no-copy way to interleave on most backends.

30Why two head dims

MLA splits each head into a content part (d_h_content) handled by latent compression, and a small RoPE part (d_rope) that carries position. Total per-head dim during attention is d_h_content + d_rope. DeepSeek-V3 uses 128 + 64.

EXAMPLE
d_h_content=128, d_rope=64 -> attention head dim = 192
38The 1/√d scale

Total head dim is d_h_content + d_rope. Using the combined dim in the scale keeps the variance of the scores constant regardless of the content/rope split — important when sweeping d_rope.

41Three matrices, one shared latent

W_DKV produces the latent once. W_UK and W_UV each pull a full keys/values bank out of the same c_KV. The model is free to learn DIFFERENT decoders for K vs V from a SHARED code — that is what makes joint compression cheap.

47Why two query projections

Queries also need a content part (matched against K_content) and a rope part (matched against K_rope). Both are per-head: a (B, T_new, n_heads * d_h_content) projection and a (B, T_new, n_heads * d_rope) projection.

51The crucial asymmetry

W_K_rope outputs only d_rope features — NOT n_heads * d_rope. There is ONE rope-key per token that every head shares. This is why MLA&apos;s decode cache is d_c + d_rope per token, not d_c + n_heads · d_rope.

EXAMPLE
k_rope.shape == (B, T, d_rope)   # not per-head!
58Pre-compute RoPE angles once

inv_freq holds one frequency per (d_rope/2) pair. Multiplying by position via einsum gives a (max_seq, d_rope/2) table of angles. We cache the cos/sin once and slice per forward pass — never recompute trig inside the hot loop.

63register_buffer vs nn.Parameter

rope_cos/sin are NOT learned. register_buffer makes them follow the module to .to(device) and into the state_dict, without showing up in .parameters(). The persistent=False flag also keeps them out of checkpoints — they are reproducible from rope_base + d_rope alone.

75Compress the new tokens (the cache write)

c_KV_new is the entire MLA write per step — d_c scalars per new token. Combined with k_rope_new (d_rope scalars per new token), this is the FULL information that goes into the cache for these tokens. K and V themselves never enter the cache.

EXAMPLE
wrote 1 token: d_c + d_rope scalars   (vs MHA: 2*n_heads*d_h)
76k_rope is shared across heads

Producing one d_rope-wide rope-key per token costs the same memory as a half-head. This is the price MLA pays to keep position information in the cache without re-introducing per-head K. It is a deliberate trade — content matching gets compression, position matching pays full price on a tiny dim.

86Stitch past + new (the cache read)

torch.cat is fine here for clarity; production systems use ring-buffers or paged allocators (vLLM&apos;s PagedAttention) to avoid the realloc cost. The math is identical — we just need every past latent visible to attention.

93Up-project keys ON THE FLY

K is rebuilt from c_KV every forward pass — it is never stored. This is the central trade: we spend O(T · n_heads · d_h · d_c) FLOPs per step to avoid storing O(T · 2 · n_heads · d_h) scalars in HBM. Compute is cheap; HBM bandwidth is the scarce resource at decode.

109RoPE on QUERY only for NEW positions

Queries are produced from the new tokens, so we rotate them by their absolute positions [past, past+T_new). Keys live in the cache and must carry their original rotations — but RoPE&apos;s property is that the score q·k depends only on the RELATIVE position (m−n), so we rotate q at position m and k at position n and let the dot product handle the rest.

110RoPE on the shared K-rope across ALL positions

We rotate the entire k_rope tensor — past and new — by their absolute positions. Past rotations are recomputed every step (cheap: d_rope is small, typically 64) so we do not have to cache pre-rotated keys.

112Broadcast — not materialize

K_rope.unsqueeze(2).expand(...) creates a VIEW, not a copy. Every head sees the same rope-key without n_heads× memory blowup. This is the runtime expression of MLA&apos;s shared-rope design.

EXAMPLE
k_rope:(B,T,d_rope) -view-> K_rope:(B,T,H,d_rope)
116Concat along head dim

Q and K both become shape (B, T*, H, d_h_content + d_rope). The attention dot product then mixes content and rope contributions naturally: q · k = q_c · k_c + q_r · k_r — exactly what section 4.4 derived.

125Causal mask with offset

triu(diagonal=past+1) masks any past+i ≥ j path where j > past+i. This works whether T_new == 1 (decode) or T_new == T (prefill) without branches.

EXAMPLE
T_new=1, past=5 -> mask only positions > 5
128softmax then .type_as(V)

F.softmax upcasts to fp32 internally for numerical stability, then we cast back to V&apos;s dtype (bf16/fp16) to keep the matmul on tensor cores. Missing this cast is a common source of mixed-precision NaNs.

134Return the NEW cache

We return the full (c_KV, k_rope) so the caller can pass it back on the next step. Returning a tuple keeps the layer pure — no hidden state mutates inside the module.

112 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5
6def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
7    """
8    Rotate the last-dim of x by (cos, sin). Works on the RoPE subspace only.
9    x:   (..., T, d_rope)  with d_rope even
10    cos: (T, d_rope/2)
11    sin: (T, d_rope/2)
12    """
13    x1, x2 = x[..., ::2], x[..., 1::2]
14    rotated_even = x1 * cos - x2 * sin
15    rotated_odd  = x1 * sin + x2 * cos
16    out = torch.stack((rotated_even, rotated_odd), dim=-1)
17    return out.flatten(-2)
18
19
20class MLA(nn.Module):
21    """
22    Multi-Head Latent Attention with decoupled RoPE — production layout.
23
24    Per-token KV cache: d_c + d_rope scalars  (independent of n_heads).
25    """
26
27    def __init__(
28        self,
29        d_model: int,
30        n_heads: int,
31        d_h_content: int,   # per-head content dim (no RoPE)
32        d_rope: int,        # shared RoPE-head dim
33        d_c: int,           # MLA latent dim
34        max_seq: int = 8192,
35        rope_base: float = 10_000.0,
36    ):
37        super().__init__()
38        self.n_heads     = n_heads
39        self.d_h_content = d_h_content
40        self.d_rope      = d_rope
41        self.d_c         = d_c
42        self.scale       = (d_h_content + d_rope) ** -0.5
43
44        # Content path — joint compression.
45        self.W_DKV = nn.Linear(d_model, d_c, bias=False)
46        self.W_UK  = nn.Linear(d_c, n_heads * d_h_content, bias=False)
47        self.W_UV  = nn.Linear(d_c, n_heads * d_h_content, bias=False)
48
49        # Query path — content part is per-head, RoPE part is per-head too.
50        self.W_Q_content = nn.Linear(d_model, n_heads * d_h_content, bias=False)
51        self.W_Q_rope    = nn.Linear(d_model, n_heads * d_rope,    bias=False)
52
53        # RoPE key path — produces ONE shared rope-key per token, NOT per head.
54        self.W_K_rope = nn.Linear(d_model, d_rope, bias=False)
55
56        # Output.
57        self.W_O = nn.Linear(n_heads * d_h_content, d_model, bias=False)
58
59        # Pre-baked rope angles for all positions we will ever see.
60        inv_freq = 1.0 / (
61            rope_base ** (torch.arange(0, d_rope, 2).float() / d_rope)
62        )
63        pos = torch.arange(max_seq).float()
64        freqs = torch.einsum("t,f->tf", pos, inv_freq)        # (max_seq, d_rope/2)
65        self.register_buffer("rope_cos", freqs.cos(), persistent=False)
66        self.register_buffer("rope_sin", freqs.sin(), persistent=False)
67
68    def forward(
69        self,
70        h: torch.Tensor,                              # (B, T_new, d_model)
71        cache: tuple[torch.Tensor, torch.Tensor] | None = None,
72    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
73        B, T_new, _ = h.shape
74
75        # --- 1. Compress the new tokens. THIS is what enters the cache. -----
76        c_KV_new = self.W_DKV(h)                      # (B, T_new, d_c)
77        k_rope_new = self.W_K_rope(h)                 # (B, T_new, d_rope) — SHARED
78
79        # --- 2. Stitch cache + new tokens ------------------------------------
80        if cache is None:
81            c_KV   = c_KV_new
82            k_rope = k_rope_new
83            past   = 0
84        else:
85            c_KV_past, k_rope_past = cache
86            c_KV   = torch.cat([c_KV_past,   c_KV_new],   dim=1)
87            k_rope = torch.cat([k_rope_past, k_rope_new], dim=1)
88            past   = c_KV_past.shape[1]
89        T = c_KV.shape[1]
90
91        # --- 3. Up-project keys/values from the FULL latent -----------------
92        K_content = self.W_UK(c_KV).view(B, T, self.n_heads, self.d_h_content)
93        V         = self.W_UV(c_KV).view(B, T, self.n_heads, self.d_h_content)
94
95        # --- 4. Queries from the NEW tokens only ----------------------------
96        Q_content = (
97            self.W_Q_content(h).view(B, T_new, self.n_heads, self.d_h_content)
98        )
99        Q_rope = self.W_Q_rope(h).view(B, T_new, self.n_heads, self.d_rope)
100
101        # --- 5. Apply RoPE on the rope subspace -----------------------------
102        cos_new = self.rope_cos[past : past + T_new]
103        sin_new = self.rope_sin[past : past + T_new]
104        cos_all = self.rope_cos[: T]
105        sin_all = self.rope_sin[: T]
106
107        Q_rope = apply_rope(Q_rope, cos_new[:, None, :], sin_new[:, None, :])
108        k_rope = apply_rope(k_rope, cos_all,            sin_all)            # (B, T, d_rope)
109        # Broadcast the shared rope-key to every head.
110        K_rope = k_rope.unsqueeze(2).expand(B, T, self.n_heads, self.d_rope)
111
112        # --- 6. Concatenate content + rope along head dim -------------------
113        Q = torch.cat([Q_content, Q_rope], dim=-1)        # (B, T_new, H, d_h_content+d_rope)
114        K = torch.cat([K_content, K_rope], dim=-1)        # (B, T,     H, d_h_content+d_rope)
115
116        # --- 7. Attention with a causal mask ---------------------------------
117        Q = Q.transpose(1, 2)                              # (B, H, T_new, d)
118        K = K.transpose(1, 2)
119        V = V.transpose(1, 2)
120
121        scores = (Q @ K.transpose(-2, -1)) * self.scale    # (B, H, T_new, T)
122
123        causal = torch.ones(T_new, T, dtype=torch.bool, device=h.device).triu(
124            diagonal=past + 1
125        )
126        scores = scores.masked_fill(causal, float("-inf"))
127        attn = F.softmax(scores, dim=-1).type_as(V)
128
129        out = attn @ V                                     # (B, H, T_new, d_h_content)
130        out = out.transpose(1, 2).reshape(B, T_new, -1)
131        y = self.W_O(out)
132
133        return y, (c_KV, k_rope)

This is the training path. It is correct, but at decode time it rebuilds KcontentK_{\text{content}} and VV on every step — which means we readcKVc_{KV} from HBM, run two matmuls against WUK,WUVW^{UK}, W^{UV}, and write the result back out before attention can even start. The next two sections remove that overhead for inference.


Prefill vs Decode: Two Code Paths

A serving system runs the same MLA layer in two very different regimes. Treating them identically is correct; treating them identically is also a 5–20× performance loss.

PrefillDecode
T_newT (full prompt)1 (single token)
Cache stateNone at layer entry(c_KV_past, k_rope_past) of length T−1
Dominant costCompute (matmuls scale with T²)Bandwidth (read full cache per step)
Optimal strategyFused attention kernel, large tile sizesAbsorbed projections, minimum HBM traffic
Arithmetic intensityHigh — compute-boundLow (~1 FLOP/byte) — bandwidth-bound

Prefill happily uses the layer above — the matmuls are huge, T2T^2 scaling dominates the cache read, and the cost of reconstructing KK from the latent is negligible compared to the attention matmul itself.

Decode is a different animal. Each step processes ONE token. Compute per step is tiny. What is not tiny is the cache: at T=128kT = 128k you must read every cached entry once per layer per step. For a 70B-parameter model that is multiple gigabytes of HBM traffic per generated token. The absorption trick exists for exactly this regime.


The Absorption Trick — Where the Inference Speedup Lives

Recall the content-side score for one head:

scontent=QcontentKcontent=(hWcontentQ)(cKVWUK).s_{\text{content}} = Q_{\text{content}} \cdot K_{\text{content}}^\top = (h W^Q_{\text{content}}) \cdot (c_{KV} W^{UK})^\top.

By associativity:

scontent=h(WcontentQWUK)cKV=hWQKcKV.s_{\text{content}} = h \big(W^Q_{\text{content}} W^{UK\,\top}\big) c_{KV}^\top = h \, W^{QK} \, c_{KV}^\top.

Same dot product, different parenthesization. The fused matrix WQK=WcontentQWUKW^{QK} = W^Q_{\text{content}} W^{UK\,\top} depends only on the trained weights — pre-compute it once at load time. At decode time you compute the query directly in the latent space and dot it against cKVc_{KV} — the full-width KcontentK_{\text{content}} tensor is never allocated.

The same identity applies on the value side. The output of the attention block can be written:

y=attnVWO=attncKV(WUVWO)=attncKVWVO.y = \text{attn} \cdot V \cdot W^O = \text{attn} \cdot c_{KV} \cdot (W^{UV} W^O) = \text{attn} \cdot c_{KV} \cdot W^{VO}.

Fuse WUVW^{UV} with the output projection and the value path also stays in latent space until the very last step.

Absorbed inference path
🐍mla_decode.py
7The identity that makes this work

Matrix multiplication is associative. The query projection (Q = h @ W_Q) followed by the dot product against decompressed keys (K = c_KV @ W_UK) factors into a query-side projection (h @ (W_Q @ W_UK)) against the latents themselves. Same arithmetic, different parenthesization, vastly different memory pattern.

EXAMPLE
(h W_Q)(c W_UK)^T = h (W_Q W_UK^T) c^T
13Why this saves bandwidth, not FLOPs

The total FLOP count is similar — you still multiply matrices of comparable size. The win is that you NEVER allocate K_content in HBM during decode. At T=128K and n_heads=128 that is gigabytes of avoided traffic per token. Decode is bandwidth-bound (see the simulator below), so bandwidth IS speed.

23Run ONCE at load time

absorb() is a function of the trained weights only — call it after loading a checkpoint, store the fused tensors, then throw away W_UK and W_Q_content if you want to. The training graph still uses the unfused form (with autograd through both matrices); inference uses the fused form.

28View, don&apos;t copy

.view(n_h, d_h, d_m) reinterprets the same memory — no copy, no extra allocation. It splits the (n_heads*d_h) axis of W_Q into per-head slices so we can fuse one head at a time. Same idea for W_UK.

33The fusion einsum

&apos;hdm,hdc->hmc&apos; sums over d_h (the per-head dim that gets dot-producted in attention) for each head h. The output W_QK has shape (n_heads, d_model, d_c) — it maps a hidden state directly into latent space, per head.

41The same trick on the value side

(attn @ V) @ W_O = attn @ c_KV @ (W_UV @ W_O). Fusing W_UV with the output projection lets us compute the attention-weighted context in LATENT space, then up-project ONCE to d_model — instead of up-projecting per-head per-token.

56@torch.no_grad on the decode path

Inference does not need the autograd tape. Disabling it saves memory and some kernel overhead. If you forget this on a long generation, activation memory grows linearly with tokens — a common cause of slow / OOM decode.

62Cache write is unchanged

The cache entries (c_KV_new, k_rope_new) are exactly the same as in the unfused forward. Absorption changes how we READ the cache, not what we WRITE — so the cache format on disk and across kernels stays identical.

70Project hidden state directly into latent space

Q_latent[h, t, c] tells you, for head h at the new token, how strongly this token looks for latent dimension c. We never go through the d_h_content-wide intermediate — that intermediate would BE K_content if we built it.

73Score in latent space

Dot product between Q_latent (B, H, 1, d_c) and c_KV (B, T, d_c) gives the content attention scores directly. The arithmetic is identical to the unfused path — only the memory layout differs.

89Context vector lives in latent space too

ctx = attn @ c_KV stays in d_c dimensions. We never materialize V_content as a (B, H, T, d_h) tensor. The final einsum with W_VO does the up-projection from latent to d_model in one fused step.

71 lines without explanation
1# ----------------------------------------------------------------------
2# Inference-only: bake W_UK into W_Q_content ONCE per checkpoint, then
3# you never reconstruct K_content again during decode.
4# ----------------------------------------------------------------------
5#
6# Math:  q · K_content  =  Q_content @ (W_UK @ c_KV.T)
7#                       =  (Q_content @ W_UK) @ c_KV.T          # associativity
8#
9# Let W_QK = W_Q_content @ W_UK    (shape: d_model x n_heads x d_c)
10# Then at decode time:
11#     Q_latent = h @ W_QK          # NEVER materialize K_content
12#     score    = Q_latent @ c_KV.T # dot product in the LATENT space
13#
14# Memory savings: K_content (n_heads * d_h_content per token) never exists.
15# FLOPs savings : decode FLOPs drop because the head-dim of the matmul
16#                 collapses from d_h_content -> d_c (in DeepSeek: 128 -> 512;
17#                 worse per-head, but you do it ONCE across all heads jointly).
18# ----------------------------------------------------------------------
19
20def absorb(mla: MLA) -> dict:
21    """Pre-compute the fused projections for inference. Run ONCE at load."""
22    n_h = mla.n_heads
23    d_c = mla.d_c
24    d_h = mla.d_h_content
25    d_m = mla.W_DKV.in_features
26
27    # W_Q_content: (n_h*d_h, d_m)  --reshape-->  (n_h, d_h, d_m)
28    # W_UK       : (n_h*d_h, d_c)  --reshape-->  (n_h, d_h, d_c)
29    W_Q = mla.W_Q_content.weight.view(n_h, d_h, d_m)
30    W_UK = mla.W_UK.weight.view(n_h, d_h, d_c)
31
32    # Fuse: for each head h,  W_QK[h] = W_Q[h].T @ W_UK[h]    (d_m x d_c)
33    W_QK = torch.einsum("hdm,hdc->hmc", W_Q, W_UK)            # (n_h, d_m, d_c)
34
35    # Same trick for the value side at the OUTPUT projection:
36    # out = (attn @ V) @ W_O, with V = c_KV @ W_UV
37    #     = attn @ (c_KV @ W_UV) @ W_O
38    #     = attn @ c_KV @ (W_UV @ W_O)          # fuse W_UV . W_O
39    W_UV = mla.W_UV.weight.view(n_h, d_h, d_c)
40    W_O  = mla.W_O.weight.view(d_m, n_h, d_h).permute(1, 2, 0)  # (n_h, d_h, d_m)
41    W_VO = torch.einsum("hdc,hdm->hcm", W_UV, W_O)             # (n_h, d_c, d_m)
42
43    return {"W_QK": W_QK, "W_VO": W_VO}
44
45
46@torch.no_grad()
47def mla_decode_step(mla: MLA, fused: dict, h_new: torch.Tensor,
48                    cache: tuple[torch.Tensor, torch.Tensor]):
49    """One token of decode using the absorbed projections."""
50    B = h_new.shape[0]
51    c_KV_past, k_rope_past = cache
52
53    # --- 1. Cache write (the only new entry per step) ----------------------
54    c_KV_new   = mla.W_DKV(h_new)                  # (B, 1, d_c)
55    k_rope_new = mla.W_K_rope(h_new)               # (B, 1, d_rope)
56    c_KV   = torch.cat([c_KV_past,   c_KV_new],   dim=1)   # (B, T, d_c)
57    k_rope = torch.cat([k_rope_past, k_rope_new], dim=1)   # (B, T, d_rope)
58    T = c_KV.shape[1]
59    past = T - 1
60
61    # --- 2. Query is computed in LATENT space directly ----------------------
62    # Q_latent[h] = h_new @ W_QK[h]      shape: (B, 1, d_c)
63    Q_latent = torch.einsum("btm,hmc->bhtc", h_new, fused["W_QK"])
64    # content score: Q_latent[h] @ c_KV.T          shape: (B, H, 1, T)
65    content_score = torch.einsum("bhtc,bsc->bhts", Q_latent, c_KV)
66
67    # --- 3. RoPE path (cheap — d_rope is small) -----------------------------
68    Q_rope = mla.W_Q_rope(h_new).view(B, 1, mla.n_heads, mla.d_rope)
69    Q_rope = apply_rope(Q_rope, mla.rope_cos[past:past+1, None, :],
70                                mla.rope_sin[past:past+1, None, :])
71    k_rope_rot = apply_rope(k_rope, mla.rope_cos[:T], mla.rope_sin[:T])  # (B,T,d_rope)
72    rope_score = torch.einsum("bthd,bsd->bhts", Q_rope, k_rope_rot)
73
74    # --- 4. Combined attention (no K_content ever materialized) -------------
75    scores = (content_score + rope_score) * mla.scale
76    attn = F.softmax(scores, dim=-1)              # (B, H, 1, T)
77
78    # --- 5. Output via fused W_VO -------------------------------------------
79    # out = (attn @ c_KV) @ W_VO[h]   — V_content never materialized either
80    ctx = torch.einsum("bhts,bsc->bhtc", attn, c_KV)              # (B, H, 1, d_c)
81    y   = torch.einsum("bhtc,hcm->btm", ctx, fused["W_VO"])       # (B, 1, d_model)
82    return y, (c_KV, k_rope)

The whole point in one sentence

Absorption converts a decompress-then-attend pipeline into an attend-in-latent-space pipeline, removing two of the three largest tensors that ever touch HBM during decode while computing exactly the same scores and exactly the same output.


The Decode Loop in Full

Wired together, generation looks like this:

🐍python
1@torch.no_grad()
2def generate(model, prompt_ids, max_new_tokens=128):
3    # ---- 1. Prefill with the unfused (training) path -----------------------
4    h = model.embed(prompt_ids)                            # (B, T_prompt, d_model)
5    cache_per_layer = [None] * len(model.layers)
6    for i, layer in enumerate(model.layers):
7        h, cache_per_layer[i] = layer.attn(h, cache=None)
8        h = layer.ffn(h) + h
9
10    next_token = sample(model.lm_head(h[:, -1]))
11
12    # ---- 2. Build absorbed projections once ---------------------------------
13    fused = [absorb(layer.attn) for layer in model.layers]
14
15    # ---- 3. Decode loop with the absorbed path ------------------------------
16    out_ids = [next_token]
17    for _ in range(max_new_tokens):
18        h_new = model.embed(next_token[:, None])           # (B, 1, d_model)
19        for i, layer in enumerate(model.layers):
20            h_new, cache_per_layer[i] = mla_decode_step(
21                layer.attn, fused[i], h_new, cache_per_layer[i],
22            )
23            h_new = layer.ffn(h_new) + h_new
24        next_token = sample(model.lm_head(h_new[:, -1]))
25        out_ids.append(next_token)
26        if (next_token == model.eos).all():
27            break
28    return torch.cat(out_ids, dim=-1)

Two passes through the cache, two different code paths, one shared cache format. The prefill writes (cKV,kR)(c_{KV}, k^R) in the natural form; the decode path reads it in latent space and never materializes KcontentK_{\text{content}} or VV.

Real serving frameworks (vLLM, SGLang, TGI) replace the torch.cat in the cache stitch with a paged allocator — KV blocks live in fixed-size pages, and a token-to-page table maps logical positions to physical pages. This makes batched decode with variable-length sequences memory-efficient. The MLA layer itself does not change; only the allocator does.


Manual Numerical Walkthrough

Click to expand: one decode step, by hand

We will run a single decode step at position t=2t = 2 with nh=1n_h = 1, dhc=2d_h^c = 2, dhR=0d_h^R = 0 (RoPE off — keep your head free), and dc=2d_c = 2.

Cache from prior steps: cKV(0)=[1,0]c_{KV}^{(0)} = [1, 0], cKV(1)=[0,1]c_{KV}^{(1)} = [0, 1].

New hidden state: h=[1,1]h = [1, 1].

Weights: WDKV=IW^{DKV} = I, WUK=IW^{UK} = I, WQ=IW^Q = I, WUV=IW^{UV} = I, WO=IW^O = I.

Step 1 — write cache: cKV(2)=hWDKV=[1,1]c_{KV}^{(2)} = h W^{DKV} = [1, 1]. Cache is now [[1,0],[0,1],[1,1]][[1,0],[0,1],[1,1]].

Step 2 — absorbed query: WQK=WQ(WUK)=IW^{QK} = W^Q (W^{UK})^\top = I, so Qlatent=hWQK=[1,1]Q_{\text{latent}} = h W^{QK} = [1, 1].

Step 3 — scores in latent space:
s0=[1,1][1,0]=1s_0 = [1,1] \cdot [1,0] = 1
s1=[1,1][0,1]=1s_1 = [1,1] \cdot [0,1] = 1
s2=[1,1][1,1]=2s_2 = [1,1] \cdot [1,1] = 2

Step 4 — scale by 1/dhc+dhR=1/21/\sqrt{d_h^c + d_h^R} = 1/\sqrt{2}: s=[0.707,0.707,1.414]s = [0.707, 0.707, 1.414].

Step 5 — softmax: exp(s)=[2.028,2.028,4.113]\text{exp}(s) = [2.028, 2.028, 4.113], sum = 8.169.
attn=[0.248,0.248,0.504]\text{attn} = [0.248, 0.248, 0.504].

Step 6 — value path in latent space: WVO=WUVWO=IW^{VO} = W^{UV} W^O = I, so the context is just attncKV\text{attn} \cdot c_{KV}:
y=0.248[1,0]+0.248[0,1]+0.504[1,1]y = 0.248 \cdot [1,0] + 0.248 \cdot [0,1] + 0.504 \cdot [1,1]
y=[0.752,0.752]y = [0.752, 0.752].

The check: if you run the unfused path — decompress K=cKVK = c_{KV}, compute Q=hWQQ = h W^Q, do the dot products explicitly, then run the attention sum against V=cKVV = c_{KV} — you get the same yy to floating-point precision. Two different paths, identical answer. That is the entire correctness argument for absorption.


KV Cache at Scale

The implementation above is most of the story. The other half is the scale: at a real model size, the absolute size of the KV cache is what determines whether long-context serving is economic.

Pick a model preset and sweep the sequence length. The growth is linear in TT, in batch, in number of layers, and (for MHA/GQA) in the number of KV heads. MLA pulls nhn_h out of that product entirely — its per-token, per-layer footprint is dc+dhRd_c + d_h^R, full stop.

Loading KV-cache memory explorer…

Concrete numbers for DeepSeek-V3 (61 layers, 128 heads, head dim 128,dc=512d_c = 512, dhR=64d_h^R = 64) at sequence length 128K, batch 1, fp16:

MechanismPer token, per layerTotal cache (128K, B=1)
MHA (full)2 · 128 · 128 = 32,768 scalars≈ 504 GB
GQA (8 KV heads)2 · 8 · 128 = 2,048 scalars≈ 31.5 GB
MLA512 + 64 = 576 scalars≈ 8.85 GB

MLA is 57×\approx 57\times smaller than the equivalent MHA cache and 3.5×\approx 3.5\times smaller than aggressive GQA. The implementation cost is one extra matrix per head fused into WQW^Q at load time.


Where the Speed Actually Comes From

It is tempting to attribute MLA's decode speedup to "fewer FLOPs." That is not the right model. Modern accelerators have roughly 200200 FLOPs per byte of HBM bandwidth at fp16; attention decode has an arithmetic intensity of roughly 11 FLOP per byte. The hardware spends 199/200 of its potential compute waiting on memory.

So decode latency is, to first order:

latencyKV cache bytesHBM bandwidth.\text{latency} \approx \frac{\text{KV cache bytes}}{\text{HBM bandwidth}}.

Cut the cache, cut the latency, end of story. The simulator below does this calculation for you — pick a model, a batch, a sequence length, and the bandwidth of your accelerator (H100 ≈ 3.35 TB/s, MI300X ≈ 5.3 TB/s, H200 ≈ 4.8 TB/s, B200 ≈ 8 TB/s) and watch the bars move.

Loading decode-latency simulator…

Two things to try in the simulator. First, fix the model and slide the sequence length from 1K to 128K — the MLA bar barely moves on the visual scale, while MHA grows linearly. Second, hold the sequence length fixed and crank up the batch — MLA's win actually grows, because larger batches stretch the cache while the MLA per-token footprint stays tiny.

The other lever that drops out of this view: batch size for serving. Cache bytes per token are constant, so the maximum batch you can fit in HBM is roughly HBM size/(bytes per tokenT)\text{HBM size} / (\text{bytes per token} \cdot T). MLA's 57× compression gives you 57× more concurrent users on the same hardware — at long context that is the dominant economic factor for an inference service.


Production Notes & Common Bugs

Initialization

The composition cKVWUKc_{KV} W^{UK} has the same statistical role as a single Linear from dmodelnhdhd_{\text{model}} \to n_h d_h. Treat the pair as ONE projection for variance scaling. A clean recipe: initialize WDKVW^{DKV} with standard Xavier on (dmodel,dc)(d_{\text{model}}, d_c), then scale WUK,WUVW^{UK}, W^{UV} by 1/dc1/\sqrt{d_c}. Otherwise the variance of K and V is off by a factor of dc\sqrt{d_c} at step 0 and you spend the first few thousand steps just re-normalizing.

Mixed precision

  • Run matmuls in bf16 (preferred) or fp16. Run the softmax and the rope angle accumulation in fp32 — both are sensitive to dynamic range.
  • Cache cKVc_{KV} and kRk^R in the SAME dtype you intend to attend in. Mixed-dtype caches cause expensive dtype-cast kernels on every decode step.
  • The absorbed WQKW^{QK} can be stored in bf16 even if you trained in fp32 — the precision floor at decode is the dtype of the cache, so going higher buys you nothing and costs bandwidth.

RoPE gotchas

  • Make sure dhRd_h^R is even — RoPE rotates in pairs. An odd dim will silently drop the last component.
  • The shared kRk^R must be rotated by absolute position, then dot-producted against qRq^R which is also rotated by absolute position. Do not try to "cache pre-rotated" kRk^R — sliding-window or partial-rope schemes will break this, and the bug is invisible until eval.
  • When extending context past the original max_seq, regenerate the cos/sin tables before the first long forward pass — silently slicing past the buffer length gives wrong rotations rather than an error.

Cache layout

  • Store cKVc_{KV} as (layer, B, T, d_c) contiguous in the last dim — that is the layout decode reads in.
  • For paged caches, keep pages aligned to a multiple of the warp size on your accelerator (32 or 64 tokens per page is typical).
  • The cache for kRk^R is small — do not bother paging it separately, just keep it contiguous next to cKVc_{KV}.

When MLA does NOT help

  • Short context, small batch — the cache fits in SRAM anyway; you eat the W_UK / W_UV compute for no bandwidth payoff. GQA is a better fit below ~4K context.
  • Training from scratch on tiny models — the latent bottleneck slightly degrades language modeling loss until the model is large enough to learn good compression directions. The break-even is roughly dmodel1024d_{\text{model}} \geq 1024 in practice.
  • Per-token-personal caches (e.g., speculative decoding with rejection that rewinds the cache often) — MLA still works, but the absorption's fixed overhead becomes a larger fraction of total time.

Summary

  1. MLA replaces per-token (K,V)(K, V) storage with a single shared latent cKVRdcc_{KV} \in \mathbb{R}^{d_c} plus a small shared rope key kRRdhRk^R \in \mathbb{R}^{d_h^R}. The implementation cost is one extra Linear per attention block.
  2. The training forward pass rebuilds KK and VV from the latent every step. That is fine — prefill is compute-bound and the rebuild is dominated by the attention matmul.
  3. The inference forward pass should absorb WUKW^{UK} into WcontentQW^Q_{\text{content}} and WUVW^{UV} into WOW^O at load time. Decode then scores and contextualizes entirely in latent space — KcontentK_{\text{content}} and VV never appear in HBM.
  4. Decode is bandwidth-bound. The size of the cache literally is the latency: ~57× smaller cache means ~57× more concurrent users at long context on the same accelerator.
  5. Decoupled RoPE is the structural fix that lets the absorption identity survive position encoding. Keep the rotated subspace small and shared across heads — it is the only part of the cache that does not compress.

The next chapter (MoE) shows the other DeepSeek-V3 lever: replace the dense FFN with a sparsely-activated expert mixture, so the parameter count scales independently of compute per token. MLA shrinks the cache; MoE shrinks the compute. Used together they are how a 671B-parameter model fits on a handful of GPUs and serves long-context users at interactive latency.

Loading comments...