Chapter 7
25 min read
Section 38 of 117

Sequential Causal MTP: DeepSeek's Implementation

Multi-Token Prediction (MTP)

Section 7.2 ended on a sharp note. Meta's naive parallel MTP saved compute by predicting all DD future tokens at once from a single shared trunk — and paid for it by silently breaking the causal chain that makes autoregressive language modeling work. Predictions for distant horizons drifted, the heads stopped agreeing with each other, and downstream metrics fell behind a single-token baseline at large model sizes. The lesson was unambiguous: if you want extra prediction heads, they must respect causality. DeepSeek's sequential MTP is the answer that fell out of taking that lesson seriously.

The promise of sequential causal MTP. Each depth is its own small transformer block. Depth kk at position ii reads the previous depth's hidden state at the same position, mixes it with the embedding of the token at position i+ki + k, and predicts the token at position i+k+1i + k + 1. The causal chain survives intact. The signal stays sharp. The trick is the architecture, not a loss-function patch.

Why DeepSeek Went Sequential

Recall the parallel design from section 7.2: DD independent output heads sit on top of the same final hidden state hih_i, and head kk tries to predict the token at position i+ki + k. The trunk gets a richer gradient signal — every position contributes DD cross-entropy terms instead of one — but the heads cannot talk to each other. Head 2 has no idea what head 1 just predicted. The two predictions become marginal distributions over the joint, not a chain rule factorization. Empirically, that gap kills the win.

DeepSeek reasoned backwards from what the chain rule actually wants. The true joint over future tokens is

P(ti+1,ti+2,,ti+Dti)=k=1DP(ti+kti+k1)P(t_{i+1}, t_{i+2}, \dots, t_{i+D} \mid t_{\le i}) = \prod_{k=1}^{D} P(t_{i+k} \mid t_{\le i + k - 1})

Notice what the conditioning set looks like. To predict ti+kt_{i+k} properly, the model needs everything up to and including ti+k1t_{i+k-1} — that is, the previously predicted future tokens. Parallel MTP throws this away: it conditions every head on tit_{\le i} alone, which is the wrong marginal for k>1k > 1. Sequential MTP gives every depth the conditioning set it needs.

PropertyParallel MTP (Meta)Sequential MTP (DeepSeek)
Conditioning at depth kt≤i (same for every k)t≤i+k−1 (proper chain rule)
Causal chain across depthsBroken — heads independentPreserved — depth k reads depth k−1
Heads see each other's outputsNoYes (via the hidden state h_i^{k-1})
Extra params per depth1 output head1 transformer block + 1 projection
Marginal costAlmost free~D× the per-token forward of the head
Quality at scaleDrops vs. 1-head baseline (>3B)Improves vs. 1-head baseline
Speculative decoding readyWeaklyDirectly — see section 7.5
The architectural insight. A parallel head asks the same trunk to predict farther into the future. A sequential module asks a small, dedicated network to extend the trunk's prediction by one more step, reusing the trunk's answer as input. The first treats the future as a set of independent samples. The second treats it as a causal continuation — exactly how language actually unfolds.

Anatomy of a Single MTP Module

An MTP module is the smallest unit you can imagine that still does useful autoregressive work. It is, almost literally, "one transformer block plus a few wires." Let us list every part:

  1. Two inputs per position. The previous depth's hidden state hik1h_i^{k-1} and the embedding of the token at position i+ki + k, which we write Embed(ti+k)\mathrm{Embed}(t_{i+k}). The hidden state is the past; the embedding is the next-known future token.
  2. RMSNorm on each input. Two independent RMSNorms — one per stream — strip away differing magnitudes so the downstream projection sees inputs on a comparable scale.
  3. Concatenation, then projection. The two normalized vectors are concatenated along the feature axis (giving a length-2d2d vector) and then projected back down to dd by a learned matrix MkM_k. This is where past and future meet.
  4. One transformer block. A single transformer block — attention + MLP + norms — is applied to the projected stream. It uses the same causal mask as the main model, so position ii still cannot see positions j>ij > i within the block.
  5. Shared output head. The block's output is run through the main model's output head (the linear layer that maps dd-dim hidden vectors to VV-dim logits), producing a vocabulary distribution.

What is shared, what is per-depth

The sharing story is what makes the design cheap. Two pieces are shared with the main model across every depth:

  • The embedding table EmbedRV×d\mathrm{Embed} \in \mathbb{R}^{V \times d} — the same one the main model uses on its inputs.
  • The output head WoutRV×dW_{\text{out}} \in \mathbb{R}^{V \times d} — tied to the embedding (weight tying) on most LLM setups.

Three pieces are owned per depth:

  • The two RMSNorms (2d\sim 2d parameters).
  • The projection MkRd×2dM_k \in \mathbb{R}^{d \times 2d} (2d22d^2 parameters).
  • One transformer block (~12d212 d^2 parameters at standard FFN expansion). This is the bulk of the per-depth cost.
How small is a single MTP module? At DeepSeek-V3 scale (d=7168d = 7168, 61 transformer blocks in the main model), one MTP module is roughly 1/611/61 of the main model in parameters — about 1.6%. The shared embedding and output head are huge, but they cost nothing extra because they are already in the main model.

The Math: A Causal Chain Through Depths

Fix a position ii in the sequence. Let hi0h_i^0 denote the main model's final hidden state at position ii — this is the starting point for all MTP depths and is, by construction, a function only of tit_{\le i}. For depths k=1,2,,Dk = 1, 2, \dots, D, define the MTP module recursively:

hik=TRMk(Mkconcat[RMSNorm(hik1),  RMSNorm(Embed(ti+k))])h_i^k = \mathrm{TRM}_k\Big(\, M_k \, \mathrm{concat}\big[\mathrm{RMSNorm}(h_i^{k-1}),\; \mathrm{RMSNorm}(\mathrm{Embed}(t_{i+k}))\big]\,\Big)

Every symbol: TRMk\mathrm{TRM}_k is a single per-depth transformer block; MkRd×2dM_k \in \mathbb{R}^{d \times 2d} is the per-depth projection; RMSNorm\mathrm{RMSNorm} is the root-mean-square norm with two independent learnable scales (one per stream); Embed\mathrm{Embed} is the shared embedding table; and ti+kt_{i+k} is the actual ground-truth token at position i+ki + k in the training sequence (known at training time because the whole sequence is given).

The depth-kk prediction at position ii is then

pi+k+1k=softmax ⁣(Wouthik)    ΔV1p^k_{i+k+1} = \mathrm{softmax}\!\big(W_{\text{out}} \, h_i^k\big) \;\in\; \Delta^{V-1}

— a distribution over the vocabulary, predicting the token at position i+k+1i + k + 1. WoutW_{\text{out}} is the shared (tied) output head. The cross-entropy loss at this depth compares pi+k+1kp^k_{i+k+1} to the one-hot of ti+k+1t_{i+k+1}; we will write the full training objective in section 7.4.

Why the recursion is the whole story

The single equation above hides three independent design choices, each of which corrects one specific failure of parallel MTP:

  1. The recursion through hik1h_i^{k-1}. This is the chain rule made concrete. Depth kk conditions on everything depth k1k - 1 knew, plus the new token ti+kt_{i+k}. Parallel MTP had every depth read hi0h_i^0 directly — no chain.
  2. The embedding of the future token. At training time we know ti+kt_{i+k}, so we feed it in. Without it, depth kk would have no information beyond what depth k1k - 1 already used — the module would be a redundant copy of its predecessor.
  3. The independent transformer block per depth. Each depth gets its own parameters because each depth solves a slightly different problem — depth 1 maps "past + next token" → "next-next token", depth 2 maps "past + next-next token" → "next-next-next token". The tasks are related but not identical.
Where does the gradient go? Backprop through the recursion is exactly what you would write down. The loss at depth kk creates gradient on hikh_i^k, which propagates through TRMk\mathrm{TRM}_k, MkM_k, and the two RMSNorms; then a piece of it flows back into hik1h_i^{k-1} — and from there into the MTP module of depth k1k - 1, and ultimately into the main model. So the deeper MTP losses also reshape the main model's representations — a free regularizer on the trunk.

Manual Numerical Walkthrough

Let us pin down a single depth-1 MTP step end-to-end with tiny numbers. We will use d=4d = 4, V=6V = 6, and look at the prediction made for position i+2=4i + 2 = 4 from position i=2i = 2.

Click to expand: depth-1 MTP step at position i = 2, d = 4, V = 6

Setup. Imagine a sequence of six tokens. The main model has already produced its hidden states. We focus on position i=2i = 2:

  • h_2^0 = [ 0.50, -0.30, 0.80, -0.10 ] (main model output)
  • Embed(t_3) = [ 0.20, 0.40, 0.10, 0.30 ] (future-token embedding)
  • target token at position i + 2 = 4 is t_4, ID = 3

Step 1 — RMSNorm both inputs. For xRdx \in \mathbb{R}^d, RMSNorm(x)=x/1djxj2+ε\mathrm{RMSNorm}(x) = x / \sqrt{\tfrac{1}{d}\sum_j x_j^2 + \varepsilon} (taking the learnable gain = 1 for clarity).

mean(h_2^0²) = (0.25 + 0.09 + 0.64 + 0.01) / 4 = 0.2475

√0.2475 ≈ 0.4975 → h_norm ≈ [ 1.005, -0.603, 1.608, -0.201 ]

mean(e²) = (0.04 + 0.16 + 0.01 + 0.09) / 4 = 0.075, √0.075 ≈ 0.274

→ e_norm ≈ [ 0.730, 1.461, 0.365, 1.095 ]

Note the two streams now have comparable scale — RMS ≈ 1 each. The unnormalized h had a feature jumping to 0.80 while e had one at 0.10; without RMSNorm the projection would have been dominated by h.

Step 2 — concatenate, then project. Concatenation yields a length-8 vector:

cat = [ 1.005, -0.603, 1.608, -0.201, 0.730, 1.461, 0.365, 1.095 ]

Multiply by a toy 4 × 8 projection (a believable random init):

M = [[ 0.10,  0.20, -0.10,  0.05,  0.15, -0.05,  0.10,  0.20],
     [-0.20,  0.10,  0.30, -0.10,  0.05,  0.20, -0.10,  0.10],
     [ 0.15, -0.10,  0.05,  0.20, -0.10,  0.10,  0.30, -0.05],
     [ 0.05,  0.20, -0.10,  0.15,  0.20, -0.10,  0.05,  0.10]]

Row 0 of McatM \cdot \text{cat}:

0.10·1.005 + 0.20·(-0.603) + (-0.10)·1.608 + 0.05·(-0.201)

+ 0.15·0.730 + (-0.05)·1.461 + 0.10·0.365 + 0.20·1.095

≈ 0.1005 - 0.1206 - 0.1608 - 0.0101 + 0.1095 - 0.0731 + 0.0365 + 0.219

≈ 0.101

Computing all four rows similarly:

combined ≈ [ 0.101, 0.118, 0.156, 0.349 ]

Step 3 — one transformer block (toy stand-in). For this walkthrough we collapse the full block into a single tanh non-linearity (the real block has attention and an MLP):

h_2^1 = tanh(combined) ≈ [ 0.101, 0.118, 0.155, 0.336 ]

(tanh of small values is approximately the identity; that is fine for this toy. In the real model attention contextualizes across positions and the MLP introduces a more substantial non-linear transformation.)

Step 4 — shared output head, then softmax. With WoutR6×4W_{\text{out}} \in \mathbb{R}^{6 \times 4} (rows are the 6 vocabulary embeddings), say:

W_out = [[ 0.5,  0.1, -0.2,  0.3],   # vocab id 0
         [-0.3,  0.4,  0.2,  0.1],   # vocab id 1
         [ 0.2, -0.1,  0.5, -0.2],   # vocab id 2
         [ 0.4,  0.3,  0.1,  0.6],   # vocab id 3 <- TARGET
         [-0.1,  0.2, -0.3,  0.4],   # vocab id 4
         [ 0.1, -0.4,  0.2, -0.1]]   # vocab id 5

Compute Wouth21W_{\text{out}} \cdot h_2^1:

logit_0 = 0.5·0.101 + 0.1·0.118 + (-0.2)·0.155 + 0.3·0.336 ≈ 0.132

logit_1 = -0.3·0.101 + 0.4·0.118 + 0.2·0.155 + 0.1·0.336 ≈ 0.0805

logit_2 = 0.2·0.101 + (-0.1)·0.118 + 0.5·0.155 + (-0.2)·0.336 ≈ 0.0185

logit_3 = 0.4·0.101 + 0.3·0.118 + 0.1·0.155 + 0.6·0.336 ≈ 0.293

logit_4 = -0.1·0.101 + 0.2·0.118 + (-0.3)·0.155 + 0.4·0.336 ≈ 0.0989

logit_5 = 0.1·0.101 + (-0.4)·0.118 + 0.2·0.155 + (-0.1)·0.336 ≈ -0.0274

Subtract the max (0.293) and exponentiate:

exp ≈ [0.851, 0.808, 0.760, 1.000, 0.824, 0.726]

sum ≈ 4.969

p ≈ [0.171, 0.163, 0.153, 0.201, 0.166, 0.146]

Step 5 — read off the prediction and the loss. The argmax is vocab id 3, which happens to be the target — good. The cross-entropy at this position is

L_MTP^1 at i = 2 = -log(p[3]) = -log(0.201) ≈ 1.604

The full MTP-1 loss averages this over every valid position i=0,1,,T3i = 0, 1, \dots, T - 3, just like any cross-entropy.

What just happened, conceptually. The hidden state h20h_2^0 carried "everything the main model knew up to position 2." The embedding Embed(t3)\mathrm{Embed}(t_3) carried the new information that the next token would be t3t_3. The MTP module fused them, ran them through a small transformer block, and produced a distribution over what comes after t3t_3 — that is, over t4t_4. The whole module is one chain-rule step.

If we had stacked a depth-2 module. h21h_2^1 would feed into MTP-2, concatenated with Embed(t4)\mathrm{Embed}(t_4), and the depth-2 output would predict t5t_5. The chain extends one step further with one more module. The arithmetic is identical; only the indices shift.

Visualizing the Sequential Forward Pass

The diagram below walks through a six-token sequence with one main transformer and two MTP modules stacked on top. Use ▶ Play to watch the depths fill in left-to-right, then hover any cell to inspect what enters and what comes out. The key observation: cell hikh_i^k always reads cell hik1h_i^{k-1} from the row above and the embedding of the token kk positions to its right — never any cell to its right at its own depth. Causality holds at every depth.

Loading sequential MTP visualizer…
What to look for. Notice that MTP-1 covers one fewer position than the main model (the last position has no ti+1t_{i+1} to feed in), and MTP-2 covers two fewer. This shrinks the available training signal by one usable position per added depth. At sequence length 4096 that loss is negligible; at sequence length 128 it would not be.

Plain Python: One MTP Module by Hand

Before reaching for PyTorch, let us implement a depth-1 MTP module in plain NumPy. The goal is to expose every shape, every matmul, every normalization — no autograd, no broadcast magic. If you understand these 30\sim 30 lines, you understand the architecture.

Depth-1 MTP module in NumPy
🐍python
4Toy dimensions

V is vocabulary size, d is hidden dim, T is sequence length. Real DeepSeek-V3 uses V ≈ 129k, d = 7168, T up to 4096 — the shape of the math is identical to this toy.

EXECUTION STATE
V = 8
d = 4
T = 5
7Shared embedding table

Critical design choice: every MTP module reuses the SAME Embed matrix as the main model. Sharing it saves V·d parameters per module and ties the input geometry across depths — the MTP head sees the same vector for the word 'cat' that the main model does.

EXECUTION STATE
Embed.shape = (V, d) = (8, 4)
10Shared output head

The MTP module's output head is also the main model's. This forces every depth to produce a hidden vector in the same target geometry; you cannot drift into a private feature space.

EXECUTION STATE
W_out.shape = (V, d) = (8, 4)
14Toy hidden states from the main model

In the real model these come out of the final transformer block. Here we generate them randomly — the MTP module does not care where h0 came from, only that it is a length-d vector per position.

EXECUTION STATE
h0.shape = (T, d) = (5, 4)
17Projection matrix M

The unique-per-depth piece. M takes the concatenation of (previous hidden, future-token embedding) and projects it back down to d. Shape is (d, 2d). This is a small matrix — about 0.2% of a real transformer block.

EXECUTION STATE
M.shape = (d, 2d) = (4, 8)
18Toy transformer block

We stand in a single linear layer for the real transformer block (attention + MLP + norms). In DeepSeek-V3 this is one full transformer block — the only deep computation per MTP module.

21Token IDs at each position

tokens[i] is the vocabulary index of the word that appeared at position i. During training we know the entire sequence, so future tokens (positions i+1, i+2, …) are available to feed into the MTP modules.

EXECUTION STATE
tokens = [3, 1, 4, 0, 5]
23RMSNorm definition

Root-mean-square normalization — same one used inside the transformer blocks. It divides by the RMS of the feature vector, leaving its direction but standardizing its magnitude. DeepSeek applies it to BOTH inputs of the MTP module before concatenation, so neither side dominates the projection.

30Loop over predictable positions

Position i predicts the token at position i + 2 (next-next-token). The last 2 positions have no valid target inside the sequence and are skipped during training.

EXECUTION STATE
loop length = T - 2 = 3
32Normalize the previous hidden state

h_norm = RMSNorm(h0[i]). The main model's hidden state arrives with whatever scale it happens to have; RMSNorm puts it on a comparable footing with the freshly-looked-up embedding below.

EXECUTION STATE
h_norm.shape = (d,) = (4,)
33Normalize the future-token embedding

We look up Embed[tokens[i + 1]] — the embedding of the token at position i + 1 — and RMSNorm it. Note position i + 1, not i + 2: the future-token embedding tells the module which token came next; the target it must predict is one further still.

EXECUTION STATE
e_norm.shape = (d,) = (4,)
36Concatenate hidden + embedding

Concatenation glues the two length-d vectors into a single length-2d vector. This is where the past (h_norm) meets the future (e_norm). They are not added — addition would force them onto the same axes; concatenation lets the projection M learn its own mixing.

EXECUTION STATE
cat.shape = (2d,) = (8,)
37Project back to dimension d

M @ cat collapses the 2d concatenation back to a length-d vector. This is the only place in the MTP module where the two information streams genuinely mix. M is the per-depth learnable bottleneck.

EXECUTION STATE
combined.shape = (d,) = (4,)
40One transformer block (toy stand-in)

In the real implementation this is a full transformer block — attention over the same sequence positions plus an MLP. In this toy we use tanh(W_block @ x) as a stand-in. The output h1_i is the hidden state of depth-1 at position i.

EXECUTION STATE
h1_i.shape = (d,) = (4,)
43Logits via the shared output head

W_out @ h1_i produces a length-V vector of unnormalized scores over the vocabulary. Because W_out is shared with the main model, the geometry of the logit space is identical at every depth — the MTP module cannot 'invent' a new vocabulary.

EXECUTION STATE
logits.shape = (V,) = (8,)
45Numerically-stable softmax

Subtract the max before exponentiating to avoid overflow, then divide by the sum. The result p is a probability distribution over the V vocabulary tokens — the model's prediction of token at position i + 2.

EXECUTION STATE
p.shape = (V,) = (8,)
47Record (position, argmax prediction, ground-truth target)

At training time we would compute cross-entropy between p and the one-hot of tokens[i + 2] — that gives the MTP loss for this position. At inference time we instead sample from p to propose the next-next token.

34 lines without explanation
1import numpy as np
2
3# Tiny toy: V = 8 vocab tokens, d = 4 hidden dim, T = 5 sequence positions.
4np.random.seed(0)
5V, d, T = 8, 4, 5
6
7# (a) Shared embedding table — same one the main model uses.
8Embed = np.random.randn(V, d) * 0.1
9
10# (b) Shared output head — also identical to the main model's.
11W_out = np.random.randn(V, d) * 0.1
12
13# (c) The toy "main model": pretend its hidden states are already computed.
14#     h0[i] is what the main model produced at position i.
15h0 = np.random.randn(T, d) * 0.5
16
17# (d) Per-depth MTP parameters. We do D = 1 (one MTP module).
18M = np.random.randn(d, 2 * d) * 0.1        # projection: 2d -> d
19W_block = np.random.randn(d, d) * 0.1      # toy 1-layer "transformer block"
20
21# Token IDs in the sequence (which words appeared at each position).
22tokens = np.array([3, 1, 4, 0, 5])         # shape (T,)
23
24def rmsnorm(x, eps=1e-6):
25    rms = np.sqrt((x ** 2).mean(axis=-1, keepdims=True) + eps)
26    return x / rms
27
28# ---- Sequential MTP, depth k = 1 ----
29# For each position i, predict the token at position i + 2 (one beyond next).
30# Module 1 needs: (h0[i], Embed[token at position i + 1]).
31predictions = []
32for i in range(T - 2):                       # last 2 positions have no target
33    # (1) RMSNorm both inputs.
34    h_norm = rmsnorm(h0[i])                  # (d,)
35    e_norm = rmsnorm(Embed[tokens[i + 1]])   # (d,)
36
37    # (2) Concatenate along the feature axis and project back to d.
38    cat = np.concatenate([h_norm, e_norm])   # (2d,)
39    combined = M @ cat                        # (d,)
40
41    # (3) One "transformer block" — toy linear stand-in.
42    h1_i = np.tanh(W_block @ combined)        # (d,)
43
44    # (4) Shared output head → vocabulary logits → softmax.
45    logits = W_out @ h1_i                     # (V,)
46    z = logits - logits.max()
47    p = np.exp(z) / np.exp(z).sum()           # (V,)
48    predictions.append((i, p.argmax(), tokens[i + 2]))
49
50for i, pred, target in predictions:
51    print(f"pos {i}: predict t+2 -> {pred:<3d}  target -> {target}")

Three observations are worth pulling out. First: the loop body is position-independent — every position runs the exact same arithmetic with different inputs. Second: the only place the future-token information enters is line 33, the embedding lookup at i+1i + 1. Third: the projection MM is the only weight that mixes the two streams; everything else either normalizes a single stream or operates on the already-mixed combined vector.

PyTorch: A Reusable MTP Module Class

In production code the MTP module is just an nn.Module\texttt{nn.Module}. It is small, vectorized over (B,T)(B, T), and takes the shared embedding and output head as forward-time arguments so they cannot be silently duplicated.

MTPModule — DeepSeek-style sequential MTP, one depth
🐍python
6MTPModule is a tiny nn.Module

Every depth of MTP is just one of these. DeepSeek-V3 ships with a single depth (D = 1), but the architecture cleanly generalizes to D = 2, 3, … by stacking instances.

28Owned vs. shared parameters

Note what the constructor does NOT take: the embedding table and the output head. Those are passed in at forward time from the parent model so they remain shared. The module owns only: two RMSNorms, one Linear(2d → d) projection, and one transformer block.

29Separate RMSNorms for h and e

Two independent RMSNorms — one for the hidden stream, one for the embedding stream. Sharing one norm would force both streams to share the same learned scale/bias parameters, which empirically hurts. Two norms = two independent gain knobs.

31Projection has no bias

bias=False matches DeepSeek-V3's convention for linear layers inside the MTP module. The bias is redundant after RMSNorm — the norm already centers things — and saves d parameters.

35Forward takes h_prev and tok_fut

h_prev is (B, T, d) — the previous depth's hidden states for the entire sequence. tok_fut is (B, T) of token IDs: tok_fut[b, i] = token at position i + k in batch element b. The parent model builds tok_fut by shifting the input sequence left by k positions.

36embed and output_head are explicit args

By passing them in we make sharing explicit and impossible to break by accident. PyTorch's nn.Module will not list them as parameters of this instance, so they will not be double-counted in the optimizer.

40Normalize the hidden stream

RMSNorm leaves the direction of h_prev but standardizes its magnitude per token. Critical because the main model's hidden states have an emergent scale that drifts during training; the projection layer becomes much more stable if both inputs share a comparable RMS.

EXECUTION STATE
h_norm.shape = (B, T, d)
41Look up + normalize the future-token embedding

embed(tok_fut) does a vocabulary-indexed gather — output is (B, T, d). RMSNorm again. The lookup is the same one the main model uses on its own inputs, which is why we did not give the module its own embedding parameters.

EXECUTION STATE
e_norm.shape = (B, T, d)
44Concatenate, not add

torch.cat along dim=-1 doubles the feature dimension to 2d. We intentionally do NOT add the two streams — addition would commit to a single shared subspace before the projection has a chance to learn the best mixing.

EXECUTION STATE
cat.shape = (B, T, 2d)
45Project back to d

self.proj(cat) collapses (B, T, 2d) → (B, T, d). This is the only place in the module where the two streams genuinely mix.

EXECUTION STATE
x.shape = (B, T, d)
49One causal transformer block

x flows through a single transformer block — attention + MLP + norms — with the SAME causal mask as the main model. This is what makes MTP sequential-causal: position i still cannot peek at positions > i within this depth.

EXECUTION STATE
h_next.shape = (B, T, d)
52Tied output head

Apply the shared output head. The result is (B, T, V) — for every position in every batch element, a distribution over the V vocabulary tokens. This is what feeds the cross-entropy loss in the next section.

EXECUTION STATE
logits.shape = (B, T, V)
54Return both h_next and logits

h_next is what the NEXT depth (k+1) will consume; logits is what the loss head consumes. Returning both lets the parent model decide whether to stack another MTP module on top.

44 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5
6class MTPModule(nn.Module):
7    """
8    One depth of DeepSeek-style sequential MTP.
9
10    Forward signature:
11        h_prev  : (B, T, d)   hidden states from the previous depth.
12        tok_fut : (B, T)      token IDs at position i + k for each i.
13
14    Outputs:
15        h_next  : (B, T, d)   hidden states at this depth.
16        logits  : (B, T, V)   distribution over vocabulary.
17
18    Sharing contract (set by the parent model, NOT by this module):
19      • embed        — shared with the main model (nn.Embedding(V, d))
20      • output_head  — shared with the main model (Linear(d, V) tied to embed)
21    Per-depth (owned by this module):
22      • proj         — Linear(2d, d)
23      • block        — one transformer block
24      • norm_h, norm_e — RMSNorm for each input stream
25    """
26
27    def __init__(self, d_model: int, block: nn.Module):
28        super().__init__()
29        self.norm_h = nn.RMSNorm(d_model)            # for the previous hidden state
30        self.norm_e = nn.RMSNorm(d_model)            # for the future-token embedding
31        self.proj = nn.Linear(2 * d_model, d_model, bias=False)
32        self.block = block                            # one transformer block
33
34    def forward(
35        self,
36        h_prev: torch.Tensor,                         # (B, T, d)
37        tok_fut: torch.Tensor,                        # (B, T)
38        embed: nn.Embedding,                          # shared with main model
39        output_head: nn.Linear,                       # shared with main model
40        attn_mask: torch.Tensor | None = None,        # causal mask for self-attn
41    ):
42        # (1) Normalize both streams BEFORE mixing them.
43        h_norm = self.norm_h(h_prev)                  # (B, T, d)
44        e_norm = self.norm_e(embed(tok_fut))          # (B, T, d)
45
46        # (2) Concatenate along the feature axis, then project back to d.
47        cat = torch.cat([h_norm, e_norm], dim=-1)     # (B, T, 2d)
48        x = self.proj(cat)                             # (B, T, d)
49
50        # (3) One transformer block. The causal mask MUST be active here —
51        #     the MTP module is still autoregressive within the sequence.
52        h_next = self.block(x, attn_mask=attn_mask)   # (B, T, d)
53
54        # (4) Logits via the shared (tied) output head.
55        logits = output_head(h_next)                  # (B, T, V)
56
57        return h_next, logits

How the parent model wires it together

The parent model owns the embedding, the main transformer stack, the output head, and a small list of MTPModule\texttt{MTPModule} instances. The forward pass looks roughly like:

# main model
h = embed(tokens)
for blk in main_blocks:
    h = blk(h, attn_mask)
h0 = h                                          # (B, T, d)
logits0 = output_head(h0)                       # main next-token logits

# MTP depths
h_prev = h0
mtp_logits = []
for k, mtp in enumerate(mtp_modules, start=1):
    tok_fut = shift_left(tokens, k)             # (B, T) tokens at i + k
    h_prev, logits_k = mtp(h_prev, tok_fut, embed, output_head, attn_mask)
    mtp_logits.append(logits_k)

Two PyTorch-specific subtleties show up here. First, shift_left\texttt{shift\_left} is not just a Python slice — it must respect padding tokens so the loss is not computed on garbage positions; almost every real implementation passes a position mask alongside. Second, the optimizer sees each MTP module's parameters but NOT a duplicate copy of the embedding or output head, because those were passed positionally rather than registered as sub-modules. That is the whole point of the explicit-arg sharing convention.

Sequential, not parallel, in the code too. The loop over mtp_modules\texttt{mtp\_modules} is serial — depth kk cannot start before depth k1k - 1 finishes, because it needs h_prevh\_prev. This is a real latency cost at training time. With D=1D = 1 it is one extra block — about 1.6%1.6\% of the main forward. With D=4D = 4 it would be 6.5%\sim 6.5\%. DeepSeek-V3 ships D=1D = 1 precisely because the marginal quality gain past one depth does not justify the marginal latency.

What Changes at Massive Scale

At a scale where the main model has 61 transformer blocks, 7168 hidden dimensions, and 671B parameters in total (with about 37B active per token), the MTP module's budget changes character. Let us walk through the costs.

ResourcePer MTP moduleAt DeepSeek-V3 scale (D = 1)
Parameters~1 transformer block + 2d² + 2d~11B (vs. ~671B main)
Active params/token~1 block worth~600M extra
FLOPs / forward token~1/61 of main forward+1.6% training FLOPs
Activations memory1 extra block worth+1 block of activations to checkpoint
Wall-clock latencySequential — adds 1 block depth+1.6% per-step latency
Inference-time costOff — modules dropped or used for speculation0 if dropped

The memory story

The dominant memory cost during training is activations, not parameters. An MTP module carries one block's worth of activations per token — and those activations live until the backward pass for the MTP loss runs. With activation checkpointing the cost can be amortized, but the simpler picture is: budget for D+1D + 1 blocks of activations instead of LL blocks. At L=61L = 61 and D=1D = 1, that is a 1.6% bump.

The communication story

In data-parallel training, gradient all-reduces dominate the network budget. MTP modules add their parameters to the bucket — about 1.6% more grad data per step. In tensor-parallel or pipeline-parallel layouts, the MTP module's transformer block can be placed on the same shard as the final main block; no extra cross-device traffic. In FSDP, the MTP module is one more wrap unit — trivial to integrate.

The throughput story

Sequential MTP is sequential, full stop. Depth kk waits for depth k1k - 1. There is no pipeline trick that breaks this dependency without breaking causality — the same dependency that made the parallel design tempting in the first place. At D=1D = 1 the cost is one block of extra depth, which on H100s at 4k sequence length is in the low single-digit percent of step time. Past D=2D = 2 the latency cost starts mattering and the per-depth quality return diminishes — which is exactly what DeepSeek's ablations show.

The accounting that drove the design. +1.6% FLOPs and +1.6% latency in exchange for a roughly 1.8× richer training signal at every position (each position contributes 2 cross-entropy losses instead of 1, weighted by λ\lambda) — and a speculative-decoding head you get for free at inference. The cost-benefit tilts hard toward sequential MTP.

Engineering Reality and Gotchas

Sequence length and the boundary problem

Depth kk needs the token at position i+ki + k. For the last kk positions of the sequence that token does not exist. Two correct ways to handle this:

  • Mask the loss on positions iTk1i \ge T - k - 1 at depth kk. Cleanest, costs a couple of token-positions per sequence at D=1D = 1.
  • Right-pad the input by DD dummy tokens so every position has a valid future. Simpler at the call site but wastes a hair of compute.

The mask-tying trap

The transformer block inside the MTP module needs its own causal mask — the same shape as the main model's. A common bug is reusing the main model's pre-computed mask after the embedding or positional bias has been folded in; the MTP block ends up attending differently than the main one, sometimes leaking future info, sometimes silently masking the present position. Make the MTP block compute its own mask, or pass an unbiased causal mask explicitly.

RMSNorm scale init

The two RMSNorm gain vectors inside an MTP module initialize to 1, same as everywhere else. But because RMSNorm sits on a stream whose magnitude depends on the main model's drift during training, empirical wisdom is to warm-start the gain on the hidden stream at a slightly lower value (e.g. 0.50.5) for the first ~1000 steps. This prevents an early stage where the projection sees huge h-stream norms and tiny e-stream norms; without the warm-start the projection's gradient blows up.

Loss scaling and lambda

The full training objective will be detailed in section 7.4, but a preview: the MTP loss is added with a coefficient λ\lambda typically in the range [0.1,0.3][0.1, 0.3]. Setting it too high crowds out the main next-token loss; setting it too low makes the MTP module learn slowly and underperform at inference-time speculation. DeepSeek-V3 uses λ\lambda around 0.3 during early training and decays it linearly.

What "sharing" really means at training time

Weight sharing is a contract enforced by the parent module, not by the MTP class itself. The cleanest way to enforce it in PyTorch is to never store the embedding or the output head as a child module of the MTP instance — pass them in. Otherwise parameters()\texttt{parameters()} double-counts them and the optimizer applies two updates per step. Most production bugs in MTP implementations come from accidentally registering the shared tensors twice.

The takeaway. Sequential causal MTP is what you get when you refuse to compromise on the chain rule and refuse to spend more than a few percent of the main model's budget. One small transformer block per depth, two RMSNorms, one projection, shared embedding and output head, run in a strict causal sequence. The rest of the chapter — the loss formulation in section 7.4 and the speculative-decoding payoff in section 7.5 — is just downstream of getting this one architectural choice right.
Loading comments...