Section 7.1 ended on a frustrating observation. Every transformer forward pass produces, at position , a hidden state that already encodes a great deal about what is going to come next — yet we only ever extract one bit of supervision from it: the loss against the single next token . The most obvious fix is also the most seductive: just bolt K parallel prediction heads on top of the backbone. Head 1 predicts , head 2 predicts , and so on. K times the supervision, no architectural changes downstream.
Meta proposed exactly this design in 2024, and it does help — modestly. But when DeepSeek built MTP into V3 they took a sharper, more complicated path. This section explains why. The naive parallel design has a mathematical defect that grows with , and that defect is not an implementation bug. It is baked into the statement of the problem.
The single-sentence verdict. Naive parallel MTP forces every head to predict a marginal distribution over the future token, while the language really lives on a chain of conditional distributions. The deeper the head, the farther the marginal drifts from the conditional, and the less supervision the loss actually carries.
The Obvious Idea, and Why It Cannot Work
Here is the design verbatim. Run a normal transformer up to the final layer. Read off the hidden states — one per position. For each position you have one hidden state but you want predictions. Solution: attach independent output projections , each mapping the -dimensional hidden state to a logit vector over the vocabulary. The k-th head predicts the token at offset ahead.
Cross-entropy losses for the heads are averaged or summed and that becomes the training objective. Inference is unchanged — production decoding can still use the head-1 next-token path; heads 2 through can be dropped at deploy time or repurposed for speculative decoding. The cost is extra output matrices (~50–200 million extra parameters at LLM scale) and a slightly slower training step.
| Property | Single-token (baseline) | Naive parallel MTP |
|---|---|---|
| Supervision per forward | 1 token (x_t+1) | K tokens |
| Output heads | 1 | K independent |
| Cross-head information flow | n/a | None |
| Compute cost | 1× | ≈ 1× backbone + K× head |
| Param cost | ≈ V·D | ≈ K · V·D |
| What each head predicts | P(x_t+1 | x_≤t) | P(x_t+k | x_≤t) — the marginal |
That last row is the entire story. Every head receives the same ; it has no access to the previous head's prediction, no recurrence, no attention bridging the future. The only thing the model can do is predict, at each horizon, the marginal distribution of given the context up to , summed over all the possible token paths that could connect them. Language is not built that way.
Intuition: The Marginal Trap
Imagine a sentence that starts with the word "the". Ask yourself: what is the very next word? Probably a noun — "cat", "dog", "car", "house". You could rank candidates with reasonable confidence; the distribution is sharp. Now ask: what is the word two positions ahead? It depends completely on which noun came at position 1. If the noun was "cat", the second word is most likely a verb like "sat" or "purred". If the noun was "car", the second word is more likely "is" or "was". The distribution at position 2 is a weighted mixture over all the nouns that could have appeared at position 1, each contributing its own downstream language.
That mixture is what naive parallel MTP is forced to predict. A single hidden state cannot encode which specific noun came at position 1, because no token has been committed yet; the model is supposed to choose it at decode time. So head 2 sees only the prior — all possible nouns and their downstream verbs — and produces the marginal distribution of the second word, integrated over the entire English-language tree branching from the context. The same spreading happens at every deeper horizon, and it compounds. By horizon 4 the distribution is so flat it is almost uninformative.
A model trained to match the k-step marginal is being asked, in effect, to memorize statistical averages of long-range futures rather than to learn rules. Worse, the loss signal from horizon degrades quickly as grows — high-entropy targets have small per-example log-likelihood gains from any one parameter update. So each extra head adds compute cost but contributes a rapidly diminishing amount of useful learning signal.
The Math: Independence vs. Chain Rule
Language modeling is the task of estimating the joint distribution . Factorization makes this tractable, and there is a right way and a wrong way. The chain rule factorization is the right one:
Each factor on the right conditions on everything that came before, including the previously predicted future tokens. This is what a standard left-to-right LM learns at training time, and it is what gives the model both calibrated sharpness and the ability to commit to one branch of the language tree.
Naive parallel MTP instead factors the joint as a product of marginals:
Note the conditioning bar: every factor stops at . The future tokens never enter the condition. This is the classical independence assumption — and it is wrong about language. The two distributions agree only when the actual data-generating process is a memoryless markov chain of order equal to the context length, which natural language emphatically is not.
The Kullback–Leibler gap
We can measure exactly how wrong by computing the KL divergence between the true joint and the naive factorization. A useful identity:
where is the conditional mutual information. Read it directly: the KL gap is the total mutual information between each future token and its predecessors. For natural language that quantity is enormous — words depend on the words that immediately precede them. Naive parallel MTP throws away all of that information and asks the model to compensate by stuffing every future moment into a single hidden state.
Why deeper heads suffer more
The marginal at horizon has the form:
It is a sum over possible intermediate sequences. Each successive horizon adds a marginalization axis, and entropy grows roughly logarithmically with the size of that sum until the marginal saturates near uniform. The visualizer below and the numerical walkthrough show this entropy growth precisely.
Manual Numerical Walkthrough
Let us watch the gap open in a toy world. Vocabulary tokens — call them the, cat, sat, purr — and a fixed first-order Markov chain that we will pretend the model has learned perfectly. The context token is the. The true future sequence is (cat, sat, purr, the).
Click to expand: K = 4 horizons, every number by hand
Setup. The Markov transition matrix :
the cat sat purr from the: [0.05, 0.55, 0.25, 0.15] from cat: [0.10, 0.05, 0.55, 0.30] from sat: [0.50, 0.10, 0.05, 0.35] from purr: [0.40, 0.20, 0.30, 0.10]
Horizon 1. Both factorizations agree at the first step (they both condition only on ):
P(x_t+1 | the) = [0.05, 0.55, 0.25, 0.15]; H = 1.110 nats
Target is "cat", so . No gap yet.
Horizon 2. The chain factorization conditions on the true previous token ("cat"), so it just reads off row 2 of :
P_chain(x_t+2 | cat) = [0.10, 0.05, 0.55, 0.30]; H = 1.070 nats
The naive marginal has to sum out the unknown :
P_naive(x_t+2 | the) = P(x_t+1 | the) · T
= 0.05·[0.05,0.55,0.25,0.15]
+ 0.55·[0.10,0.05,0.55,0.30]
+ 0.25·[0.50,0.10,0.05,0.35]
+ 0.15·[0.40,0.20,0.30,0.10]
= [0.243, 0.110, 0.373, 0.275]; H = 1.309 nats
Target "sat": , . The chain factorization assigns nearly more probability mass to the correct token.
Horizon 3. Chain reads row 3 of (from "sat"):
P_chain(x_t+3 | sat) = [0.50, 0.10, 0.05, 0.35]; H = 1.094 nats
Naive marginal applies one more time:
P_naive(x_t+3) = [0.319, 0.231, 0.222, 0.227]; H = 1.374 nats
That entropy of 1.374 nats is within of : the horizon-3 marginal has collapsed almost to uniform. Target "purr": , .
Horizon 4. Chain reads row 4 (from "purr"): [0.40, 0.20, 0.30, 0.10]. Target "the": . The naive marginal is now essentially uniform — we get and .
Total log-likelihood across K = 4 heads.
log P_naive total = -0.598 - 0.987 - 1.482 - 1.423 = -4.490
log P_chain total = -0.598 - 0.598 - 1.050 - 0.916 = -3.162
gap (chain - naive) = 1.328 nats
That 1.328-nat gap is exactly the supervision signal naive parallel MTP loses. Per token, the chain factorization is more confident on the true future.
Where the gap lives, by horizon.
horizon k chain logP naive logP gap (nats) marginal entropy 1 -0.598 -0.598 0.000 1.110 2 -0.598 -0.987 0.389 1.309 3 -1.050 -1.482 0.432 1.374 4 -0.916 -1.423 0.507 1.380* total -3.162 -4.490 1.328 * approaching log 4 = 1.386 (uniform)
Read the trend. The gap monotonically grows with , while the marginal entropy saturates near . The model has nothing left to learn at deep horizons — the loss is dominated by the irreducible entropy of the future, not by the information actually available in .
Visualizing the Coherence Collapse
The visualizer below shows the K parallel heads, each producing its own distribution over the toy vocabulary. Toggle between Naive marginal and Chain conditional to see how the same horizon looks under the two factorizations. Step up from 1 to 4 and watch the naive bars flatten out while the chain bars stay sharp.
Three things to lock in. First, at both factorizations agree perfectly — this is just standard next-token prediction, and naive parallel MTP is free of defect there. Second, by the naive marginal at horizon 3 is already nearly uniform: entropy nats out of a maximum of nats. The model has nothing to predict because the average over all possible prefixes is flat. Third, the bottom bar shows the accumulated log-likelihood gap — that is the training signal that naive MTP cannot recover with any amount of extra parameters.
Plain Python: Two Heads From One Hidden State
The full implementation of naive parallel MTP is short — that is part of its appeal. The code below builds independent output projections, runs them all on the same , and computes a summed cross-entropy. Every line maps directly to one factor in .
Two structural details to notice. First, there is nothing in the code that prevents the heads from sharing information — we simply do not provide a path. The list comprehension over W[k] is the entire story: independent matmuls from the same vector. Second, the gradient of head 's loss flows back through and into — and from there into the backbone. So the backbone receives a sum of gradient signals at every position. This is the upside that drove the idea in the first place: K times the supervision pressure on the backbone.
Sanity check. Set and the code becomes ordinary next-token cross-entropy. Set very large and you will seepreds[k]for big flatten toward uniform — exactly the marginal collapse the math predicts. The architecture cannot escape it.
PyTorch: K Parallel Heads, Vectorized
The production implementation runs in a batched, GPU-vectorized form. The new ideas are minimal: a nn.ModuleList to keep the K head parameters discoverable by the optimizer, and a single torch.stack to give the head axis a place in the tensor.
Three subtleties worth marking, all about how this module interacts with the rest of the training stack:
- FSDP cares about the K heads. At billion-scale, each head matrix is parameters. Three or four heads pushes you into multi-billion extra parameters that have to be sharded just like the rest of the model. Engineering choice: shard each head separately (one device per head) or shard each head's rows across the data-parallel group. Either way, the cost is real.
- Loss imbalance across horizons. Because deeper horizons have higher entropy, their losses are systematically larger. With reduction='mean' over flat (B*T*K), the deeper heads dominate the gradient direction. A common fix is to weight each horizon by or by some manual schedule, but this introduces another hyperparameter and does not address the underlying marginalization problem.
- The label-shift trick.
targets[:, :, k]is constructed by shifting the input ids by positions and padding the tail with-100. The last K positions of every sequence have no valid horizon-K target. Forget to set the ignore index and you train the deepest head on garbage from the next batch's first tokens.
What Changes at Massive Scale
At toy scale, naive parallel MTP looks like a tax on training compute that buys some extra supervision. At production scale, the cost accounting flips:
| Quantity | Single-token baseline | Naive parallel MTP (K=4) | Why it matters |
|---|---|---|---|
| Backbone FLOPs / step | 1× | 1× | Heads add a fixed cost on top, not multiply |
| Output-head FLOPs / step | 1× | 4× | V·D matmul repeated K times per position |
| Parameter count | ~600B | ~602B | K × V·D ≈ 2B extra at LLM scale |
| Activation memory | 1× | K× | K logit tensors of shape (B, T, V) |
| Effective loss signal / step | 1× | ~1.5–2× | Diminishing returns from marginal collapse |
| Wall-clock per step | 1× | 1.05–1.15× | Backbone dominates; heads are cheap relative |
Two patterns to read here. First, the cost is modest — backbone matmuls dwarf the output heads at LLM scale, so adding 4 heads adds maybe 10% to per-step time. That is the upside Meta's original paper claimed: a cheap way to extract more supervision per forward.
Second, the effective gain is much less than . The numerical walkthrough showed the per-horizon loss saturating near — meaning the gradient signal from head 4 carries roughly the same information as the difference between log-uniform and slight-deviation-from-uniform. Empirically, naive parallel MTP buys about a 1.5×–2× signal density per step rather than the naive K×. That is real but underwhelming for the engineering investment.
The activation-memory cost
At , a single logits tensor takes in bf16. At , that is per layer of head outputs. Even with chunked cross-entropy (computing the loss for slices of the logits and discarding) the head memory becomes a real constraint in tight FSDP setups. This is the second reason teams looked for a better MTP design — they wanted the supervision without the K-fold blowup in head activations.
Why DeepSeek Did Not Use Naive Parallel MTP
DeepSeek-V3's team measured the naive design carefully before committing to an alternative. Three findings, all flowing from the marginal-vs-conditional issue, drove their choice:
- Token-acceptance rate at decode time is poor. One benefit of MTP is using head at inference to speculate -ahead tokens for speculative decoding. Naive heads give acceptance rates in the 20–40% range beyond , because the marginal distribution diverges from what the actual chain produces. Each rejected speculation is a wasted forward.
- The deeper heads degrade backbone quality. Because head 's loss pressures the backbone to encode marginal-future information in — information that is not actionable at inference — the backbone's representation is pulled away from the chain-rule optimum that single-token training would have produced. DeepSeek measured a small but consistent regression in downstream evals when naive MTP was added.
- The fix is not architectural depth. Doubling head size, tying head weights, sharing more parameters — none of these touch the conditioning structure. The only way to recover the chain-rule factorization is to let head see what head predicted. That requires a causal dependency between the heads, which is exactly what section 7.3 builds.
The one sentence to carry forward: K parallel heads on a shared hidden state can only ever learn the marginal future distribution, and the marginal collapses toward uniform as the horizon grows — so the extra heads cost compute, parameters, and activation memory without buying the proportional supervision they appear to promise. Section 7.3 turns this failure into the recipe for a fix.