Chapter 9
25 min read
Section 53 of 117

Inference-Aware Scaling Laws

Scaling Laws and Compute-Optimal Training

The Chinchilla recipe says train a 70 B model on 1.4 T tokens. Llama 3 trained an 8 B model on 15 T tokens — ten times smaller, ten times more data. Both teams read the same scaling-law paper. Neither is wrong. They are solving different problems: Chinchilla minimises training compute; Llama 3 minimises training plus a decade of inference. This section derives the inference-aware optimum and shows why every modern foundation model is dramatically over-trained relative to Hoffmann et al. 2022.

The Real Problem: Chinchilla Ignores Inference

Hoffmann et al. (2022) asked the right question for 2022: given a fixed training compute budget, what model size and token count minimise final loss? Their answer — roughly 20 tokens per parameter, scale NN and DD together — was a revolution. Gopher (280 B params, 300 B tokens) was severely under-trained. Chinchilla (70 B params, 1.4 T tokens) beat it at one-quarter the parameter count and the same training compute. The community rebuilt every training plan around the new recipe.

Then a different question became urgent. A foundation model is not a research artefact; it is the engine of a product that will serve billions of tokens per day for years. ChatGPT alone serves on the order of a quadrillion tokens of inference per year. Llama 3 is downloaded and re-served on millions of devices. Suddenly the cost equation is not C=6NDC = 6 N D anymore — it is Ctotal=Ctrain+CinfC_{\text{total}} = C_{\text{train}} + C_{\text{inf}}, and the second term can dwarf the first.

QuantityChinchilla framingProduction framing
ObjectiveMinimise training FLOPs at fixed lossMinimise lifetime FLOPs at fixed loss
Cost per token (train)6 N6 N
Cost per token (infer)Not counted2 N per query token
Lifetime workloadD training tokensD train + D_inf served
Model size NGrows with budgetShrinks if inference is heavy
Tokens per parameter≈ 20100 – 500+ for modern products
The empirical signal that motivates the formula. Llama 2 7B (2 T tokens, ~285×), Llama 3 8B (15 T tokens, ~1875×), Mistral 7B (~8 T tokens, ~1140×), Phi-3-mini (3.8 B params, 3.3 T tokens, ~870×). None of these sit anywhere near Chinchilla's 20× ratio. They are not mistakes — they are deliberate over-training driven by inference economics. Sardana et al. (2024) gave the math a name: Beyond Chinchilla-Optimal.

Intuition: Training Is Paid Once, Inference Forever

Imagine you are running a restaurant. You can buy a giant 8-burner stove for $100 000, or a small 2-burner stove for $20 000. The big stove cooks faster — fewer prep hours per meal. The small stove is slower but cheaper. Which one wins?

It depends entirely on how many meals you plan to serve. If you cook a single banquet and close the restaurant forever, the big stove is foolish — you pay $100 000 to save twenty hours of prep. If you serve ten thousand meals a year for a decade, the big stove pays for itself in saved labour.

Model size NN is the stove. The big stove (large NN) needs less training data to reach a given quality — fewer prep hours. The small stove (small NN) is cheaper to buy but slower per meal served — every inference token costs 2N2 N FLOPs, and you will serve trillions of them. If you plan to serve a billion users, you want the smallest possible NN that can still reach quality LL^{\star} — and you compensate by training it on enormous amounts of data.

Why bigger isn't always better. A 70 B model has roughly nine times the per-token inference cost of an 8 B model. If both reach the same downstream quality, the 8 B model wins by a factor of nine on every served token — forever. The training cost gap that you paid to over-train the 8 B (maybe 5 – 10× extra training FLOPs) is recovered in the first few weeks of serving traffic.

The Inference-Aware Cost Equation

We start from Chinchilla's loss law:

L(N,D)=E+ANα+BDβL(N, D) = E + \frac{A}{N^{\alpha}} + \frac{B}{D^{\beta}}

with E1.69,α0.34,β0.28E \approx 1.69, \, \alpha \approx 0.34, \, \beta \approx 0.28. We do not change this — the loss surface is what it is. What changes is what we minimise on top of it.

The training FLOPs of a forward+backward pass on DD tokens follow the standard 6N6N rule (forward + backward + activation recompute):

Ctrain=6NDC_{\text{train}} = 6 \, N \, D

Inference does only one forward pass per generated token, no backward, no optimiser update:

Cinf=2NDinfC_{\text{inf}} = 2 \, N \, D_{\text{inf}}

where DinfD_{\text{inf}} is the total number of tokens (input prompts + generated outputs) the model will serve over its lifetime, summed across every query, every user, every session, every redeployment. The lifetime cost we actually pay is:

Ctotal(N,D)=6ND+2NDinfC_{\text{total}}(N, D) = 6 N D + 2 N D_{\text{inf}}

Constrained minimisation

We want a fixed quality LL^{\star} and we want to minimise CtotalC_{\text{total}} over (N,D)(N, D). The constraint L(N,D)=LL(N, D) = L^{\star} traces a curve in the (N,D)(N, D) plane — the iso-loss curve — and we slide along it looking for the lowest CtotalC_{\text{total}}.

Solve DD from the iso-loss equation:

D(N,L)=(BLEANα)1/βD(N, L^{\star}) = \left( \frac{B}{L^{\star} - E - A N^{-\alpha}} \right)^{1/\beta}

Substitute into the cost and take the derivative with respect to NN. After some algebra (Sardana et al. 2024, eqn 5), the optimum satisfies:

CtotalN=6D+6NDN+2Dinf=0\frac{\partial C_{\text{total}}}{\partial N} = 6 D + 6 N \frac{\partial D}{\partial N} + 2 D_{\text{inf}} = 0

The first two terms are exactly Chinchilla's training-only gradient — they vanish at the Chinchilla optimum. The third term, 2Dinf2 D_{\text{inf}}, is always positive. So at the Chinchilla point the total-cost gradient is positive, meaning NN is too large. We have to shrink NN until the negative D/N\partial D / \partial N term (more data is needed) balances the positive inference term.

The shifted optimum, in closed form

Define the inference-to-training ratio at the optimum:

r=DinfDr = \frac{D_{\text{inf}}}{D^{\star}}

Sardana et al. show that the inference-aware optimum sits at:

DNα+ββ11+3r1\frac{D^{\star}}{N^{\star}} \approx \frac{\alpha + \beta}{\beta} \cdot \frac{1}{1 + 3r^{-1}}

Two readings. First, when r=0r = 0 (no inference), the tokens-per-param ratio collapses to (α+β)/β22(\alpha + \beta)/\beta \approx 22 — exactly Chinchilla. Second, as rr grows, the bracket grows without bound, so the tokens-per-param ratio grows without bound. A modern product with DinfDD_{\text{inf}} \gg D^{\star} sits at hundreds of tokens per parameter, exactly where Llama 3 and Phi-3 are pinned.

Manual Numerical Walkthrough

Two configurations, same target loss, one pure research and one Llama-3-scale product. All numbers derived by hand from the equations above.

Click to expand: Chinchilla vs Llama-3-style at L* = 2.00

Setup. Fix the target loss at L=2.00L^{\star} = 2.00 nats. Constants: E=1.69,A=406.4,B=410.7,α=0.34,β=0.28E = 1.69, A = 406.4, B = 410.7, \alpha = 0.34, \beta = 0.28. For the iso-loss curve, LE=0.31L^{\star} - E = 0.31.

Configuration 1 — Chinchilla research point. Dinf=0D_{\text{inf}} = 0 (model is never served, just published). Chinchilla's closed-form gives N65N^{\star} \approx 65 B parameters at this loss with D1.3D^{\star} \approx 1.3 T tokens — about 20 tokens per parameter.

Verify the loss: L=1.69+406.4(6.51010)0.34+410.7(1.31012)0.28L = 1.69 + 406.4 \cdot (6.5 \cdot 10^{10})^{-0.34} + 410.7 \cdot (1.3 \cdot 10^{12})^{-0.28}

=1.69+0.102+0.2072.00= 1.69 + 0.102 + 0.207 \approx 2.00

Training cost: Ctrain=66.510101.310125.11023C_{\text{train}} = 6 \cdot 6.5 \cdot 10^{10} \cdot 1.3 \cdot 10^{12} \approx 5.1 \cdot 10^{23} FLOPs. No inference. Lifetime cost ≈ training cost.

Configuration 2 — Llama-3-scale product point. Dinf=21015D_{\text{inf}} = 2 \cdot 10^{15} tokens (two quadrillion served — roughly the order of a Llama-3 class deployment over five years). At N=8N = 8 B and L=2.00L^{\star} = 2.00, capacity term: A(8109)0.340.205A \cdot (8 \cdot 10^{9})^{-0.34} \approx 0.205. Residual loss budget for data: 0.310.205=0.1050.31 - 0.205 = 0.105. Required tokens: D=(410.7/0.105)1/0.281.51013D = (410.7 / 0.105)^{1/0.28} \approx 1.5 \cdot 10^{13} = 15 T tokens — exactly Llama 3 8B's training set, by no accident.

Tokens per parameter: 15000/8187515000 / 8 \approx 1875 — about 93× Chinchilla's ratio.

Training cost: 681091.510137.210236 \cdot 8 \cdot 10^{9} \cdot 1.5 \cdot 10^{13} \approx 7.2 \cdot 10^{23} FLOPs (slightly more than Chinchilla — over-training is not free). Inference cost: 28109210153.210252 \cdot 8 \cdot 10^{9} \cdot 2 \cdot 10^{15} \approx 3.2 \cdot 10^{25} FLOPs (44× larger than training). Lifetime total: 3.271025\approx 3.27 \cdot 10^{25} FLOPs.

The counterfactual. Serve the 65 B Chinchilla model on the same 2 × 10¹⁵ inference workload. Inference cost: 26.5101021015=2.610262 \cdot 6.5 \cdot 10^{10} \cdot 2 \cdot 10^{15} = 2.6 \cdot 10^{26} FLOPs — 8× more than the 8 B over-trained model. Lifetime cost: 2.61026\approx 2.6 \cdot 10^{26} FLOPs.

Reading the numbers. Same loss target. Same inference workload. Chinchilla lifetime cost 2.610262.6 \cdot 10^{26} FLOPs vs Llama-3-style 3.2710253.27 \cdot 10^{25} FLOPs — the over-trained small model is eight times cheaper to own over its lifetime. That single ratio is why no frontier lab ships a Chinchilla-optimal model anymore.

Visualizing the Shifted Optimum

The widget below traces both cost curves as a function of NN at a fixed loss target. The amber curve is training cost only — its minimum is the Chinchilla point. The green curve adds inference. As you drag the inference-tokens slider, watch the green curve's minimum migrate left to smaller models, and the tokens-per-parameter ratio climb past 100×, 200×, then beyond.

Loading inference-aware scaling visualizer…
Three things to try in the widget. (1) Click Research model — the green and amber minima sit on top of each other; pure Chinchilla. (2) Click Llama-3-scale product — the green minimum jumps about a decade to the left of amber, and tokens/param ratio crosses 100×. (3) Slide the target loss tighter (smaller LL^{\star}); the iso-loss curve pushes both minima up and to the right — quality always costs FLOPs, but inference-aware costs LESS than Chinchilla at every quality target once Dinf>0D_{\text{inf}} > 0.

Plain Python: Lifetime-Cost Optimisation

Here is the full search, written as a clean Python loop. No framework, no GPU — just the iso-loss algebra and a sweep over NN. This is the script every infra team runs (in some form) before committing to a pre-training run.

lifetime-cost-search.py
🐍python
3Chinchilla scaling-law constants

These are the fitted constants from Hoffmann et al. 2022. E is the irreducible text entropy — the floor no model can drive lower. A and B are the capacity and data scale constants. ALPHA and BETA are the exponents that say how fast loss falls as N and D grow. We reuse these unchanged from §9.1 — the inference-aware law does NOT modify the loss surface, only the objective we minimise on top of it.

EXECUTION STATE
ALPHA = 0.34 (capacity exponent)
BETA = 0.28 (data exponent)
9Loss for a candidate (N, D)

Straight Chinchilla: predicted next-token cross-entropy in nats. The first term is unreachable language entropy; the second falls as N grows; the third falls as D grows. We will not minimise THIS — we will hold it equal to a target L* and let (N, D) trade against each other on the iso-loss curve.

EXECUTION STATE
N = n_B × 10⁹ params
D = d_B × 10⁹ tokens
14Solve D given a target loss and N

Algebra: L* = E + A·N^-α + B·D^-β ⇒ D = (B / (L* − E − A·N^-α))^(1/β). The capacity term eats part of the loss budget; whatever is left has to be paid in data. If the model is so small that the capacity term already exceeds L* − E, no amount of data fixes it — we return None.

EXECUTION STATE
capacity = A · N^-α (loss from finite N)
residual = L* − E − capacity (loss budget left for D)
D = (B / residual)^(1/β), in raw tokens
22Lifetime FLOPs = training + inference

Training costs 6·N FLOPs per token (forward + backward + activation recompute, the standard 6N rule from Kaplan et al. 2020). Inference costs 2·N FLOPs per token (one forward pass only — no backward, no optimiser update). Both scale linearly with N, so total inference cost is 2·N·D_inf — independent of D. That asymmetry is the entire point of the section.

EXECUTION STATE
6 N D = training compute (FLOPs)
2 N D_inf = inference compute (FLOPs)
27Sweep N at fixed L*, find lifetime-cost minimum

We walk N from ~0.01 B up to ~1.6 T parameters in 1/40-decade log steps (~9× per step). For each N, solve D from the iso-loss equation, then compute training cost (grows in D, falls in N²-ish), inference cost (grows linearly in N), and total cost. Keep the (N, D) with the smallest total. With D_inf = 0 the optimum collapses to Chinchilla (training only). With D_inf huge, the inference term dominates and the optimum shifts toward small N — and from the iso-loss equation, smaller N forces enormous D.

LOOP TRACE · 3 iterations
Step k=0 (N ≈ 1 B)
c_inf / c_train = tiny model, train dominates
Step k=80 (N ≈ 100 B)
c_inf / c_train = large model, inference dominates
Optimum
(N*, D*) = wherever ∂c_total/∂N = 0
39Two regimes, same target loss

Both calls pin the model to L* = 2.00 nats. The research call (D_inf = 0) reproduces Chinchilla: ~20 tokens/param, N* in the tens of billions. The product call (D_inf = 2 × 10¹⁵, i.e. 2 quadrillion served tokens — roughly Llama-3-scale traffic) shifts the optimum to a SMALLER N and ~150–200 tokens/param. Same loss, very different model, very different lifetime cost. That is the inference-aware bargain in two function calls.

EXECUTION STATE
Research (D_inf=0) = ~Chinchilla, 20 tok/param
Product (D_inf=2e15) = ~7-10× over-trained, ~150-200 tok/param
45 lines without explanation
1import math
2
3# Chinchilla-fitted constants (Hoffmann et al. 2022).
4E      = 1.69    # irreducible loss
5A, B   = 406.4, 410.7
6ALPHA  = 0.34    # capacity exponent
7BETA   = 0.28    # data exponent
8
9def loss(n_B, d_B):
10    """Predict loss for N params (B) and D tokens (B)."""
11    N, D = n_B * 1e9, d_B * 1e9
12    return E + A * N**(-ALPHA) + B * D**(-BETA)
13
14def tokens_for_loss(n_B, target):
15    """Solve D from L* = E + A N^-a + B D^-b."""
16    capacity = A * (n_B * 1e9)**(-ALPHA)
17    residual = target - E - capacity
18    if residual <= 0:
19        return None                       # model too small even with infinite D
20    return (B / residual)**(1 / BETA)     # D in raw tokens
21
22def lifetime_cost(n_B, d_tokens, d_inf_tokens):
23    """6ND training FLOPs + 2N*D_inf inference FLOPs."""
24    N = n_B * 1e9
25    return 6 * N * d_tokens + 2 * N * d_inf_tokens
26
27def optimal(target, d_inf_tokens):
28    """Sweep N at fixed L*, return (N*, D*, costs) minimising lifetime FLOPs."""
29    best = None
30    for k in range(-20, 141):             # log10 N from -2 .. 1.4 -> 0.01B .. ~25B... actually 1.6T
31        n_B = 10 ** (k / 40.0)
32        D = tokens_for_loss(n_B, target)
33        if D is None:
34            continue
35        c_train = 6 * n_B * 1e9 * D
36        c_inf   = 2 * n_B * 1e9 * d_inf_tokens
37        c_total = c_train + c_inf
38        if best is None or c_total < best['c_total']:
39            best = dict(n_B=n_B, d_B=D / 1e9,
40                        c_train=c_train, c_inf=c_inf, c_total=c_total)
41    return best
42
43# Compare two regimes at the same target loss L* = 2.00
44target = 2.00
45research = optimal(target, d_inf_tokens=0)              # train once, never serve
46product  = optimal(target, d_inf_tokens=2e15)           # 2 quadrillion served tokens
47
48print(f"Chinchilla (no inference): N* = {research['n_B']:6.1f} B, "
49      f"D* = {research['d_B']:6.0f} B, tokens/param = {research['d_B']/research['n_B']:5.1f}x")
50print(f"Inference-aware (Llama-3-scale): N* = {product['n_B']:6.1f} B, "
51      f"D* = {product['d_B']:6.0f} B, tokens/param = {product['d_B']/product['n_B']:5.1f}x")

PyTorch: Searching the (N, D, D_inf) Grid

The same logic, vectorised across hundreds of candidate model sizes at once. This pattern shows up inside frontier labs' planning pipelines, where the search runs over many candidate architectures (dense, MoE, different attention variants) at every plausible quality target.

inference_aware_search.py
🐍python
1Why PyTorch for scaling-law search?

Scaling-law search is a tiny optimisation — a few hundred candidate Ns at fixed L*. PyTorch is overkill for the FLOP count, but it gives us three things for free: (1) vectorised arithmetic over the candidate grid via broadcasting; (2) GPU acceleration if we ever scale to millions of candidates; (3) autograd, so the same code can be plugged into a downstream loss that depends on the chosen (N*, D*). In practice this exact pattern is what a planning team runs before kicking off a 20-million-dollar pre-training job.

4Wrapping the constants in tensors

Torch scalars participate in tensor arithmetic just like Python floats, but they carry dtype + device. Keeping ALPHA, BETA, E, A, B as tensors means loss_fn and tokens_for_loss return tensors, which we can broadcast over a whole grid of N candidates in one shot. No Python loop, no per-step overhead.

EXECUTION STATE
E, A, B = torch.tensor scalars
ALPHA, BETA = fitted exponents, dimensionless
9The loss function as a tensor expression

Identical to the Python version, but every operand is a tensor. If we pass N of shape (steps,) and D of shape (steps,), the output is shape (steps,) — one loss per candidate, computed in a single fused kernel.

EXECUTION STATE
N = tensor of params, shape (steps,)
D = tensor of tokens, shape (steps,)
loss = predicted loss, shape (steps,)
12Vectorised iso-loss solve with masking

residual can go non-positive if N is so small that the capacity term alone already exceeds L* − E. Raising a negative number to (1/β) would produce NaN. The torch.where trick replaces those entries with a sentinel large number — D becomes huge, c_total becomes huge, argmin avoids them. Cleaner than try/except inside a vectorised pipeline.

EXECUTION STATE
residual = L* − E − A·N^-α
safe = residual with non-positive entries masked
D = (B / safe)^(1/β), tokens
19Building the candidate grid

torch.logspace gives 400 log-uniform Ns between 0.5 B and 2 T parameters. This is the search space — every realistic model size, from a small distilled student up to a frontier foundation model. We will compute the iso-loss D and lifetime cost for ALL 400 simultaneously, then pick the argmin.

EXECUTION STATE
n_B = shape (400,) log-uniform
N = n_B × 10⁹, shape (400,)
24The three cost tensors

c_train depends on both N (linear) and D (linear) — but D itself depends on N through the iso-loss constraint, so c_train as a function of N alone falls steeply for small N (need huge D) and rises gently for large N. c_inf is exactly linear in N. c_total = c_train + c_inf has a U shape: training cost drives the left wall, inference cost drives the right wall. The minimum sits where the two walls cross.

EXECUTION STATE
c_train = 6·N·D, shape (400,)
c_inf = 2·N·D_inf, shape (400,)
c_total = c_train + c_inf, shape (400,)
28argmin picks the row

torch.argmin returns the integer index of the minimum entry. We pull out the (N, D, costs) for that index and return them as Python floats. On a GPU this whole search takes microseconds; on CPU it is bottlenecked by the torch.logspace allocation, not the math.

EXECUTION STATE
k = argmin index, integer
33The headline sweep

Five workloads: pure research (no inference), early-stage deployment (10¹³), internal API (10¹⁴), product (10¹⁵), hyperscale (10¹⁶). Watch the tokens-per-parameter column climb from ~20 to ~200 as inference traffic grows. The Llama-3 8B recipe (15 T tokens, 1875× tokens/param) sits roughly at the D_inf = 10¹⁵ row plus a safety margin for distillation downstream — and that is exactly the argument Meta made in the Llama 3 paper.

LOOP TRACE · 4 iterations
D_inf = 0
tok/param = ~20 (Chinchilla)
D_inf = 1e14
tok/param = ~60
D_inf = 1e15
tok/param = ~150
D_inf = 1e16
tok/param = ~400
32 lines without explanation
1import torch
2
3# Chinchilla constants as torch scalars (so we can vectorise and differentiate).
4E      = torch.tensor(1.69)
5A, B   = torch.tensor(406.4), torch.tensor(410.7)
6ALPHA, BETA = torch.tensor(0.34), torch.tensor(0.28)
7
8def loss_fn(N, D):
9    return E + A * N**(-ALPHA) + B * D**(-BETA)
10
11def tokens_for_loss(N, target):
12    residual = target - E - A * N**(-ALPHA)
13    # Mask impossible (N, target) combinations.
14    safe = torch.where(residual > 0, residual, torch.tensor(1e30))
15    return (B / safe) ** (1 / BETA)
16
17def search(target, d_inf_tokens, n_min_B=0.5, n_max_B=2000.0, steps=400):
18    """Vectorised grid search over N at fixed L*; returns the best row."""
19    # 400 candidate Ns spread log-uniformly between 0.5 B and 2 T.
20    n_B = torch.logspace(torch.log10(torch.tensor(n_min_B)),
21                         torch.log10(torch.tensor(n_max_B)),
22                         steps=steps)
23    N = n_B * 1e9                                   # (steps,)
24    D = tokens_for_loss(N, target)                  # (steps,)
25    d_B = D / 1e9
26    c_train = 6 * N * D                             # (steps,)
27    c_inf   = 2 * N * d_inf_tokens                  # (steps,)
28    c_total = c_train + c_inf                       # (steps,)
29    k = torch.argmin(c_total).item()
30    return dict(n_B=n_B[k].item(), d_B=d_B[k].item(),
31                c_train=c_train[k].item(), c_inf=c_inf[k].item(),
32                c_total=c_total[k].item())
33
34# Sweep five inference workloads at L* = 2.00.
35for d_inf_tokens in [0, 1e13, 1e14, 1e15, 1e16]:
36    r = search(torch.tensor(2.00), torch.tensor(d_inf_tokens))
37    tok_per_param = r['d_B'] / r['n_B']
38    print(f"D_inf = {d_inf_tokens:9.0e} tok  ->  "
39          f"N* = {r['n_B']:6.1f} B   D* = {r['d_B']:6.0f} B   "
40          f"tok/param = {tok_per_param:5.1f}x")

At Massive Scale: Llama 3, Phi, and the Over-Training Era

Every modern open-weight foundation model published since 2024 is dramatically over-trained relative to Chinchilla. The pattern is not accidental — it is exactly the inference-aware optimum.

ModelParamsTraining tokensTokens / paramChinchilla ratio
Chinchilla (DeepMind, 2022)70 B1.4 T20×1× (reference)
Llama 2 7B (Meta, 2023)7 B2 T286×14×
Mistral 7B (Mistral, 2023)7 B≈ 8 T1140×57×
Llama 3 8B (Meta, 2024)8 B15 T1875×94×
Phi-3-mini (Microsoft, 2024)3.8 B3.3 T870×43×
Llama 3 70B70 B15 T214×11×
DeepSeek-V3 (37 B activated)37 B act / 671 B tot14.8 T400× (act)20× (act)

Read down the table. The smaller the model, the more over-trained it is — and the more over-trained it is, the more dominant inference is in its lifetime cost. Phi-3-mini at 3.8 B parameters is designed to run on a phone; its inference-to-training ratio is extreme, and Microsoft pushed its training to 3.3 T tokens to match. Llama 3 8B is designed to run on a single A100; same logic, 15 T tokens.

What changes at scale

  1. Data quality becomes the bottleneck. Once you commit to D>10D > 10 T tokens, you run out of pristine web text. Modern recipes shift heavily to synthetic data, curriculum filtering, and aggressive deduplication. The scaling law assumes IID tokens of fixed quality — at 1875× tokens/param, that assumption strains.
  2. Training cost grows faster than Chinchilla predicts. The 6N rule counts model FLOPs but ignores data-pipeline cost, checkpoint I/O, and failed runs. Over-training a small model on 15 T tokens is cheap in FLOPs, expensive in wall-clock and dataloader engineering.
  3. Inference workload is uncertain at training time. You commit to NN before you know DinfD_{\text{inf}}. The decision is a bet on adoption. Llama 3 8B's 15 T tokens is the right answer if the model gets billions of inference calls; it is wasteful if the model never ships. This is why frontier labs train a family of sizes (8 B, 70 B, 405 B) — to hedge across adoption outcomes.
  4. Distillation flips the equation. If a 405 B model produces synthetic data that trains an 8 B model, the large model's training is amortised across all downstream students. The inference workload that matters becomes the student's, not the teacher's. Modern LLM recipes are increasingly teacher-student pipelines optimised end-to-end on inference cost.

Engineering Reality: When Chinchilla Still Wins

Inference-aware over-training is not always the right answer. Four regimes where Chinchilla optimality still rules:

RegimeWhy Chinchilla winsTypical model
Research checkpoint, never servedD_inf = 0, the inference term vanishesPre-publication ablations
Fine-tuning a base modelBase model already trained, the choice is downstream-onlyDomain SFT on a pre-trained 70 B
Capacity-bound task (long-context, high recall)Smaller N hits a quality floor no D can fixLong-context 1M-token retrieval
Pre-training data is the bottleneck, not FLOPsCannot get more tokens at acceptable qualitySpecialised domains: medical, legal

And four engineering gotchas that the clean equation hides:

  1. The 6N and 2N rules are approximations. Real training is closer to 6.something N depending on activation recompute strategy; real inference is 2 N plus KV-cache read/write cost that grows with sequence length. At 128k context, attention FLOPs alone can double inference cost. The cleanest analysis treats the constants as functions of sequence length, not literals.
  2. Memory is not in the equation. A 70 B model needs ~140 GB of weights in BF16 alone. Inference cost in FLOPs is one thing; can you fit it on one GPU? Memory bandwidth, not FLOPs, is often the binding constraint at inference. The inference-aware law gives you the FLOP-optimal frontier; the memory-optimal frontier can sit elsewhere.
  3. Inference workload distribution matters. 2 × 10¹⁵ tokens spread across 10 ms-latency mobile chat is different from 2 × 10¹⁵ tokens in nightly batch jobs. The latter can use a larger model at higher utilisation; the former forces a smaller one regardless of training math. The scaling law is necessary but not sufficient.
  4. Quantisation changes the cost ratio. If you serve an 8 B model in INT4, inference FLOPs collapse by 4× — the inference-aware optimum shifts back toward larger training and smaller over-training ratios. Llama 3's 15 T-token recipe was tuned around BF16 serving; an INT4-served deployment might prefer a slightly different (N, D).
The deepest lesson. Chinchilla taught us that training compute is a constrained resource and the optimum is not where intuition ("just make N bigger") puts it. Sardana et al. taught us the same lesson at a higher level: lifetime compute is a constrained resource, and the optimum is again not where the previous generation's intuition ("just follow Chinchilla") puts it. Every time the cost frontier changes, the optimum moves. The next shift — likely driven by hardware accelerators with non-linear cost-per-FLOP, or by amortised distillation pipelines — will move it again.
Loading comments...