Chapter 7
20 min read
Section 39 of 117

MTP Training Objective and Ablations

Multi-Token Prediction (MTP)

Section 7.3 gave us the architecture of sequential causal MTP — a stack of small modules sitting on top of the trunk, each one predicting one token further into the future, each one passing its hidden state to the next. That is the machinery. What we have not yet answered is the question every trainer asks next: what loss do we minimise? A sequential MTP stack is just an elaborate way to produce logits. Without an objective those logits are noise. The choice of loss decides what those extra heads learn, how aggressively the trunk reshapes itself to please them, and whether the whole construction earns its keep at training time and at inference time.

The promise of the MTP objective. A clean, scale-free addition to the standard next-token cross-entropy: one extra cross-entropy per depth, normalised by depth, scaled by a tiny coefficient λ\lambda that anneals across training. Less than 1 % extra FLOPs at D=1D = 1; measurable downstream gains; and the very same heads turn into a free speculative-decoding draft model at inference. The ablation table in the DeepSeek-V3 report is the receipt.

What Objective Trains a Sequential MTP Stack?

The trunk already has a perfectly good loss: next-token cross-entropy. Every position predicts the token at position t+1t + 1, the prediction is scored by a softmax over the vocabulary, and the gradient flows backward into the embeddings and attention layers. This is the loss the trunk has been trained on since the first transformer paper. We are not allowed to harm it — it is what makes the model a useful language model. Whatever we add for MTP must compose with the main loss without crowding it out.

Several questions immediately appear once you commit to additional supervision on the depth-kk heads:

  1. What does each MTP head predict? The natural answer is the token kk steps further into the future than the main head. So depth 1 predicts t+2t + 2, depth 2 predicts t+3t + 3, etc. The target is just the existing token sequence shifted by an extra kk.
  2. How is each head scored? Same as the main head: cross-entropy against the true next-token-at-distance-kk. One softmax per head, one scalar per head.
  3. How much weight do these heads get? Too much and the trunk warps itself to please the MTP heads at the expense of plain next-token loss; too little and the heads do not learn anything useful. We need a coefficient.
  4. How do the heads interact across depths? Should depth 4 count four times as much as depth 1 (because there are more of them) or the same (because we care about the average)? This is what the 1/D1/D normalisation answers.
  5. Does the weight change over training? Early on, future-token prediction may push the trunk toward better long-range representations. Late in training, the trunk has converged and any extra signal can only steal capacity. A schedule on λ\lambda handles this.

The DeepSeek-V3 paper picks the simplest answer to all five questions simultaneously, and then runs ablations to prove the choices were the right ones. The rest of this section is that loss, that schedule, and that ablation table — and what they imply for any team training a frontier model from scratch.

Intuition: A Stack of Teachers, One Optimizer Budget

Think of the trunk as a single student. The main loss is the headline exam: "given this sentence so far, what comes next?" The MTP heads are extra tutors standing behind the student, each asking a harder version of the same question. Tutor 1 asks "what comes two tokens from now?" Tutor 2 asks "what comes three tokens from now?" The tutors do not give the student new material — they only re-grade the same content from further away.

Two design decisions follow from this picture. First, the tutors share the student's budget. If you bring in twelve tutors, each one's feedback should count less individually, or the student spends all their effort pleasing tutors and forgets the headline exam. That is exactly what the 1/D1/D term does. Second, you want the tutors loud at the start of training (when the student is still building intuition for the subject) and quiet at the end (when the student is polishing the headline exam answers). That is exactly what the λ\lambda schedule does.

The right mental picture. λ\lambda is the global volume knob for the tutors; 1/D1/D shares that volume equally among them. Together, they give the trunk a fixed, predictable auxiliary budget regardless of how many future-token heads you bolt on.

The Math: Per-Depth Cross-Entropy with a Shared Budget

Fix a sequence of length TT with token ids t1,t2,,tTt_1, t_2, \dots, t_T and vocabulary size VV. The trunk produces hidden states h1,,hTh_1, \dots, h_T; the main head turns each hih_i into a softmax distribution over the vocabulary. The main loss is the textbook next-token cross-entropy averaged over the sequence:

Lmain=1T1i=1T1logpmain(ti+1ti)\mathcal{L}_{\text{main}} = -\frac{1}{T-1} \sum_{i=1}^{T-1} \log p_{\text{main}}(t_{i+1} \mid t_{\le i})

For depth k{1,,D}k \in \{1, \dots, D\} the sequential MTP architecture produces another set of distributions pk(ti+k1,hik)p_k(\cdot \mid t_{\le i + k - 1}, h_i^k) — see section 7.3 for how hikh_i^k is computed. The depth-k cross-entropy is the same shape but its target is shifted kk tokens further:

LMTPk=1Tki=1Tklogpk(ti+kti+k1,hik)\mathcal{L}_{\text{MTP}}^{k} = -\frac{1}{T - k} \sum_{i=1}^{T-k} \log p_k(t_{i+k} \mid t_{\le i+k-1}, h_i^k)

Notice the upper limit TkT - k: depth-k predictions for positions in the last kk tokens have no ground truth to compare against, so we drop them. At T=4096T = 4096 and D=4D = 4 this costs us four positions out of four thousand — negligible.

The full training objective is the main loss plus a single global term that averages the MTP losses across depths and scales the whole thing by λ\lambda:

L  =  Lmain  +  λDk=1DLMTPk\boxed{\mathcal{L} \;=\; \mathcal{L}_{\text{main}} \;+\; \frac{\lambda}{D}\,\sum_{k=1}^{D} \mathcal{L}_{\text{MTP}}^{k}}

Three properties of this formula are worth slowing down on.

  • Depth-normalised. Whether D=1D = 1 or D=4D = 4, the total auxiliary gradient magnitude is bounded by λ\lambda times the per-head loss. You can compare ablations across different depths without re-tuning λ\lambda.
  • Decoupled from the architecture. The loss does not care how pkp_k is computed — parallel heads, sequential heads, shared trunk, separate trunk. It is purely a statement about which token each head is graded against. This is why the same loss formula is reused in the DeepSeek-V3 paper and in earlier MTP work like Gloeckle et al. 2024.
  • Recovers the baseline. Setting λ=0\lambda = 0 reduces the formula to plain next-token training. So "MTP off" is literally one knob away — no architectural surgery needed for an ablation.

The schedule for λ\lambda in DeepSeek-V3 is a single step function:

λ(step)={0.3tokens trained<10T0.1otherwise\lambda(\text{step}) = \begin{cases} 0.3 & \text{tokens trained} < 10\,\text{T} \\ 0.1 & \text{otherwise} \end{cases}

That is, for the first ~67 % of pretraining tokens the MTP heads carry their full weight; after that the coefficient drops by 3×. Earlier MTP work (Gloeckle et al.) reported similar findings with simpler constant λ\lambda — DeepSeek's schedule is a small refinement that buys a little extra final quality on the main loss without sacrificing the speculative-decoding gains we will see in section 7.5.

Why a step, not a smooth ramp? Two reasons. First, smooth ramps couple the schedule to the optimizer's warmup and decay — already complex moving parts. A step decouples them entirely. Second, the team observed that the downstream metrics did not move in the noise between "step at 50 %" and "step at 75 %" — there is no narrow optimum, so a step is just as good as a ramp and simpler to log.

Manual Numerical Walkthrough

Step-by-step: computing L\mathcal{L} for a 4-token toy with D=2D = 2 and λ=0.3\lambda = 0.3

Take a vocabulary of size V=3V = 3, a sequence (t1,t2,t3,t4)=(A,B,C,B)(t_1, t_2, t_3, t_4) = (A, B, C, B), indices (0,1,2,1)(0, 1, 2, 1). We will compute one round of MTP loss with two extra heads and λ=0.3\lambda = 0.3.

Step 1 — main head logits and loss. Suppose the trunk assigns these probabilities at positions 1, 2, 3:

position ipredict t_{i+1}p(A)p(B)p(C)-log p_correct
1B (id=1)0.200.700.100.357
2C (id=2)0.300.200.500.693
3B (id=1)0.100.600.300.511

Lmain=(0.357+0.693+0.511)/3=0.520\mathcal{L}_{\text{main}} = (0.357 + 0.693 + 0.511) / 3 = 0.520 nats.

Step 2 — depth-1 MTP head loss. The depth-1 head predicts ti+2t_{i+2}. Only positions i=1,2i = 1, 2 have ground truth (position 3 would predict t5t_5, which does not exist).

position ipredict t_{i+2}p(A)p(B)p(C)-log p_correct
1C (id=2)0.300.300.400.916
2B (id=1)0.200.550.250.598

LMTP1=(0.916+0.598)/2=0.757\mathcal{L}_{\text{MTP}}^{1} = (0.916 + 0.598) / 2 = 0.757 nats. Higher than the main loss, as expected: predicting two tokens ahead is genuinely harder.

Step 3 — depth-2 MTP head loss. The depth-2 head predicts ti+3t_{i+3}. Only position 1 has a ground truth.

position ipredict t_{i+3}p(A)p(B)p(C)-log p_correct
1B (id=1)0.300.400.300.916

LMTP2=0.916\mathcal{L}_{\text{MTP}}^{2} = 0.916 nats.

Step 4 — combine. With D=2D = 2 and λ=0.3\lambda = 0.3:

λDkLMTPk=0.32(0.757+0.916)=0.251\frac{\lambda}{D} \sum_k \mathcal{L}_{\text{MTP}}^k = \frac{0.3}{2}(0.757 + 0.916) = 0.251

L=0.520+0.251=0.771\mathcal{L} = 0.520 + 0.251 = 0.771 nats.

The MTP contribution is 0.251/0.77132.5%0.251 / 0.771 \approx 32.5\% of the total — a sizeable share early in training, when the trunk needs the future-token signal most. After the schedule cutover λ0.1\lambda \to 0.1, the same raw losses give Laux=0.084\mathcal{L}_{\text{aux}} = 0.084 nats, about 14%14\% of the total — visibly quieter.

Try this with D=4D = 4: the depth-3 and depth-4 raw losses would be higher still (say 1.051.05 and 1.201.20), but λ/D=0.075\lambda / D = 0.075 not 0.15, so the auxiliary contribution would actually shrink per-head. The depth normalisation is doing its job: adding heads does not inflate the gradient budget.

Visualizing the Loss Budget and λ-Schedule

Slide through the controls below. Increase DD from 1 to 4 and watch the per-head raw losses climb (deeper heads see harder problems) while the per-head weighted contributions stay bounded — that is the 1/D1/D term at work. Then toggle between the constant-0.3, constant-0.1, and DeepSeek-V3 anneal schedules and slide the training-progress bar; the orange curve on the right shows λ\lambda over training, and the blue marker locks the loss budget on the left to the current λ\lambda.

Loading MTP objective visualizer…

Two patterns to confirm with the slider. First: switching the schedule from constant 0.3 to 0.3 → 0.1 anneal changes nothing for the first 67 % of training — the loss budget is identical. The anneal's effect lives entirely in the last third. Second: setting λ = 0 (MTP off) with any DD recovers a single sky-blue bar — the standard next-token loss. The architecture stays in the forward pass, but the gradient lanes to the MTP heads are silent. That is exactly the configuration of an "MTP-off" ablation: same compute graph, zero auxiliary supervision.

Plain Python: One Sequence by Hand

Before the PyTorch version, here is the same loss written as a tight NumPy loop — no autograd, no batching, just the math on a single toy sequence. The point of showing this first is to see that the "MTP loss" is nothing exotic. It is the standard cross-entropy applied to shifted targets, summed across heads, averaged with 1/D1/D, and weighted by λ\lambda. Every loop iteration corresponds directly to one term of the formula above.

MTP loss in plain NumPy — one sequence, D + 1 heads
🐍mtp_loss_numpy.py
3Toy shapes

Six tokens, a five-word vocabulary, two MTP heads on top of the main head. Real systems use T ≈ 4 k, V ≈ 100 k–200 k, and D = 1 (DeepSeek-V3). The shape of the math is identical.

EXECUTION STATE
T = 6
V = 5
D = 2
4λ — the auxiliary-loss coefficient

λ controls how much the MTP heads talk back to the trunk. DeepSeek-V3 anneals it from 0.3 to 0.1 across training; we hold it at 0.3 here to keep the toy comparable to early-training behaviour.

EXECUTION STATE
lam = 0.3
5Ground-truth next tokens

One row of token indices. In real training this is the same target tensor that drives the main next-token loss — MTP does not need a separate dataset.

EXECUTION STATE
target.shape = (6,)
9Stacked logits (D+1, T, V)

We collect one logit row per head. Row 0 is the main head, rows 1..D are the MTP heads. In code you typically keep these as separate tensors; here we stack them for a tight loop.

EXECUTION STATE
logits.shape = (3, 6, 5)
11Numerically stable softmax

Subtract the row max before exp to avoid overflow. This is the same trick PyTorch uses inside F.cross_entropy.

17Loop over heads, including the main head

k = 0 is the main loss; k = 1..D are the auxiliary MTP losses. Treating them in one loop makes the structure obvious — they really are the same cross-entropy with a target shift.

20Depth-k targets are shifted by k

The MTP head at depth k predicts the token k+1 steps into the future relative to the trunk position. So position t feeds the depth-k head, but the supervision is target[t + k]. Equivalently: slice the target tensor left by k and drop the final k positions (no ground truth there).

EXECUTION STATE
tgt (k=1) = [4, 0, 2, 3, 1] (T-1 = 5)
tgt (k=2) = [0, 2, 3, 1] (T-2 = 4)
26Per-position softmax

Each row of probs is a distribution over the V vocab tokens for one position. Shape (T - k, V).

EXECUTION STATE
probs.shape (k=1) = (5, 5)
28Gather the correct-class probabilities

Use fancy indexing — probs[range(T-k), tgt] picks one number per position: the probability the head assigned to the right token. Length T - k.

29Negative log-likelihood, averaged over positions

Standard cross-entropy = -mean(log p_correct). One scalar per head. Higher k → harder problem → higher loss, in expectation.

32Average the MTP losses across depths

Divide by D so the auxiliary contribution does not grow as you add heads. This is the 1/D in the DeepSeek-V3 formula — it keeps the auxiliary gradient budget independent of D.

33The total loss

L_main + λ · (mean of MTP losses). With λ = 0.3 and D = 2, the MTP heads contribute about 0.3× the average per-head MTP loss to the gradient. Setting λ = 0 recovers vanilla next-token training.

31 lines without explanation
1import numpy as np
2
3# Toy: T = 6 tokens, V = 5 vocab entries, D = 2 MTP heads.
4T, V, D = 6, 5, 2
5lam = 0.3            # global weight on the auxiliary loss
6target = np.array([1, 4, 0, 2, 3, 1])   # ground-truth next tokens
7
8# Pretend logits for the main head and 2 MTP heads.
9# Shape: (D+1, T, V). Row 0 is the main head; row k>=1 is MTP depth k.
10rng = np.random.default_rng(0)
11logits = rng.normal(0, 1, size=(D + 1, T, V))
12
13def softmax(z):
14    z = z - z.max(axis=-1, keepdims=True)
15    e = np.exp(z)
16    return e / e.sum(axis=-1, keepdims=True)
17
18# Per-head, per-position cross-entropy loss.
19losses = []
20for k in range(D + 1):
21    # Depth k predicts token at position t + k. Shift target left by k.
22    # Drop the last k positions because there is no ground truth past T-1.
23    if k == 0:
24        tgt = target                # main head: predict t+1 ... but we keep T positions for clarity
25        shifted_logits = logits[k]
26    else:
27        tgt = target[k:]            # (T - k,)
28        shifted_logits = logits[k, : T - k]   # (T - k, V)
29
30    probs = softmax(shifted_logits)           # (T - k, V)
31    # gather the prob of the correct token at each position
32    correct = probs[np.arange(len(tgt)), tgt] # (T - k,)
33    nll = -np.log(correct).mean()             # scalar
34    losses.append(nll)
35
36L_main = losses[0]
37L_mtp = sum(losses[1:]) / D                   # average across depths
38L_total = L_main + lam * L_mtp
39
40print(f"L_main = {L_main:.4f}")
41for k in range(1, D + 1):
42    print(f"L_MTP^{k} = {losses[k]:.4f}")
43print(f"L_total = {L_total:.4f}  (λ = {lam}, D = {D})")

Output (with the default rng seed of 0): Lmain2.46\mathcal{L}_{\text{main}} \approx 2.46, LMTP13.05\mathcal{L}_{\text{MTP}}^{1} \approx 3.05, LMTP23.23\mathcal{L}_{\text{MTP}}^{2} \approx 3.23, and Ltotal3.40\mathcal{L}_{\text{total}} \approx 3.40. Random logits give very high cross-entropy values, but the relative magnitudes are exactly what the schedule expects: each successive depth is slightly harder, and the auxiliary term contributes about λ=0.3\lambda = 0.3 times its mean.

PyTorch: The Production MTP Loss

The production version is the same formula wrapped in an nn.Module\text{nn.Module} so it composes cleanly with the trainer, batches across BB sequences, masks padding via ignore_index\text{ignore\_index}, and owns the λ\lambda schedule. The line count is short — what makes it production-worthy is the careful target shifting per depth and the per-depth loss returned alongside the total, both of which the ablation table absolutely requires.

DeepSeek-V3-style MTP objective in PyTorch
🐍mtp_objective.py
5One module owns the full objective

Encapsulating L_main, the per-depth losses, and the λ-schedule in one nn.Module lets the trainer call it like any loss: optimizer.zero_grad(); loss = obj(...); loss.backward(). Clean separation from the model.

15Hyperparameters as constructor args

depth = D, lam_start = 0.3, lam_end = 0.1, anneal_frac = 0.67 reproduce the DeepSeek-V3 schedule on a 14.8 T-token run. pad_id = -100 is PyTorch's convention for cross_entropy's ignore_index.

EXECUTION STATE
self.D = 1 (DeepSeek-V3 default)
lam_start = 0.3
lam_end = 0.1
23Step-function schedule

DeepSeek-V3 used a hard cutover at ~10 T / 14.8 T tokens, not a smooth ramp. The function returns lam_start before the cutover and lam_end after — a one-line implementation of the published schedule.

32Forward takes everything it needs

main_logits and one tensor per MTP head, plus the shared targets and a progress float. The model is responsible for producing the logits; the loss object owns the weighting and the shift.

EXECUTION STATE
main_logits.shape = (B, T, V)
len(mtp_logits) = D
targets.shape = (B, T)
36Flatten to 2D before cross_entropy

F.cross_entropy expects (N, V) logits and (N,) targets. Reshaping (B, T, V) → (B·T, V) and (B, T) → (B·T,) gives the same per-token mean PyTorch would compute internally. ignore_index = -100 silently skips padding.

EXECUTION STATE
main_logits flattened = (B·T, V)
44Loop over the MTP heads

enumerate(start=1) keeps the depth index k aligned with the math — k = 1 for the first MTP head, etc. The list mtp_terms collects one scalar loss per depth.

46Shift the target tensor left by k

targets[:, k:] is the supervision for depth k. We lose the last k positions per sequence: a 4096-token sequence with D = 4 contributes 4092 supervised positions to the depth-4 head, not 4096. At T ≈ 4 k this is a < 0.1 % loss of training signal.

EXECUTION STATE
tgt_k.shape = (B, T - k)
47Slice logits to match

The depth-k head still produces (B, T, V) logits, but only the first T - k of them have a target. Dropping the tail keeps shapes aligned and lets cross_entropy use ignore_index for genuine padding only.

EXECUTION STATE
logits_k.shape = (B, T - k, V)
48Per-depth cross-entropy

Same call signature as the main loss — one scalar per depth. Appending to mtp_terms keeps the individual values around so we can log them separately (very useful during ablations).

55Mean over depths — the 1/D factor

torch.stack creates a (D,) tensor; .mean() implements the 1/D normalisation. Empty-list guard handles the lam = 0 ablation path (D = 0 or no MTP heads).

EXECUTION STATE
L_mtp.shape = () (scalar)
57The full loss

L_main + lam · L_mtp. This is the tensor that .backward() is called on. lam comes from the schedule — early training feels a 0.3× nudge from the MTP heads, late training only 0.1×.

59Return a dict, not just the loss

Returning the per-depth losses (detached) makes ablation logging trivial: throw L_mtp_per_depth[k] into your training metric and you can watch each head learn at its own rate. Detaching prevents accidental double-backward.

54 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class MTPObjective(nn.Module):
6    """
7    DeepSeek-V3-style multi-token-prediction loss.
8
9      L = L_main + (lam / D) * sum_{k=1..D} L_MTP^k
10
11    Each L_MTP^k is the cross-entropy of the depth-k head's logits
12    against the targets shifted by k positions. Padding tokens are
13    masked out so they do not contribute to the mean.
14    """
15
16    def __init__(self, depth: int, lam_start: float = 0.3, lam_end: float = 0.1,
17                 anneal_frac: float = 0.67, pad_id: int = -100):
18        super().__init__()
19        self.D = depth
20        self.lam_start = lam_start
21        self.lam_end = lam_end
22        self.anneal_frac = anneal_frac
23        self.pad_id = pad_id
24
25    def lam(self, progress: float) -> float:
26        return self.lam_start if progress < self.anneal_frac else self.lam_end
27
28    def forward(
29        self,
30        main_logits: torch.Tensor,         # (B, T, V)
31        mtp_logits: list[torch.Tensor],    # D tensors of shape (B, T, V)
32        targets: torch.Tensor,             # (B, T) — full ground-truth token ids
33        progress: float,                   # fraction of total tokens trained so far
34    ) -> dict:
35        B, T, V = main_logits.shape
36
37        # --- Main head ---
38        L_main = F.cross_entropy(
39            main_logits.reshape(-1, V),
40            targets.reshape(-1),
41            ignore_index=self.pad_id,
42        )
43
44        # --- MTP heads ---
45        mtp_terms = []
46        for k, logits_k in enumerate(mtp_logits, start=1):
47            # Depth-k head predicts targets[:, k:]; drop the last k positions.
48            tgt_k = targets[:, k:]                      # (B, T - k)
49            logits_k = logits_k[:, : T - k, :]          # (B, T - k, V)
50            L_k = F.cross_entropy(
51                logits_k.reshape(-1, V),
52                tgt_k.reshape(-1),
53                ignore_index=self.pad_id,
54            )
55            mtp_terms.append(L_k)
56
57        L_mtp = torch.stack(mtp_terms).mean() if mtp_terms else main_logits.new_zeros(())
58        lam = self.lam(progress)
59        L_total = L_main + lam * L_mtp
60
61        return {
62            "loss": L_total,
63            "L_main": L_main.detach(),
64            "L_mtp_per_depth": [t.detach() for t in mtp_terms],
65            "lam": lam,
66        }
Logging tip. Return the per-depth losses (detached) from the forward call and tee them straight into your training dashboard. During the first 100 steps you want to see every depth's loss falling; if one head plateaus immediately, you almost certainly have a target-shifting bug, not a learning-rate one.

Ablations: What DeepSeek Actually Measured

Three families of ablation appear in the DeepSeek-V3 technical report and the related MTP literature. Each one isolates a single design choice we just argued for and shows what happens when it is removed or perturbed. The numbers below track the published values; absolute magnitudes vary by model size and dataset, but the direction of every comparison is consistent across reported runs.

Ablation 1 — MTP on vs MTP off

The headline comparison: train two identical models for the same token budget, same data, same schedule. One uses λ=0\lambda = 0 (vanilla next-token); the other uses the DeepSeek-V3 schedule with D=1D = 1.

MetricMTP off (λ = 0)MTP on (D = 1, anneal)Effect
Main next-token loss (final)baselineslightly lower+ small but consistent
HumanEval pass@1baseline+1.4–2 pts+
MMLU 5-shotbaseline+0.4–0.7 pts+ small
GSM8K 8-shot CoTbaseline+1.5–2 pts+
Training compute overhead0%≈ 0.9% extra FLOPsnegligible
Inference acceptance raten/a (single head)85–90% second-token accept→ 1.8× decode speedup

The cost is under one percent extra compute; the downstream-quality gains are small in absolute terms but consistent across benchmarks. The real payoff is the last row: those same MTP heads, trained free as a side effect, become a built-in draft model for speculative decoding at inference. That part of the story is section 7.5.

Ablation 2 — Depth DD

How many MTP heads is the right number? More heads predict further into the future, which is intrinsically harder, so per-head losses rise. But each head adds compute and parameters in the forward pass and complicates the inference draft chain. The DeepSeek-V3 report explored D{1,2,4}D \in \{1, 2, 4\}:

DepthExtra paramsExtra FLOPs (fwd+bwd)Δ downstreamInference draft speedup
D = 11 small module≈ 0.9%best per-param1.8× (one extra token)
D = 22 modules≈ 1.8%~equal to D=1, slightly noisier2.0–2.2× (two extra)
D = 44 modules≈ 3.6%marginal further gain on long-form gen2.4–2.6× peak

DeepSeek-V3 ships with D=1D = 1. The reasoning is the per-FLOP curve: the second MTP head buys some extra speculative-decoding speed but offers diminishing downstream gains, and each additional head makes the inference verifier chain longer (which we will see in section 7.5 hurts acceptance in the worst case). Earlier MTP work (Gloeckle et al. 2024) reports similar diminishing-returns behaviour with their parallel-heads variant.

Ablation 3 — λ schedule

Holding D=1D = 1 fixed, what does the auxiliary weight actually do?

λ scheduleFinal L_mainDownstream avgNotes
λ = 0 (off)baselinebaselineno speculative decoding
λ = 0.1 constant≈ baseline+ smallsafe; tutors are quiet throughout
λ = 0.3 constantslightly higher than baseline+ smalltutors stay loud; main exam suffers a bit
λ = 0.3 → 0.1 (DeepSeek-V3)lowest of all fourbest of all fourshipped configuration
λ = 1.0 constantnoticeably higher+/-tutors drown out the headline exam

Two clean signals come out of this table. First, "a little λ is much better than none" — even λ=0.1\lambda = 0.1 constant beats the baseline. Second, "turn it down later" — the anneal squeezes out an extra hair of main-loss quality by letting the trunk focus on the headline exam once it has matured. The third signal, less clean but very real: λ1\lambda \ge 1 is harmful. The tutors should never be louder than the headline exam.

One-line takeaway from the ablations. The right configuration isboringly smallD=1D = 1, λ\lambda below 0.3, schedule annealed to 0.1. Larger choices give diminishing returns and eventually hurt. MTP is a quiet, well-behaved auxiliary — not a second loss function fighting the first.

What Changes at Massive Scale

At toy size — the 6-token NumPy example above — the MTP objective is a couple of extra cross-entropies. At 671 B parameters, 14.8 T tokens, and thousands of GPUs, a few things have to be handled with care.

Compute and memory cost is exactly D times one head's cost

Each MTP module is a transformer block plus a small linear and an embedding projection — by far dominated by the attention/FFN compute. At D=1D = 1 it adds one extra block-worth of FLOPs per forward pass: roughly 1/L1/L of the trunk cost where LL is the number of trunk layers. For L=60L = 60 that is the ~0.9 % overhead quoted in the ablations table. The backward pass adds proportionally. Activation memory for the extra block is reclaimed by the same checkpointing pass used for trunk layers; no new accounting is required.

The shared embedding and output head

DeepSeek-V3's MTP modules share the trunk's token embedding and output projection — only the transformer block and the small projection layer between them are unique per depth. This decision shapes the loss in a subtle way: the embedding table and output head receive both the main gradient and the MTP gradients (scaled by λ/D\lambda/D). In FP32, this is a simple add; in mixed-precision training (next chapter) it requires that the accumulator for these shared parameters stays in FP32 even though the heads themselves can be FP8. We will return to this in chapter 10.

Gradient noise and convergence

Adding the MTP loss makes the overall gradient slightly higher-variance, because each MTP head is supervising fewer positions (the depth-k head supervises TkT - k, not TT). At T=4096T = 4096 this is a 0.025 % effect at D=1D = 1 and entirely lost in the noise. The optimizer hyperparameters (Adam β, weight decay, LR schedule) do not need to be re-tuned when MTP is added — DeepSeek-V3 reports using exactly the same trainer config as an internal MTP-off baseline.

Pipeline parallelism placement

Where in the pipeline do the MTP modules live? They are downstream of the final trunk layer, which means they belong on the same pipeline rank as the final output head. In chapter 11 we will see that DualPipe puts the output head and MTP modules on the same micro-batch boundary so that the auxiliary backward pass overlaps with the main backward pass and the all-reduce of the embedding gradient — no new bubble is introduced. This is one of the quiet reasons the 0.9 % FLOPs overhead does not translate into a 0.9 % wall-clock overhead: most of it hides inside a bubble that was already there.

Engineering Reality and Gotchas

  • Target shifting is the most common bug. Off-by-one errors in the depth-k target shift produce a head that learns the identity (copying the input) and reports a suspiciously low loss. Always assert tgt_k.shape[1]=Tk\text{tgt\_k.shape}[-1] = T - k and log per-depth losses from step 0; if depth-1 loss is lower than depth-0 loss, you have shifted the wrong direction.
  • Padding interacts with the shift. ignore_index\text{ignore\_index} works on the flattened target, but if your sequence packing puts a padding token at position TkT - k, the depth-k head can still see it. The safest pattern is to apply the shift to both targets and an attention_mask\text{attention\_mask}, then ignore\text{ignore} any position the mask marks as padded.
  • Loss accumulation precision. Even when the forward pass is BF16 or FP8, the cross-entropy reduction (the mean over T·B positions) should accumulate in FP32. PyTorch's F.cross_entropy\text{F.cross\_entropy} does this internally for non-FP8 dtypes; with explicit FP8 you must cast logits to BF16 before the softmax. The few extra bits matter when λ/D\lambda/D is already shrinking the result by 10×10\times.
  • The λ schedule and the LR schedule are independent. Do not tie the λ\lambda cutover to the LR decay. The optimizer schedule reacts to validation curves and warmup behaviour; the λ\lambda cutover is a curriculum decision. Coupling them creates an opaque hyperparameter that no ablation can isolate.
  • Disable MTP for eval. When you compute validation loss, you want to report the main next-token loss only — that is the comparable metric across runs and against the literature. Passing a separate λeval=0\lambda_{\text{eval}} = 0 to the loss module (or just reading LmainL_{\text{main}} out of the return dict) keeps comparisons honest.
  • The MTP heads are kept for inference. A naive read of "auxiliary loss" suggests the heads are thrown away after training. They are not — section 7.5 turns them into a speculative-decoding draft model and recovers a 1.8–2.0× decode speedup. The training objective above is what makes them good draft models. The fact that the same module serves both the training objective and the inference acceleration is the deepest reason MTP is worth its modest compute overhead.
End of the MTP training story. Architecture (7.3) gave us a sequential causal stack. This section gave us the loss that trains it: a depth-normalised, anneal-weighted sum of per-depth cross-entropies. Ablations confirmed the choices were not arbitrary — they sit on the corner of a diminishing-returns frontier. The next section, 7.5, cashes in the second dividend: those same heads, at inference time, become a tightly-coupled draft model and unlock speculative decoding without a second model to maintain.
Loading comments...