Sections 4.1 and 4.2 left us in an uncomfortable place. The KV cache, not the parameter count, sets the memory wall of long-context inference. Standard multi-head attention (MHA) stores two full-width tensors per token per layer — keys and values — and the total memory grows linearly in sequence length, number of heads, head dimension, and number of layers. For a 128K-token request through a 128-head, 61-layer model that is several hundred gigabytes per sequence. No commercially viable serving stack can pay that.
Multi-query attention (MQA) and grouped-query attention (GQA) chip away at the problem by sharing K and V across heads. MQA collapses all heads to one shared KV; GQA picks an intermediate number of groups. Both give large constant-factor savings, but they impose a hard floor on the cache: you still spend (nheadskv)⋅dh floats per token, and pushing that floor lower means giving away expressivity.
The question that defines this section. Can we drop below the GQA/MQA floor — store far fewer numbers per token — without losing the per-head specialisation that gives multi-head attention its modelling power?
Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2 and reused in DeepSeek-V3 and DeepSeek-R1, answers yes — by reframing the cache as a low-rank joint code, not a tensor of keys and values at all.
The Intuition: An Autoencoder for the KV Cache
Look at what a single layer of MHA stores per token: the keys Kt∈Rnheads⋅dh and the values Vt∈Rnheads⋅dh. Both are deterministic linear projections of the same hidden state ht∈Rdmodel. So we are caching two views of the same source — and we are caching them at full rank, even though in practice trained attention weights and SVD analyses of WK and WV show that the effective rank is much lower than nheads⋅dh.
Here is the MLA idea in one sentence: compress ht once into a small latent vector cKV,t, store only that, and reconstruct the keys and values on the fly when attention needs them.
Analogy — the photographer's contact sheet
Imagine an old film photographer storing a contact sheet — a single page with thumbnail images — instead of the full negatives. The contact sheet is tiny, but it contains enough information to recognise any frame and, with a development step, recover a usable print. MLA is the contact sheet for the KV cache: a low-rank cKV per token, plus two learned decoders (WUK,WUV) that develop it back into Kt and Vt at attention time.
Mathematically this is a learned low-rank factorisation of the would-be key and value projections, with the factor shared between K and V. Two ideas you have already met merge here:
Linear bottleneck — the latent dim dc≪nheads⋅dh forces information to compress, the way the middle layer of an autoencoder does.
Joint coding — keys and values are decoded from the same code, exploiting their statistical dependence. Two separate low-rank decompositions would not see that dependence.
From this picture two questions immediately follow: how small can dc get before performance breaks? and does the up-projection at every attention call destroy the savings? The math below makes both answerable.
Standard MHA: What We Want to Cache Less Of
To keep the derivation honest, start from the standard MHA cache. Given a token's hidden state ht∈Rdmodel, MHA computes per head i:
qt,i=htWiQ, kt,i=htWiK, vt,i=htWiV, with WiQ,WiK,WiV∈Rdmodel×dh.
At decode step t we must hold the keys and values for all previous tokens, across all heads: cachetMHA={(ks,i,vs,i)}s≤t,i≤nheads. That is 2⋅t⋅nheads⋅dh floats per layer.
Mechanism
Floats per token per layer
Comment
MHA
2 · n_heads · d_h
Two full-width tensors per head (K and V).
MQA
2 · d_h
One shared K and V across all heads. Hard floor.
GQA (g groups)
2 · g · d_h
MQA ≤ GQA ≤ MHA. Pick g to trade quality for memory.
MLA
d_c (+ d_h^R, see §4.4)
Single latent, smaller than 2·d_h is allowed.
Notice the last row: MLA breaks out of the MQA floor entirely. Where MQA must keep at least 2dh floats, MLA caches one latent of size dc, and nothing forces dc to be ≥2dh. In DeepSeek-V3 dc=512 and dh=128, so MLA actually caches more per token than MQA would — but it does so once for all heads instead of once per head, and the quality difference is the whole reason MLA exists.
Step 1: Down-Projection into a Latent
Introduce one new learned matrix:
WDKV∈Rdmodel×dc, with dc≪nheads⋅dh.
For every token, compute the latent code once:
cKV,t=htWDKV∈Rdc.
This is the only quantity that enters the KV cache. The cache size per token, per layer, collapses from 2nheadsdh floats to dc floats — a compression ratio of (2nheadsdh)/dc. For DeepSeek-V3 numbers that is (2⋅128⋅128)/512=64×, before counting any extra positional channel.
Three properties of WDKV matter:
Shared across heads. Unlike WiK, there is no per-head copy. One encoder, one cache row.
Shared between K and V. The same latent feeds both up-projections below. That is the "joint" in joint compression — it forces the model to find statistics common to keys and values.
Bottleneck rank.rank(WDKV)≤dc. Every downstream projection that flows through it inherits that rank ceiling.
Step 2: Up-Projection at Forward Time
Attention still needs full-width keys and values to do its job. Introduce two up-projection matrices, one for each:
WUK∈Rdc×nheads⋅dh, WUV∈Rdc×nheads⋅dh.
Reconstruct keys and values from the cached latent:
Kt=cKV,tWUK, Vt=cKV,tWUV.
Then split each across heads — column block i of Kt is head i's key, and similarly for values. From here, attention is exactly the standard scaled dot-product:
Reading this honestly, you should worry. We saved memory on the cache, but now every attention call does an extra O(t⋅dc⋅nheads⋅dh) matmul to recover the keys and values. Has MLA only moved cost from memory to compute? The next subsection answers no.
Step 3: The Absorption Trick (Why MLA Is Fast)
Write out a single attention score in terms of the latent. For the new query at step t attending to past position s, head i:
The two learned matrices WiQ and WiUK are constant once training is done. Define
W~iQ≜WiQ(WiUK)⊤∈Rdmodel×dc.
Then at inference the score becomes
qt,i⋅ks,i⊤=(htW~iQ)⋅cKV,s⊤.
Read this slowly. The dot product against ks,i has become a dot product against cKV,s. The up-projection WiUK has been absorbed into a modified query matrix W~iQ. At inference we never materialise the full per-head key — we project the query into latent space and dot it against the cached latent directly.
The same trick applies on the value side. The output of head i mixes value vectors weighted by attention probabilities, then passes through the output projection WiO. The chain WiUVWiO can likewise be pre-multiplied into a single matrix from the latent to the residual stream, so the up-projection of V never runs at decode time either.
What the absorption costs and saves
Cost: we permanently store the absorbed matrices W~iQ and (analogously) W~iO. They are larger than the originals because they are products. This is a per-layer model-weight overhead, paid once.
Saves: at every decode step, no per-head key matmul, no per-head value matmul, no full-width K and V tensors in memory. The new step computes a single dmodel→dc projection of the query plus a dc×t dot product per head against the latent.
At training the absorption is not done. We keep WDKV,WUK,WUV,WQ,WO as separate learnable matrices so each gradient has a clear signal. At inference we collapse adjacent linear maps into single matrices because nothing in between them is non-linear or per-input.
The Value Stream and the Output Projection
It is worth spelling out the value side because the asymmetry with the query side is sometimes confusing. The attention output for one token, one head, is
ot,i=∑s≤tat,s,ivs,i=∑s≤tat,s,icKV,sWiUV
and the layer output mixes heads through WO∈Rnheads⋅dh×dmodel:
yt=concati(ot,i)WO=∑iot,iWiO, with WiO the block of WO corresponding to head i.
Substituting and grouping:
yt=∑i∑s≤tat,s,icKV,s(WiUVWiO)
The factor WiUVWiO is constant — precompute it as W~iOV. At inference, no value matmul ever expands the latent; the latent flows straight into the residual stream through one absorbed projection per head.
If you have read the original DeepSeek-V2 paper, the equations above are exactly its (35)–(38), rearranged to make the absorption explicit. Different implementations sometimes absorb on only one side, or absorb at runtime via torch.compile rather than statically — the algebra is the same.
Manual Numerical Walkthrough
Click to expand a fully calculated three-token MLA pass with T=3,dmodel=4,nheads=2,dh=2,dc=2. The same numbers reappear in the NumPy and PyTorch implementations below — and you can multiply every step by hand.
For reference: a standard MHA cache at these dimensions would store 2⋅nheads⋅dh=8 floats per token, i.e. 24 floats. MLA stores 6. A 4× ratio in this toy; ~56× in DeepSeek-V3.
Step 2 — up-project to K and V.K=cKVWUK, V=cKVWUV.
The two routes agree exactly. At inference we run the second one and never materialise k0,h0.
Interactive Visualization
Use the explorer below to walk through MLA on a 5-token toy sequence. The Pipeline view shows the original key, the latent code, and the reconstruction side-by-side per token. The Cache Size view scales the latent dim dc and lets you watch the per-token memory cost move past MQA and below. The Heatmap view overlays standard attention weights on MLA weights — the bottleneck warps the similarity space, but the global pattern survives.
Loading MLA Cache Explorer…
Watch what happens at dc=1: every token is squeezed onto a single number, the reconstruction collapses, and the heatmap loses most of its structure. That is what the bottleneck looks like when it is too tight. Practical MLA picks dc well above the rank floor at which loss starts to climb.
Plain Python (NumPy) Implementation
Before wrapping anything in nn.Module, run one MLA forward pass in NumPy. Every weight is hardcoded, every value printable, no autograd in the way. Click any line on the right to see exactly what arrives, what each operation does, and what leaves.
One MLA forward pass — pure NumPy
🐍mla_numpy.py
Explanation(19)
Code(53)
1import numpy as np
We use plain NumPy so every matrix is a printable array — no autograd, no GPU, no hidden machinery. The point of this trace is to see MLA work end-to-end with values you can multiply by hand.
EXECUTION STATE
numpy = Numerical array library. Provides np.ndarray, @ for matmul, einsum, and broadcasting.
3np.random.seed(0)
Pins the PRNG so any later random call is reproducible. Here we use it as a discipline marker even though we hardcode every weight matrix below.
EXECUTION STATE
📚 np.random.seed(s) = Initialises NumPy's global random stream. After seed(0), np.random.randn(2) deterministically returns [1.7641, 0.4002].
6Shared toy config
These are the dimensions we work with end-to-end. They are small enough that every printed array fits on one screen, but rich enough that every MLA mechanism (down-projection, up-projection, multi-head, causal mask) is exercised.
EXECUTION STATE
T = 3 = Sequence length. Three tokens (positions 0, 1, 2).
d_model = 4 = Width of the residual stream (the hidden-state dim).
n_heads = 2 = Number of attention heads. Each head gets its own slice of K and V.
d_h = 2 = Per-head dimension. Concatenated across heads this gives n_heads*d_h = 4 — the full key/value width.
d_c = 2 = MLA latent dimension. THIS is the only thing we will store in the KV cache per token. Compression ratio vs full KV is (2*n_heads*d_h)/d_c = 8/2 = 4× in this toy. In DeepSeek-V3 it is ~56×.
9Hidden states H — input to the attention layer
H is the output of the previous layer (or the embedding for layer 0). One row per token, one column per channel of the residual stream. These are the values that MLA will compress into the latent c_KV.
This is the learned matrix that squeezes the d_model-wide hidden state into the d_c-wide latent. Its rank is the MLA bottleneck. In DeepSeek-V3 this is (7168, 512). Here we pick a simple permutation-like weight so the latent is easy to read.
EXECUTION STATE
W_DKV =
[[1, 0],
[0, 1],
[1, 0],
[0, 1]]
shape (4, 2) = (d_model, d_c)
Role: maps every hidden state h_t to a 2-D code c_KV[t] = h_t @ W_DKV.
17W_UK — the up-projection for keys (the DECOMPRESSOR for K)
When attention actually fires, we need full-width K vectors. W_UK takes the latent c_KV back up to n_heads * d_h = 4 channels. Because the latent is shared between K and V, both K and V are produced from the same c_KV — this is what 'joint compression' means.
EXECUTION STATE
W_UK =
[[1, 0, 0, 1],
[0, 1, 1, 0]]
shape (2, 4) = (d_c, n_heads * d_h)
Role: K = c_KV @ W_UK gives the full per-head key block.
20W_UV — the up-projection for values
Same idea as W_UK but for V. Two different up-projections share the latent. The model is free to learn different decoders for K and V from the same compressed code — that is why a single d_c-wide latent can carry both.
EXECUTION STATE
W_UV =
[[0, 1, 1, 0],
[1, 0, 0, 1]]
shape (2, 4) = (d_c, n_heads * d_h)
Role: V = c_KV @ W_UV gives the full per-head value block.
23W_Q — the query projection (NOT compressed in plain MLA)
Queries are computed fresh from the current token's hidden state — they are never cached, so MLA does not bother compressing them in this minimal form. In the full DeepSeek architecture there is also a query-side latent (c_Q) used purely to reduce activation memory during training; we omit it for clarity.
EXECUTION STATE
W_Q =
[[1, 0, 0, 1],
[0, 1, 1, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]]
shape (4, 4) = (d_model, n_heads * d_h)
Role: Q = H @ W_Q. The last two rows are zeros so the toy query depends only on the first two H columns.
28c_KV = H @ W_DKV — THE KEY LINE
This single line is the entire purpose of MLA. We replace 'cache K_t and V_t per token' (8 numbers per token in this toy) with 'cache c_KV[t]' (2 numbers per token). Everything else in MLA exists to make this compression compatible with multi-head attention.
EXECUTION STATE
⬇ input H =
[[1,0,1,0],[0,1,0,1],[1,1,0,0]] — shape (3, 4)
📚 numpy.matmul / @ = Standard matrix product. For 2-D inputs (A @ B)[i, j] = sum_k A[i,k] * B[k,j]. Here (3,4) @ (4,2) → (3,2).
⬆ return c_KV =
Row 0: 1·[1,0] + 0·[0,1] + 1·[1,0] + 0·[0,1] = [2, 0]
Row 1: 0+[0,1]+0+[0,1] = [0, 2]
Row 2: [1,0]+[0,1]+0+0 = [1, 1]
c_KV = [[2,0],[0,2],[1,1]] shape (3, 2)
→ THIS is what we store. 6 floats for 3 tokens, regardless of n_heads.
31K = (c_KV @ W_UK).reshape(T, n_heads, d_h)
We pull keys back out of the latent and split them across heads. c_KV @ W_UK has shape (T, n_heads*d_h); .reshape splits the trailing dim into (n_heads, d_h) so each head sees its own d_h-wide key slice.
📚 ndarray.reshape(T, H, d_h) = Re-interprets the same 12 floats as (3, 2, 2). No data moved. Useful so attention can index per-head with K[:, h, :].
⬆ return K =
Per head 0 (cols 0,1): [[2,0],[0,2],[1,1]]
Per head 1 (cols 2,3): [[0,2],[2,0],[1,1]]
shape (3, 2, 2)
32V = (c_KV @ W_UV).reshape(T, n_heads, d_h)
Identical pattern to K but with a different up-projection W_UV. Notice that K and V are FUNCTIONS of the same latent c_KV — they share information. This is the 'joint compression' in MLA's name.
Per head 0: [[1,0],[0,1],[1,1]]
Per head 1: [[0,1],[1,0],[1,1]]
shape (3, 2, 2)
38def softmax(x)
Numerically-stable softmax along the last axis: subtract the row max before exponentiating so the largest exponent is exp(0)=1 (no overflow), then normalise. We will apply it to the (H, T, T) attention scores.
EXECUTION STATE
x =
Any tensor of scores. We softmax along the last axis so each row sums to 1 — the standard attention-weight pattern.
📚 x - x.max(axis=-1, keepdims=True) = axis=-1 → reduce along the trailing dim. keepdims=True → keep a length-1 axis so broadcasting works. Result is x shifted so its row max is 0.
43scores = einsum("thd,Thd->htT", Q, K) / sqrt(d_h)
Per-head dot-product attention scores. We use einsum so the head axis stays out front and the result is shaped (H, T_q, T_k) — perfect for masking and softmax. Scaling by sqrt(d_h) prevents large logits from saturating softmax (Vaswani 2017).
EXECUTION STATE
📚 np.einsum('thd,Thd->htT', Q, K) = Contracts the d (per-head) axis. For each head h, computes Q[:, h, :] @ K[:, h, :].T — exactly the per-head score matrix. Output shape: (n_heads, T, T).
Builds the causal mask. np.triu returns the upper-triangular part of a matrix; k=1 means 'strictly above the diagonal'. So mask[i, j] = True exactly when j > i — those are the future positions we must block.
EXECUTION STATE
📚 np.triu(M, k=1) = Zeros everything on and below the k-th diagonal, keeps the rest. With k=1 and an all-ones matrix, we get True only strictly above the diagonal.
Wherever mask is True we overwrite the score with -1e9 (effectively -infinity). After softmax these positions get exp(-1e9) ≈ 0 weight, so the token cannot attend to its future.
EXECUTION STATE
📚 np.where(cond, a, b) = Elementwise: take a where cond is True, b where cond is False. Result shape broadcasts over the inputs.
Per-row softmax across the (T,) key axis gives the attention weights. Because of the mask, row i has zeros at positions > i. This is exactly the per-token attention distribution.
Weighted sum of value vectors per head. For each query position t and head h, we mix the per-position values V[:, h, :] using the weights attn[h, t, :]. Output shape (T, H, d_h) — exactly the shape we need before concatenating heads.
EXECUTION STATE
📚 np.einsum('htT,Thd->thd', attn, V) = Contracts the T (keys) axis. For each (t, h), result[t,h,:] = sum_T attn[h,t,T] * V[T,h,:]. Same math as torch.bmm/matmul, just expressed cleanly.
Concatenate heads. .reshape((T, n_heads*d_h)) glues the d_h-wide slices back into one n_heads*d_h-wide row per token. In a full layer this would feed into the output projection W_O.
1import numpy as np
23np.random.seed(0)45# --- shared toy config (matches the worked example above) ---6T, d_model, n_heads, d_h, d_c =3,4,2,2,278# --- hidden states for 3 tokens (T, d_model) ---9H = np.array([[1.,0.,1.,0.],10[0.,1.,0.,1.],11[1.,1.,0.,0.]])1213# --- MLA learned matrices ---14W_DKV = np.array([[1.,0.],15[0.,1.],16[1.,0.],17[0.,1.]])# (d_model, d_c)18W_UK = np.array([[1.,0.,0.,1.],19[0.,1.,1.,0.]])# (d_c, n_heads * d_h)20W_UV = np.array([[0.,1.,1.,0.],21[1.,0.,0.,1.]])# (d_c, n_heads * d_h)22W_Q = np.array([[1.,0.,0.,1.],23[0.,1.,1.,0.],24[0.,0.,0.,0.],25[0.,0.,0.,0.]])# (d_model, n_heads * d_h)2627# --- Step 1: down-project hidden states into the latent ---28c_KV = H @ W_DKV # (T, d_c) — THIS is what we cache2930# --- Step 2: up-project keys and values from the latent ---31K =(c_KV @ W_UK).reshape(T, n_heads, d_h)# (T, H, d_h)32V =(c_KV @ W_UV).reshape(T, n_heads, d_h)# (T, H, d_h)3334# --- Queries are independent of the cache ---35Q =(H @ W_Q).reshape(T, n_heads, d_h)# (T, H, d_h)3637# --- attention per head, with causal mask ---38defsoftmax(x):39 x = x - x.max(axis=-1, keepdims=True)40 e = np.exp(x)41return e / e.sum(axis=-1, keepdims=True)4243scores = np.einsum("thd,Thd->htT", Q, K)/ np.sqrt(d_h)# (H, T, T)44mask = np.triu(np.ones((T, T), dtype=bool), k=1)# block future45scores = np.where(mask,-1e9, scores)46attn = softmax(scores)# (H, T, T)4748out_per_head = np.einsum("htT,Thd->thd", attn, V)# (T, H, d_h)49out = out_per_head.reshape(T, n_heads * d_h)# (T, n_heads*d_h)5051print("c_KV (cached):\n", c_KV)52print("attn head 0:\n", attn[0])53print("out shape:", out.shape)
Shape table
The eight tensors a plain-MLA forward actually touches:
Tensor
Shape
Role
H
(T, d_model)
Input hidden states from the previous layer
W_DKV
(d_model, d_c)
Down-projection — produces the cache
c_KV
(T, d_c)
Cached latent. THIS is what grows during decode.
W_UK / W_UV
(d_c, n_heads · d_h)
Up-projections (decompressors)
K, V
(T, n_heads, d_h)
Reconstructed per-head keys/values (training)
Q
(T, n_heads, d_h)
Per-head queries from current hidden state
attn
(n_heads, T, T)
Attention weights, row-stochastic, causal
out
(T, n_heads · d_h)
Layer output, pre-W_O
PyTorch Implementation
The PyTorch version of the same forward pass is a direct transcription of the NumPy trace into nn.Module form, with one important addition: it accepts an optional past latent cache so it can be called repeatedly during autoregressive decoding. The cache it returns is the only thing the next step needs.
MultiHeadLatentAttention — PyTorch class
🐍mla_pytorch.py
Explanation(21)
Code(59)
5class MultiHeadLatentAttention(nn.Module)
Inheriting from nn.Module gives us: (1) parameter discovery via .parameters(), (2) .to(device) moves all submodules at once, (3) .eval()/.train() toggles dropout, (4) torch.save/load handles state_dict automatically. Every PyTorch attention block follows this pattern.
EXECUTION STATE
nn.Module = Base class. Any nn.Linear, nn.Parameter, or child nn.Module assigned via self.xxx = ... is registered as a parameter and tracked by autograd.
11__init__(self, d_model, n_heads, d_h, d_c)
Four numbers fully describe an MLA layer's shape. Compare to standard MHA which needs only the first three — d_c is the new knob and it is what makes the cache cheap.
EXECUTION STATE
d_model = Residual-stream width (e.g. 7168 in DeepSeek-V3).
n_heads = Number of attention heads (128 in V3).
d_h = Per-head dim (128 in V3).
d_c = Latent dim. 512 in V3 — vs n_heads*d_h = 16,384 for full K (or V). Ratio = 32× per stream, 64× counting both K and V.
Standard query projection — same as in plain MHA. We deliberately do NOT compress queries here: the inference cost of MLA comes from the per-step cache, and queries are never cached.
EXECUTION STATE
📚 nn.Linear(in, out, bias) = Stores a weight of shape (out, in) and (optionally) a bias of shape (out,). forward(x) returns x @ weight.T + bias. bias=False removes the additive shift — standard in transformer projections.
weight shape = (n_heads * d_h, d_model). For V3 this is (16384, 7168) ≈ 117 M params.
The compressor. Its output is the only tensor MLA stores per token. Because d_c ≪ n_heads * d_h, this projection acts as an information bottleneck — the model is forced to learn a compact joint code for both K and V.
EXECUTION STATE
weight shape = (d_c, d_model). V3: (512, 7168) ≈ 3.7 M params per layer — tiny compared to W_Q.
Why no bias = A bias would shift every cached latent by the same constant — easy to absorb into the up-projection. Removing it costs nothing in expressiveness and saves d_c parameters.
The key decoder. At training time it is used explicitly; at inference time we can pre-multiply it into W_Q so it never needs to run as its own matmul. That trick is the 'absorption' you saw in the derivation.
EXECUTION STATE
weight shape = (n_heads * d_h, d_c). V3: (16384, 512) ≈ 8.4 M params. The pair (W_DKV, W_UK) together is a learned low-rank factorisation of the would-be (d_model → n_heads*d_h) key projection.
The value decoder. The pair (W_DKV, W_UV) is the analogous low-rank factorisation for the value projection. Same d_c → joint compression — that is what makes MLA more than two parallel low-rank attentions.
Standard output projection. After attention concatenates heads, W_O mixes them back into the residual stream. Identical in role to plain MHA — MLA's changes are upstream of this line.
29def forward(self, h, c_KV_past=None)
We accept an OPTIONAL past latent. At training c_KV_past is None and we process the whole sequence. At inference c_KV_past is the growing latent cache and h is just the new token(s).
EXECUTION STATE
h = (B, T_new, d_model). Hidden states of the NEW tokens this call should produce. At training T_new = full seq length; at autoregressive decode T_new = 1.
c_KV_past = (B, T_past, d_c) or None. The previously cached latents for everything before h. The point of MLA is that this is d_c-wide — not n_heads*d_h-wide as a normal KV cache would be.
33c_KV_new = self.W_DKV(h)
One pass through the down-projection for the new hidden states. Shape (B, T_new, d_c). At decode this is computed for a single token and appended to the cache.
EXECUTION STATE
⬇ input h = (B, T_new, d_model). Toy values for one batch: H from the NumPy trace.
📚 nn.Linear.forward(x) = Computes x @ self.weight.T + (bias if not None). Here bias is None.
⬆ return c_KV_new = (B, T_new, d_c). For the toy: [[[2,0],[0,2],[1,1]]] — the same six floats from the NumPy walkthrough.
34c_KV = torch.cat([c_KV_past, c_KV_new], dim=1)
Append the new latents to the existing cache. dim=1 is the sequence axis, so we extend along time. The cache grows by d_c floats per token — independent of n_heads.
EXECUTION STATE
📚 torch.cat(tensors, dim) = Concatenates a list of tensors along the given dim. Shapes must match in every other axis. Returns a new tensor; no inplace mutation of the cache list.
memory growth = B * d_c * sizeof(float) per new token, per layer. At V3 numbers: 1 * 512 * 2 bytes = 1 KB per token per layer.
Up-project the FULL latent stream (past + new) into per-head keys. .view(...) reinterprets the trailing n_heads*d_h axis as (n_heads, d_h) without copying memory.
EXECUTION STATE
📚 Tensor.view(*shape) = Returns a new tensor sharing the same underlying storage but with a different shape. Requires the tensor to be contiguous — which it is here because nn.Linear's output is fresh.
⬆ return K = (B, T, n_heads, d_h). Per-head key block for every position in the sequence.
Same pattern as K, with a different up-projection. Because both K and V are produced from c_KV here, they are highly correlated — the latent is a JOINT code, not two independent compressions.
Queries are computed only for the NEW tokens. At decode T_new=1 so this is the smallest matmul in the forward pass.
EXECUTION STATE
shape = (B, T_new, n_heads, d_h). Crucially T_new — not T — so we never re-query the past.
45Q = Q.transpose(1, 2) (and same for K, V)
Move the head axis to position 1 so the per-head matrices land in the trailing two axes — required by Q @ K.T below. .transpose returns a view (no copy) but the result is no longer contiguous, which is fine for matmul.
EXECUTION STATE
📚 Tensor.transpose(d0, d1) = Swaps axes d0 and d1. Returns a view (no data movement). For a (B, T, H, d) tensor, .transpose(1, 2) gives (B, H, T, d).
49scores = (Q @ K.transpose(-2, -1)) / sqrt(d_h)
Standard scaled dot-product. K.transpose(-2,-1) swaps the last two axes so the per-head shape becomes (d_h, T), making Q @ K.T contract along d_h. Division by sqrt(d_h) is the original Vaswani scaling.
EXECUTION STATE
📚 a.transpose(-2, -1) = Swaps the last two dims. For a (B, H, T, d_h) tensor this gives (B, H, d_h, T) — exactly what matmul wants on the right.
📚 @ on 4-D tensors = PyTorch broadcasts batched matmul: the last two dims are matrix-multiplied, the leading dims are broadcast batches. Output shape: (B, H, T_new, T).
Build a causal mask that is correct even when T_new < T (the typical decode case). The diagonal argument shifts the triangle so position i in the new block can attend to keys 0…(T_past + i) but no further.
EXECUTION STATE
📚 torch.triu(M, diagonal=k) = Returns the upper-triangular part of M starting from the k-th diagonal (k>0 moves the cut above the main diagonal). With k=T-T_new+1 we line up the diagonal exactly at the boundary between past and new.
Why this shift = If T_new=1 (one new token) and T=128 (128-token cache), we want the new token to attend to all 128 keys. diagonal = 128 - 1 + 1 = 128 → triu picks columns ≥ 128 → no future, since there is none. Correct.
Set masked positions to -inf so softmax assigns them zero weight. masked_fill is the PyTorch idiom — broadcasts the boolean mask across batch and head dims for free.
EXECUTION STATE
📚 Tensor.masked_fill(mask, value) = Returns a new tensor identical to self except positions where mask is True are replaced with value. Standard for causal/padding masks.
54attn = F.softmax(scores, dim=-1)
Softmax across the key axis (last axis). dim=-1 is the conventional pick. F.softmax already subtracts the row max internally for numerical stability — you do not need to do it yourself.
EXECUTION STATE
📚 F.softmax(x, dim) = Numerically-stable softmax. Equivalent to (x - x.max(dim)).exp() normalised. dim=-1 means softmax over the trailing axis.
56out = attn @ V
Weighted sum of value vectors. attn has shape (B, H, T_new, T) and V has shape (B, H, T, d_h), so the matmul contracts along T and returns (B, H, T_new, d_h) — exactly the per-head output.
57out = out.transpose(1, 2).reshape(B, T_new, -1)
Move heads next to d_h again, then merge them. -1 in .reshape tells PyTorch to infer that axis as n_heads * d_h. After this line each token has one (n_heads * d_h)-wide vector — ready for W_O.
EXECUTION STATE
📚 Tensor.reshape(*shape) = Like .view but copies if needed (when the tensor is not contiguous). After .transpose the tensor is not contiguous, so .reshape silently copies.
58return self.W_O(out), c_KV
We return BOTH the layer output (the next residual-stream update) AND the updated latent cache. The caller stores c_KV alongside any other layers' caches and passes them back next step.
EXECUTION STATE
⬆ first return = (B, T_new, d_model). Same shape as the input h — drop-in compatible with the residual stream.
⬆ second return = (B, T, d_c) — the new full latent cache. This is the ONLY thing the next decode step needs from this layer's past. NOT K, NOT V, just c_KV.
38 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
45classMultiHeadLatentAttention(nn.Module):6"""
7 Plain MLA (no RoPE — section 4.4 adds it).
8 Caches only c_KV ∈ R^{d_c} per token instead of K and V.
9 """1011def__init__(self, d_model:int, n_heads:int, d_h:int, d_c:int):12super().__init__()13 self.n_heads, self.d_h, self.d_c = n_heads, d_h, d_c
1415# Queries are produced fresh from the current hidden state.16 self.W_Q = nn.Linear(d_model, n_heads * d_h, bias=False)1718# Down-projection — produces the cached latent.19 self.W_DKV = nn.Linear(d_model, d_c, bias=False)2021# Up-projections — decompress the latent into K and V at forward time.22 self.W_UK = nn.Linear(d_c, n_heads * d_h, bias=False)23 self.W_UV = nn.Linear(d_c, n_heads * d_h, bias=False)2425# Output projection — same role as in standard MHA.26 self.W_O = nn.Linear(n_heads * d_h, d_model, bias=False)2728defforward(self, h: torch.Tensor, c_KV_past: torch.Tensor |None=None):29# h: (B, T_new, d_model). c_KV_past: (B, T_past, d_c) or None.30 B, T_new, _ = h.shape
3132# 1. Compress the new tokens once. THIS is what we cache.33 c_KV_new = self.W_DKV(h)# (B, T_new, d_c)34 c_KV = c_KV_new if c_KV_past isNone \
35else torch.cat([c_KV_past, c_KV_new], dim=1)# (B, T, d_c)36 T = c_KV.shape[1]3738# 2. Up-project the FULL latent stream to keys and values.39 K = self.W_UK(c_KV).view(B, T, self.n_heads, self.d_h)40 V = self.W_UV(c_KV).view(B, T, self.n_heads, self.d_h)4142# 3. Queries from the new hidden states only.43 Q = self.W_Q(h).view(B, T_new, self.n_heads, self.d_h)4445# 4. Standard scaled dot-product attention with a causal mask.46 Q = Q.transpose(1,2)# (B, H, T_new, d_h)47 K = K.transpose(1,2)# (B, H, T, d_h)48 V = V.transpose(1,2)# (B, H, T, d_h)4950 scores =(Q @ K.transpose(-2,-1))/(self.d_h **0.5)# (B, H, T_new, T)51 causal = torch.triu(torch.ones(T_new, T, dtype=torch.bool,52 device=h.device),53 diagonal=T - T_new +1)54 scores = scores.masked_fill(causal,float("-inf"))55 attn = F.softmax(scores, dim=-1)5657 out = attn @ V # (B, H, T_new, d_h)58 out = out.transpose(1,2).reshape(B, T_new,-1)# concat heads59return self.W_O(out), c_KV # return new cache
Note on the absorption trick. The class above does not implement the absorbed forward path. At training and prefill, materialising K and V is fine because we process the whole sequence and there is no per-step compression payoff. The absorbed version of forward — where queries are projected straight into dc space and dotted against cKV — is what production decoders run. Section 4.6 ("Implementing MLA in PyTorch") builds the absorbed kernel explicitly.
At Massive Scale: DeepSeek-V3 Numbers
Plug DeepSeek-V3's actual hyperparameters into the MLA cache formula and the engineering story becomes vivid:
Quantity
Value
Where it comes from
Layers (L)
61
1 dense + 60 MoE
d_model
7168
Residual stream width
n_heads
128
Multi-head split
d_h
128
Per-head dim
d_c
512
MLA KV latent dim
d_h^R
64
Decoupled-RoPE channel (§4.4)
Context
128 K
Long-context inference target
Precision
BF16 (2 bytes)
Standard serving dtype
Standard MHA cache. Per token, per layer: 2⋅nheads⋅dh=32,768 floats = 64 KB at BF16. Across 61 layers: 3.91 MB per token. At 128 K context: ~500 GB per sequence. Not servable on a single node.
MLA cache. Per token, per layer: dc+dhR=576 floats = 1.15 KB at BF16. Across 61 layers: 70.3 KB per token. At 128 K context: ~9.2 GB per sequence. Servable on a single H100.
Ratio: (32768)/(576)≈56.9× per layer. The DeepSeek-V2 paper reports a 93.3% reduction in KV cache memory with no loss of quality versus its MHA baseline at equal head count; DeepSeek-V3 reports comparable savings at much larger scale.
The implication for serving is direct. KV cache is the dominant memory cost of long-context inference, more than parameter weights when sequences run long. Cutting it by 56× changes which models can run on which hardware, how many concurrent users a node can hold, and how aggressively a batch can be packed. In DeepSeek-V3's reported serving setup, MLA is the reason the model fits within commodity multi-GPU nodes at 128 K context lengths.
The arithmetic of attention cost itself
Memory is half the story. Counting the FLOPs of one attention step at decode (one new token, T in the cache):
Standard MHA:2nheadsdhT for the query-key dot products plus the same for the value mix.
MLA (absorbed):2nheadsdcT for the query-latent dots plus a one-time dmodel→dc projection of the new query.
At V3 numbers the FLOP ratio for the dominant attention term is dh/dc=128/512=1/4 — MLA is cheaper per step than the MHA baseline it replaces, not just smaller in memory. This is the rare case in transformer engineering where the memory-saving move is also the compute-saving move; usually you trade one for the other.
What Can Go Wrong in Practice
MLA looks like a free lunch on paper. Three things bite in practice.
1. The latent dimension is a real architectural knob
Too small dc and the model loses per-head resolution — heatmaps blur, downstream perplexity climbs. The DeepSeek-V2 ablations sweep dc over a wide range and identify a sharp elbow below which loss degrades quickly. The chosen value (dc=512) sits well above the elbow with a safety margin.
2. RoPE does not pass through the latent cleanly
Rotary positional embedding rotates K (and Q) by position-dependent block-diagonal matrices before the dot product. If you cache cKV and apply RoPE only at decompression time, you need RoPE to commute with the up-projection WUK — which it does not in general. This is not a small detail; naive MLA + RoPE breaks long-range position sensitivity. The fix is the decoupled-RoPE channel covered in §4.4: an additional small per-head positional key ktR∈RdhR is cached alongside the latent and concatenated into the score computation. It is the dhR=64 in the V3 table above.
3. Absorption inflates the weight matrices
Multiplying WQ by (WUK)⊤ gives a denser matrix than either factor. For V3 numbers the absorbed query matrix is dmodel⋅dc per head — larger than the original per-head dmodel⋅dh because dc>dh. The weight overhead is paid once at load time; the inference savings are paid back per token. Worth it for any non-trivial context length, but a real memory tradeoff that the serving stack must budget for.
Implementations that target the absorbed forward path naively can also pessimise the prefill pass, where you do need full K and V across the whole prompt. Production stacks (FlashAttention-3, vLLM's MLA kernel, DeepSeek's own inference repo) run one code path at prefill and a different one at decode. If you are reading public implementations and an MLA forward looks like two functions rather than one, this is why.
Summary
MLA replaces the standard (Kt,Vt) KV cache with a single per-token latent cKV,t=htWDKV of dimension dc≪nheads⋅dh.
Keys and values are reconstructed at training/prefill time via WUK,WUV. The two up-projections share the same latent — "joint compression".
At inference, the up-projections absorb into the query and output projections. The decode-time forward never materialises full K or V. Score computation runs in latent space: query→dc, then dot against cKV.
On DeepSeek-V3 dimensions, MLA cuts the KV cache by ~57× per layer and is also cheaper in attention FLOPs at decode than the MHA baseline it replaces.
The two real complications are an extra positional channel (§4.4's decoupled RoPE) and a one-time weight inflation when the absorbed matrices are stored.
The next section adds the missing piece. Section 4.4 derives the decoupled-RoPE channel from the failure mode it fixes, and completes MLA's connection to long-context positional reasoning.