Chapter 7
15 min read
Section 36 of 117

The Single-Token Prediction Bottleneck

Multi-Token Prediction (MTP)

Every transformer trained today does an enormous amount of work and is rewarded with a tiny amount of feedback. A 671-billion-parameter model spends roughly four trillion FLOPs to process a single token through its full forward and backward pass, and at the end of that journey it is graded on one question: what is the next token? One categorical target. A handful of bits of information per step. Most of the network's machinery is, in a precise information-theoretic sense, going to waste. This chapter is about why the wastage exists, what it costs, and how DeepSeek-V3 — along with a generation of speculative-decoding work — claws some of the lost signal back. This first section is just the diagnosis.

The bottleneck in one sentence. Next-token prediction makes the forward pass do the work of N layers and extracts the signal of one categorical guess. The compute scales with the model; the supervisory signal scales with the vocabulary. The gap between the two is the largest underused resource in large-model training.

The Asymmetry: Huge Forward Pass, Tiny Signal

Walk through what one training step actually does. A batch of sequences enters the model. For each of the TT positions in each sequence, every one of the NN transformer layers runs its full forward computation: self-attention, MLP, normalization, residual adds. Then the gradients flow back through the same N layers. The compute is proportional to NTN \cdot T, and it grows with the model.

At the very top of that stack, the final hidden state at position tt is projected to a vocabulary-sized logit vector and compared, by cross-entropy, against exactly one ground-truth token: the token at position t+1t+1. The supervisory signal that this step delivers is bounded by Tlog2(V)T \cdot \log_2(V) bits — and in practice it is much smaller, because natural text is nowhere near maximum entropy. The signal scales with the data, not the model.

QuantityScales with7B example671B example
FLOPs per stepmodel size × tokens~1.7 × 10¹⁴~1.6 × 10¹⁶
Supervisory bits (ceiling)tokens × log(vocab)~63,000 bits~70,000 bits
Effective bits (real text)tokens × ~3 bits~12,000 bits~13,000 bits
Bits / GFLOP (effective)DECREASES with model size~7 × 10⁻⁵~8 × 10⁻⁷
Read the last row. The bigger the model, the worse the ratio of supervisory signal to compute. Doubling parameters doubles compute, but does nothing to the signal. By the time you reach a frontier-scale 671B model, every gigaflop of training compute is producing roughly a millionth of a bit of supervisory information. The model has the capacity to learn far more from each step than the objective is giving it.

Intuition: The 100-Step Exam Graded on One Number

Imagine a student sitting an exam that takes 100 steps of arithmetic to complete. The student does every step carefully. The teacher grades the work by reading only the final number — they never look at any intermediate result — and writes either a check or a cross at the bottom of the page. Over the course of a semester, the student improves on the final answer, but slowly, because each round of feedback is one bit of signal compressed from a hundred steps of effort. If the teacher had instead written a check or cross at every intermediate step, the student would have learned in ten exams what previously took a hundred.

A transformer is exactly this student. The forward pass computes a long chain of intermediate representations — early-layer hidden states contain low-level features, mid-layer hidden states contain syntactic structure, late-layer hidden states contain semantic content. By the time the final layer produces its logits, the model has done something analogous to a hundred steps of reasoning. The standard objective reads only the final logit vector and grades it on one cross-entropy target. Everything else is signal-free.

The physical picture. Think of the forward pass as a freight train moving through N layer-stations, each station performing nontrivial physics on the cargo. The training objective is a single inspector at the last station who weighs the train once. The inspector does not measure the cargo at any earlier station, and does not measure how far ahead the train could have predicted what comes next. All of those measurements would be cheap; we simply do not make them.

Why has the field tolerated this for so long? Two reasons. First, the single-target objective is mathematically clean — cross-entropy on the last hidden state has a tidy gradient and works out of the box. Second, for a long time the data side was the bottleneck: there was always more text to train on, so the question of "more signal per FLOP" was less urgent than the question of "more data, please." That changed once frontier labs started running out of high-quality tokens and once total training FLOPs began to dwarf data scaling. With both data and capital expensive, the bits-per-FLOP ratio became the new lever.

The Math: Supervisory Density per FLOP

Fix the model at NN total parameters and one training step that processes TT tokens. The standard estimate for FLOPs per step is:

Fstep=6NTF_{\text{step}} = 6 \cdot N \cdot T

The factor of 6 comes from 22 FLOPs per parameter in the forward pass (one multiply plus one add, counted as two operations) and roughly 44 FLOPs per parameter in the backward pass (one for the activation gradient, one for the weight gradient, each a multiply-add). Some analyses use a slightly different constant — 8 if you include the optimizer update — but 6 is the standard in scaling-law literature.

Now the supervisory side. Next-token prediction emits one cross-entropy target per position. The maximum information content of a categorical distribution over VV classes is log2V\log_2 V bits, achieved when the true distribution is uniform. Real text falls well below this ceiling, but as an upper bound:

Bstep=Tlog2VB_{\text{step}} = T \cdot \log_2 V

Define the supervisory density ρ\rho as bits per FLOP:

ρ=BstepFstep=Tlog2V6NT=log2V6N\rho = \frac{B_{\text{step}}}{F_{\text{step}}} = \frac{T \log_2 V}{6 N T} = \frac{\log_2 V}{6 N}

The TT's cancel — a longer context does not change the density. Vocabulary buys only a logarithmic improvement. The only term that survives is NN in the denominator, which is exactly the observation we want: supervisory density falls linearly with model size. Big models are the worst possible thing for bits per FLOP.

The MTP lever, in one line

Suppose we attach kk output heads instead of one, so position tt predicts the tokens at t+1,t+2,,t+kt+1, t+2, \dots, t+k. The extra heads cost O(kdV)O(k \cdot d \cdot V) FLOPs per step — small compared with the trunk's O(NT)O(N \cdot T) — so to a first approximation the FLOPs are unchanged. The bits scale linearly:

ρkklog2V6N\rho_k \approx \frac{k \cdot \log_2 V}{6 N}

We multiplied the density by kk for free. Whether the model can use the extra k targets without confusing itself is the engineering question that the remaining sections of this chapter answer. Here we are only establishing the prize.

Why this matters more as models get bigger. The denominator 6N6N grows with model size; the numerator log2V\log_2 V is essentially fixed. So ρ\rho shrinks every time you scale the model. The bigger the model, the more aggressively you want k>1k > 1. This is the structural reason MTP shows up first in 100B+ models like DeepSeek-V3 and not in early small transformers.

Manual Numerical Walkthrough

Numbers make the asymmetry vivid. Let us compute the supervisory density for three concrete configurations, by hand.

Click to expand: density calculation for 7B, 70B, and 671B models

Common inputs. Context length T=4096T = 4096, vocabulary V=50,000V = 50{,}000, batch size B=1B = 1. We compute everything per processed sequence so the numbers stay on a human scale.

Step 1 — Llama-7B class, k = 1. N=7×109N = 7 \times 10^9 parameters.

  • FLOPs/step = 6 · 7e9 · 4096 ≈ 1.72 × 10¹⁴ (172 TFLOPs)
  • log₂(50000) ≈ 15.61 bits per target
  • Bits/step (ceiling) = 4096 · 15.61 ≈ 63,929 bits
  • ρ = 63929 / 1.72e14 ≈ 3.71 × 10⁻¹⁰ bits/FLOP
  • ρ ≈ 3.71 × 10⁻⁴ bits per GFLOP

Read: every gigaflop of training compute is rewarded with less than half a millibit of supervisory information. That is the optimistic ceiling — real text only fills 2–4 bits of the 15.6 available, so the actual figure is closer to 7×1057 \times 10^{-5} bits/GFLOP.

Step 2 — A 70B model, k = 1. N=7×1010N = 7 \times 10^{10}.

  • FLOPs/step = 6 · 7e10 · 4096 ≈ 1.72 × 10¹⁵
  • Bits/step (ceiling) unchanged ≈ 63,929 bits
  • ρ ≈ 63929 / 1.72e15 ≈ 3.71 × 10⁻⁵ bits/GFLOP

We multiplied parameters by 10× and density dropped by 10×. The formula ρ=log2(V)/(6N)\rho = \log_2(V) / (6N) predicted exactly this — and any actual model would land at this ratio without intervention.

Step 3 — DeepSeek-V3, k = 1. N=6.71×1011N = 6.71 \times 10^{11} total parameters (with MoE; ~37B activated per token, but the gradient touches all parameters via expert routing across the dataset).

  • FLOPs/step (effective, activated) ≈ 6 · 3.7e10 · 4096 ≈ 9.1 × 10¹⁴
  • FLOPs/step (full, naïve) ≈ 6 · 6.71e11 · 4096 ≈ 1.65 × 10¹⁶
  • Bits/step ceiling ≈ 4096 · log₂(128000) ≈ 4096 · 17.0 ≈ 69,632 bits
  • ρ (effective) ≈ 7.6 × 10⁻⁵ bits/GFLOP
  • ρ (full) ≈ 4.2 × 10⁻⁶ bits/GFLOP

Either way you count it — by activated FLOPs or by the full sparse footprint — the density is a tiny fraction of even a millibit per gigaflop. The compute spend per unit of supervisory information has collapsed.

Step 4 — Same DeepSeek-V3, but k = 2. The trunk cost is essentially unchanged. The extra LM head costs roughly dV71681280009.2×108d \cdot V \approx 7168 \cdot 128000 \approx 9.2 \times 10^8 parameters per head — a fraction of a percent of the 671B trunk. So FLOPs/step grow by perhaps 0.2%, and bits/step double.

  • Bits/step ≈ 2 · 69,632 = 139,264 bits
  • ρ (effective) ≈ 1.5 × 10⁻⁴ bits/GFLOP — a 2× improvement
  • FLOP overhead vs k = 1: ~0.2%

The headline. Moving from k = 1 to k = 2 gives a 2× boost in supervisory density for a 0.2% increase in compute. This is the economic case for MTP at frontier scale — and the next four sections describe how to actually realize that gain without breaking training stability.

Sanity check. If you set N = 100M and run the formula, ρ2.6×102\rho \approx 2.6 \times 10^{-2} bits/GFLOP — two orders of magnitude better than the 7B case. That is why early small transformers never felt this pressure: the bottleneck only becomes crippling once N gets large.

Visualizing the Bottleneck

The diagram below makes the compute-vs-signal asymmetry visceral. Each token column is a full forward + backward pass through NN layers — every amber cell lit, every step. The thin green lines at the bottom are the only supervisory signal that pass extracts. Hover any column to see what one position is asked to learn from one round of compute.

Loading single-token bottleneck visualizer…

Three behaviors to try out. First, push the model-size slider up to 671B and watch the "bits per GFLOP" meter shrink — even though you have not changed the supervisory side at all, the compute side ballooned and the ratio collapsed. Second, push the predictions-per-position slider from 1 to 4. Compute barely budges; the supervisory bar grows in lockstep with k. The density meter climbs. Third, try the preset buttons: "DeepSeek-V3 / NTP" gives the tiny ratio that motivated MTP, and "DeepSeek-V3 / MTP-2" shows the resulting density without redrawing a single layer of the trunk.

The asymmetry to remember. When you double a model's parameter count you double its compute and gain nothing on the supervisory side. When you double k you double the supervisory side and gain (nearly) nothing on the compute side. The first is what scaling laws have been doing for a decade. The second is what this chapter is going to do.

Plain Python: Counting the Signal

Before we touch a real loss function, it is worth writing the bookkeeping directly. The code below does no training — it simply counts, in NumPy, what the standard NTP step costs and what it produces. The same formulas scale, unchanged, from the toy case to DeepSeek-V3.

🐍ntp_bottleneck_numpy.py
6Pick a concrete model + data shape

We use Llama-7B-class numbers because they are easy to reason about: 7 billion parameters, a 4k context window, and a 50k vocabulary. The same arithmetic scales linearly to DeepSeek-V3 (671B activated) — we just multiply N_PARAMS by ~100×.

EXECUTION STATE
N_PARAMS = 7e9
T = 4096
V = 50000
12FLOPs per step — the canonical estimate

The widely-cited rule from Kaplan et al. (2020) and refined by Hoffmann et al. (2022): forward FLOPs per processed token ≈ 2 · N_params (one multiply-add per parameter, counted as two FLOPs). Backward adds another ~4 · N_params, giving a total of ~6 · N_params per token. Multiply by T tokens per sequence and by batch size.

EXECUTION STATE
flops_per_step = ≈ 1.72e14 (172 TFLOPs)
17Bits per step — the supervisory ceiling

Each cross-entropy target is a categorical over V classes. Information theory caps the bits it can carry at log₂(V). With T positions per sequence and one target per position, the per-step ceiling is T · log₂(V). 'Ceiling' because real text is far below maximum entropy; effective bits are even smaller, which makes the bottleneck worse than this calculation suggests.

EXECUTION STATE
bits_per_step = ≈ 6.32e4 bits (~63 Kbits)
21Density: the ratio that hurts

Divide bits by gigaflops to land on an intuitive number. The result for Llama-7B-NTP is ~0.00037 bits per GFLOP. That is the headline: roughly one bit of supervisory signal per 2.7 teraflops of compute. The transformer trunk does enormous work; we extract a sliver of it.

EXECUTION STATE
density = ≈ 3.67e-4 bits/GFLOP
28Vary k — preview MTP's lever

If we attached k LM heads instead of one, each position would emit k cross-entropy targets. The forward+backward cost of the heads is small compared to the trunk, so flops_per_step is essentially unchanged. Bits scale linearly with k. So the density also scales linearly with k — which is the entire economic argument for multi-token prediction.

31Density across k = 1..4

You will see density values of roughly 3.7e-4, 7.3e-4, 1.1e-3, 1.5e-3 bits/GFLOP. Each additional predicted position roughly doubles, triples, quadruples the supervisory density — and that is exactly what makes the same compute budget produce a better model. Whether the model can actually USE the extra supervisory bits is the engineering question Chapter 7 spends the rest of its sections answering.

28 lines without explanation
1import numpy as np
2
3# A toy transformer's training step. We do NOT run the network here — we just
4# count the FLOPs it would do and the supervisory bits we extract from one step.
5
6# Model + data dimensions.
7N_PARAMS  = 7e9          # 7B parameters
8T         = 4096         # context length (tokens per sequence)
9V         = 50_000       # vocabulary size
10BATCH     = 1            # per-GPU micro-batch for the demo
11
12# (a) FLOPs per training step. Standard estimate:
13#     forward + backward ≈ 6 * N_params per processed token.
14flops_per_step = 6 * N_PARAMS * T * BATCH
15
16# (b) Supervisory signal. Standard next-token prediction supplies exactly
17#     one cross-entropy target per position. Upper bound per target is log2(V)
18#     bits (the maximum entropy a categorical with V outcomes can carry).
19bits_per_step = T * BATCH * np.log2(V)
20
21# (c) Supervisory density: how many bits of signal per gigaflop of compute.
22density = bits_per_step / (flops_per_step / 1e9)
23
24print(f"FLOPs/step      = {flops_per_step:.3e}")
25print(f"Bits/step (max) = {bits_per_step:.3e}")
26print(f"Density         = {density:.6f} bits / GFLOP")
27
28# (d) What if we predicted k targets per position instead of 1?
29#     The forward pass cost is unchanged (one extra LM head ≪ a 70-layer trunk).
30#     The supervisory bits scale linearly with k.
31for k in [1, 2, 3, 4]:
32    bits_k    = k * T * BATCH * np.log2(V)
33    density_k = bits_k / (flops_per_step / 1e9)
34    print(f"k = {k}: density = {density_k:.6f} bits / GFLOP")

Two structural points are worth marking. First, the loop over k is the cheapest possible preview of MTP: nothing about the network changes; we simply multiply the bits term by k and re-divide. The fact that this trivial edit produces a multiplicatively-better ratio is the entire reason MTP is a serious research direction. Second, the FLOPs estimate uses the canonical 6NT6 \cdot N \cdot T rule because that is what scaling-law papers use; if you have read Hoffmann et al.'s Chinchilla paper, this is exactly their accounting.

Sanity check. If you set N_PARAMS = 1 and V = 2, the bits/step becomes T1=TT \cdot 1 = T and the FLOPs/step becomes 6T6T. Density = 1/6 bits/FLOP — one bit per six operations. That is the theoretical best a transformer could ever get at this objective. Any real model lands many orders of magnitude below it, and the gap is the cost of using the transformer as a function approximator.

PyTorch: The Standard NTP Loss in One Line

Now the same idea, instantiated in PyTorch. This is the loss every modern LLM is trained with, give or take a label-smoothing parameter. Notice how short it is — that brevity is exactly what we are about to enrich.

🐍ntp_loss_pytorch.py
4Inputs to the loss

logits has shape (B, T, V): one length-V probability vector per position per sequence. inputs has shape (B, T): the raw token IDs. NTP wants logits[t] to assign high probability to inputs[t+1], so the shapes already differ by exactly the shift we are about to apply.

12Drop the last logit position

Position T-1 cannot predict anything — there is no token T. We slice it off, leaving (B, T-1, V). This is the geometric statement of 'one prediction per position, then one less'.

EXECUTION STATE
shift_logits.shape = (B, T-1, V)
14Drop the first input position

Position 0 has no predecessor to predict it — nothing in the model has 'caused' the first token from a transformer point of view. We drop inputs[:, 0] and keep inputs[:, 1:], landing at (B, T-1). Now shift_logits[t] and shift_labels[t] are aligned: the model's logits at position t should match the token at position t+1.

EXECUTION STATE
shift_labels.shape = (B, T-1)
18Cross-entropy in one call

F.cross_entropy expects 2D logits and 1D targets, so we flatten. The default reduction is 'mean', so each of the B·(T-1) positions contributes 1/(B·(T-1)) of the final scalar. That averaging is exactly why each step extracts only B·(T-1) supervisory targets — the loss IS the count of how many signals you compressed.

EXECUTION STATE
loss = scalar tensor
28Count the targets explicitly

B · (T-1) = 2 · 7 = 14 targets are reduced into one scalar. log₂(100) ≈ 6.64, so the upper bound on bits extracted by this step is 14 · 6.64 ≈ 93 bits. Compare that with the FLOPs this micro-step cost (millions, even on toy shapes) and the asymmetry is already visible at toy scale.

EXECUTION STATE
n_targets = 14
30What a real LM would print here

A converged 7B model on natural text gives loss ≈ 2.0 in nats, ≈ 2.9 bits per token. That is the effective bits — much smaller than the log₂(V) ≈ 15.6 ceiling for V = 50000. The bottleneck is even tighter than our optimistic count suggests, which is why MTP, distillation, and richer auxiliary objectives all show up later in this book.

31 lines without explanation
1import torch
2import torch.nn.functional as F
3
4def ntp_loss(
5    logits: torch.Tensor,  # (B, T, V) — model output for every position
6    inputs: torch.Tensor,  # (B, T)    — token IDs
7) -> torch.Tensor:
8    """
9    Standard next-token prediction loss. The hidden state at position t is
10    asked to predict the token at position t+1. The last position has no
11    target, so we shift logits and inputs and align them.
12    """
13    # (a) Drop the last position of logits — it has no successor to predict.
14    shift_logits = logits[:, :-1, :]              # (B, T-1, V)
15    # (b) Drop the first position of inputs — it has no predecessor to be predicted.
16    shift_labels = inputs[:, 1:]                  # (B, T-1)
17
18    # (c) Cross-entropy over the full vocabulary, averaged across all
19    #     (B * (T-1)) positions. Each averaged term is ONE supervisory target.
20    loss = F.cross_entropy(
21        shift_logits.reshape(-1, shift_logits.size(-1)),
22        shift_labels.reshape(-1),
23    )
24    return loss
25
26# Sanity check on the toy: 2 sequences, 8 tokens, vocab of 100.
27torch.manual_seed(0)
28B, T, V = 2, 8, 100
29logits  = torch.randn(B, T, V)
30inputs  = torch.randint(0, V, (B, T))
31
32loss = ntp_loss(logits, inputs)
33# How many supervisory targets contributed to this scalar?
34n_targets = B * (T - 1)
35print(f"loss = {loss.item():.4f}")
36print(f"supervisory targets reduced into this scalar = {n_targets}")
37print(f"bits/target upper bound = {torch.log2(torch.tensor(float(V))).item():.3f}")

Two subtleties worth marking, both about how this loss interacts with the rest of the training step:

  1. The reduction is a mean, not a sum. Each of the B(T1)B(T-1) positions contributes 1/(B(T1))1/(B(T-1)) of the final scalar. This means the loss magnitude does not grow with context length, which is convenient — but it also means the gradient delivered per token is a tiny fraction of the total loss. The asymmetry we counted at the FLOP level is faithfully preserved at the gradient level: more positions means more shared weight on each backward signal.
  2. The shift is non-negotiable. Logits at position t predict the token at t+1. If you forget to shift, you train the model to predict itself — trivial accuracy, zero useful gradient. This is a common bug in hand-rolled training loops and the reason most frameworks ship a wrapped loss with the shift baked in.
What the next sections will change. The whole point of this chapter is to replace the single F.cross_entropy line above with a structured family of k-target losses. The shift will go from +1+1 to +1,+2,,+k+1, +2, \dots, +k; the averaging will become a weighted sum over heads; and the model architecture will grow shallow MTP modules that produce the extra predictions. The bookkeeping changes; the supervisory density formula ρk=klog2(V)/(6N)\rho_k = k \cdot \log_2(V) / (6N) is the prize.

What Changes at Massive Scale

At toy scale the bottleneck is a curiosity. At frontier scale it is a defining constraint of the training economy. Three things change as N grows from millions to hundreds of billions of parameters.

Compute becomes the dominant cost

At 100M parameters, a training run might cost a few thousand GPU-hours and the data pipeline is the bottleneck — you tune tokenization, dedup, and domain mixing because that is where the gains are. At 671B parameters, a single training run costs tens of millions of GPU-hours, and the data pipeline is solved enough that every additional bit of supervisory signal per FLOP translates directly into millions of dollars of compute equivalent. The density ratio stops being academic.

The model can absorb more signal than it gets

Empirically, 100B+ models trained on NTP show clear evidence of under-supervision: probes of intermediate layers reveal predictive information about tokens many positions ahead, but the model is never asked to use it. Multi-token prediction does not teach the model new tricks — it unlocks tricks the model has already learned in service of a one-step objective. The capacity is already there; only the supervisory signal is missing.

Inference cost lines up with training-time choices

A model trained to predict kk tokens ahead at every position is, almost by accident, also a model whose internal representations are well-suited for speculative decoding at inference — draft k tokens in parallel and accept the prefix that matches a verifier. Section 7.5 makes this connection explicit. What looks like a training-time density argument in this section turns into a 2–4× inference speedup later, with no extra parameters and no separate draft model required.

EraTypical NBottleneckWhy NTP was fine
GPT-2 (2019)1.5BData qualityDensity was ~5e-3 bits/GFLOP — plenty
GPT-3 (2020)175BData scaleDensity dropped, but data outpaced compute
Chinchilla / Llama-2 (2022-23)7B–70BCompute–data balanceDensity tight but tolerable
DeepSeek-V3 / GPT-4 class (2024-25)300B–1T+Bits per FLOPDensity collapsed; MTP becomes necessary
The era we are now in. Every frontier lab has either announced or is privately running variants of multi-token prediction, speculative drafting, or richer objective families. DeepSeek-V3's paper is the cleanest public derivation, which is why this chapter follows it line by line. The bottleneck is universal; the fix is what differs.

Engineering Reality and the Door MTP Opens

Three practical observations sit on top of the math, and each one will recur through the rest of this chapter.

  1. The bottleneck is not a bug; it is a default. F.cross_entropy on shifted logits is so easy and so robust that it became the universal recipe. There is nothing "wrong" with it — every model in production uses it. The point is that the default leaves information on the table, and the cost of leaving that information on the table grows with model scale. MTP is not fixing a bug; it is collecting on a debt.
  2. Naïve parallel k-target prediction does not work. The obvious move is to attach k independent output heads, each predicting t+it+i from the same hidden state. This is what Section 7.2 examines, and it fails: heads compete for the same representation, gradients clash, and the model often does worse than NTP. DeepSeek-V3's sequential causal MTP from Section 7.3 is the construction that actually works, and the rest of the chapter is about why.
  3. The right loss coefficient is small. Even when the architecture is right, the MTP heads are auxiliary — the main objective is still next-token prediction. DeepSeek-V3 weights MTP losses at roughly λ=0.3\lambda = 0.3 early in training and decays to 0.10.1 by the end. Section 7.4 derives why this schedule matters, and how the wrong choice silently destroys the gains from the architecture in Section 7.3.

One sentence to carry forward into the rest of the chapter: the single-token objective is a compute-efficient learner of a signal-inefficient problem; multi-token prediction is the cheapest known way to fix the supervisory side without touching the trunk. Everything that follows in Chapter 7 is the careful engineering of that fix.

Where we go from here. Section 7.2 examines the naïve parallel-MTP construction and shows precisely why it breaks. Section 7.3 derives DeepSeek-V3's sequential causal MTP module — the construction that actually works. Section 7.4 covers training-objective design and the weight-decay schedule. Section 7.5 closes the loop by showing how the same architecture, used at inference, produces the speculative-decoding speedups DeepSeek-V3 reports in production.
Loading comments...