The Real Problem: Aligning Costs Capability
After fourteen trillion tokens of pretraining, your base model can do an extraordinary amount: solve grade-school math, write working code, answer trivia, translate between forty languages, write a sonnet about a kettle. Then you fine-tune it for a few hundred thousand instruction examples to teach it to follow user requests politely — and a measurable slice of all that ability quietly disappears. HumanEval drops 8 points. GSM8K falls 12. The model that used to translate Hungarian now refuses. This is catastrophic forgetting, and every SFT recipe you have ever read is, at its core, a strategy for paying as little of it as possible.
The phrase is older than language models. It was coined in 1989 by McCloskey and Cohen, watching a small connectionist network completely overwrite an earlier task when trained on a new one. The mechanism is the same in a 671B-parameter transformer: gradient descent on a new loss has no information about the old loss, and so it moves weights freely in directions that happen to be cheap for the new task but expensive for everything else. The model is not forgetting the way a human forgets — it is being actively rewritten, one gradient step at a time, in directions chosen by an objective that does not care about pretraining.
Why is this section in the SFT chapter and not in "continual learning"? Because in a frontier LM pipeline, every post-training stage is continual learning: instruction-tuning over the pretrained base, preference-tuning (DPO/RLHF) over the SFT checkpoint, domain SFT over the aligned checkpoint, safety RLHF on top of that. Each stage risks forgetting the prior one. If you understand SFT forgetting, the same vocabulary carries through the rest of the post-training pipeline unchanged.
Intuition: Sculpting Over a Statue
Imagine the base model is a finished marble statue. Pretraining carved every muscle, every fold of robe, every strand of hair. SFT is a sculptor who has been hired to add a decorative ribbon across the chest of the statue. The ribbon is small and the sculptor has the right chisels — but every cut is taken out of the same marble. Carve enough material to make the ribbon stand out, and you have also taken some of the chest. Cut deeper and you start nicking the ribs underneath.
Two facts about the statue make this concrete. First, not all marble is equal. The hair and the muscles are intricate; they were the slowest, most expensive parts to carve. Damaging them costs months of recovery work. The plain back of the statue is smoother; you can plane a millimetre off without anyone noticing. In a network, the Fisher information of each weight is the "intricacy" — it tells you how much pretraining loss you would pay per unit of weight change.
Second, the sculptor has options. They can carve the ribbon directly into the chest (full fine-tune, expensive in lost capability). They can lay a thin metal plate over the chest and engrave the ribbon into the plate (LoRA — the base is untouched). They can carve slowly and check the rest of the statue after every cut (low learning rate plus eval gates). They can also keep a mirror of the original statue beside them and stop whenever the differences grow too large (KL anchor). Each option trades how visible the ribbon is against how intact the rest of the statue remains.
The geometry in one sentence. Pretraining converges to a wide basin of parameter space; SFT's gradient field points out of that basin in a direction unrelated to its width. Mitigation is about how far you let the model walk before something else pulls it back.
The Math: Why SFT Pulls Weights Off the Base Minimum
Let be the model parameters and the base-model checkpoint (the converged pretraining optimum). Pretraining gave us a loss for which . Around we can write the second-order expansion
where is the Fisher information matrix — exactly the curvature of the pretraining loss around the converged checkpoint. Forgetting is now a closed-form quantity: any displacement costs us in pretraining loss.
SFT minimises a different objective — cross-entropy on instruction data — . Its gradient at is generally nonzero, and the resulting trajectory accumulates a displacement . The total objective the five-lever mitigation stack actually optimises is
Each term is one lever. is the replay ratio: a fraction of every gradient minibatch is drawn from the pretraining distribution, so part of the gradient already points back toward . is the KL anchor: it penalises distributional drift of the model's output policy from the base. The KL itself expands to second order as — it is the same Fisher quadratic as the PT loss, which is why a small KL coefficient can be a remarkably good proxy for "don't forget". The term is L2-SP / EWC weight regularisation — explicit anchoring in parameter space rather than function space.
Two additional levers act outside the loss. LoRA restricts to a low-rank subspace with . This is a hard constraint, not a penalty: the model is mechanically incapable of large parameter drift, because most directions in weight space are not reachable. The learning-rate budget caps the Frobenius norm of by setting for training steps — small means small total drift, regardless of what the gradient wants.
The Mitigation Stack: Five Levers
In production, the five techniques below are stacked, not picked. Each one is cheap, each one is partial, and the combination is what reliably keeps forgetting under control.
| Lever | Mechanism | Typical setting | What it costs you | What it does not fix |
|---|---|---|---|---|
| Replay (data mixing) | Mix pretraining tokens into SFT batches | 10–20 % replay rate | extra forward/backward cost, longer wall-clock | format-level drift (chat template still gets learned the same) |
| KL anchor (β) | Penalise KL(π‖π_base) per token | β ≈ 0.01–0.1 | extra forward pass through frozen base; over-regularises if β too large | weights are still moving; only the output distribution is anchored |
| LoRA / PEFT | Freeze base; train low-rank adapters | rank 8–64, α/r ≈ 2 | small expressivity ceiling; slight inference cost without merge | can still forget within the low-rank channel if LR and steps are wrong |
| Learning-rate budget | Small LR + few epochs; cosine decay | lr ≈ 1–5e-6 (full FT), 1–5e-5 (LoRA) | longer time to reach target instruction quality | if you do not also clip gradients, one bad batch still spikes drift |
| Capability evals as a gate | Block release on > X-point drop on MMLU/HumanEval/etc. | 1–3 point tolerance per benchmark | eval compute; release latency | evals miss what they don't measure (rare languages, niche capabilities) |
Treat the table as a checklist, not a menu. The single biggest mistake in SFT pipelines is to pick one lever ("we're using LoRA so we're safe") and skip the rest. LoRA without small LR still forgets within its low-rank subspace; replay without LR control still forgets on every non-replay step; KL anchoring without capability evals can hide drift on rare skills the KL term does not see. The stack works because the levers are partially independent — when one fails (a bad replay batch, a KL miscalibration), the others still pull.
Manual Numerical Walkthrough
We will run the two-task toy model from the Python section by hand and watch the forgetting bill change as we turn levers on. Parameters: , base optimum , SFT optimum , Fisher diagonal . The two tasks agree on dim 0 (both want ) and disagree on dim 1 by 2.5 units. Both losses are quadratic, so we can solve the four-regime sweep in closed form.
Click to expand: solve the four regimes by hand
Step 1 — write down the gradient. Combining the three terms of the mitigation objective gives, per coordinate,
Step 2 — solve for the stationary point. Setting and solving for gives the fixed point of SGD,
Notice the Fisher dropped out — the equilibrium is independent of curvature. Curvature only sets how fast the trajectory gets there, not where it ends up. This is why a tiny KL coefficient or a small replay rate can have a big effect even on weights with large Fisher: those weights converge faster, but to a closer point.
Step 3 — plug in the four regimes for dim 1. Dim 0 is uninteresting: both optima are at 1, so the fixed point is exactly 1 in every regime. All the action is on dim 1, where :
| Regime | r | β | θ₁ at fixed point | ΔL_PT = ½·F₁·θ₁² | % forgetting |
|---|---|---|---|---|---|
| Naive SFT | 0 | 0 | 2.500 | 0.781 | 100 % |
| + 10 % replay | 0.10 | 0 | 2.250 | 0.633 | 81 % |
| + KL β = 0.5 | 0 | 0.5 | 1.667 | 0.347 | 44 % |
| Stack: replay + KL | 0.10 | 0.5 | 1.500 | 0.281 | 36 % |
Step 4 — read the numbers. Naive SFT lands at the pure SFT optimum and pays 0.781 of pretraining loss — that is our 100% baseline of forgetting. 10 % replay alone reduces forgetting by 19 %. KL on its own reduces it by 56 %. Stacking both reduces it by 64 %. The two levers are not additive — the KL and replay terms partially compete for the same dimension-1 movement, but they compose constructively.
Step 5 — check the SFT cost. What did we give up? The instruction optimum was ; we landed at 1.50. The SFT loss at the stacked point is . So we cut PT-loss damage from 0.781 to 0.281 (a 0.500 reduction) at a cost of 0.125 in SFT loss — a 4× favourable trade. The instruction task only mildly cares about dim 1 (it has Fisher 0.25, ten times smaller than dim 0), which is exactly why mitigation buys so much here: the SFT optimum is loose along dim 1 and we can give up most of it cheaply. This generalises: the lever-stack is most effective in directions where the SFT task is loose and the PT task is tight — and those are also the directions where forgetting is most expensive, so the cost/benefit ratio works in your favour exactly when it matters.
The lesson. Forgetting can be quantified, traded against SFT performance, and reduced to a closed form on the toy model. The same machinery, scaled to billions of parameters and with the Fisher diagonal estimated empirically (or approximated by KL on the output distribution), drives every production SFT recipe.
Visualizing Capability Trajectories
The chart below simulates what happens to four canonical capabilities during SFT under different lever settings. The green curve is instruction-following — what we are trying to teach. The other three are pretraining capabilities the model is supposed to keep: general knowledge (MMLU), math reasoning (GSM8K), and code (HumanEval). The starting values are the post-pretraining scores of a typical 7B model; the shapes are calibrated to match the magnitudes reported in the SFT mitigation literature (LoRA paper, InstructGPT appendix, Lima paper).
Three slider sweeps to try:
- Start with all controls at zero, then drag α toward 1. You are watching a naive full fine-tune. Instruction-following climbs from 5 % to 95 %; HumanEval falls from 41 % to 8 %, GSM8K from 48 % to 12 %. This is what a paper means by "catastrophic forgetting" — the orange and purple lines simply collapse.
- Now drag replay r up to 0.2. The orange and purple curves snap back toward their baselines; instruction-following keeps almost the same shape because we only spent 20 % of our gradient budget on replay. This single lever recovers most of the drop and is the single largest line item in any production SFT recipe.
- Set α down to 0.15 (LoRA-rank regime), r = 0.1, β = 0.05. This is the modern frontier-LM SFT stack. Every capability stays within 2–3 points of its baseline, instruction-following plateaus at ~78 %, and the parameter footprint of the run is ~0.5 % of the base model. This is the picture a paper from 2024–2025 shows when it claims "no measurable forgetting".
Note that the chart is a smoothed model — real curves are noisier and have benchmark- specific shocks (a single bad SFT batch can drop HumanEval 5 points before recovering). But the directions of all three sweeps match what real teams see when they ablate replay, adapters, and KL anchoring.
Plain Python: A Two-Task Toy Model
Before invoking PyTorch and a billion-parameter LM, we reproduce the math closed form with 50 lines of NumPy. The two-task quadratic model is the entire forgetting story, and watching the trajectory under the four (r, β) settings is the cleanest possible illustration of how levers compose.
The fixed points the script lands at match the closed-form table from the numerical walkthrough to three decimal places (you can verify by setting and to the same values and reading ). The print at the bottom is the smallest possible ablation table for an SFT mitigation experiment: change one knob, observe both losses, keep the Pareto-best.
PyTorch: SFT with Replay, LoRA, and KL Anchor
The toy model collapses into one line per lever; a real SFT loop is one function per lever. Below is the minimum loop that combines all three: LoRA from peft, replay via a second DataLoader, and an on-the-fly forward-KL anchor against a frozen base. The scheduler and clip lines are exactly as you would copy them into a Megatron-LM or trl SFTTrainer subclass.
What scales and what doesn't. The toy model and the production loop differ in size, optimiser, and tokenizer — but they share the structure: one term per lever, summed into the loss, with one knob per lever in the hyperparameter sweep. Every complication on top of this (mixed precision, FSDP sharding, sequence packing, masked loss on assistant turns only) is orthogonal to forgetting. Get the five-lever skeleton right first; bolt on the systems engineering second.
What Changes at 671B Parameters
Three things change uncomfortably when you scale the recipe from a 7B model to a frontier 671B-parameter MoE.
| Quantity | 7B SFT | 671B-MoE SFT | Why it matters |
|---|---|---|---|
| SFT dataset size | ~50k–500k examples | 1–5M curated examples | more lever-tuning is justified because the eval surface is larger |
| Active params per token | 7B | ~37B (MoE routing) | forgetting is concentrated in routed experts, not shared weights |
| Replay tokens needed | 5–10B tokens | 50–200B tokens | you cannot afford to redo replay if you set the rate wrong |
| Full FT cost | tractable on 8 GPUs | thousands of GPU-hours | LoRA / PEFT becomes operationally mandatory, not just attractive |
| KL forward pass | free (same GPU) | another sharded model in memory | many teams approximate KL with cached base-model logits |
1. MoE makes forgetting localise — and that helps
In a dense model, every weight contributes to every token; an SFT batch updates everything. In an MoE model like DeepSeek-V3, only the experts routed to by each token receive gradient signal. SFT data has its own routing distribution (instruction-style prompts route to different experts than encyclopaedia text), so forgetting concentrates in the instruction-routed experts and barely touches the rest. This is good news — it means the majority of the model is mechanically frozen by the routing pattern, not by an explicit adapter. It is also bad news: if your SFT data is narrow, the experts it does touch can be very thoroughly overwritten and you may not see the damage until production traffic routes a different topic into them.
2. Replay datasets must match pretraining mix proportionally
A 7B model that was 80 % English / 15 % code / 5 % math at pretraining cannot safely replay 100 % English tokens during SFT — the code and math channels will still drift. Production replay datasets are explicitly mixed to match the original PT proportions, usually at the same coarse-domain granularity used during pretraining. This is why teams keep a frozen snapshot of the original PT data mixer for the lifetime of a model: every downstream SFT or RLHF stage uses it for replay.
3. The KL anchor needs a frozen base in memory
A naive KL implementation runs the frozen base on every batch — doubling GPU memory and nearly doubling FLOPs. Frontier teams cut this cost with two tricks: cached base logits (pre-compute and store top-k logits per training token; ~50 GB for a few billion tokens is much cheaper than running the base each step), and top-k KL approximations (compute KL only over the top-128 tokens of the base distribution, which captures > 99.9 % of the mass for a well-calibrated LM). The savings are large enough that production SFT loops invariably use one or both.
Engineering Reality: How Teams Actually Catch It
The single most expensive way to discover catastrophic forgetting is in production. The second most expensive is during the final pre-release eval. Mature SFT pipelines push the detection earlier and earlier in the loop. The pattern below is what frontier teams converge on after a few incidents.
- Per-step capability probes. Every 100–500 steps, run a 200-example subset of MMLU / HumanEval / GSM8K through the current checkpoint. Plot the curve next to the training loss. If any capability curve dips by more than 1.5 % per 1000 steps, page the on-call SFT engineer. This catches misconfigurations (replay loader serving the wrong distribution, LoRA rank zero by accident) within an hour rather than at end of run.
- Held-out PT-loss canary. Keep a 1M-token sample of pretraining data that is never used for training. Compute its cross-entropy under the SFT checkpoint every N steps. A rising curve is forgetting in raw form, before any benchmark translates it into a percentage. This is the single most sensitive forgetting metric and the cheapest to compute.
- Capability evals as a release gate. Block any SFT checkpoint that drops more than 2 points on MMLU, 3 on HumanEval, 3 on GSM8K, or 1 on a multilingual suite, relative to the base model. Even if the instruction quality is excellent. The implementation cost of this gate is small; the cost of skipping it is occasionally a public regression.
- Adversarial forgetting probes. Maintain a small set of rare-capability prompts (Hungarian poetry, obscure programming languages, niche academic terminology) that public benchmarks do not cover. Catastrophic forgetting on these is what users notice first because they are the long tail of capability.
Done well, the eval scaffolding doubles the wall-clock cost of an SFT run and reduces the incidence of full-run rollbacks by an order of magnitude. The math from the start of this section makes the trade obvious: a 2× SFT-cost increase is a tiny line item next to the cost of a 670B-parameter pretraining run whose capabilities you accidentally erased in post-training.