From Theory to Code
Sections 4.3 and 4.4 derived MLA on paper — joint low-rank compression of the KV cache, decoupled RoPE for position, and the algebraic identity that lets us absorb the up-projection into the query matrix at inference. This section turns that derivation into code you can train and serve.
We are going to build the layer in three passes. First a minimal NumPy version where every intermediate is a real matrix the reader can inspect. Then a complete nn.Module with decoupled RoPE wired in, written the way you would actually ship it. Finally an inference-only fused path that exploits the absorption trick to delete two of the largest tensors in the decode hot loop.
What you should already know
The MLA latent , the two up-projections , and the decoupled-RoPE head split into a content part of dim and a shared rope part of dim . If those symbols look unfamiliar, read sections 4.3 and 4.4 first — this one assumes the math.
The Shape Contract
Every implementation bug in attention code is, at heart, a shape bug. Before any keystroke, pin down the contract for every tensor that crosses a function boundary. If you cannot fill this table from memory for your own layer, you will spend tomorrow debugging einsums.
| Tensor | Shape | Lives where | Notes |
|---|---|---|---|
| h (hidden state input) | (B, T_new, d_model) | per-step input | T_new = 1 in decode, T in prefill |
| c_KV (latent) | (B, T, d_c) | KV cache (persistent) | the only large cache write |
| k_rope (shared rope key) | (B, T, d_rope) | KV cache (persistent) | shared across heads — not per-head |
| Q_content | (B, T_new, n_heads, d_h_content) | transient | rebuilt every step from h |
| Q_rope | (B, T_new, n_heads, d_rope) | transient | rebuilt every step from h |
| K_content (decompressed) | (B, T, n_heads, d_h_content) | transient — training path | FUSED AWAY at inference |
| V (decompressed) | (B, T, n_heads, d_h_content) | transient — training path | FUSED AWAY at inference |
| attn weights | (B, n_heads, T_new, T) | transient | softmax over last dim |
| y (output) | (B, T_new, d_model) | passed to next layer | feeds the residual |
Read the third column carefully. Only and are persistent. Everything else is a transient that lives for the duration of one forward pass. The whole MLA story is moving as much of the work into the "transient" column as possible.
Minimal MLA in Plain NumPy
The smallest possible MLA forward pass — no RoPE, no cache, no batching tricks — fits on one screen of NumPy. Build it first, run it, print the shapes. Then go look at the explanation panel to see what each line is actually doing.
Notice what is missing. There is no position information yet — is the same whether the token is at position 1 or position 10,000. That is the gap decoupled RoPE will fill. And we recompute and from every step — fine for training, wasteful at decode. The absorption trick further down deletes both of those rebuilds.
What you can do with this code
- Set (here: ) and verify the output matches a plain MHA you write next to it — at full rank, MLA degenerates into MHA up to a basis change.
- Shrink to 1 and watch the attention collapse onto a single direction: every key now lives on the same line through the origin, so all scores become co-linear and the softmax flattens.
- Print to confirm there is no reconstruction error at full rank — the loss appears only when .
The pipeline, visualized
The pipeline below is the same arithmetic you just read, run on a slightly larger toy. Pick a token, slide , and watch the reconstruction quality of degrade as the latent gets squeezed.
The reconstruction never has to be perfect. Training co-adapts and the two up-projections to the down-stream attention pattern — the model learns to compress in directions that matter to softmax, not in L2 of the keys. That is why DeepSeek-V3 can push against an MHA-equivalent of without measurable quality loss.
PyTorch: The Module Skeleton
Translating the NumPy version into PyTorch is mostly bookkeeping. Wrap the matrices in nn.Linear so they are tracked by nn.Module and autograd, hand off the softmax to F.softmax, and let einsum or @ do the per-head batching.
We jump straight to the production version below — but you should understand what the skeleton looks like before all the RoPE and cache plumbing arrives. In short:
1class MLA_Skeleton(nn.Module):
2 def __init__(self, d_model, n_heads, d_h, d_c):
3 super().__init__()
4 self.n_heads, self.d_h, self.d_c = n_heads, d_h, d_c
5 self.W_DKV = nn.Linear(d_model, d_c, bias=False)
6 self.W_UK = nn.Linear(d_c, n_heads * d_h, bias=False)
7 self.W_UV = nn.Linear(d_c, n_heads * d_h, bias=False)
8 self.W_Q = nn.Linear(d_model, n_heads * d_h, bias=False)
9 self.W_O = nn.Linear(n_heads * d_h, d_model, bias=False)
10
11 def forward(self, h):
12 c_KV = self.W_DKV(h) # <-- the cache write
13 K = self.W_UK(c_KV).view(*h.shape[:-1], self.n_heads, self.d_h)
14 V = self.W_UV(c_KV).view(*h.shape[:-1], self.n_heads, self.d_h)
15 Q = self.W_Q(h).view(*h.shape[:-1], self.n_heads, self.d_h)
16 # ... scaled dot-product attention with a causal mask ...
17 return y, c_KVThat is the entire scaffold. Every additional line below exists for exactly one of three reasons: handling the KV cache across multiple forward passes, integrating decoupled RoPE without breaking the absorption trick, or wiring up the inference fast path.
RoPE: The Implementation Problem
Section 4.4 explained why RoPE has to be decoupled from the latent. This is what it looks like as an implementation constraint, before we know the fix.
The naive thing to do is rotate right after decompressing it from the latent:
1# WRONG — naive attempt to add RoPE on top of vanilla MLA.
2# This silently breaks the cache invariant.
3
4# Step 1: build full K from the latent (as in plain MLA)
5K = self.W_UK(c_KV).view(B, T, self.n_heads, self.d_h)
6
7# Step 2: rotate K by absolute position
8K = apply_rope(K, cos_all, sin_all) # <-- the bug lives here
9
10# Why this is wrong:
11# RoPE inserts position into K. If we cache c_KV but reconstruct K every
12# step and re-rotate, the rotation is fine. But now we can NO LONGER absorb
13# W_UK into W_Q — because Q @ rotate(K) is NOT Q @ rotate(c_KV @ W_UK)
14# unless we can commute rotation with W_UK, which we cannot. So we lose the
15# absorption trick AND still pay d_h_content per head per token of rotation
16# work. We get the worst of both worlds.This compiles and runs. It even trains. What it quietly destroys is the absorption identity — because once a rotation matrix lives between and , you can no longer fuse them. At inference you are then forced to materialize the full every decode step — which is exactly the cost MLA exists to avoid.
The structural fix
Decoupled RoPE splits every head into two subspaces:
- A content subspace of dim that is fed by the latent and stays rotation-free — so absorption still works on it.
- A RoPE subspace of dim with its own dedicated projection from , rotated by absolute position, and shared across all heads on the key side so the per-token cache footprint stays instead of .
The visual below shows the head split, the rotation on the small RoPE subspace, and how the per-token cache changes as you slide the dials. Drop to zero and watch the rotation panel disappear; raise and watch the cache grow linearly.
The Complete MLA Layer
Here is the full PyTorch layer. It is roughly 100 lines, handles prefill and decode in the same forward(), applies decoupled RoPE on the small shared subspace, and returns a clean cache tuple you can thread through your generation loop.
This is the training path. It is correct, but at decode time it rebuilds and on every step — which means we read from HBM, run two matmuls against , and write the result back out before attention can even start. The next two sections remove that overhead for inference.
Prefill vs Decode: Two Code Paths
A serving system runs the same MLA layer in two very different regimes. Treating them identically is correct; treating them identically is also a 5–20× performance loss.
| Prefill | Decode | |
|---|---|---|
| T_new | T (full prompt) | 1 (single token) |
| Cache state | None at layer entry | (c_KV_past, k_rope_past) of length T−1 |
| Dominant cost | Compute (matmuls scale with T²) | Bandwidth (read full cache per step) |
| Optimal strategy | Fused attention kernel, large tile sizes | Absorbed projections, minimum HBM traffic |
| Arithmetic intensity | High — compute-bound | Low (~1 FLOP/byte) — bandwidth-bound |
Prefill happily uses the layer above — the matmuls are huge, scaling dominates the cache read, and the cost of reconstructing from the latent is negligible compared to the attention matmul itself.
Decode is a different animal. Each step processes ONE token. Compute per step is tiny. What is not tiny is the cache: at you must read every cached entry once per layer per step. For a 70B-parameter model that is multiple gigabytes of HBM traffic per generated token. The absorption trick exists for exactly this regime.
The Absorption Trick — Where the Inference Speedup Lives
Recall the content-side score for one head:
By associativity:
Same dot product, different parenthesization. The fused matrix depends only on the trained weights — pre-compute it once at load time. At decode time you compute the query directly in the latent space and dot it against — the full-width tensor is never allocated.
The same identity applies on the value side. The output of the attention block can be written:
Fuse with the output projection and the value path also stays in latent space until the very last step.
The whole point in one sentence
Absorption converts a decompress-then-attend pipeline into an attend-in-latent-space pipeline, removing two of the three largest tensors that ever touch HBM during decode while computing exactly the same scores and exactly the same output.
The Decode Loop in Full
Wired together, generation looks like this:
1@torch.no_grad()
2def generate(model, prompt_ids, max_new_tokens=128):
3 # ---- 1. Prefill with the unfused (training) path -----------------------
4 h = model.embed(prompt_ids) # (B, T_prompt, d_model)
5 cache_per_layer = [None] * len(model.layers)
6 for i, layer in enumerate(model.layers):
7 h, cache_per_layer[i] = layer.attn(h, cache=None)
8 h = layer.ffn(h) + h
9
10 next_token = sample(model.lm_head(h[:, -1]))
11
12 # ---- 2. Build absorbed projections once ---------------------------------
13 fused = [absorb(layer.attn) for layer in model.layers]
14
15 # ---- 3. Decode loop with the absorbed path ------------------------------
16 out_ids = [next_token]
17 for _ in range(max_new_tokens):
18 h_new = model.embed(next_token[:, None]) # (B, 1, d_model)
19 for i, layer in enumerate(model.layers):
20 h_new, cache_per_layer[i] = mla_decode_step(
21 layer.attn, fused[i], h_new, cache_per_layer[i],
22 )
23 h_new = layer.ffn(h_new) + h_new
24 next_token = sample(model.lm_head(h_new[:, -1]))
25 out_ids.append(next_token)
26 if (next_token == model.eos).all():
27 break
28 return torch.cat(out_ids, dim=-1)Two passes through the cache, two different code paths, one shared cache format. The prefill writes in the natural form; the decode path reads it in latent space and never materializes or .
Real serving frameworks (vLLM, SGLang, TGI) replace the torch.cat in the cache stitch with a paged allocator — KV blocks live in fixed-size pages, and a token-to-page table maps logical positions to physical pages. This makes batched decode with variable-length sequences memory-efficient. The MLA layer itself does not change; only the allocator does.
Manual Numerical Walkthrough
Click to expand: one decode step, by hand
We will run a single decode step at position with , , (RoPE off — keep your head free), and .
Cache from prior steps: , .
New hidden state: .
Weights: , , , , .
Step 1 — write cache: . Cache is now .
Step 2 — absorbed query: , so .
Step 3 — scores in latent space:
Step 4 — scale by : .
Step 5 — softmax: , sum = 8.169.
.
Step 6 — value path in latent space: , so the context is just :
.
The check: if you run the unfused path — decompress , compute , do the dot products explicitly, then run the attention sum against — you get the same to floating-point precision. Two different paths, identical answer. That is the entire correctness argument for absorption.
KV Cache at Scale
The implementation above is most of the story. The other half is the scale: at a real model size, the absolute size of the KV cache is what determines whether long-context serving is economic.
Pick a model preset and sweep the sequence length. The growth is linear in , in batch, in number of layers, and (for MHA/GQA) in the number of KV heads. MLA pulls out of that product entirely — its per-token, per-layer footprint is , full stop.
Concrete numbers for DeepSeek-V3 (61 layers, 128 heads, head dim 128,, ) at sequence length 128K, batch 1, fp16:
| Mechanism | Per token, per layer | Total cache (128K, B=1) |
|---|---|---|
| MHA (full) | 2 · 128 · 128 = 32,768 scalars | ≈ 504 GB |
| GQA (8 KV heads) | 2 · 8 · 128 = 2,048 scalars | ≈ 31.5 GB |
| MLA | 512 + 64 = 576 scalars | ≈ 8.85 GB |
MLA is smaller than the equivalent MHA cache and smaller than aggressive GQA. The implementation cost is one extra matrix per head fused into at load time.
Where the Speed Actually Comes From
It is tempting to attribute MLA's decode speedup to "fewer FLOPs." That is not the right model. Modern accelerators have roughly FLOPs per byte of HBM bandwidth at fp16; attention decode has an arithmetic intensity of roughly FLOP per byte. The hardware spends 199/200 of its potential compute waiting on memory.
So decode latency is, to first order:
Cut the cache, cut the latency, end of story. The simulator below does this calculation for you — pick a model, a batch, a sequence length, and the bandwidth of your accelerator (H100 ≈ 3.35 TB/s, MI300X ≈ 5.3 TB/s, H200 ≈ 4.8 TB/s, B200 ≈ 8 TB/s) and watch the bars move.
Two things to try in the simulator. First, fix the model and slide the sequence length from 1K to 128K — the MLA bar barely moves on the visual scale, while MHA grows linearly. Second, hold the sequence length fixed and crank up the batch — MLA's win actually grows, because larger batches stretch the cache while the MLA per-token footprint stays tiny.
The other lever that drops out of this view: batch size for serving. Cache bytes per token are constant, so the maximum batch you can fit in HBM is roughly . MLA's 57× compression gives you 57× more concurrent users on the same hardware — at long context that is the dominant economic factor for an inference service.
Production Notes & Common Bugs
Initialization
The composition has the same statistical role as a single Linear from . Treat the pair as ONE projection for variance scaling. A clean recipe: initialize with standard Xavier on , then scale by . Otherwise the variance of K and V is off by a factor of at step 0 and you spend the first few thousand steps just re-normalizing.
Mixed precision
- Run matmuls in bf16 (preferred) or fp16. Run the softmax and the rope angle accumulation in fp32 — both are sensitive to dynamic range.
- Cache and in the SAME dtype you intend to attend in. Mixed-dtype caches cause expensive dtype-cast kernels on every decode step.
- The absorbed can be stored in bf16 even if you trained in fp32 — the precision floor at decode is the dtype of the cache, so going higher buys you nothing and costs bandwidth.
RoPE gotchas
- Make sure is even — RoPE rotates in pairs. An odd dim will silently drop the last component.
- The shared must be rotated by absolute position, then dot-producted against which is also rotated by absolute position. Do not try to "cache pre-rotated" — sliding-window or partial-rope schemes will break this, and the bug is invisible until eval.
- When extending context past the original max_seq, regenerate the cos/sin tables before the first long forward pass — silently slicing past the buffer length gives wrong rotations rather than an error.
Cache layout
- Store as (layer, B, T, d_c) contiguous in the last dim — that is the layout decode reads in.
- For paged caches, keep pages aligned to a multiple of the warp size on your accelerator (32 or 64 tokens per page is typical).
- The cache for is small — do not bother paging it separately, just keep it contiguous next to .
When MLA does NOT help
- Short context, small batch — the cache fits in SRAM anyway; you eat the W_UK / W_UV compute for no bandwidth payoff. GQA is a better fit below ~4K context.
- Training from scratch on tiny models — the latent bottleneck slightly degrades language modeling loss until the model is large enough to learn good compression directions. The break-even is roughly in practice.
- Per-token-personal caches (e.g., speculative decoding with rejection that rewinds the cache often) — MLA still works, but the absorption's fixed overhead becomes a larger fraction of total time.
Summary
- MLA replaces per-token storage with a single shared latent plus a small shared rope key . The implementation cost is one extra Linear per attention block.
- The training forward pass rebuilds and from the latent every step. That is fine — prefill is compute-bound and the rebuild is dominated by the attention matmul.
- The inference forward pass should absorb into and into at load time. Decode then scores and contextualizes entirely in latent space — and never appear in HBM.
- Decode is bandwidth-bound. The size of the cache literally is the latency: ~57× smaller cache means ~57× more concurrent users at long context on the same accelerator.
- Decoupled RoPE is the structural fix that lets the absorption identity survive position encoding. Keep the rotated subspace small and shared across heads — it is the only part of the cache that does not compress.
The next chapter (MoE) shows the other DeepSeek-V3 lever: replace the dense FFN with a sparsely-activated expert mixture, so the parameter count scales independently of compute per token. MLA shrinks the cache; MoE shrinks the compute. Used together they are how a 671B-parameter model fits on a handful of GPUs and serves long-context users at interactive latency.