Chapter 7
15 min read
Section 37 of 117

Naive Parallel MTP and Its Failure Mode

Multi-Token Prediction (MTP)

Section 7.1 ended on a frustrating observation. Every transformer forward pass produces, at position tt, a hidden state hth_t that already encodes a great deal about what is going to come next — yet we only ever extract one bit of supervision from it: the loss against the single next token xt+1x_{t+1}. The most obvious fix is also the most seductive: just bolt K parallel prediction heads on top of the backbone. Head 1 predicts xt+1x_{t+1}, head 2 predicts xt+2x_{t+2}, and so on. K times the supervision, no architectural changes downstream.

Meta proposed exactly this design in 2024, and it does help — modestly. But when DeepSeek built MTP into V3 they took a sharper, more complicated path. This section explains why. The naive parallel design has a mathematical defect that grows with KK, and that defect is not an implementation bug. It is baked into the statement of the problem.

The single-sentence verdict. Naive parallel MTP forces every head to predict a marginal distribution over the future token, while the language really lives on a chain of conditional distributions. The deeper the head, the farther the marginal drifts from the conditional, and the less supervision the loss actually carries.

The Obvious Idea, and Why It Cannot Work

Here is the design verbatim. Run a normal transformer up to the final layer. Read off the hidden states h1,h2,,hTh_1, h_2, \dots, h_T— one per position. For each position you have one hidden state but you want KK predictions. Solution: attach KK independent output projections W1,,WKW_1, \dots, W_K, each mapping the DD-dimensional hidden state to a logit vector over the vocabulary. The k-th head predicts the token at offset kk ahead.

Cross-entropy losses for the KK heads are averaged or summed and that becomes the training objective. Inference is unchanged — production decoding can still use the head-1 next-token path; heads 2 through KK can be dropped at deploy time or repurposed for speculative decoding. The cost is KK extra output matrices (~50–200 million extra parameters at LLM scale) and a slightly slower training step.

PropertySingle-token (baseline)Naive parallel MTP
Supervision per forward1 token (x_t+1)K tokens
Output heads1K independent
Cross-head information flown/aNone
Compute cost≈ 1× backbone + K× head
Param cost≈ V·D≈ K · V·D
What each head predictsP(x_t+1 | x_≤t)P(x_t+k | x_≤t) — the marginal

That last row is the entire story. Every head receives the same hth_t; it has no access to the previous head's prediction, no recurrence, no attention bridging the future. The only thing the model can do is predict, at each horizon, the marginal distribution of xt+kx_{t+k} given the context up to tt, summed over all the possible token paths xt+1,,xt+k1x_{t+1}, \dots, x_{t+k-1} that could connect them. Language is not built that way.

Where the problem lives. The bug is not in the loss, not in the head architecture, not in the data pipeline. It is in theindependence assumption the architecture forces on the K head outputs. The math below makes this precise.

Intuition: The Marginal Trap

Imagine a sentence that starts with the word "the". Ask yourself: what is the very next word? Probably a noun — "cat", "dog", "car", "house". You could rank candidates with reasonable confidence; the distribution is sharp. Now ask: what is the word two positions ahead? It depends completely on which noun came at position 1. If the noun was "cat", the second word is most likely a verb like "sat" or "purred". If the noun was "car", the second word is more likely "is" or "was". The distribution at position 2 is a weighted mixture over all the nouns that could have appeared at position 1, each contributing its own downstream language.

That mixture is what naive parallel MTP is forced to predict. A single hidden state hth_t cannot encode which specific noun came at position 1, because no token has been committed yet; the model is supposed to choose it at decode time. So head 2 sees only the prior — all possible nouns and their downstream verbs — and produces the marginal distribution of the second word, integrated over the entire English-language tree branching from the context. The same spreading happens at every deeper horizon, and it compounds. By horizon 4 the distribution is so flat it is almost uninformative.

The right mental picture. Think of language as a tree. At time tt the model is sitting on one node. The next-token distribution is the set of branches leaving that node — sharp and informative. The two-step distribution is the set of leaves of all 2-deep subtrees — broader. The k-step marginal is the cross-section of the entire k-deep tree at depth kk, which has exponentially many leaves and therefore exponentially less structure visible from any one of them.

A model trained to match the k-step marginal is being asked, in effect, to memorize statistical averages of long-range futures rather than to learn rules. Worse, the loss signal from horizon kk degrades quickly as kk grows — high-entropy targets have small per-example log-likelihood gains from any one parameter update. So each extra head adds compute cost but contributes a rapidly diminishing amount of useful learning signal.

The Math: Independence vs. Chain Rule

Language modeling is the task of estimating the joint distribution P(xt+1,xt+2,,xt+Kxt)P(x_{t+1}, x_{t+2}, \dots, x_{t+K} \mid x_{\le t}). Factorization makes this tractable, and there is a right way and a wrong way. The chain rule factorization is the right one:

P(xt+1,,xt+Kxt)=k=1KP(xt+kxt+k1)P(x_{t+1}, \dots, x_{t+K} \mid x_{\le t}) = \prod_{k=1}^{K} P(x_{t+k} \mid x_{\le t+k-1})

Each factor on the right conditions on everything that came before, including the previously predicted future tokens. This is what a standard left-to-right LM learns at training time, and it is what gives the model both calibrated sharpness and the ability to commit to one branch of the language tree.

Naive parallel MTP instead factors the joint as a product of marginals:

Pnaive(xt+1,,xt+Kxt)=k=1KP(xt+kxt)P_{\text{naive}}(x_{t+1}, \dots, x_{t+K} \mid x_{\le t}) = \prod_{k=1}^{K} P(x_{t+k} \mid x_{\le t})

Note the conditioning bar: every factor stops at xtx_{\le t}. The future tokens never enter the condition. This is the classical independence assumption — and it is wrong about language. The two distributions agree only when the actual data-generating process is a memoryless markov chain of order equal to the context length, which natural language emphatically is not.

The Kullback–Leibler gap

We can measure exactly how wrong by computing the KL divergence between the true joint and the naive factorization. A useful identity:

KL(PPnaive)=k=2KI(xt+k;xt+1,,xt+k1xt)\text{KL}(P \,\|\, P_{\text{naive}}) = \sum_{k=2}^{K} I(x_{t+k} ; x_{t+1}, \dots, x_{t+k-1} \mid x_{\le t})

where II is the conditional mutual information. Read it directly: the KL gap is the total mutual information between each future token and its predecessors. For natural language that quantity is enormous — words depend on the words that immediately precede them. Naive parallel MTP throws away all of that information and asks the model to compensate by stuffing every future moment into a single hidden state.

Why deeper heads suffer more

The marginal at horizon kk has the form:

P(xt+kxt)=xt+1,,xt+k1j=1k1P(xt+jxt+j1)P(xt+kxt+k1)P(x_{t+k} \mid x_{\le t}) = \sum_{x_{t+1}, \dots, x_{t+k-1}} \prod_{j=1}^{k-1} P(x_{t+j} \mid x_{\le t+j-1}) \cdot P(x_{t+k} \mid x_{\le t+k-1})

It is a sum over Vk1V^{k-1} possible intermediate sequences. Each successive horizon adds a marginalization axis, and entropy grows roughly logarithmically with the size of that sum until the marginal saturates near uniform. The visualizer below and the numerical walkthrough show this entropy growth precisely.

The asymmetry is not negotiable. No matter how expressive head kk is, it can only learn the marginal — because the conditional information it would need (the actual xt+1,,xt+k1x_{t+1}, \dots, x_{t+k-1}) is simply not present in hth_t at train time. You cannot fix this with a wider head, a deeper head, or more parameters. You have to fix it by changing the conditioning structure — which is exactly what DeepSeek's sequential causal MTP does in the next section.

Manual Numerical Walkthrough

Let us watch the gap open in a toy world. Vocabulary V=4V = 4 tokens — call them the, cat, sat, purr — and a fixed first-order Markov chain that we will pretend the model has learned perfectly. The context token is xt=x_t = the. The true future sequence is (cat, sat, purr, the).

Click to expand: K = 4 horizons, every number by hand

Setup. The Markov transition matrix Tij=P(xn+1=jxn=i)T_{ij} = P(x_{n+1}=j \mid x_n=i):

            the    cat    sat    purr
from the:  [0.05,  0.55,  0.25,  0.15]
from cat:  [0.10,  0.05,  0.55,  0.30]
from sat:  [0.50,  0.10,  0.05,  0.35]
from purr: [0.40,  0.20,  0.30,  0.10]

Horizon 1. Both factorizations agree at the first step (they both condition only on xtx_t):

P(x_t+1 | the) = [0.05, 0.55, 0.25, 0.15]; H = 1.110 nats

Target is "cat", so logPnaive(1)=logPchain(1)=log0.55=0.598\log P_{\text{naive}}^{(1)} = \log P_{\text{chain}}^{(1)} = \log 0.55 = -0.598. No gap yet.

Horizon 2. The chain factorization conditions on the true previous token ("cat"), so it just reads off row 2 of TT:

P_chain(x_t+2 | cat) = [0.10, 0.05, 0.55, 0.30]; H = 1.070 nats

The naive marginal has to sum out the unknown xt+1x_{t+1}:

P_naive(x_t+2 | the) = P(x_t+1 | the) · T
= 0.05·[0.05,0.55,0.25,0.15]
+ 0.55·[0.10,0.05,0.55,0.30]
+ 0.25·[0.50,0.10,0.05,0.35]
+ 0.15·[0.40,0.20,0.30,0.10]
= [0.243, 0.110, 0.373, 0.275]; H = 1.309 nats

Target "sat": logPchain(2)=log0.55=0.598\log P_{\text{chain}}^{(2)} = \log 0.55 = -0.598, logPnaive(2)=log0.373=0.987\log P_{\text{naive}}^{(2)} = \log 0.373 = -0.987. The chain factorization assigns nearly 50%50\%more probability mass to the correct token.

Horizon 3. Chain reads row 3 of TT (from "sat"):

P_chain(x_t+3 | sat) = [0.50, 0.10, 0.05, 0.35]; H = 1.094 nats

Naive marginal applies TT one more time:

P_naive(x_t+3) = [0.319, 0.231, 0.222, 0.227]; H = 1.374 nats

That entropy of 1.374 nats is within 1%1\%of log4=1.386\log 4 = 1.386: the horizon-3 marginal has collapsed almost to uniform. Target "purr": logPchain(3)=log0.35=1.050\log P_{\text{chain}}^{(3)} = \log 0.35 = -1.050, logPnaive(3)=log0.227=1.482\log P_{\text{naive}}^{(3)} = \log 0.227 = -1.482.

Horizon 4. Chain reads row 4 (from "purr"): [0.40, 0.20, 0.30, 0.10]. Target "the": logPchain(4)=log0.40=0.916\log P_{\text{chain}}^{(4)} = \log 0.40 = -0.916. The naive marginal is now essentially uniform — we get Pnaive(the)0.241P_{\text{naive}}(\text{the}) \approx 0.241 and log0.241=1.423\log 0.241 = -1.423.

Total log-likelihood across K = 4 heads.

log P_naive total = -0.598 - 0.987 - 1.482 - 1.423 = -4.490
log P_chain total = -0.598 - 0.598 - 1.050 - 0.916 = -3.162
gap (chain - naive) = 1.328 nats

That 1.328-nat gap is exactly the supervision signal naive parallel MTP loses. Per token, the chain factorization is e0.331.4×e^{0.33} \approx 1.4 \times more confident on the true future.

Where the gap lives, by horizon.

horizon k   chain logP   naive logP   gap (nats)   marginal entropy
   1          -0.598       -0.598       0.000          1.110
   2          -0.598       -0.987       0.389          1.309
   3          -1.050       -1.482       0.432          1.374
   4          -0.916       -1.423       0.507          1.380*
   total      -3.162       -4.490       1.328
   * approaching log 4 = 1.386 (uniform)

Read the trend. The gap monotonically grows with kk, while the marginal entropy saturates near logV\log V. The model has nothing left to learn at deep horizons — the loss is dominated by the irreducible entropy of the future, not by the information actually available in hth_t.

Visualizing the Coherence Collapse

The visualizer below shows the K parallel heads, each producing its own distribution over the toy vocabulary. Toggle between Naive marginal and Chain conditional to see how the same horizon looks under the two factorizations. Step KK up from 1 to 4 and watch the naive bars flatten out while the chain bars stay sharp.

Loading naive parallel MTP visualizer…

Three things to lock in. First, at K=1K = 1 both factorizations agree perfectly — this is just standard next-token prediction, and naive parallel MTP is free of defect there. Second, by K=3K = 3 the naive marginal at horizon 3 is already nearly uniform: entropy 1.37\approx 1.37 nats out of a maximum of log41.39\log 4 \approx 1.39 nats. The model has nothing to predict because the average over all possible prefixes is flat. Third, the bottom bar shows the accumulated log-likelihood gap — that is the training signal that naive MTP cannot recover with any amount of extra parameters.

Plain Python: Two Heads From One Hidden State

The full implementation of naive parallel MTP is short — that is part of its appeal. The code below builds K=3K = 3 independent output projections, runs them all on the same hth_t, and computes a summed cross-entropy. Every line maps directly to one factor in kP(xt+kxt)\prod_k P(x_{t+k} \mid x_{\le t}).

🐍naive_parallel_mtp_numpy.py
3Toy dimensions: V, D, K

A tiny world: 4-token vocabulary, 8-dim hidden state, K = 3 future heads. Real models use V around 100k, D around 5k+, K = 1 (single-token) or K = 2–4 (MTP).

EXECUTION STATE
V = 4
D = 8
K = 3
7The shared hidden state h_t

This is the output of the backbone transformer at position t. In naive parallel MTP, this single vector is the only signal every head receives — it must somehow encode information about x_(t+1), x_(t+2), and x_(t+3) simultaneously.

EXECUTION STATE
h_t.shape = (8,)
11Per-head output projections

Each head k owns its own matrix W[k] of shape (V, D). The K heads are structurally independent: there is no recurrence, no attention, no information flow between them. They are K parallel linear maps from the same input.

EXECUTION STATE
len(W) = 3
W[0].shape = (4, 8)
12Per-head biases

Standard bias term added before the softmax. Tiny detail compared to W, included only to make the toy code parallel real implementations.

19K independent softmaxes

📚 The list comprehension runs softmax(W[k] @ h_t + b[k]) for k = 0, 1, 2. Each result is a length-V probability vector. Because each W[k] sees only h_t, head k cannot condition on whatever head k-1 predicted — that is the entire failure mode of this design.

EXECUTION STATE
len(preds) = 3
preds[0] = P(x_t+1 | h_t)
preds[1] = P(x_t+2 | h_t)
preds[2] = P(x_t+3 | h_t)
22Ground-truth future tokens

The true tokens at positions t+1, t+2, t+3. In real training these come from the data stream; here they are fixed targets so we can compute a loss.

EXECUTION STATE
y_true = [1, 2, 3]
25Sum of per-head negative log-likelihoods

The MTP loss is just K cross-entropy losses added together. Each term pulls preds[k][y_true[k]] toward 1. Critically, the gradient of the term for horizon k only flows back into W[k] and into h_t — there is no gradient path between heads.

31Why the deeper heads will always underperform

When you run this and inspect preds[2], you will see a much flatter distribution than preds[0] — even though all three heads have identical architecture and parameter count. The reason is not in the code; it is in the data. h_t simply does not encode enough information to pin down x_(t+3) without knowing what x_(t+1) and x_(t+2) were going to be. The deeper the head, the more it is forced to predict the marginal of the future given x_(<=t), not the conditional given the prefix. We unpack this in the math section below.

24 lines without explanation
1import numpy as np
2
3# Toy LM: V=4 vocab, D=8 hidden dim, K=3 future heads.
4V, D, K = 4, 8, 3
5np.random.seed(0)
6
7# One transformer forward gives h_t — the hidden state at position t.
8# In naive parallel MTP, EVERY head reads this single vector.
9h_t = np.array([0.5, -0.2, 0.8, 0.3, -0.4, 0.6, 0.1, -0.7])
10
11# Each head has its own output projection (V, D) and bias (V,).
12W = [np.random.randn(V, D) * 0.3 for _ in range(K)]
13b = [np.random.randn(V) * 0.1 for _ in range(K)]
14
15def softmax(x):
16    z = x - x.max()
17    e = np.exp(z)
18    return e / e.sum()
19
20# K independent predictions from the same h_t.
21preds = [softmax(W[k] @ h_t + b[k]) for k in range(K)]
22
23# Ground-truth future tokens at positions t+1, t+2, t+3.
24y_true = [1, 2, 3]
25
26# Total cross-entropy loss across all K horizons.
27loss = -sum(np.log(preds[k][y_true[k]]) for k in range(K))
28
29for k, p in enumerate(preds):
30    print(f"head {k+1} -> x_(t+{k+1}): "
31          f"P = {p.round(3)}  P(target) = {p[y_true[k]]:.3f}")
32print(f"Total NLL across K heads: {loss:.3f}")

Two structural details to notice. First, there is nothing in the code that prevents the heads from sharing information — we simply do not provide a path. The list comprehension over W[k] is the entire story: KK independent matmuls from the same vector. Second, the gradient of head kk's loss flows back through WkW_k and into hth_t — and from there into the backbone. So the backbone receives a sum of KK gradient signals at every position. This is the upside that drove the idea in the first place: K times the supervision pressure on the backbone.

Sanity check. Set K=1K = 1 and the code becomes ordinary next-token cross-entropy. Set KK very large and you will see preds[k] for big kk flatten toward uniform — exactly the marginal collapse the math predicts. The architecture cannot escape it.

PyTorch: K Parallel Heads, Vectorized

The production implementation runs in a batched, GPU-vectorized form. The new ideas are minimal: a nn.ModuleList to keep the K head parameters discoverable by the optimizer, and a single torch.stack to give the head axis a place in the tensor.

🐍naive_parallel_mtp_pytorch.py
6A nn.Module wrapping the K heads

We register K linear layers under one parent so PyTorch tracks their parameters for the optimizer and for FSDP sharding. The shared backbone lives outside this module — it is passed in as `h` at forward time.

11ModuleList of K independent Linear layers

📚 nn.ModuleList registers every contained module as a child so its parameters are seen by `model.parameters()`. We use it instead of a plain Python list because a plain list hides parameters from the optimizer.

EXECUTION STATE
K = e.g. 3
self.heads[0].weight.shape = (V, D)
12bias=False follows GPT-style convention

Modern LMs frequently drop the output bias — it adds parameters without measurable quality. Set bias=True if you want to mirror older codebases.

17Input: shared backbone hidden states

h is the output of the entire transformer stack — every layer's contribution baked in. Shape (B, T, D): B sequences, T tokens each, D hidden dim. Every head reads the same h. This is the structural choice that defines naive parallel MTP.

EXECUTION STATE
h.shape = (B, T, D)
18Targets: K future tokens per position

targets[b, t, k] is the token at position t+k+1 in sequence b. This is built from the input ids by shifting K times and stacking. Padding positions are set to -100 so cross_entropy ignores them.

EXECUTION STATE
targets.shape = (B, T, K)
22Per-head linear projection

📚 The list comprehension calls each head's linear layer on the same h. head(h) is shape (B, T, V). torch.stack on dim=-2 inserts a new K-axis, producing (B, T, K, V). Notice the heads share no computation — each just does one matmul with its own weights.

EXECUTION STATE
logits.shape = (B, T, K, V)
28Read V from the last logits axis

Avoids hardcoding the vocab size — the same module works for any V at load time.

29Cross-entropy reduction

📚 F.cross_entropy combines log_softmax + nll_loss. Reshaping logits to (B*T*K, V) and targets to (B*T*K,) lets us treat every (sequence, position, head) triple as one independent classification example. reduction='mean' averages the negative log-likelihoods.

30logits.reshape(-1, V)

Flattens the leading three dims into one. The K head dimension is treated identically to the batch and time dims for loss purposes — this is what makes naive MTP cheap: it just adds K-fold to the cross-entropy bill.

31targets.reshape(-1)

Matches the flattened logits. Each entry is the integer class label for one (b, t, k) cell.

32ignore_index=-100 for padding

📚 cross_entropy skips any target equal to ignore_index. We pad targets with -100 for positions where the future runs off the end of the sequence (the last K positions of every sequence have no valid t+k token).

33reduction='mean'

📚 Per-example losses are averaged. Switch to 'none' if you want to weight horizons differently — for example, down-weighting deeper heads. The naive implementation gives every horizon equal weight, which is itself a questionable choice (deeper horizons are harder, so they incur larger loss and dominate the gradient).

23 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5
6class NaiveParallelMTP(nn.Module):
7    """K independent prediction heads on top of one shared backbone."""
8
9    def __init__(self, d_model: int, vocab_size: int, K: int):
10        super().__init__()
11        self.K = K
12        # K separate output projections — no parameter sharing.
13        self.heads = nn.ModuleList(
14            nn.Linear(d_model, vocab_size, bias=False) for _ in range(K)
15        )
16
17    def forward(
18        self,
19        h: torch.Tensor,        # (B, T, D) backbone hidden states
20        targets: torch.Tensor,  # (B, T, K) tokens at offsets +1..+K
21    ) -> torch.Tensor:
22        # All heads read the SAME h. No mixing between heads.
23        logits = torch.stack(
24            [head(h) for head in self.heads], dim=-2
25        )                                              # (B, T, K, V)
26
27        # Flatten and run one cross-entropy over all K * B * T positions.
28        V = logits.size(-1)
29        loss = F.cross_entropy(
30            logits.reshape(-1, V),
31            targets.reshape(-1),
32            ignore_index=-100,
33            reduction="mean",
34        )
35        return loss

Three subtleties worth marking, all about how this module interacts with the rest of the training stack:

  1. FSDP cares about the K heads. At billion-scale, each head matrix is VD100k5k=500MV \cdot D \approx 100\text{k} \cdot 5\text{k} = 500\text{M} parameters. Three or four heads pushes you into multi-billion extra parameters that have to be sharded just like the rest of the model. Engineering choice: shard each head separately (one device per head) or shard each head's rows across the data-parallel group. Either way, the cost is real.
  2. Loss imbalance across horizons. Because deeper horizons have higher entropy, their losses are systematically larger. With reduction='mean' over flat (B*T*K), the deeper heads dominate the gradient direction. A common fix is to weight each horizon by 1/Hk1 / H_k or by some manual schedule, but this introduces another hyperparameter and does not address the underlying marginalization problem.
  3. The label-shift trick. targets[:, :, k] is constructed by shifting the input ids by k+1k + 1 positions and padding the tail with -100. The last K positions of every sequence have no valid horizon-K target. Forget to set the ignore index and you train the deepest head on garbage from the next batch's first tokens.
Implementation note. Some open-source codebases save memory by reusing the same output embedding matrix across all K heads (tied weights). This is parameter-efficient but does not change the marginal-collapse problem at all — the failure is conditional, not in the size of the heads.

What Changes at Massive Scale

At toy scale, naive parallel MTP looks like a tax on training compute that buys some extra supervision. At production scale, the cost accounting flips:

QuantitySingle-token baselineNaive parallel MTP (K=4)Why it matters
Backbone FLOPs / stepHeads add a fixed cost on top, not multiply
Output-head FLOPs / stepV·D matmul repeated K times per position
Parameter count~600B~602BK × V·D ≈ 2B extra at LLM scale
Activation memoryK logit tensors of shape (B, T, V)
Effective loss signal / step~1.5–2×Diminishing returns from marginal collapse
Wall-clock per step1.05–1.15×Backbone dominates; heads are cheap relative

Two patterns to read here. First, the cost is modest — backbone matmuls dwarf the output heads at LLM scale, so adding 4 heads adds maybe 10% to per-step time. That is the upside Meta's original paper claimed: a cheap way to extract more supervision per forward.

Second, the effective gain is much less than K×K \times. The numerical walkthrough showed the per-horizon loss saturating near logV\log V — meaning the gradient signal from head 4 carries roughly the same information as the difference between log-uniform and slight-deviation-from-uniform. Empirically, naive parallel MTP buys about a 1.5×–2× signal density per step rather than the naive K×. That is real but underwhelming for the engineering investment.

The activation-memory cost

At B=8,T=4096,V=102kB = 8, T = 4096, V = 102\text{k}, a single logits tensor takes 84096102k2 bytes6.7 GB8 \cdot 4096 \cdot 102\text{k} \cdot 2 \text{ bytes} \approx 6.7\text{ GB} in bf16. At K=4K = 4, that is 27 GB\approx 27\text{ GB} per layer of head outputs. Even with chunked cross-entropy (computing the loss for slices of the logits and discarding) the head memory becomes a real constraint in tight FSDP setups. This is the second reason teams looked for a better MTP design — they wanted the supervision without the K-fold blowup in head activations.

Why DeepSeek Did Not Use Naive Parallel MTP

DeepSeek-V3's team measured the naive design carefully before committing to an alternative. Three findings, all flowing from the marginal-vs-conditional issue, drove their choice:

  1. Token-acceptance rate at decode time is poor. One benefit of MTP is using head kk at inference to speculate kk-ahead tokens for speculative decoding. Naive heads give acceptance rates in the 20–40% range beyond k=1k = 1, because the marginal distribution diverges from what the actual chain produces. Each rejected speculation is a wasted forward.
  2. The deeper heads degrade backbone quality. Because head kk's loss pressures the backbone to encode marginal-future information in hth_t — information that is not actionable at inference — the backbone's representation is pulled away from the chain-rule optimum that single-token training would have produced. DeepSeek measured a small but consistent regression in downstream evals when naive MTP was added.
  3. The fix is not architectural depth. Doubling head size, tying head weights, sharing more parameters — none of these touch the conditioning structure. The only way to recover the chain-rule factorization is to let head kk see what head k1k-1 predicted. That requires a causal dependency between the heads, which is exactly what section 7.3 builds.
The reframe. Naive parallel MTP fails because it treats the K future tokens as if they came from K independent random variables given the context, when in reality they come from one chain with strong serial dependencies. The fix is to add a thin causal connection between the heads — preserving most of the parallelism and cheapness, while restoring the chain-rule factorization that language actually obeys. That is the architecture DeepSeek built.

The one sentence to carry forward: K parallel heads on a shared hidden state can only ever learn the marginal future distribution, and the marginal collapses toward uniform as the horizon grows — so the extra heads cost compute, parameters, and activation memory without buying the proportional supervision they appear to promise. Section 7.3 turns this failure into the recipe for a fix.

Loading comments...