Chapter 7
20 min read
Section 40 of 117

MTP as Speculative Decoding at Inference

Multi-Token Prediction (MTP)

Sections 7.1 through 7.4 made the case for Multi-Token Prediction as a training objective: it pulls more supervision out of every forward pass, sharpens the model's long-range planning, and costs almost nothing in extra parameters. But there is a second life for the same machinery that DeepSeek-V3 quietly cashes in at deployment time. The K small MTP heads that were trained to predict 2,3,,K2, 3, \dots, K tokens ahead already know how to draft K speculative tokens for any prefix. Hand those drafts to the main model for one verification forward pass and you get a wall-clock speedup that costs nothing to deploy. That is MTP-driven speculative decoding — and it is why a 671B-parameter mixture-of-experts model can serve tokens at a price that makes commercial sense.

The promise of MTP at inference. Reuse the training-time draft heads as a free, in-distribution draft model. One main-model forward pass per round, K draft tokens proposed in parallel, between 1 and K+1K+1 tokens committed. The arithmetic is striking: with K=3 and a 75% acceptance rate, the same hardware emits roughly 3×3\times more tokens per second than the standard autoregressive loop — at identical sampling distribution.

The Inference Bottleneck

Standard autoregressive decoding does one forward pass per generated token. For a 67B-parameter dense model that is hundreds of gigabytes of parameter weights and KV-cache state pulled through GPU memory per token. The bottleneck on a modern H100 is not compute — the tensor cores are idling. The bottleneck is the memory bandwidth needed to stream the model weights and the cache into the SMs for each forward.

Concretely, a single H100 has roughly 3 TB/s of HBM bandwidth, and a 67B FP16 model is about 134 GB. If every token required a full model read, the ceiling is around 22 tokens per second per GPU — and that ceiling does not budge no matter how fast your tensor cores are. Compute utilisation at inference often sits below 5%. The expensive thing on the chip is doing nothing for 95% of every second.

ResourceUsed in trainingUsed in autoregressive inference
Tensor-core FLOPsNear-saturated (large batch, long seq)~3–5% utilised
HBM bandwidthMostly hidden behind computeThe hard ceiling
KV-cache readsOnce per layer per stepOnce per layer per token
Effective throughputHundreds of TFLOPSTens of tokens / sec / GPU
What is wastedAlmost nothingThe entire compute budget
The opportunity hidden inside the waste. If we could pack more than one token of useful work into each forward pass — without sacrificing the sampling distribution — we would convert idle tensor cores directly into tokens per second. That is the entire premise of speculative decoding, and MTP heads turn out to be the cheapest available source of in-distribution drafts.

Intuition: Two Models, One Forward Pass

The mental image is a typist and a proof-reader. The typist (the MTP heads) is quick and not always right — they bash out the next few words at speed. The proof-reader (the main model) is slow and authoritative — they read the whole proposal and stop at the first word that they themselves would not have written. Everything before that line is kept. The proof-reader then writes the corrected word and hands the draft back.

Crucially, the proof-reader does not re-read the page once for each word. They read the entire proposed continuation in a single sweep — and because the proof-reader is a causal transformer, that sweep already gives them their own prediction at every position simultaneously. One forward pass delivers K+1 verdicts. If the typist guessed even half-right, the throughput of the system as a whole rises by exactly that factor.

Two properties make this work without compromising output quality. First, the accept-or-reject rule is greedy and exact: a draft token is accepted if and only if it equals what the main model would have produced at that position. Second, the rejected position is repaired using the main model's own argmax at that position — which is sitting right there in the verification forward pass. So the committed sequence is byte-for-byte identical to what pure autoregressive decoding would have produced. Speculative decoding is a speedup, not an approximation.

The right mental picture. Standard decoding is a single worker walking down a corridor opening one door at a time. Speculative decoding is one worker plus a fast scout: the scout marks K doors ahead, the worker checks them in one sweep, keeps every door the scout got right, opens the first wrong door themselves, and the pair restarts from there.

The Math of Acceptance and Speedup

Let pp be the per-token probability that a draft from the MTP heads agrees with what the main model would have emitted at the same position, and let KK be the number of draft tokens proposed each round. Acceptance is sequential — the moment one draft is rejected, all subsequent drafts are discarded because they were conditioned on a prefix the main model would not have written. So the number of accepted drafts JJ in a round follows a truncated geometric distribution:

Pr[J=j]={pj(1p)0j<KpKj=K\Pr[J = j] = \begin{cases} p^{j}(1 - p) & 0 \le j < K \\ p^{K} & j = K \end{cases}

Every round commits J+1J + 1 tokens (the +1+1 being the main model's free prediction at the rejection point — even when J=KJ = K we get one extra position out of the forward pass). The expected number of tokens committed per round is therefore:

E[J+1]=1+j=1Kpj=1pK+11p\mathbb{E}[J + 1] = 1 + \sum_{j=1}^{K} p^{j} = \frac{1 - p^{K+1}}{1 - p}

That ratio is, to a first approximation, the wall-clock speedup over autoregressive decoding — because every round costs essentially one main forward pass and the standard loop would need E[J+1]\mathbb{E}[J+1] such passes to produce the same number of tokens. Three regimes are worth burning into intuition:

KpE[J+1]Wall-clock speedup
10.801.801.80×
20.802.442.44×
30.802.952.95×
40.803.363.36×
30.501.881.88×
30.903.443.44×

Why the speedup saturates with K

The sum j=0Kpj\sum_{j=0}^{K} p^j is a geometric series: its limit as KK \to \infty is 1/(1p)1 / (1 - p). With p=0.8p = 0.8 that ceiling is exactly 5×5\times. Doubling K past 3 or 4 barely moves the needle because the probability that draft 5 is right given draft 4 was already right is multiplied by yet another factor of pp. This is why DeepSeek-V3 trains only one MTP head and still gets most of the available speedup at inference — the marginal benefit of head 4 is much smaller than the marginal training cost of training it well.

Why the cost of a draft is not zero — but close

The above formula treated each round as one main-model forward pass. The MTP heads themselves do cost something — but each head is a few linear layers riding on the same trunk output, so the K-head draft costs roughly KcheadK \cdot c_{\text{head}} where cheadctrunkc_{\text{head}} \ll c_{\text{trunk}}. Empirically the overhead is on the order of 1–3% of one main-model forward. So a speedup of E[J+1]=2.95E[J+1] = 2.95 in the table above lands as about 2.85×2.85\times after honest accounting — still decisive.

What pp actually is in production. DeepSeek-V3 reports the second-token MTP acceptance rate measured on real inference traffic falls between p=0.85p = 0.85 and p=0.90p = 0.90 across domains, with code generation at the high end (lots of predictable syntax) and free-form chat at the low end (more semantic branching). The acceptance rate is the single most important number to monitor in a speculative-decoding deployment.

Manual Numerical Walkthrough

Let us run three concrete rounds with K=3K = 3, p=0.8p = 0.8, writing the toy sentence The small boat drifted into the harbour. Every number is worked out by hand so the mechanism is fully exposed.

Click to expand: three rounds, K=3, p=0.8, by hand

Setup. The main model would (if asked one token at a time) emit the sentence The small boat drifted into the harbour. We will simulate the MTP heads drawing K=3 drafts per round with each draft agreeing with the main model with probability p=0.8p = 0.8.

Round 1 — prefix is empty. The MTP heads draft 3 tokens. Suppose they emit ["The", "small", "ship"]. The main model is then run on the concatenation ["The", "small", "ship"] and emits its own predictions at every position: ["The", "small", "boat", "drifted"] — that is K+1 = 4 next-token predictions for free.

Compare draft vs. main, left-to-right: position 0 agrees (The), position 1 agrees (small), position 2 disagrees (ship vs. boat). We accept J=2J = 2 drafts and commit the main model's third prediction. Tokens committed this round: ["The", "small", "boat"] — three tokens for one main-model forward pass.

Round 2 — prefix is "The small boat". The MTP heads draft 3 more: ["drifted", "into", "the"]. All three agree with the main model's argmax at positions 3, 4, 5. J=K=3J = K = 3. We accept all three drafts and commit the main model's free prediction at position 6: "harbour". Tokens committed: ["drifted", "into", "the", "harbour"] — four tokens for one forward pass. This is the K+1 best-case round.

Round 3 — sentence complete. The sentence ends here, so the loop terminates after 2 rounds. Standard autoregressive decoding would have needed 7 forward passes to emit the same 7 tokens. Wall-clock speedup on this micro-example: 7/2=3.5×7 / 2 = 3.5\times.

Check against the formula. With p=0.8p = 0.8, K=3K = 3: E[J+1]=(10.84)/(10.8)=(10.4096)/0.2=2.952\mathbb{E}[J + 1] = (1 - 0.8^{4}) / (1 - 0.8) = (1 - 0.4096) / 0.2 = 2.952. Our two rounds averaged (3+4)/2=3.5(3 + 4) / 2 = 3.5 tokens/round — a bit lucky compared to the long-run mean of 2.952, but within one round's noise.

The MTP-head cost. Each draft round we ran K=3 small heads on a single trunk output. If one head costs 1% of one main forward, the three drafts add 3% per round. So our effective speedup is 2.952/1.032.87×2.952 / 1.03 \approx 2.87\times — not the naive 2.952, but still a transformative win on the inference cost sheet.

Visualizing One Verification Round

The visualizer below builds the sentence one round at a time. Each round shows the K draft tokens (green = main model would have written this, red = rejected) and the corrected main-model token tacked on at the rejection point. Toggle between Autoregressive (forced K=1) and Speculative to feel the round-by-round difference. Slide K between 1 and 4, and the acceptance probability between 0.30 and 0.95, to explore the regime around p0.8p \approx 0.8 where DeepSeek-V3 actually operates.

Loading speculative-decoding visualizer…

Three things to lock in. First, the "tokens / round (observed)" stat is the empirical version of E[J+1]\mathbb{E}[J + 1]. With p=0.8 and K=3 it hovers around 3, exactly as the geometric series predicts. With K=1 it can never exceed 2 — that is the speedup ceiling for any 1-token draft model. Second, every round commits at least one token, even when the very first draft was wrong: the main model's "corrected" token always lands. There is no scenario where speculative decoding produces fewer tokens than autoregressive would have. Third, the speedup is bounded above by 1/(1p)1 / (1 - p): push K beyond 4 with p=0.8 and the gain past 5× becomes vanishingly thin.

Plain Python: A Speculative Decoding Loop

Below is the entire mechanism in plain Python — no PyTorch, no model. We replace the main model with a deterministic lookup table (it "wants to write" a fixed sentence) and the MTP heads with a coin flip per token. The loop counts rounds versus tokens so the speedup is visible.

🐍speculative_loop.py
4A frozen 'ground truth' sentence

We pretend the main model has decided to write a fixed 7-token sentence. In reality this is whatever the model would emit token-by-token; we hard-code it so the demo is deterministic.

EXECUTION STATE
len(TRUE) = 7
7The expensive main model

One call = one forward pass through the huge transformer. In autoregressive decoding you would call this once per generated token, sequentially. Cost dominates wall-clock time at inference.

12The cheap draft model — MTP heads

K speculative tokens per call. From section 7.3, DeepSeek's MTP heads are small modules that share the main trunk's representations, so producing K drafts is a tiny fraction of one main-model forward pass.

EXECUTION STATE
K = 3 (typical)
p = ≈0.7–0.9 in production
19Verification = one main forward over prefix + drafts

Critical line. The main model processes the concatenation [prefix, drafts] in ONE forward pass. Because the transformer is causal, position t's output is the model's prediction for token t+1 conditioned on everything up to t. So one pass gives us K+1 'what would I have said' targets in parallel.

EXECUTION STATE
len(targets) = K+1 = 4
24Accept the longest matching prefix

Walk left-to-right comparing drafts[j] to targets[j]. The instant they disagree, we stop. Everything before the disagreement was a valid main-model output — so it is safe to commit. This greedy comparison is the simplest accept rule; sampling variants use a probability ratio test instead.

EXECUTION STATE
j = 0..K
27Commit accepted drafts + the corrected token

After j accepts, draft j+1 was wrong — but targets[j] is the main model's free prediction at that position. We commit drafts[:j] plus that corrected token. So every round commits between 1 (worst case: first draft wrong) and K+1 (all drafts right) tokens.

EXECUTION STATE
committed (min) = 1 token
committed (max) = K+1 tokens
30The outer loop counts rounds, not tokens

Each iteration is one main-model forward pass. The standard loop would need len(TRUE) = 7 forwards; with K=3 and p≈0.8 we expect ≈3 rounds. That ratio is the wall-clock speedup.

40 lines without explanation
1import random
2random.seed(0)
3
4# A toy "vocabulary" — what the main model wants to say.
5TRUE = ["The", "small", "boat", "drifted", "into", "the", "harbour"]
6WRONG = ["a", "ship", "of", "with", "in"]
7
8def main_model_next(prefix):
9    """The expensive main model: returns the next token given a prefix.
10       Cost = 1 forward pass per token in the standard loop."""
11    return TRUE[len(prefix)]
12
13def draft_model_chunk(prefix, K, p):
14    """The cheap MTP heads: return K speculative tokens.
15       Each one is 'right' with probability p (it agrees with the main model)."""
16    out = []
17    for i in range(K):
18        right = TRUE[len(prefix) + i]
19        out.append(right if random.random() < p else random.choice(WRONG))
20    return out
21
22def main_model_verify(prefix, drafts):
23    """One main-model forward pass over prefix+drafts. Returns the main
24       model's next-token at every prefix length — exactly what attention
25       gives us for free in a single forward."""
26    return [TRUE[len(prefix) + i] for i in range(len(drafts) + 1)]
27
28def speculative_step(prefix, K=3, p=0.8):
29    # 1. Draft K tokens with the cheap MTP heads.
30    drafts = draft_model_chunk(prefix, K, p)
31    # 2. ONE forward pass through the main model verifies them all.
32    targets = main_model_verify(prefix, drafts)
33    # 3. Accept the longest matching prefix; commit the corrected token.
34    j = 0
35    while j < K and drafts[j] == targets[j]:
36        j += 1
37    committed = drafts[:j] + [targets[j]]
38    return committed, j
39
40prefix = []
41rounds = 0
42while len(prefix) < len(TRUE):
43    committed, j = speculative_step(prefix, K=3, p=0.8)
44    prefix += committed
45    rounds += 1
46    print(f"round {rounds:>2}: accepted {j}/3, committed {len(committed)}{prefix}")
47print(f"done in {rounds} rounds vs {len(TRUE)} for autoregressive")

Two structural details deserve a second pass. First, the verification call in main_model_verify intentionally returns K+1 targets, not K. That extra slot — the main model's prediction at position L+K, conditioned on all the drafts — is the "free" token that gets committed even when every single draft was correct. Forget the +1 and your best-case round only commits K tokens instead of K+1.

Second, the accept rule is greedy-equality: drafts are accepted iff they exactly match the main model's argmax. This works for greedy decoding (temperature 0). For temperature sampling or top-p sampling, the rule is replaced by the speculative-sampling probability-ratio test from Chen et al., 2023: accept draft token xx with probability min(1,qmain(x)/qdraft(x))\min(1, q_{\text{main}}(x) / q_{\text{draft}}(x)) where qq are the per-position output distributions. That rule preserves the main model's sampling distribution exactly — sampled outputs are still drawn from the main model, just much faster.

Sanity check the formula. Run the script ten times with random.seed(0..9) and compute the empirical average tokens per round. With K=3,p=0.8K=3, p=0.8 you should see 2.95\approx 2.95 tokens per round — the geometric series talking back to you in code.

PyTorch: Verifying K Drafts in One Forward Pass

The production version replaces the deterministic table with real tensors: argmax over the vocabulary at K+1 positions of a single forward pass. The accept-or-reject decision lives entirely on the GPU, with no Python loop inside the hot path.

🐍speculative_pytorch.py
4Inference is no_grad and runs in eval mode

@torch.no_grad() turns off autograd entirely — we are not training. This frees roughly half the activation memory and lets us push K and batch sizes higher without OOM.

10K is the speculation depth

DeepSeek-V3 trains with one MTP head (K=1) and reports they can sample with K=2 at inference. Other speculative-decoding stacks push K=4 or K=8 with separate draft models. K is the lever between 'small win, low risk' and 'big win, high risk'.

EXECUTION STATE
K (typical) = 1–4
17Trunk forward fills the KV cache up to L

main_model.trunk runs the transformer body up to but not including the final LM head. The KV cache for positions 0..L-1 is now populated — exactly what the next verification step will reuse.

EXECUTION STATE
trunk_out.shape = (1, L, d)
18Read the last hidden state

Each MTP head is conditioned on the trunk's representation of the LAST input token — exactly as section 7.3 derived. We slice position L-1 out of the trunk output.

EXECUTION STATE
last_h.shape = (1, d)
20K parallel heads, K parallel logit vectors

Each head h is a small module (a couple of linear layers + the LM head weights, tied to the main vocab). Stacking gives one tensor (1, K, V) — K predictions in a single tensor op. This is the whole 'draft is cheap' premise.

EXECUTION STATE
draft_logits.shape = (1, K, V)
22Greedy draft picks (K=4 → 4 tokens)

Argmax along V. For temperature-based sampling, replace with torch.multinomial(softmax(...), 1); the accept rule then needs the probability-ratio test from the speculative sampling paper.

EXECUTION STATE
draft_ids.shape = (1, K)
27Concat prefix and drafts for verification

verify_in is (1, L+K). Because the main model is causal, position L+i in the forward pass sees positions 0..L+i-1 — exactly the conditioning we want for the i-th draft.

EXECUTION STATE
verify_in.shape = (1, L+K)
28ONE forward pass returns K+1 next-token logits

This is the engineering miracle. Whatever the main model would have predicted at positions L, L+1, …, L+K is sitting in verify_out.logits[:, -(K+1):, :] — at the cost of a single forward pass instead of K+1 sequential ones. The KV cache for positions L..L+K-1 is also now built and reusable.

EXECUTION STATE
verify_out.logits.shape = (1, L+K, V)
33Greedy match between draft and main

Elementwise equality of the K draft token IDs against the main model's first K predictions. agree[i] is True iff draft i would have been the main model's argmax at that position.

EXECUTION STATE
agree.shape = (K,) bool
34cumprod-sum is the 'first False index' idiom

cumprod over a bool sequence stays 1 until the first 0, then collapses to 0. Summing gives exactly the count of leading True values — i.e. how many drafts to accept. No Python loop, no branching.

EXECUTION STATE
j = 0..K
36Commit drafts[:j] + main_ids[j:j+1]

Slice the first j accepted drafts and append the main model's free prediction at the rejection point. The output tensor grows by exactly j+1 tokens — never zero, even when every draft was wrong.

EXECUTION STATE
committed.shape = (1, j+1)
39Return the extended prefix and the accept count

j is the per-round telemetry — average j across many rounds gives the effective acceptance rate, which is the quantity that determines wall-clock speedup. Production systems log E[j] continuously and tune K to maximize tokens/second.

26 lines without explanation
1import torch
2import torch.nn.functional as F
3
4@torch.no_grad()
5def speculative_decode_step(
6    main_model,        # huge transformer; one call = one big forward pass
7    mtp_heads,         # K small heads, share trunk with main_model
8    input_ids,         # (1, L)  — current prefix
9    K: int = 4,
10    temperature: float = 0.0,
11):
12    """One round of MTP-driven speculative decoding."""
13
14    # 1) Draft K tokens with the MTP heads.
15    #    The heads ride on the trunk's last hidden state, so we run the
16    #    trunk ONCE to fill the KV cache up to position L, then read
17    #    K parallel logits from the K heads.
18    trunk_out = main_model.trunk(input_ids)                  # (1, L, d)
19    last_h    = trunk_out.last_hidden_state[:, -1, :]        # (1, d)
20    draft_logits = torch.stack(
21        [h(last_h) for h in mtp_heads], dim=1
22    )                                                         # (1, K, V)
23    draft_ids = draft_logits.argmax(dim=-1)                  # (1, K)
24
25    # 2) Verify: feed [input_ids, draft_ids] into the main LM head in ONE
26    #    forward pass and read K+1 next-token logits.
27    verify_in  = torch.cat([input_ids, draft_ids], dim=1)    # (1, L+K)
28    verify_out = main_model(verify_in)                       # (1, L+K, V)
29    main_logits = verify_out.logits[:, -(K + 1):, :]         # (1, K+1, V)
30    main_ids    = main_logits.argmax(dim=-1)                 # (1, K+1)
31
32    # 3) Accept the longest matching prefix; commit the corrected token.
33    agree = (draft_ids[0] == main_ids[0, :K])                # (K,) bool
34    j = int(agree.cumprod(dim=0).sum().item())               # first-False trick
35    committed = torch.cat(
36        [draft_ids[:, :j], main_ids[:, j:j + 1]], dim=1
37    )                                                         # (1, j+1)
38    return torch.cat([input_ids, committed], dim=1), j

Three subtleties worth marking, all about how this loop interacts with the rest of an LLM serving stack:

  1. The KV cache is reused across rounds. The verification forward pass populates the cache for positions LL+K1L \dots L + K - 1. Only the cache entries past the rejection point need to be invalidated — the entries for the jj accepted positions are still valid because they were computed conditioned on the actual committed prefix. A well-implemented serving stack saves the entire trunk recomputation cost for those jj tokens.
  2. Sampling preserves the main-model distribution. Speculative decoding is not an approximation. With greedy decoding the committed sequence is byte-for-byte identical to standard decoding; with temperature sampling the accept rule is the probability-ratio test and the marginal distribution of every committed token is still exactly the main model's. You cannot tell from the output whether a server is using speculative decoding under the hood.
  3. Batching changes the arithmetic. If the server batches N requests, the main-model forward pass already amortises bandwidth across N sequences — the memory-bandwidth bottleneck loosens and the tensor cores warm up. Speculative decoding still helps, but the speedup shrinks because the autoregressive baseline was already faster per token. The win is largest in single-stream / low-batch interactive serving, which is also where latency matters most.
Implementation note. Real inference servers (vLLM, TGI, DeepSeek's own serving stack) implement speculative decoding as a scheduler-level concept, not a Python loop. The scheduler decides per-step how many draft tokens to request, packs verifications across the active batch, and dynamically tunes K based on the running acceptance rate. The code above is the kernel of that machinery; the production version is scheduling plumbing around it.

At Massive Scale: Memory-Bandwidth Wins

Speculative decoding looks like a clever programming trick until you do the arithmetic for a real 671B-parameter MoE model. The DeepSeek-V3 numbers turn it into a deployment necessity.

ModelActive params / tokenAR tokens/s (1×H100)Spec tokens/s (K=2)Speedup
Llama-3 70B (dense)70B≈18≈321.8×
DeepSeek-V3 (MoE)37B≈28≈541.9×
DeepSeek-V3 + MTP K=2 (paper)37B≈28≈58≈2.1×
Mistral-Large (dense)123B≈11≈222.0×

Two patterns to read. First, every serving stack listed here uses speculative decoding in production — the question is where the drafts come from. Llama-3 70B uses a separately trained small Llama-3 8B draft model. DeepSeek-V3 uses the MTP heads that already exist for free because they were part of the training objective. The MTP route saves gigabytes of draft-model memory and avoids the distribution-shift problem that hurts external draft models (a separately trained 8B does not always agree with what the 70B would have written).

Second, the speedup is meaningfully larger for MoE models than the bare formula predicts. The reason is that an MoE forward pass only activates a sliver of the parameters (37B out of 671B for DeepSeek-V3) but still has to load all 671B of expert weights into HBM cache lines in the worst case. Speculative decoding amortises that loading cost over K+1 tokens instead of one. The HBM-bandwidth ceiling moves up by the same factor as your accept rate.

The interaction with expert parallelism

When experts are spread across GPUs (Chapter 5), each MoE forward pass triggers an all-to-all shuffle of tokens to the right experts and back. Speculative decoding multiplies the "tokens" in that shuffle by K+1 — which sounds bad until you remember that the all-to-all cost is almost entirely fixed-latency setup, not per-token data. Sending 1 token and sending 4 tokens through the same all-to-all costs almost the same wall-clock time. So speculative decoding turns the all-to-all latency into a smaller fraction of the per-token cost. The serving cluster gets cheaper, not more expensive.

The interaction with FP8 inference

DeepSeek-V3's FP8 training (Chapter 10) carries over to inference, cutting the model weight size roughly in half versus FP16. That shifts the memory-bandwidth ceiling upward by about 2×2\times on its own. Speculative decoding then stacks on top, so a serving node that produced 18 tokens/s/H100 in FP16 autoregressive mode can produce roughly 18227218 \cdot 2 \cdot 2 \approx 72 tokens/s/H100 in FP8 + spec-decode mode. That 4× compound speedup is what actually makes a 671B-parameter MoE model affordable to serve.

Engineering Reality and Gotchas

MTP-driven speculative decoding looks like a clean win on paper. Three production failure modes earn their flags:

  1. Draft / main distribution drift. The MTP heads were trained alongside the main model on the same data — but the main model continued to evolve through any fine-tuning, RLHF, or DPO stage that followed pre-training. If the alignment phase did not include the MTP heads in its optimisation, the heads' acceptance rate can degrade sharply on the new distribution (chat, instruction-following, refusal behaviour). The fix is to either re-tune the heads after alignment or back-propagate the alignment loss through them. DeepSeek-V3 reports a ≈4-point acceptance drop on chat traffic that they recover via head re-tuning.
  2. K too large hurts more than it helps. Pushing K past 4 rarely improves throughput on real workloads, and it can hurt: every extra draft adds head-compute, and the geometric saturation means the last drafts are mostly wasted. Production schedulers monitor the running acceptance rate per request and cap K dynamically — typically K=2 for chat, K=3 for code, K=4 for highly templated outputs.
  3. The verification forward pass needs careful masking. The K draft tokens added to the input must still be attended to causally — draft i sees drafts 0..i-1 but not i+1..K-1. If the attention mask accidentally lets later drafts leak into earlier positions, the verification logits become contaminated and the accept rate craters silently. This is the single most common bug in custom speculative-decoding kernels. Always unit-test by comparing a spec-decode forward pass against the autoregressive baseline at greedy temperature: they must produce byte-identical outputs.
How DeepSeek monitors the speedup at runtime. Three per-request metrics are logged: empirical pp (running mean of acceptance), tokens-per-round (the realised J+1J + 1), and end-to-end latency. If pp drops below 0.6, the system reduces K toward 1 — at low acceptance, the head-compute overhead starts to outweigh the amortisation gain. Speculative decoding is not a static knob; it is a feedback controller calibrated against the live distribution.

The one sentence to carry forward: the MTP heads paid for themselves twice — once during training by improving sample efficiency, once at inference by turning every otherwise-idle tensor core into committed tokens — and that double-dip is the engineering reason a 671B-parameter MoE model can match dense-70B serving cost.

Where we go from here. Chapter 7 closes here. Across the five sections we built up MTP from the single-token bottleneck (7.1), rejected the naive parallel formulation (7.2), derived DeepSeek's sequential causal MTP (7.3), studied the training objective and the ablations that justify it (7.4), and now finished by cashing in the same heads at inference time. Chapter 8 leaves architecture entirely and turns to the data pipeline — the 14.8T-token foundation on which all the architectural cleverness in Chapters 4–7 actually trains.
Loading comments...