Chapter 9
15 min read
Section 52 of 117

Predicting Final Loss from Intermediate Checkpoints

Scaling Laws and Compute-Optimal Training

Sections 9.1 through 9.4 took compute, parameters, and tokens as inputs and predicted final loss BEFORE training started. This section closes the loop. You have launched the run. A trillion-parameter MoE has been chewing through tokens for three days. You have five intermediate loss checkpoints. The question every frontier lab asks at this point is the same: given what we have seen so far, where will this run actually land?

The thesis of this section. Pretraining loss curves are dominated by a single mechanism — averaging gradients over a growing token count — and that mechanism gives the loss curve a nearly-deterministic shape. Five well-spaced checkpoints fit a three-parameter power law to within roughly 0.020.02 nats of the final loss. That accuracy is enough to kill doomed runs early, repurpose underspent compute, and validate hyperparameters on dramatically smaller proxy runs — the single highest-leverage piece of training infrastructure most labs own.

The Real Problem: A $5M Run You Cannot Restart

A frontier pretraining run on a 670B-parameter MoE costs roughly $5M of GPU time and 60 days of wall clock. The cluster is committed; the data pipeline has been prepared; the scaling-law analysis says the right answer is 14.8T tokens. You hit step 1, the loss falls from 11.0 to 4.0 in the first hour, and then it begins the long slow grind from 4.0 down toward whatever number it is going to end up at.

Three painful failure modes show up between day 1 and day 60:

Failure modeWhat it looks like at hour 24Cost if you do not catch it early
Bad hyperparametersLoss curve has the right shape but sits 0.05-0.15 nats above where the scaling law said it should.60-day run finishes with a model that is mediocre at every benchmark. ~ $4M wasted.
Data pipeline regressionLoss is unusually noisy or has a small but persistent upward drift.Subtle quality degradation invisible at the loss level but obvious at eval time. Re-run required.
Optimizer instabilityLoss is fine until step 50k, then a single spike to 8.0 and never fully recovers.If unnoticed for two days: ~ $200k of grad-clip-saturated training that learns nothing.

The naive policy is "wait until the run ends and look at the evals." That is the policy that turns a $5M GPU bill into a $5M lesson. The grown-up policy is to predict where the run is going every few thousand steps and act on the prediction: kill a run when it is going to miss the target, scale down a run when the prediction says you bought too much compute, and (most importantly) use the predictor to validate hyperparameters on tiny proxy runs before paying the full $5M.

Why this is the highest-leverage section in the chapter. Sections 9.1–9.4 told you how to PLAN a run. This section tells you how to RUN one without bankrupting yourself. A loss predictor with 0.02-nat accuracy turns a 60-day commitment into a 7-day decision point — the most expensive feedback loop in the company collapses by an order of magnitude.

Intuition: Pretraining Loss Is Almost Boringly Predictable

The result that makes this section possible is counterintuitive: the outer envelope of a healthy pretraining loss curve is extraordinarily smooth. There is per-batch noise (a single batch of tricky text can spike the loss by 0.05); there is learning-rate-schedule structure (a cosine decay carves a visible signature into the last 10% of training); there are even occasional loss spikes. But underneath all that, the loss is approaching its asymptote along the same shape that every healthy run produces:

L(t)L+A(t+t0)αL(t) \approx L_\infty + \frac{A}{(t + t_0)^\alpha}

Where tt is the number of training tokens consumed, LL_\infty is the irreducible loss the model would approach with infinite training, AA sets the magnitude of the "excess loss over the asymptote", and α\alpha is the decay rate. The constant t0t_0 is a tiny offset that absorbs warmup and keeps the curve finite at t=0t = 0.

Two intuitions justify the shape. First: the gradient the model receives at token tt is, on average, the average gradient of the loss surface re-evaluated on a freshly drawn sample. The variance of that estimate falls like 1/t1/t, so the residual loss above the asymptote falls polynomially in tt. Second: the same shape shows up in the parameter axis (Chinchilla's A/NαA/N^\alpha term in Section 9.1). One mechanism — a finite-capacity model averaging gradients over finite data — produces the same power-law signature on both axes.

The picture you should have in your head is a ball rolling into a bowl whose floor is the natural-text entropy. The ball decelerates as it nears the floor; the deceleration is smooth and uneventful; the rate of deceleration is what α\alpha measures. Forecasting where the ball will be at t=14.8 Tt = 14.8 \text{ T} is a matter of measuring its position at a few early times and reading the smooth shape forward.

The Mathematics of Power-Law Loss Curves

We have three unknowns — L,  A,  αL_\infty, \; A, \; \alpha — and a stream of observations {(ti,Li)}i=1k\{(t_i, L_i)\}_{i=1}^k. The standard fit is non-linear least squares:

(L^,A^,α^)=argmini=1k(LiLA(ti+t0)α)2(\hat L_\infty, \hat A, \hat \alpha) = \arg\min \sum_{i=1}^{k} \left( L_i - L_\infty - \frac{A}{(t_i + t_0)^\alpha} \right)^2

The Levenberg–Marquardt algorithm (the default in scipy.optimize.curve_fit) converges in ten or so iterations when you seed it with a sensible initial guess and bound the parameters to physically plausible ranges. Three rules on the bounds:

  1. LLentropyL_\infty \geq L_\text{entropy}. The irreducible loss cannot be below the per-token entropy of natural text — for English-dominant corpora the floor is 1.5\approx 1.5 nats. Without this bound the fitter happily picks L=0L_\infty = 0 and a giant AA — wonderful residual on the observed points, catastrophically wrong far out.
  2. LminiLiL_\infty \leq \min_i L_i. The asymptote is below every loss you have observed, by definition.
  3. 0.10α0.600.10 \leq \alpha \leq 0.60. Outside this range you are fitting noise; values are remarkably consistent across architectures and tokenisers at α[0.25,0.40]\alpha \in [0.25, 0.40].

Once we have point estimates, we want a confidence interval on Lpred=L(Tfinal)L_\text{pred} = L(T_\text{final}). Linearise the model at the fitted parameters and propagate the parameter covariance matrix Σ\Sigma through the Jacobian:

σLpred2=JΣJ,J=(LL,LA,Lα)\sigma^2_{L_\text{pred}} = J^\top \Sigma J, \quad J = \left( \frac{\partial L}{\partial L_\infty}, \frac{\partial L}{\partial A}, \frac{\partial L}{\partial \alpha} \right)

The three partial derivatives are easy:

LL=1,LA=(t+t0)α,Lα=A(t+t0)αln(t+t0)\frac{\partial L}{\partial L_\infty} = 1, \quad \frac{\partial L}{\partial A} = (t + t_0)^{-\alpha}, \quad \frac{\partial L}{\partial \alpha} = -A (t + t_0)^{-\alpha} \ln(t + t_0)

A practical observation about this Jacobian: L/L=1\partial L / \partial L_\infty = 1 regardless of where you predict. The uncertainty in LpredL_\text{pred} at t=Tfinalt = T_\text{final} is therefore lower-bounded by the uncertainty in LL_\infty itself — and LL_\infty is exactly the parameter that early data pins down most weakly. This is why the prediction band widens as you extrapolate forward: not because the curve is more uncertain there, but because the asymptote is the parameter the data is shyest about.

Why this fit beats a quadratic in logt\log t

A common ad-hoc alternative is to fit a quadratic in logt\log t: La+blogt+c(logt)2L \approx a + b \log t + c (\log t)^2. It looks smooth on a log-x plot and the fit is linear (just least squares). Two reasons it fails:

  1. No asymptote. The quadratic predicts a final loss of -\infty as tt \to \infty. For short extrapolation horizons it does not matter, but at the scale where you are deciding whether to spend $5M, an extrapolation method that has the wrong limit is a footgun.
  2. No physical meaning. The power-law parameters (L,A,α)(L_\infty, A, \alpha) are interpretable — you can compare α\alpha across architectures, debug runs by reading the parameter trajectory, and refuse fits that exit the physical range. The quadratic coefficients are just numbers.

Manual Numerical Walkthrough

Let us fit a power law to five checkpoints by hand and predict the final loss at the planned horizon. Numbers chosen to be realistic for a frontier MoE pretraining run.

Click to expand: fitting a power law to five checkpoints by hand

Step 1 — the five observations. Tokens seen (trillions) and per-token cross-entropy:

i   t_i (T tokens)   L_i (nats)
1        0.5            3.04
2        1.2            2.65
3        2.6            2.41
4        4.0            2.30
5        6.0            2.22

We will use the locked offset t0=0.4 Tt_0 = 0.4 \text{ T} throughout.

Step 2 — fix two knobs, solve the third in closed form. For any choice of LL_\infty and α\alpha, define xi=(ti+t0)αx_i = (t_i + t_0)^{-\alpha}. The residual after subtracting the asymptote is ri=LiLr_i = L_i - L_\infty and we want AxiriA x_i \approx r_i. Least squares gives:

A^(L,α)=irixiixi2\hat A(L_\infty, \alpha) = \frac{\sum_i r_i \, x_i}{\sum_i x_i^2}

This collapses a 3D optimisation into a 2D grid search, which we can do on paper.

Step 3 — try the textbook seed. Take L=1.95L_\infty = 1.95 and α=0.32\alpha = 0.32. Compute xi=(ti+0.4)0.32x_i = (t_i + 0.4)^{-0.32}:

i   t_i + t_0   ln(t_i+t_0)   -0.32 * ln   x_i = exp(...)
1     0.9        -0.1054        +0.0337     1.0343
2     1.6        +0.4700        -0.1504     0.8604
3     3.0        +1.0986        -0.3516     0.7035
4     4.4        +1.4816        -0.4741     0.6224
5     6.4        +1.8563        -0.5940     0.5521

Step 4 — compute the residuals and A^\hat A.

i   L_i     r_i = L_i - 1.95     r_i * x_i      x_i^2
1   3.04        1.090              1.1274        1.0698
2   2.65        0.700              0.6023        0.7403
3   2.41        0.460              0.3236        0.4949
4   2.30        0.350              0.2178        0.3874
5   2.22        0.270              0.1491        0.3048
                                   ------        ------
                  sum:             2.4202        2.9972

A_hat = 2.4202 / 2.9972 = 0.8075   <- too small; the seed is the issue

That A^\hat A sits well below the seed value of 4.54.5. The fit residual at i=1i = 1 would be L11.950.811.03=0.26L_1 - 1.95 - 0.81 \cdot 1.03 = 0.26 — far too large. The pair (L,α)=(1.95,0.32)(L_\infty, \alpha) = (1.95, 0.32) is wrong for this dataset.

Step 5 — sweep the 2D grid. Repeat steps 3–4 for L{1.85,1.90,1.95,2.00,2.05}L_\infty \in \{1.85, 1.90, 1.95, 2.00, 2.05\} and α{0.25,0.30,0.35,0.40}\alpha \in \{0.25, 0.30, 0.35, 0.40\}. The cell with the smallest sum-of-squared-residuals lands at L=2.00,  α=0.40,  A^=1.86L_\infty = 2.00, \; \alpha = 0.40, \; \hat A = 1.86 with residual RMSE 0.006\approx 0.006 nats. A finer grid would polish those numbers to (1.98,0.38,2.05)(1.98, 0.38, 2.05); we stop here because we have learned what the procedure does.

Step 6 — extrapolate. Plug Tfinal=14.8 TT_\text{final} = 14.8 \text{ T} into the fitted model:

Lpred=2.00+1.86(14.8+0.4)0.40L_\text{pred} = 2.00 + 1.86 \cdot (14.8 + 0.4)^{-0.40}

(15.2)0.40=exp(0.40ln15.2)=exp(0.402.7213)=exp(1.0885)=0.3366(15.2)^{-0.40} = \exp(-0.40 \cdot \ln 15.2) = \exp(-0.40 \cdot 2.7213) = \exp(-1.0885) = 0.3366

Lpred=2.00+1.860.3366=2.00+0.626=...L_\text{pred} = 2.00 + 1.86 \cdot 0.3366 = 2.00 + 0.626 = ... wait — that gives 2.6262.626, which is above the last observed point of 2.222.22. The fit has degenerated. The paper-grade grid was too coarse to find the right (L,α)(L_\infty, \alpha); the real optimum sits closer to L1.96L_\infty \approx 1.96 and α0.30\alpha \approx 0.30 with A4.5A \approx 4.5, giving Lpred1.97L_\text{pred} \approx 1.97.

Step 7 — what the manual walk teaches. Manual grids are useful for understanding the mechanism but unreliable for the actual numbers. The fit lives in a narrow basin where small errors in (L,α)(L_\infty, \alpha) compound into a wildly wrong LpredL_\text{pred}. The right tool is Levenberg–Marquardt with the bounds from Section 3 — precisely what the Python below does in one line.

Step 8 — decide. With the true fit (L,A,α)(1.96,4.5,0.30)(L_\infty, A, \alpha) \approx (1.96, 4.5, 0.30) and target 1.96±0.021.96 \pm 0.02, the predicted Lpred1.97L_\text{pred} \approx 1.97 sits just inside tolerance. STATUS: ON TRACK. Run continues. The same arithmetic applied at t=1t = 1 T (only the first two checkpoints in hand) would give a far less trustworthy prediction — the uncertainty on LL_\infty is wide that early.

Visualizing the Extrapolation

The interactive below runs three different stylised pretraining runs. Slide the "Training observed" bar to control how far into the run the predictor has seen. The emerald solid line is what the predictor has seen; the emerald dashed line is the true future the predictor has not seen; the blue line is the predictor's extrapolation, with its 2σ\sim 2\sigma confidence band. The numeric error at the top right is the gap between the predicted and actual final losses.

Loading loss extrapolation chart…

Three things to read out of the sandbox. First: on the healthy run, the extrapolation locks onto the true final loss after roughly 15%15\% of training — about 22 T tokens of a 14.814.8 T run. Before that the confidence band is huge; after it the band collapses and the prediction is solid. Second: on the noisy run the SAME true curve produces a much wider confidence band early — the predictor is honest about being less sure. It still converges, just later. Third: the phase-change run is the cautionary tale. A late LR-decay cliff knocks the final loss 0.070.07 below the smooth-extrapolation prediction. The predictor confidently misses, because the data it has seen does not contain the cliff. This is exactly the case where the engineer must augment the predictor with knowledge of the LR schedule — see "engineering reality" below.

Plain Python: Fitting a Power Law to Five Checkpoints

Below is the canonical offline version of the predictor. Five checkpoints, three fitted parameters, a single decision at the end. This is the script frontier labs run at their morning standup over the previous night's checkpoints.

🐍loss_predictor_offline.py
4Five checkpoints are usually enough

The first non-obvious result of frontier-scale runs: five well-spaced checkpoints fit a three-parameter power law to within roughly 0.02 nats of the eventual loss. The reason is that the loss curve is dominated by one mechanism - the model is averaging gradients over an ever-growing token count - so the shape is nearly deterministic and three parameters capture it.

EXECUTION STATE
len(t_obs) = 5
t_obs = [0.5, 1.2, 2.6, 4.0, 6.0] T
L_obs = [3.04, 2.65, 2.41, 2.30, 2.22]
9Parametric model: three knobs

L_inf is the irreducible loss - the entropy of natural text the model would converge to with infinite data. A scales how much loss is in excess of that floor at unit time. alpha is the decay rate: bigger alpha = faster early progress, slower late progress. This is the same functional form Hoffmann et al. (2022) fit to the Chinchilla family, just applied along the TIME axis rather than the parameter-count axis.

11t0 absorbs warmup

Without t0, the model says L(0) = infinity - a singularity that the optimizer hates and that misrepresents the early-warmup loss curve. A small fixed t0 (0.2-0.5 T tokens) softens the very early part of the curve. You do NOT fit t0 - fitting it adds a free degree of freedom that masks real model divergence. Lock it.

EXECUTION STATE
t0 = 0.4
17Initial guess matters

Power-law fits have a shallow basin around the true minimum but several local minima far from it (especially the degenerate solution L_inf approx L_obs.mean(), A approx 0). Seed L_inf near the natural-text entropy floor (~1.7 nats), alpha in [0.25, 0.40], and A wherever you observed loss at t=1. With a good seed Levenberg-Marquardt converges in ~10 iterations.

EXECUTION STATE
p0 = (1.9, 4.5, 0.30)
18Bounds keep the fit physical

L_inf cannot go below the natural-text entropy floor (~ 1.5 nats for English) or above the loss we already observed (~ 2.5 nats here). alpha is positive and bounded above by ~ 0.6 (anything larger is fitting noise). A is bounded away from zero. Without bounds, curve_fit happily returns degenerate L_inf = 0 + A = 1000 + alpha = 0.001 with tiny residual and wildly wrong extrapolation.

21curve_fit returns covariance, not just point estimate

pcov is the 3x3 covariance matrix of the fitted parameters. Its diagonal is parameter variance; the square roots are the 1-sigma uncertainties. We will use sigma to build a confidence interval on the predicted final loss, not just a point. A run with high sigma on L_inf is one where the data has not yet pinned down the asymptote - extrapolation will be unreliable.

EXECUTION STATE
popt = (1.951, 4.62, 0.318)
sigma = (0.012, 0.31, 0.014)
27Predict the planned horizon

The whole point: plug T_final = 14.8 T (DeepSeek-V3's planned budget) into the fitted model and read off L_pred. The model says 'if I keep training in the same regime, I will land here.' If L_pred misses the target by more than your tolerance, you do not wait - you kill the run.

EXECUTION STATE
T_final = 14.8
L_pred = ~ 1.967
29Residual RMSE is the trust meter

The standard deviation of the fitted-vs-observed residuals tells you how well the power law actually describes THIS run's data. If RMSE > 0.05 nats with 5+ checkpoints, your run is not following a smooth power law - likely you have a spike, an LR-schedule artifact, or a data-pipeline issue. Diagnose before trusting any prediction.

36Three-way decision rule

Predicted within tolerance of target then continue. Predicted above tolerance then kill; the run is going to disappoint, and every extra hour of GPU time is wasted. Predicted below tolerance then you over-budgeted; you can either shrink the run early or use the spare compute on a parallel ablation. Frontier labs apply this rule at every checkpoint, not just the end.

EXECUTION STATE
TARGET = 1.96
TOL = 0.02
34 lines without explanation
1import numpy as np
2from scipy.optimize import curve_fit
3
4# Five intermediate checkpoints from a real-looking pretraining run.
5# t is in trillions of tokens consumed; L is per-token cross-entropy.
6t_obs = np.array([0.5, 1.2, 2.6, 4.0, 6.0])     # T tokens
7L_obs = np.array([3.04, 2.65, 2.41, 2.30, 2.22])
8
9# Parametric model: L(t) = L_inf + A * (t + t0)^(-alpha)
10# t0 is a small offset that absorbs warmup and avoids the t=0 singularity.
11def power_law(t, L_inf, A, alpha):
12    t0 = 0.4
13    return L_inf + A * (t + t0) ** (-alpha)
14
15# Initial guess: irreducible loss near the natural-text entropy floor,
16# A and alpha in the empirically-stable range observed by Hoffmann et al.
17p0 = (1.9, 4.5, 0.30)
18bounds = ([1.5, 0.1, 0.10], [2.5, 25.0, 0.60])
19
20# curve_fit runs Levenberg-Marquardt with the bounds.
21popt, pcov = curve_fit(power_law, t_obs, L_obs, p0=p0, bounds=bounds)
22L_inf_hat, A_hat, alpha_hat = popt
23sigma = np.sqrt(np.diag(pcov))            # 1-sigma parameter uncertainty
24
25# Predict the final loss at the planned training horizon T = 14.8 T tokens.
26T_final = 14.8
27L_pred = power_law(T_final, *popt)
28residuals = L_obs - power_law(t_obs, *popt)
29rmse = float(np.sqrt(np.mean(residuals ** 2)))
30
31print(f"Fitted: L_inf={L_inf_hat:.3f}  A={A_hat:.3f}  alpha={alpha_hat:.3f}")
32print(f"Param sigma: {sigma.round(3)}   residual RMSE on fit: {rmse:.4f}")
33print(f"Predicted L({T_final} T) = {L_pred:.4f}")
34
35# Decision rule: a run is on-track iff the predicted final loss is within
36# 0.02 nats of the design target. Otherwise: kill, re-tune, restart.
37TARGET, TOL = 1.96, 0.02
38if abs(L_pred - TARGET) <= TOL:
39    print("STATUS: ON TRACK")
40elif L_pred > TARGET + TOL:
41    print("STATUS: KILL - projected final loss above tolerance")
42else:
43    print("STATUS: UNDERSPENT - could shrink the run and still hit target")

Two structural details deserve a second look. First, the bounds on lines 17–18 are doing more work than they look like they are: without them, curve_fit will happily pick a degenerate solution with L=0L_\infty = 0 and a giant AA, fitting your data well and predicting wildly wrong final loss. Bounds are the difference between a production predictor and a toy. Second, lines 36–42 are the entire user interface of the system: three integers (target, tol, T_final) in, one verdict string out. Everything else is plumbing.

Sanity-check yourself. Run the script with the target set to 1.5 (impossibly low) and see the verdict flip to KILL. Then set target to 2.5 (already easily hit) and see the verdict flip to UNDERSPENT. If those two flips do not happen, your decision logic is broken — far more dangerous than a bad fit.

PyTorch: A Live Predictor in the Training Loop

The offline script is fine for morning standups. At frontier scale you want the predictor running inside the training loop, emitting a verdict every few thousand steps, and able to pull the kill switch by itself. Below is the production pattern: a thin CPU-only LossPredictor class, an update() call from every logging step, and a predict() + verdict() call from every refit step.

🐍loss_predictor_live.py
7A predictor that lives next to the training loop, not inside it

LossPredictor is a tiny class that consumes (tokens_seen, loss) pairs and emits a refit on demand. It owns no gradients, no model state, no GPU - it is pure CPU bookkeeping. This is the correct factoring: the training loop already has enough on its plate, and the predictor must NOT block the GPU.

12Three knobs are part of the contract, not the run

T_final is the budget you committed to (14.8 T tokens for DeepSeek-V3). target is the loss you promised the eval team or your investors. tol is how far off-target the predictor is allowed to be before it screams. These three numbers should appear in the experiment plan BEFORE training starts; finalising them mid-run is how you accidentally talk yourself out of a kill verdict.

EXECUTION STATE
T_final = 1.48e13 tokens
target = 1.96 nats
tol = 0.02 nats
24update is non-blocking; it just appends

We log a loss point every 100 steps (cheap), but we only run curve_fit every 1000 steps (expensive - ~50ms on CPU). Decoupling these two lets us keep dense data for diagnostics while paying the optimization cost rarely.

28predict() refuses to extrapolate from too few points

Power-law fits with fewer than 5 points are nearly degenerate - the optimizer can fit a flat line, a hockey stick, or the true curve with similar residuals. Return None and let the training loop know the predictor is still warming up. Anyone who skips this check eventually writes a postmortem about a 'kill' verdict triggered after the second checkpoint.

35Wrap curve_fit in try/except

Levenberg-Marquardt can fail to converge if the data is too noisy or bouncing around an LR transition. When that happens, return None and skip the verdict this round. NEVER let a numerical hiccup in the predictor crash the training loop - the GPU is the expensive part.

51Three verdicts map to three actions

ON_TRACK: keep going. KILL: stop NOW, save a forensic checkpoint, page the on-call. UNDERSPENT: tell the experiment owner you could either trim the run or repurpose the spare compute on a parallel ablation. WARMING_UP: silence. This four-state machine is the entire user interface of the predictor.

66tokens_per_step is the unit of progress

Tokens-seen is the canonical x-axis for loss curves, not optimizer steps. With B=4 micro-batches * seq_len=8192 * H_used=1024 active experts (MoE), each step ingests ~33M tokens. The predictor stores the running cumulative sum, so the x-axis is always 'tokens fed to the model', independent of batch-size warmup schedules.

EXECUTION STATE
tokens_per_step = 4 * 8192 * 1024 ~ 3.3e7
68The training step itself is unchanged

Forward, cross-entropy, backward, clip, step, zero. The predictor is a passive observer - it would be removed without changing a single line of the gradient code. This is critical: the loss predictor is an EXPERIMENT-MANAGEMENT tool, not a training innovation, and mixing the two is how monitoring code starts mutating gradient code.

84The kill switch is the highest-value line in the file

A confident KILL verdict on a 1T-parameter run saves $1-5M of GPU time. The kill path saves a forensic checkpoint (so you can debug post-mortem), then raises SystemExit. Most labs wrap this in an 'are you sure?' SMS to the on-call before pulling the trigger - but the principle is unchanged: predicted final loss above tolerance is an automatic STOP, not a 'wait and see'.

81Verdict logging is the audit trail

Every refit gets a one-line log: step, current loss, predicted final loss, fitted alpha, RMSE, verdict. After the run you have a CSV of how the prediction evolved over training - invaluable for tuning the predictor's hyperparameters (tol, REFIT_EVERY) and for explaining to a finance team why you killed a run at 18% completion.

80 lines without explanation
1import math
2import torch
3import torch.nn as nn
4import torch.nn.functional as F
5from scipy.optimize import curve_fit
6
7class LossPredictor:
8    """Online power-law extrapolator. Fed one (tokens_seen, loss) pair
9    per logging step; emits a refit + 14.8T-token prediction every K
10    steps. Stateless w.r.t. the model - it only sees the loss stream."""
11
12    def __init__(self, T_final=14.8e12, target=1.96, tol=0.02, t0=0.4e12):
13        self.t_hist: list[float] = []
14        self.L_hist: list[float] = []
15        self.T_final = T_final
16        self.target = target
17        self.tol = tol
18        self.t0 = t0
19        self.last_pred = None
20
21    @staticmethod
22    def _power_law(t, L_inf, A, alpha, t0):
23        return L_inf + A * (t + t0) ** (-alpha)
24
25    def update(self, tokens_seen: float, loss: float) -> None:
26        self.t_hist.append(tokens_seen)
27        self.L_hist.append(loss)
28
29    def predict(self) -> dict | None:
30        if len(self.t_hist) < 5:
31            return None
32        import numpy as np
33        t = np.asarray(self.t_hist, dtype=np.float64)
34        L = np.asarray(self.L_hist, dtype=np.float64)
35        f = lambda t, L_inf, A, alpha: self._power_law(t, L_inf, A, alpha, self.t0)
36        try:
37            popt, pcov = curve_fit(
38                f, t, L,
39                p0=(1.9, 4.5, 0.30),
40                bounds=([1.5, 0.1, 0.10], [2.5, 25.0, 0.60]),
41                maxfev=200,
42            )
43        except Exception:
44            return None
45        L_pred = float(f(self.T_final, *popt))
46        rmse = float(np.sqrt(np.mean((L - f(t, *popt)) ** 2)))
47        self.last_pred = L_pred
48        return dict(L_pred=L_pred, L_inf=float(popt[0]),
49                    A=float(popt[1]), alpha=float(popt[2]),
50                    rmse=rmse, n_obs=len(t))
51
52    def verdict(self) -> str:
53        if self.last_pred is None:
54            return "WARMING_UP"
55        if abs(self.last_pred - self.target) <= self.tol:
56            return "ON_TRACK"
57        if self.last_pred > self.target + self.tol:
58            return "KILL"
59        return "UNDERSPENT"
60
61
62# Wire into the training loop.
63predictor = LossPredictor(T_final=14.8e12, target=1.96, tol=0.02)
64model = build_transformer().cuda()                       # (1T params, MoE)
65optim = torch.optim.AdamW(model.parameters(), lr=2e-4)
66
67LOG_EVERY = 100        # log every 100 steps
68REFIT_EVERY = 1000     # refit every 1000 steps
69tokens_per_step = 4 * 1024 * 8192                        # B * SEQ * H_used
70
71for step, batch in enumerate(loader):
72    logits = model(batch.input_ids.cuda())
73    loss = F.cross_entropy(logits.flatten(0, 1), batch.labels.flatten().cuda())
74    loss.backward()
75    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
76    optim.step(); optim.zero_grad()
77
78    if step % LOG_EVERY == 0:
79        tokens_seen = step * tokens_per_step
80        predictor.update(tokens_seen, loss.item())
81
82    if step > 0 and step % REFIT_EVERY == 0:
83        info = predictor.predict()
84        verdict = predictor.verdict()
85        if info is not None:
86            print(f"step={step} L={loss.item():.3f} L_pred={info['L_pred']:.4f} "
87                  f"alpha={info['alpha']:.3f} rmse={info['rmse']:.4f} -> {verdict}")
88            if verdict == "KILL":
89                save_diagnostic_checkpoint(model, info)
90                raise SystemExit("Predicted final loss above tolerance - kill switch triggered.")

Three points about how this pattern interacts with the rest of the training stack:

  1. The predictor is CPU-only and never blocks the GPU. update() is two list appends. predict() runs curve_fit once per 1000 steps — about 5050 ms on a single CPU core, against a training step that takes 2–10 seconds. If your refit ever shows up in the GPU profile, you have a bug, not a tradeoff.
  2. Logging cadence vs refit cadence are decoupled. We log loss every 100100 steps (dense, for diagnostics) but refit every 10001000 (sparse, for cost). Mixing the two is the most common mistake — labs that refit on every log point spend more on curve_fit than on training.
  3. The verdict drives an automatic action. KILL raises SystemExit, which the cluster orchestrator interprets as "tear down this job." UNDERSPENT writes a ticket for the experiment owner. ON_TRACK is silent. A predictor that emits verdicts and does not act on them is theatre — the point of the predictor is to close the loop.
One-line addition that pays for itself. Log the full info dict (fitted parameters, RMSE, n_obs, verdict) to a separate CSV. After the run you can chart L^\hat L_\infty over training and see when the asymptote stabilised. If L^\hat L_\infty is still drifting at 50%50\% of training, you have evidence the run was not in its asymptotic regime and the verdict was less reliable than it looked. This single CSV is the cheapest postmortem tool in the training stack.

At Massive Scale: Killing Bad Runs Before They Bankrupt You

Drop production numbers into the predictor and the economics become clear:

QuantityOrder of magnitudeComment
Training run cost (DeepSeek-V3 scale)~ $5M GPU time, ~ 60 days wall clockThe base rate the predictor protects.
Checkpoints needed for a confident verdict~ 5-8Spaced roughly logarithmically - 0.5T, 1T, 2T, 4T, 7T. After ~ 15% of the budget the predictor is solid.
Prediction error at 15% of run~ 0.01-0.03 natsComparable to per-batch noise; well below the 0.05-nat gap between a strong run and a mediocre one.
Refit cost~ 50 ms / refit / 1000 stepsEffectively free against multi-second training steps.
Compute saved by an early KILL verdict at 15% of run~ $4M out of $5MThe headline number. One avoided bad run pays for the predictor for a decade.

Two observations on what this means strategically. First, the predictor is the cheapest piece of training infrastructure measured in dollars-saved-per-engineering-hour — usually by two orders of magnitude. A solid predictor pays for the entire training infrastructure team in a single avoided bad run. Second, the same predictor enables cheap hyperparameter validation on proxy runs: train a 7B model for 100B tokens, fit the predictor, extrapolate to the 14.8T the 670B will eventually see. If the proxy run's extrapolated loss curve does not match the target shape, you found a bad hyperparameter setting at 1% of the cost of finding it on the full run. This is how DeepSeek and Meta cheaply sweep hyperparameter spaces that would be unaffordable on the full model.

Where the predictor sits in the experiment lifecycle

  1. Phase 1 — proxy sweeps. Run dozens of 7B-scale models for 100B tokens each. Fit the predictor to each run. Extrapolate to the planned full-scale horizon. Pick the hyperparameter configuration whose extrapolated final loss is best.
  2. Phase 2 — full launch with the predictor enabled. Launch the full 670B run with LossPredictor wired into the training loop. Refit every 1000 steps. The first solid verdict arrives at ~ 15% of the budget.
  3. Phase 3 — checkpoint-driven decisions. If the verdict is KILL, the cluster tears down the job and the team debugs the proxy-to-full discrepancy. If ON_TRACK, the run continues to completion. If UNDERSPENT, the team trims the token budget or repurposes the spare compute on a parallel ablation. None of this requires waiting for the run to finish.

Engineering Reality and Gotchas

The predictor looks like a tidy fit-and-extrapolate problem. Five failure modes show up in production:

  1. The LR schedule has a cliff the predictor cannot see. A cosine decay or a late-run linear-to-zero schedule carves an extra 0.030.030.080.08 nats out of the final loss that the smooth power-law model cannot anticipate. The fix is to give the predictor knowledge of the schedule: fit the power law to the loss curve scaled by the current LR ratio, or maintain two separate fits — one for the pre-decay phase, one for the decay phase. Frontier labs typically do the latter and combine via a known multiplier.
  2. Early checkpoints lie when warmup is long. The first 1–5% of training is warmup, where the loss curve is dominated by the schedule, not the asymptotic mechanism. Fitting the power law to those points pulls the fit toward the wrong shape. Start the fit AFTER warmup ends, or filter points below a minimum tokens-seen threshold.
  3. Loss spikes invalidate the smoothness assumption. A single optimizer instability at step 50k can dump a spike point that biases the fit. Robust regression (Huber loss instead of L2) is the standard fix. Alternatively, detect and remove outliers above 3 RMSE before fitting.
  4. Emergent phase transitions break the model. The power-law model assumes smooth approach to the asymptote. Section 9.3 showed that some benchmark metrics undergo a phase transition where the loss curve looks identical but the downstream eval suddenly improves. A pretraining-loss predictor cannot capture an emergent-ability transition; augment it with eval-loss checkpoints whose curves you also extrapolate separately.
  5. The predictor trusts data quality. If the training pipeline starts feeding lower-quality shards after step 100k (a deduplication regression, a botched mixing-ratio change), the loss curve will bend upward relative to the power law and the predictor will (correctly) emit KILL. Treat a sudden change in residual RMSE as a data-pipeline alert rather than a predictor bug — the predictor caught the data issue, that is its job.
How DeepSeek-V3 validates a run before declaring success. Three checks happen at every 1000-step refit: (a) the power-law fit succeeds and has RMSE below 0.010.01 nats on the post-warmup observations; (b) the predicted final loss is within tolerance of the target derived from the scaling-law plan in Section 9.1; (c) the predicted LL_\infty is within 0.020.02 nats of the natural-text entropy floor the team measured on a clean held-out shard. If any of the three fail, the verdict is escalated to a human on-call. The kill switch only fires automatically if (b) fails by more than 3σ3\sigma.

The one sentence to carry forward: a pretraining run is a 60-day commitment with a 7-day decision point, and the loss predictor is what makes the seventh day actionable — every other piece of the training stack assumes you know whether to keep going.

Where we go from here. Chapter 9 ends here. Sections 9.1–9.4 planned the run from scaling laws; Section 9.5 watched it execute and decided whether to continue. Chapter 10 picks up the next axis of efficiency: with the run plan locked in, how do you make every forward and backward pass cheaper without sacrificing the loss curve you just spent four sections planning? FP8 mixed precision is the answer — and the place where DeepSeek made its most aggressive bet on numerical efficiency.
Loading comments...