Section 7.2 ended on a sharp note. Meta's naive parallel MTP saved compute by predicting all future tokens at once from a single shared trunk — and paid for it by silently breaking the causal chain that makes autoregressive language modeling work. Predictions for distant horizons drifted, the heads stopped agreeing with each other, and downstream metrics fell behind a single-token baseline at large model sizes. The lesson was unambiguous: if you want extra prediction heads, they must respect causality. DeepSeek's sequential MTP is the answer that fell out of taking that lesson seriously.
The promise of sequential causal MTP. Each depth is its own small transformer block. Depth at position reads the previous depth's hidden state at the same position, mixes it with the embedding of the token at position , and predicts the token at position . The causal chain survives intact. The signal stays sharp. The trick is the architecture, not a loss-function patch.
Why DeepSeek Went Sequential
Recall the parallel design from section 7.2: independent output heads sit on top of the same final hidden state , and head tries to predict the token at position . The trunk gets a richer gradient signal — every position contributes cross-entropy terms instead of one — but the heads cannot talk to each other. Head 2 has no idea what head 1 just predicted. The two predictions become marginal distributions over the joint, not a chain rule factorization. Empirically, that gap kills the win.
DeepSeek reasoned backwards from what the chain rule actually wants. The true joint over future tokens is
Notice what the conditioning set looks like. To predict properly, the model needs everything up to and including — that is, the previously predicted future tokens. Parallel MTP throws this away: it conditions every head on alone, which is the wrong marginal for . Sequential MTP gives every depth the conditioning set it needs.
| Property | Parallel MTP (Meta) | Sequential MTP (DeepSeek) |
|---|---|---|
| Conditioning at depth k | t≤i (same for every k) | t≤i+k−1 (proper chain rule) |
| Causal chain across depths | Broken — heads independent | Preserved — depth k reads depth k−1 |
| Heads see each other's outputs | No | Yes (via the hidden state h_i^{k-1}) |
| Extra params per depth | 1 output head | 1 transformer block + 1 projection |
| Marginal cost | Almost free | ~D× the per-token forward of the head |
| Quality at scale | Drops vs. 1-head baseline (>3B) | Improves vs. 1-head baseline |
| Speculative decoding ready | Weakly | Directly — see section 7.5 |
Anatomy of a Single MTP Module
An MTP module is the smallest unit you can imagine that still does useful autoregressive work. It is, almost literally, "one transformer block plus a few wires." Let us list every part:
- Two inputs per position. The previous depth's hidden state and the embedding of the token at position , which we write . The hidden state is the past; the embedding is the next-known future token.
- RMSNorm on each input. Two independent RMSNorms — one per stream — strip away differing magnitudes so the downstream projection sees inputs on a comparable scale.
- Concatenation, then projection. The two normalized vectors are concatenated along the feature axis (giving a length- vector) and then projected back down to by a learned matrix . This is where past and future meet.
- One transformer block. A single transformer block — attention + MLP + norms — is applied to the projected stream. It uses the same causal mask as the main model, so position still cannot see positions within the block.
- Shared output head. The block's output is run through the main model's output head (the linear layer that maps -dim hidden vectors to -dim logits), producing a vocabulary distribution.
What is shared, what is per-depth
The sharing story is what makes the design cheap. Two pieces are shared with the main model across every depth:
- The embedding table — the same one the main model uses on its inputs.
- The output head — tied to the embedding (weight tying) on most LLM setups.
Three pieces are owned per depth:
- The two RMSNorms ( parameters).
- The projection ( parameters).
- One transformer block (~ parameters at standard FFN expansion). This is the bulk of the per-depth cost.
The Math: A Causal Chain Through Depths
Fix a position in the sequence. Let denote the main model's final hidden state at position — this is the starting point for all MTP depths and is, by construction, a function only of . For depths , define the MTP module recursively:
Every symbol: is a single per-depth transformer block; is the per-depth projection; is the root-mean-square norm with two independent learnable scales (one per stream); is the shared embedding table; and is the actual ground-truth token at position in the training sequence (known at training time because the whole sequence is given).
The depth- prediction at position is then
— a distribution over the vocabulary, predicting the token at position . is the shared (tied) output head. The cross-entropy loss at this depth compares to the one-hot of ; we will write the full training objective in section 7.4.
Why the recursion is the whole story
The single equation above hides three independent design choices, each of which corrects one specific failure of parallel MTP:
- The recursion through . This is the chain rule made concrete. Depth conditions on everything depth knew, plus the new token . Parallel MTP had every depth read directly — no chain.
- The embedding of the future token. At training time we know , so we feed it in. Without it, depth would have no information beyond what depth already used — the module would be a redundant copy of its predecessor.
- The independent transformer block per depth. Each depth gets its own parameters because each depth solves a slightly different problem — depth 1 maps "past + next token" → "next-next token", depth 2 maps "past + next-next token" → "next-next-next token". The tasks are related but not identical.
Manual Numerical Walkthrough
Let us pin down a single depth-1 MTP step end-to-end with tiny numbers. We will use , , and look at the prediction made for position from position .
Click to expand: depth-1 MTP step at position i = 2, d = 4, V = 6
Setup. Imagine a sequence of six tokens. The main model has already produced its hidden states. We focus on position :
- h_2^0 = [ 0.50, -0.30, 0.80, -0.10 ] (main model output)
- Embed(t_3) = [ 0.20, 0.40, 0.10, 0.30 ] (future-token embedding)
- target token at position i + 2 = 4 is t_4, ID = 3
Step 1 — RMSNorm both inputs. For , (taking the learnable gain = 1 for clarity).
mean(h_2^0²) = (0.25 + 0.09 + 0.64 + 0.01) / 4 = 0.2475
√0.2475 ≈ 0.4975 → h_norm ≈ [ 1.005, -0.603, 1.608, -0.201 ]
mean(e²) = (0.04 + 0.16 + 0.01 + 0.09) / 4 = 0.075, √0.075 ≈ 0.274
→ e_norm ≈ [ 0.730, 1.461, 0.365, 1.095 ]
Note the two streams now have comparable scale — RMS ≈ 1 each. The unnormalized h had a feature jumping to 0.80 while e had one at 0.10; without RMSNorm the projection would have been dominated by h.
Step 2 — concatenate, then project. Concatenation yields a length-8 vector:
cat = [ 1.005, -0.603, 1.608, -0.201, 0.730, 1.461, 0.365, 1.095 ]
Multiply by a toy 4 × 8 projection (a believable random init):
M = [[ 0.10, 0.20, -0.10, 0.05, 0.15, -0.05, 0.10, 0.20],
[-0.20, 0.10, 0.30, -0.10, 0.05, 0.20, -0.10, 0.10],
[ 0.15, -0.10, 0.05, 0.20, -0.10, 0.10, 0.30, -0.05],
[ 0.05, 0.20, -0.10, 0.15, 0.20, -0.10, 0.05, 0.10]]Row 0 of :
0.10·1.005 + 0.20·(-0.603) + (-0.10)·1.608 + 0.05·(-0.201)
+ 0.15·0.730 + (-0.05)·1.461 + 0.10·0.365 + 0.20·1.095
≈ 0.1005 - 0.1206 - 0.1608 - 0.0101 + 0.1095 - 0.0731 + 0.0365 + 0.219
≈ 0.101
Computing all four rows similarly:
combined ≈ [ 0.101, 0.118, 0.156, 0.349 ]
Step 3 — one transformer block (toy stand-in). For this walkthrough we collapse the full block into a single tanh non-linearity (the real block has attention and an MLP):
h_2^1 = tanh(combined) ≈ [ 0.101, 0.118, 0.155, 0.336 ]
(tanh of small values is approximately the identity; that is fine for this toy. In the real model attention contextualizes across positions and the MLP introduces a more substantial non-linear transformation.)
Step 4 — shared output head, then softmax. With (rows are the 6 vocabulary embeddings), say:
W_out = [[ 0.5, 0.1, -0.2, 0.3], # vocab id 0
[-0.3, 0.4, 0.2, 0.1], # vocab id 1
[ 0.2, -0.1, 0.5, -0.2], # vocab id 2
[ 0.4, 0.3, 0.1, 0.6], # vocab id 3 <- TARGET
[-0.1, 0.2, -0.3, 0.4], # vocab id 4
[ 0.1, -0.4, 0.2, -0.1]] # vocab id 5Compute :
logit_0 = 0.5·0.101 + 0.1·0.118 + (-0.2)·0.155 + 0.3·0.336 ≈ 0.132
logit_1 = -0.3·0.101 + 0.4·0.118 + 0.2·0.155 + 0.1·0.336 ≈ 0.0805
logit_2 = 0.2·0.101 + (-0.1)·0.118 + 0.5·0.155 + (-0.2)·0.336 ≈ 0.0185
logit_3 = 0.4·0.101 + 0.3·0.118 + 0.1·0.155 + 0.6·0.336 ≈ 0.293
logit_4 = -0.1·0.101 + 0.2·0.118 + (-0.3)·0.155 + 0.4·0.336 ≈ 0.0989
logit_5 = 0.1·0.101 + (-0.4)·0.118 + 0.2·0.155 + (-0.1)·0.336 ≈ -0.0274
Subtract the max (0.293) and exponentiate:
exp ≈ [0.851, 0.808, 0.760, 1.000, 0.824, 0.726]
sum ≈ 4.969
p ≈ [0.171, 0.163, 0.153, 0.201, 0.166, 0.146]
Step 5 — read off the prediction and the loss. The argmax is vocab id 3, which happens to be the target — good. The cross-entropy at this position is
L_MTP^1 at i = 2 = -log(p[3]) = -log(0.201) ≈ 1.604
The full MTP-1 loss averages this over every valid position , just like any cross-entropy.
What just happened, conceptually. The hidden state carried "everything the main model knew up to position 2." The embedding carried the new information that the next token would be . The MTP module fused them, ran them through a small transformer block, and produced a distribution over what comes after — that is, over . The whole module is one chain-rule step.
If we had stacked a depth-2 module. would feed into MTP-2, concatenated with , and the depth-2 output would predict . The chain extends one step further with one more module. The arithmetic is identical; only the indices shift.
Visualizing the Sequential Forward Pass
The diagram below walks through a six-token sequence with one main transformer and two MTP modules stacked on top. Use ▶ Play to watch the depths fill in left-to-right, then hover any cell to inspect what enters and what comes out. The key observation: cell always reads cell from the row above and the embedding of the token positions to its right — never any cell to its right at its own depth. Causality holds at every depth.
Plain Python: One MTP Module by Hand
Before reaching for PyTorch, let us implement a depth-1 MTP module in plain NumPy. The goal is to expose every shape, every matmul, every normalization — no autograd, no broadcast magic. If you understand these lines, you understand the architecture.
Three observations are worth pulling out. First: the loop body is position-independent — every position runs the exact same arithmetic with different inputs. Second: the only place the future-token information enters is line 33, the embedding lookup at . Third: the projection is the only weight that mixes the two streams; everything else either normalizes a single stream or operates on the already-mixed combined vector.
PyTorch: A Reusable MTP Module Class
In production code the MTP module is just an . It is small, vectorized over , and takes the shared embedding and output head as forward-time arguments so they cannot be silently duplicated.
How the parent model wires it together
The parent model owns the embedding, the main transformer stack, the output head, and a small list of instances. The forward pass looks roughly like:
# main model
h = embed(tokens)
for blk in main_blocks:
h = blk(h, attn_mask)
h0 = h # (B, T, d)
logits0 = output_head(h0) # main next-token logits
# MTP depths
h_prev = h0
mtp_logits = []
for k, mtp in enumerate(mtp_modules, start=1):
tok_fut = shift_left(tokens, k) # (B, T) tokens at i + k
h_prev, logits_k = mtp(h_prev, tok_fut, embed, output_head, attn_mask)
mtp_logits.append(logits_k)Two PyTorch-specific subtleties show up here. First, is not just a Python slice — it must respect padding tokens so the loss is not computed on garbage positions; almost every real implementation passes a position mask alongside. Second, the optimizer sees each MTP module's parameters but NOT a duplicate copy of the embedding or output head, because those were passed positionally rather than registered as sub-modules. That is the whole point of the explicit-arg sharing convention.
What Changes at Massive Scale
At a scale where the main model has 61 transformer blocks, 7168 hidden dimensions, and 671B parameters in total (with about 37B active per token), the MTP module's budget changes character. Let us walk through the costs.
| Resource | Per MTP module | At DeepSeek-V3 scale (D = 1) |
|---|---|---|
| Parameters | ~1 transformer block + 2d² + 2d | ~11B (vs. ~671B main) |
| Active params/token | ~1 block worth | ~600M extra |
| FLOPs / forward token | ~1/61 of main forward | +1.6% training FLOPs |
| Activations memory | 1 extra block worth | +1 block of activations to checkpoint |
| Wall-clock latency | Sequential — adds 1 block depth | +1.6% per-step latency |
| Inference-time cost | Off — modules dropped or used for speculation | 0 if dropped |
The memory story
The dominant memory cost during training is activations, not parameters. An MTP module carries one block's worth of activations per token — and those activations live until the backward pass for the MTP loss runs. With activation checkpointing the cost can be amortized, but the simpler picture is: budget for blocks of activations instead of blocks. At and , that is a 1.6% bump.
The communication story
In data-parallel training, gradient all-reduces dominate the network budget. MTP modules add their parameters to the bucket — about 1.6% more grad data per step. In tensor-parallel or pipeline-parallel layouts, the MTP module's transformer block can be placed on the same shard as the final main block; no extra cross-device traffic. In FSDP, the MTP module is one more wrap unit — trivial to integrate.
The throughput story
Sequential MTP is sequential, full stop. Depth waits for depth . There is no pipeline trick that breaks this dependency without breaking causality — the same dependency that made the parallel design tempting in the first place. At the cost is one block of extra depth, which on H100s at 4k sequence length is in the low single-digit percent of step time. Past the latency cost starts mattering and the per-depth quality return diminishes — which is exactly what DeepSeek's ablations show.
Engineering Reality and Gotchas
Sequence length and the boundary problem
Depth needs the token at position . For the last positions of the sequence that token does not exist. Two correct ways to handle this:
- Mask the loss on positions at depth . Cleanest, costs a couple of token-positions per sequence at .
- Right-pad the input by dummy tokens so every position has a valid future. Simpler at the call site but wastes a hair of compute.
The mask-tying trap
The transformer block inside the MTP module needs its own causal mask — the same shape as the main model's. A common bug is reusing the main model's pre-computed mask after the embedding or positional bias has been folded in; the MTP block ends up attending differently than the main one, sometimes leaking future info, sometimes silently masking the present position. Make the MTP block compute its own mask, or pass an unbiased causal mask explicitly.
RMSNorm scale init
The two RMSNorm gain vectors inside an MTP module initialize to 1, same as everywhere else. But because RMSNorm sits on a stream whose magnitude depends on the main model's drift during training, empirical wisdom is to warm-start the gain on the hidden stream at a slightly lower value (e.g. ) for the first ~1000 steps. This prevents an early stage where the projection sees huge h-stream norms and tiny e-stream norms; without the warm-start the projection's gradient blows up.
Loss scaling and lambda
The full training objective will be detailed in section 7.4, but a preview: the MTP loss is added with a coefficient typically in the range . Setting it too high crowds out the main next-token loss; setting it too low makes the MTP module learn slowly and underperform at inference-time speculation. DeepSeek-V3 uses around 0.3 during early training and decays it linearly.
What "sharing" really means at training time
Weight sharing is a contract enforced by the parent module, not by the MTP class itself. The cleanest way to enforce it in PyTorch is to never store the embedding or the output head as a child module of the MTP instance — pass them in. Otherwise double-counts them and the optimizer applies two updates per step. Most production bugs in MTP implementations come from accidentally registering the shared tensors twice.
The takeaway. Sequential causal MTP is what you get when you refuse to compromise on the chain rule and refuse to spend more than a few percent of the main model's budget. One small transformer block per depth, two RMSNorms, one projection, shared embedding and output head, run in a strict causal sequence. The rest of the chapter — the loss formulation in section 7.4 and the speculative-decoding payoff in section 7.5 — is just downstream of getting this one architectural choice right.