Chapter 4
22 min read
Section 23 of 117

Decoupled RoPE

Multi-Head Latent Attention (MLA)

Introduction

In the previous section we built MLA from scratch: every token's key and value are reconstructed on the fly from a single low-rank latent vector ctKV=WDKVhtc_t^{KV} = W^{DKV} h_t, and the cache shrinks from 2nhdh2 \cdot n_h \cdot d_h scalars per token to just dcd_c. Beautiful — until we try to add rotary position embeddings (RoPE) and the whole trick falls apart.

RoPE is not a small detail. It is how every modern decoder — LLaMA, Qwen, DeepSeek, Mistral, Yi — injects position into attention. Pulling it out is not negotiable. But naively bolting RoPE onto MLA inflates the cache back to roughly MHA size, throwing away the entire compression win.

DeepSeek's answer is a small, surgical idea called Decoupled RoPE: split each attention head into two slices — a large "content" slice that uses MLA compression without position, and a small "position" slice that carries the RoPE rotation as a single shared key. This section explains exactly why the split is needed, how the math forces it, and how to implement it.

Why this matters: Without Decoupled RoPE there is no practical MLA. Every production MLA model — DeepSeek-V2, V3, R1 — uses this trick. It is the bridge between "a clever paper idea" and "a 671B model that actually fits in your KV cache at 128k context."

4.1 The Real Problem: RoPE Breaks MLA Absorption

What MLA needs to work

Recall the magic of MLA. For a query at position mm and a key at position nn, the attention score is

sm,n  =  (WQhm)(WUKcnKV)  =  hm(WQWUK)cnKV.s_{m,n} \;=\; (W^Q h_m)^\top \, (W^{UK} c_n^{KV}) \;=\; h_m^\top \,\big(W^{Q\,\top} W^{UK}\big)\, c_n^{KV}.

The parenthesised product WQWUKW^{Q\top} W^{UK} is position-independent: it is one fixed matrix that we can precompute at load time and absorb into a single up-projection. That is precisely why MLA only needs to cache cnKVc_n^{KV} per token — no per-head KK is ever materialised at inference time.

What RoPE does to that score

RoPE applies a position-dependent rotation matrix RmR_m to the query and RnR_n to the key. The score becomes

sm,n  =  (RmWQhm)(RnWUKcnKV)  =  hmWQRmRnWUKcnKV.s_{m,n} \;=\; \big(R_m W^Q h_m\big)^\top \big(R_n W^{UK} c_n^{KV}\big) \;=\; h_m^\top \, W^{Q\,\top} \, R_m^{\top} R_n \, W^{UK} \, c_n^{KV}.

Look closely at the middle: RmRnR_m^{\top} R_n. RoPE has the lovely property that this product only depends on the relative position mnm - n. Wonderful for modelling — but catastrophic for MLA, because that matrix is wedged between WQW^Q and WUKW^{UK}. We can no longer multiply them together once and forget about it. Every query position would force a different effective up-projection, which means we would have to either:

  1. materialise the full per-head key Kn=WUKcnKVK_n = W^{UK} c_n^{KV} at every step and apply RnR_n on the fly — re-introducing the full-size cache we just got rid of, or
  2. drop RoPE entirely and lose all relative-position information — quality collapses on long contexts.

Neither option is acceptable. We need a third path.

The bottleneck in one line: RoPE's RmRnR_m^{\top} R_n sits in the middle of the score and refuses to commute. It blocks the matrix absorption that makes MLA cheap.

4.2 Intuition: Cut the Head in Two

Here is the trick. Position information does not need to live in every dimension of every head. The model only needs some channels that carry position; the rest can be pure content. So we physically split each head's dimension into two pieces:

  • NoPE part (size dhcd_h^c): no rotation at all. This part flows through MLA normally — content-only, absorbed up-projection, beautifully cheap.
  • RoPE part (size dhRd_h^R): rotated by RmR_m. This part is small and is computed in the "old" way — but because it is small we can afford it.

Think of an attention head as a wide highway. We are reserving one narrow lane on the right for position-carrying traffic, and letting all the other lanes carry content. The two lanes never interact within a head — they are concatenated and their scores simply add up.

Then DeepSeek adds one more squeeze: the entire RoPE key knRk_n^R is shared across all heads — one singledhRd_h^R-dimensional rotated vector per token, like Multi-Query Attention but only for the tiny position slice. Now the cache is essentially dc+dhRd_c + d_h^R scalars per token, regardless of how many heads the model has.

Mental model: MLA handles "what the token means" via a fat compressed latent. Decoupled RoPE handles "where the token is" via a thin shared rotated key. The two scores are added — and the cache stays tiny.

4.3 The Mathematical Idea

Split the head dimension

For each attention head ii we define two query projections and two key projections:

qm,i=[qm,icqm,iR]Rdhc+dhR,kn,i=[kn,icknR]Rdhc+dhR.q_{m,i} = \begin{bmatrix} q_{m,i}^{c} \\ q_{m,i}^{R} \end{bmatrix} \in \mathbb{R}^{d_h^c + d_h^R}, \qquad k_{n,i} = \begin{bmatrix} k_{n,i}^{c} \\ k_{n}^{R} \end{bmatrix} \in \mathbb{R}^{d_h^c + d_h^R}.

Each symbol:

  • qm,ic=WiUQcmQq_{m,i}^{c} = W_i^{UQ}\, c_m^{Q} — content part of the query for head ii, produced via MLA's query down-up path (or directly from hmh_m).
  • qm,iR=RoPE(WiQRcmQ)mq_{m,i}^{R} = \mathrm{RoPE}\big( W_i^{QR}\, c_m^{Q} \big)_m — RoPE-rotated part of the query, one per head.
  • kn,ic=WiUKcnKVk_{n,i}^{c} = W_i^{UK}\, c_n^{KV} — content part of the key, reconstructed from the latent.
  • knR=RoPE(WKRhn)nk_n^{R} = \mathrm{RoPE}\big( W^{KR}\, h_n \big)_n — RoPE-rotated key. Note the missing head index ii: this vector is shared across all heads.

The attention score

Because the head vector is a concatenation, the dot product splits into two independent terms:

sm,n,i  =  qm,ikn,i  =  qm,ickn,icMLA-absorbable content score  +  qm,iRknRsmall RoPE score (shared kR).s_{m,n,i} \;=\; q_{m,i}^{\top} k_{n,i} \;=\; \underbrace{q_{m,i}^{c\,\top} k_{n,i}^{c}}_{\text{MLA-absorbable content score}} \;+\; \underbrace{q_{m,i}^{R\,\top} k_n^{R}}_{\text{small RoPE score (shared } k^R\text{)}}.

Read this slowly. The first term has no rotation between WUQW^{UQ} and WUKW^{UK}, so MLA absorption still works and we only need to cache cnKVc_n^{KV}. The second term has rotations, but it lives in a tiny dhRd_h^R-dimensional subspace and the key is the same across heads, so we only need to cache one dhRd_h^R-dim vector knRk_n^R per token. Total cache per token:

  KV cache per token  =  dc  +  dhR  \boxed{\;\text{KV cache per token} \;=\; d_c \;+\; d_h^R\;}

DeepSeek-V2 uses dc=512d_c = 512, dhR=64d_h^R = 64, nh=128n_h = 128, dh=128d_h = 128. So per token, MHA would store 2128128=32,7682 \cdot 128 \cdot 128 = 32{,}768 scalars, whereas Decoupled-RoPE MLA stores 512+64=576512 + 64 = 576. That is a 57× compression — and it is fully RoPE-compatible.

Why sharing kRk^R across heads is OK

In Multi-Query Attention we saw that sharing keys across heads costs very little quality if the shared key is given enough capacity. Here we share something even smaller — only the position-carrying slice. The content lanes are still fully per-head, so per-head specialisation survives where it matters. Empirically (DeepSeek-V2 §4.2), this MLA variant matches or beats vanilla MHA on language-modelling losses while slashing the cache.


4.4 Manual Numerical Walkthrough

Let us compute one score end-to-end with absurdly small dimensions so every number is visible. We will use 1 head, dhc=2d_h^c = 2, dhR=2d_h^R = 2, a sequence of 3 tokens, and m=2m = 2, n=1n = 1.

▶ Manual Numerical Walkthrough — Decoupled RoPE score for (m=2, n=1)

Inputs

The MLA latents have already been computed. We pick them by hand:

📝text
1c_m^Q  = [1.0,  0.0]      # query latent at m=2
2c_n^KV = [0.5,  0.5]      # KV latent  at n=1
3h_n    = [1.0,  0.0]      # token hidden at n=1
4
5# Projection matrices (1 head, d_h^c = d_h^R = 2)
6W^UQ = [[ 1.0, 0.0],      # content query up-proj
7        [ 0.0, 1.0]]
8
9W^UK = [[ 1.0, 1.0],      # content key   up-proj
10        [ 1.0,-1.0]]
11
12W^QR = [[ 0.0, 1.0],      # rope query  up-proj
13        [ 1.0, 0.0]]
14
15W^KR = [[ 1.0, 0.0],      # rope key    proj (from h, not latent)
16        [ 0.0, 1.0]]

Step 1 — Content query and content key

📝text
1q_m^c = W^UQ · c_m^Q = [[1,0],[0,1]] · [1,0]
2                     = [1.0, 0.0]
3
4k_n^c = W^UK · c_n^KV = [[1,1],[1,-1]] · [0.5, 0.5]
5                      = [1.0, 0.0]
6
7content score  =  q_m^c · k_n^c  =  1·1 + 0·0  =  1.0

Step 2 — Pre-rotation RoPE query and key

📝text
1q_m^{R,pre} = W^QR · c_m^Q = [[0,1],[1,0]] · [1,0] = [0.0, 1.0]
2k_n^{R,pre} = W^KR · h_n    = [[1,0],[0,1]] · [1,0] = [1.0, 0.0]

Step 3 — Apply RoPE rotation (one 2-D pair)

For a single 2-D pair, RoPE at position pp rotates by angle pθp \cdot \theta where θ=1000020/dhR=1\theta = 10000^{-2 \cdot 0 / d_h^R} = 1 for the first pair (we keep it simple).

📝text
1m = 2  → angle_m = 2 · 1 = 2 rad
2n = 1  → angle_n = 1 · 1 = 1 rad
3
4cos(2)=−0.4161   sin(2)= 0.9093
5cos(1)= 0.5403   sin(1)= 0.8415
6
7R_m · [0, 1] = [−sin(2),  cos(2)] = [−0.9093, −0.4161]
8R_n · [1, 0] = [ cos(1),  sin(1)] = [ 0.5403,  0.8415]

Step 4 — RoPE score

📝text
1q_m^R · k_n^R  = (−0.9093)·0.5403 + (−0.4161)·0.8415
2              = −0.4912 − 0.3502
3              = −0.8414

Sanity check: this equals qmR,preRmnknR,preq_m^{R,pre} \cdot R_{m-n} \cdot k_n^{R,pre} because RoPE is RmRn=RnmR_m^\top R_n = R_{n - m}. The relative position is nm=1n - m = -1, and indeed cos(1)0+sin(1)1=0.8415\cos(-1) \cdot 0 + \sin(-1) \cdot 1 = -0.8415 (small rounding). Position is correctly baked into the score and depends only on mnm - n.

Step 5 — Final score

📝text
1s_{m,n} = content + rope = 1.0 + (−0.8414) = 0.1586

What we just proved

  1. The two scores really do add. No cross terms, no interference.
  2. The RoPE term depends on the relative position mnm - n, just like vanilla RoPE.
  3. The content term used only the cached latent cnKVc_n^{KV} — no full-size KK needed.
  4. The RoPE term used only the cached knRk_n^R (2 scalars) — and would be the same for every head.

4.5 Interactive Visualization

The visualizer below lets you slide dhcd_h^c, dhRd_h^R, dcd_c and the number of heads. The four bars show what the KV cache actually has to store under different schemes — you can watch the "broken" configuration (RoPE on every dimension) blow up to MHA size, and watch the shared-kRk^R configuration stay flat.

Things to try:

  • Set dhR=0d_h^R = 0: the rotation panel disappears and the cache is minimal, but the model would have no position information.
  • Push nhn_h up to 64. Vanilla MHA scales linearly; the shared-kRk^R MLA cache does not move, because the RoPE key does not depend on the head count.
  • Slide the token position. The 2-D rotation snapshot shows the sameknRk_n^R rotating to a different angle — exactly what RoPE does in a real model.

4.6 Plain Python Implementation

Before any framework, let us implement the score in plain Python with math and lists, exactly mirroring section 4.4. The point is to expose every loop and every shape.

Decoupled RoPE score from scratch
🐍decoupled_rope.py
3RoPE on one 2-D pair

Implements the core RoPE operation in its simplest form: rotate a 2-D vector by angle pos * theta. Real RoPE applies this to every consecutive pair of dimensions in the head, each with its own frequency theta_k = base^(−2k/d). We keep one pair so you can see every multiplication.

16Query MLA latent c_m^Q

Already-compressed query latent at position m. Comes from earlier MLA layers (W^DQ · h_m). Shape is just d_q' = 2. In a real model this is hundreds of dims.

17KV latent c_n^KV

The ONLY thing the KV cache will hold for the content path at position n. d_c scalars per token, no per-head storage. This is the MLA win.

18Raw hidden state h_n

Note that W^KR projects from h_n directly, NOT from the latent. This keeps the RoPE key path independent of MLA absorption — a deliberate design choice that makes derivatives and caching simpler.

20Content up-projections

W^UQ and W^UK reconstruct the content slice of q and k from the latents. These are the matrices that MLA absorbs into one big effective projection at inference time, since nothing position-dependent sits between them.

22RoPE-side projections

W^QR makes the rotating query slice from the query latent. W^KR makes the SHARED rotating key slice from the hidden state — only one key per token, regardless of how many heads the model has.

25Pick two positions

m=2 attends to n=1. RoPE only cares about the relative position m − n = 1.

28Content query q_m^c

Result: [1.0, 0.0]. No rotation applied here — this is what makes the absorption W^UQ^T · W^UK valid at inference time.

29Content key k_n^c

Result: [1.0, 0.0]. Built from the cached latent c_n^KV — at inference we never store this full vector, we recompute it on the fly via the absorbed projection.

30Content score (the absorbable term)

Dot product = 1.0. This is exactly the term that benefits from MLA. In production this term is computed as h_m^T · (W^UQ^T · W^UK) · c_n^KV with the parenthesised matrix precomputed once.

33Pre-rotation RoPE query

Result: [0.0, 1.0]. Still has no position information at this point.

34Pre-rotation RoPE key (shared)

Result: [1.0, 0.0]. The single shared key slice — at inference this is what we cache for the RoPE path, just d_h^R scalars per token.

35Apply R_m to the query

Inject position m into the query side. Critical detail: this happens once per query token, but the rotated key only needs to be computed once per key token because we cache the pre-rotation k_R and rotate at attention time — or vice versa (frameworks differ, see §4.7).

36Apply R_n to the shared key

Same rotation idea, but only ONE vector per token regardless of head count. This is the second compression win on top of MLA.

37RoPE score

Dot product after rotation = −0.8414 in our example. Because R_m^T R_n only depends on m − n, the RoPE score is a function of relative position, just like in vanilla RoPE-Attention.

40Final attention score

Sum of the two terms. In real code you would then divide by √(d_h^c + d_h^R), apply causal mask, and softmax — the rest of attention is unchanged.

26 lines without explanation
1import math
2
3def rope_pair(v, pos, theta=1.0):
4    """Rotate a single 2-D vector by angle pos * theta."""
5    a = pos * theta
6    cos_a, sin_a = math.cos(a), math.sin(a)
7    return [v[0] * cos_a - v[1] * sin_a,
8            v[0] * sin_a + v[1] * cos_a]
9
10def dot(a, b):
11    return sum(x * y for x, y in zip(a, b))
12
13def matvec(W, x):
14    return [dot(row, x) for row in W]
15
16# --- Inputs (1 head, d_h^c = 2, d_h^R = 2) ----------------------
17c_m_Q   = [1.0, 0.0]            # MLA query latent  at m
18c_n_KV  = [0.5, 0.5]            # MLA KV latent     at n
19h_n     = [1.0, 0.0]            # raw hidden state  at n
20
21W_UQ = [[1.0, 0.0], [0.0,  1.0]]
22W_UK = [[1.0, 1.0], [1.0, -1.0]]
23W_QR = [[0.0, 1.0], [1.0,  0.0]]
24W_KR = [[1.0, 0.0], [0.0,  1.0]]
25
26m, n = 2, 1
27
28# --- Content path (MLA-absorbable, NO rotation) -----------------
29q_c = matvec(W_UQ, c_m_Q)       # [1.0, 0.0]
30k_c = matvec(W_UK, c_n_KV)      # [1.0, 0.0]
31content_score = dot(q_c, k_c)   # 1.0
32
33# --- RoPE path (small, rotated, shared key) ---------------------
34q_R_pre = matvec(W_QR, c_m_Q)   # [0.0, 1.0]
35k_R_pre = matvec(W_KR, h_n)     # [1.0, 0.0]
36q_R = rope_pair(q_R_pre, m)     # rotated by m
37k_R = rope_pair(k_R_pre, n)     # rotated by n  (shared across heads)
38rope_score = dot(q_R, k_R)
39
40# --- Final attention score (still needs / sqrt(d_h) in real code) ---
41score = content_score + rope_score
42print(f"content={content_score:.4f}  rope={rope_score:.4f}  total={score:.4f}")

Run it and you should see roughly content=1.0000 rope=-0.8415 total=0.1585 — within rounding, matching the hand calculation from §4.4. That is the entire mechanism. Everything in PyTorch is just this, batched and vectorised.


4.7 PyTorch Implementation

Now the production shape. Batches, heads, sequence length, and the full RoPE applied to every consecutive pair of dimensions. We will write the forward of one MLA attention layer with Decoupled RoPE; this is very close to what sits in transformers for DeepSeek-V2.

MLA with Decoupled RoPE (single layer, training-time forward)
🐍mla_decoupled_rope.py
5Layer construction

We pass model width d_model, number of heads, the MLA latent dim d_c, the per-head NoPE width d_h_c, and the per-head RoPE width d_h_R. In DeepSeek-V2: d_c=512, d_h_c=128, d_h_R=64, n_heads=128.

12Why d_h = d_h_c + d_h_R

Each head's effective dimension is the SUM of the two slices. Scaling by sqrt(d_h) keeps softmax variance under control.

14MLA latent down-projection

h → c^KV. This is the only thing we cache for the content path at inference time. Shape per token: d_c (e.g. 512), versus 2 · n_h · d_h_c = 32,768 for plain MHA.

17Content key up-projection

c^KV → per-head k^c. At INFERENCE we don't actually materialise this; we fold W_UK^T into the score via h_m^T · (W_UQ^T · W_UK) · c_n^KV. Training keeps the explicit version because it is more memory-friendly for backprop.

20Per-head RoPE query

n_heads × d_h_R values per token. Each head still has its own rotated query — diversity in WHERE each head looks is preserved.

21SHARED RoPE key

Crucially this is NOT multiplied by n_heads — it produces just d_h_R values per token, regardless of head count. This is the second cache compression on top of MLA itself.

24Cosine and sine table buffer

Precomputed inverse frequencies for the RoPE rotation. Registered as a buffer so it moves with the model to GPU and survives state_dict round-trips, but is NOT a learnable parameter.

28Frequency builder

Standard RoPE frequencies: theta_k = base^(-2k/d) for k = 0, 1, ..., d/2 − 1. base = 10,000 is the original NTK-style choice. Real models often retune base for long context.

33RoPE rotation

Implements R_p · x for every position p in the sequence in one vectorised shot. We build per-position cos/sin and apply the classic [x0·cos − x1·sin, x0·sin + x1·cos] pairing across consecutive dimensions.

37Repeat_interleave to match pair layout

We need [cos_0, cos_0, cos_1, cos_1, ...] so each consecutive pair (x0, x1) rotates with the same angle. Using repeat_interleave(2, -1) over the half-sized frequency table gives exactly that.

38Pair-swapped vector

Constructs [−x1, x0, −x3, x2, ...]. Combined with the cos/sin tables on the next line, this evaluates the 2x2 rotation matrix on every adjacent pair simultaneously.

42Forward begins

h has shape [B, L, d_model]. positions is a 1-D tensor [L] of integer positions (0, 1, 2, ...) — comes from the model wrapper, typically arange(L) plus any KV-cache offset at inference.

47Compute the latent

One projection per token. At inference time this c_KV is what gets pushed into the KV cache for the content path — d_c floats per token.

48Content query split into heads

Standard reshape-transpose dance: [B, L, H·d_h_c] → [B, L, H, d_h_c] → [B, H, L, d_h_c]. After this we can do per-head matmul with one batched call.

50Value also rebuilt from c_KV

V shares the same latent as K — that's the 'joint compression' in MLA. So the cache stores ONE c_KV and reconstructs both K and V on demand.

53Per-head RoPE query

Shape [B, H, L, d_h_R]. Each head has its own RoPE query slice — preserves head-level position diversity.

54Shared k_R: one head dimension

Notice the view uses H=1, not H=n_heads. PyTorch's broadcasting will replicate this single key across all heads when we compute the score — zero extra memory.

55Rotate the queries

Apply R_m position by position. After this, q_R encodes 'query at position m for head h.'

56Rotate the shared key

Apply R_n to the shared key. The KV cache stores k_R AFTER rotation in many implementations — that way we never re-rotate the same key twice during incremental decoding.

59Score = content matmul + rope matmul

Two batched matmuls, summed. PyTorch broadcasts the H=1 dim of the rope term to all H heads automatically — every head sees the same rope key contribution, modulo its own rope query. This is the whole math from §4.3, in one line.

60Scale by sqrt(d_h)

We scale by the FULL head dimension d_h = d_h_c + d_h_R, because the score is the dot product of full-length head vectors. Using only sqrt(d_h_c) would underdamp the softmax.

61Causal mask

Strict upper-triangular −inf so each token only attends to itself and earlier tokens. Same as any decoder-only model.

67Standard attention output

attn @ v then merge heads. v here is per-head content (size d_h_c), so the output projection o_proj maps from H·d_h_c back to d_model.

43 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class DecoupledRoPEAttention(nn.Module):
6    def __init__(self, d_model, n_heads, d_c, d_h_c, d_h_R, max_pos=8192):
7        super().__init__()
8        self.n_heads = n_heads
9        self.d_h_c   = d_h_c
10        self.d_h_R   = d_h_R
11        self.d_h     = d_h_c + d_h_R
12
13        # MLA down/up projections
14        self.W_DKV = nn.Linear(d_model, d_c, bias=False)             # h → c^KV
15        self.W_UQ  = nn.Linear(d_model, n_heads * d_h_c, bias=False) # h → Q^c
16        self.W_UK  = nn.Linear(d_c,     n_heads * d_h_c, bias=False) # c → K^c
17        self.W_UV  = nn.Linear(d_c,     n_heads * d_h_c, bias=False) # c → V
18
19        # Decoupled RoPE projections
20        self.W_QR  = nn.Linear(d_model, n_heads * d_h_R, bias=False) # h → Q^R (per head)
21        self.W_KR  = nn.Linear(d_model,             d_h_R, bias=False) # h → k^R (SHARED)
22
23        self.o_proj = nn.Linear(n_heads * d_h_c, d_model, bias=False)
24        self.register_buffer("inv_freq", self._build_freqs(d_h_R, max_pos))
25
26    def _build_freqs(self, d, max_pos, base=10000.0):
27        # Standard RoPE frequencies, half the head dim
28        i = torch.arange(0, d, 2, dtype=torch.float32)
29        return 1.0 / (base ** (i / d))
30
31    def _rope(self, x, positions):
32        # x: [..., L, d_h_R], positions: [L]
33        d = x.shape[-1]
34        freqs = positions[:, None] * self.inv_freq                  # [L, d/2]
35        cos = freqs.cos()[..., None, :].repeat_interleave(2, -1)    # [L, 1, d]
36        sin = freqs.sin()[..., None, :].repeat_interleave(2, -1)
37        x_rot = torch.stack((-x[..., 1::2], x[..., 0::2]), dim=-1).flatten(-2)
38        return x * cos + x_rot * sin
39
40    def forward(self, h, positions):
41        B, L, _ = h.shape
42        H, dc, dR = self.n_heads, self.d_h_c, self.d_h_R
43
44        # --- Content path (no rotation) ---
45        c_KV = self.W_DKV(h)                                      # [B, L, d_c]
46        q_c  = self.W_UQ(h).view(B, L, H, dc).transpose(1, 2)     # [B, H, L, d_h_c]
47        k_c  = self.W_UK(c_KV).view(B, L, H, dc).transpose(1, 2)  # [B, H, L, d_h_c]
48        v    = self.W_UV(c_KV).view(B, L, H, dc).transpose(1, 2)  # [B, H, L, d_h_c]
49
50        # --- RoPE path (rotated, k^R shared across heads) ---
51        q_R = self.W_QR(h).view(B, L, H, dR).transpose(1, 2)      # [B, H, L, d_h_R]
52        k_R = self.W_KR(h).view(B, L, 1, dR).transpose(1, 2)      # [B, 1, L, d_h_R]
53        q_R = self._rope(q_R, positions)
54        k_R = self._rope(k_R, positions)
55
56        # --- Score = content + rope, broadcasting k_R across heads ---
57        scores  = (q_c @ k_c.transpose(-1, -2)) + (q_R @ k_R.transpose(-1, -2))
58        scores  = scores / (self.d_h ** 0.5)
59        scores  = scores.masked_fill(
60            torch.triu(torch.ones(L, L, device=h.device), 1).bool(),
61            float('-inf'),
62        )
63        attn = F.softmax(scores, dim=-1)
64
65        out = (attn @ v).transpose(1, 2).reshape(B, L, H * dc)
66        return self.o_proj(out)

What changes at inference time

The training-time forward above keeps things explicit. At inference, two efficiencies kick in:

  1. Content path uses absorbed projection. Instead of computing k_c = W_UK · c_n^KV per token and per head, we fold W_UQ^T · W_UK into a single matrix and score directly against c_n^KV. The per-head kck^c is never materialised.
  2. KV cache stores (cnKV,knR)(c_n^{KV}, \, k_n^R) only. That is dc+dhRd_c + d_h^R scalars per token — no per-head factor. For DeepSeek-V2 with dc=512,dhR=64d_c=512, d_h^R=64: 576 floats per token. For MHA with the same head budget: 32,768.

The forward code stays clean during training because the absorbed-matrix trick is purely an inference optimisation — it has no effect on gradients or learning dynamics.


4.8 Connection to Massive Model Training

Where the bottleneck moves

At the scale where MLA matters — say a 100B+ parameter model serving at 128k context to thousands of concurrent users — the KV cache, not the weights, is the constraint. A 671B DeepSeek-V3 model has its weights sharded across many GPUs, but each active sequence carries its own KV cache that scales linearly with sequence length and cannot be shared between users.

Concretely, for an MHA model with nh=128n_h = 128 heads, dh=128d_h = 128, 64 layers, at 128k tokens and FP16:

ArchitectureCache per token (scalars)Cache per 128k seq (bytes)Sequences per 80GB GPU
MHA32,768≈ 8.4 GB · 64 layers ≈ 537 GB0 (impossible)
GQA (8 groups)2,048≈ 33 GB~2
MLA + Decoupled RoPE576≈ 9.4 GB~8

Those last two columns are why Decoupled RoPE exists. A 4× to 8× improvement in concurrent-user capacity at long context is not a micro-optimisation; it is the difference between a viable product and an unaffordable one.

How training scales with the trick

  • Memory: Training-time activations grow with the full K and V — the savings only kick in at inference. During training MLA+RoPE is roughly as expensive as MHA, sometimes slightly more because of the extra projections.
  • Throughput: Because the score is now two matmuls per layer instead of one, training FLOPs go up by ~5 to 8% versus MHA at equal width. The team accepts this cost because every percent of training compute saves orders of magnitude at serving time.
  • Stability: Splitting the head means the absorbed content projection sees lower-rank gradients than vanilla MHA. In practice this is fine — DeepSeek-V2 trained 0.1B → 236B with no unusual instability — but it does mean LR warmup and AdamW betas are retuned. The chapter on optimization will revisit this.
  • Communication: Under tensor parallelism the per-head shards still split cleanly along nhn_h. The shared kRk^R is replicated on every TP rank, which is a tiny cost (it is at most dhRd_h^R scalars per token).

Why every modern MLA model uses this

DeepSeek-V2, V3, R1 and the Yi-Lightning variants all ship Decoupled RoPE. The choice is not a research curiosity — it is the only practical way to combine the strongest serving-time architecture (MLA) with the strongest position scheme (RoPE), and every published variant respects roughly the same split: dhR=dh/2d_h^R = d_h / 2 in early ablations, shrinking to dhR=64d_h^R = 64 once the team confirmed it is enough.


4.9 Engineering Reality and Pitfalls

  • RoPE base for long context. A 10,000 base wraps aggressively past ~8k tokens. For 128k context the base is typically rescaled (YaRN, NTK-aware scaling) before serving. The decoupling itself is orthogonal to this — it works with any RoPE variant.
  • Cache k^R before or after rotation? Both work. DeepSeek caches after rotation so the rotation runs once per key token and never repeats during streaming decode. The cost: when context base is rescaled mid-run (which it shouldn't be) the cache is invalid.
  • Numerical precision. The RoPE term is small (only dhRd_h^R dims). Computing it in FP16 is fine, but the absorbed content matrix should be FP32 for accumulation — this is the same rule as standard mixed-precision attention.
  • Flash-Attention compatibility. Flash-Attention does not natively know about the dual-score MLA layout. The two common patches: (a) concatenate qcq_c with qRq_R and similarly for keys, so Flash sees a single head of width dhd_h; (b) keep them split and call Flash twice, summing the scores. Pattern (a) is simpler and dominates in production.
  • Shared kRk^R is not MQA. It looks like MQA because there is one key vector per token, but the query side still has nhn_h different rotations, so heads still differentiate. Calling it MQA over-states the sharing.
  • Don't fold RoPE into the content latent. A tempting "optimisation" is to push the rotation back into cKVc^{KV}. It breaks absorption all over again — we are back to square one. The decoupling must stay decoupled.

Summary

  • MLA's cache compression depends on WQWUKW^{Q\top} W^{UK} being a single position-independent matrix. Vanilla RoPE puts RmRnR_m^\top R_n between them and breaks absorption.
  • The fix is structural: cut each head into a NoPE content slice (handled by MLA) and a small RoPE position slice. The two scores add.
  • The RoPE key kRk^R is shared across all heads, so the KV cache costs only dc+dhRd_c + d_h^R scalars per token — independent of head count.
  • The mechanism is exactly the plain-Python and PyTorch code we walked through. There is no hidden complexity — just a head split, two matmuls, and a sum.
  • In production this is the only known way to get MLA-grade cache compression while keeping RoPE's position quality. Every deployed MLA model uses this exact trick.

With this in place, the next section can finally compare MLA against GQA and MHA on equal footing — knowing that the RoPE compatibility problem is fully solved.

Loading comments...