Chapter 13
15 min read
Section 76 of 117

Catastrophic Forgetting and Mitigation

Supervised Fine-Tuning (SFT)

The Real Problem: Aligning Costs Capability

After fourteen trillion tokens of pretraining, your base model can do an extraordinary amount: solve grade-school math, write working code, answer trivia, translate between forty languages, write a sonnet about a kettle. Then you fine-tune it for a few hundred thousand instruction examples to teach it to follow user requests politely — and a measurable slice of all that ability quietly disappears. HumanEval drops 8 points. GSM8K falls 12. The model that used to translate Hungarian now refuses. This is catastrophic forgetting, and every SFT recipe you have ever read is, at its core, a strategy for paying as little of it as possible.

The phrase is older than language models. It was coined in 1989 by McCloskey and Cohen, watching a small connectionist network completely overwrite an earlier task when trained on a new one. The mechanism is the same in a 671B-parameter transformer: gradient descent on a new loss has no information about the old loss, and so it moves weights freely in directions that happen to be cheap for the new task but expensive for everything else. The model is not forgetting the way a human forgets — it is being actively rewritten, one gradient step at a time, in directions chosen by an objective that does not care about pretraining.

The thesis of this section. SFT is a Pareto problem, not an optimisation problem. You cannot maximise instruction-following without paying some pretraining loss; you can only choose how much, on which capabilities, and through which mechanism. The five-lever mitigation stack below — replay, KL anchor, LoRA / PEFT, learning-rate budget, and capability-aware evals — is how every production team negotiates the trade-off.

Why is this section in the SFT chapter and not in "continual learning"? Because in a frontier LM pipeline, every post-training stage is continual learning: instruction-tuning over the pretrained base, preference-tuning (DPO/RLHF) over the SFT checkpoint, domain SFT over the aligned checkpoint, safety RLHF on top of that. Each stage risks forgetting the prior one. If you understand SFT forgetting, the same vocabulary carries through the rest of the post-training pipeline unchanged.

Intuition: Sculpting Over a Statue

Imagine the base model is a finished marble statue. Pretraining carved every muscle, every fold of robe, every strand of hair. SFT is a sculptor who has been hired to add a decorative ribbon across the chest of the statue. The ribbon is small and the sculptor has the right chisels — but every cut is taken out of the same marble. Carve enough material to make the ribbon stand out, and you have also taken some of the chest. Cut deeper and you start nicking the ribs underneath.

Two facts about the statue make this concrete. First, not all marble is equal. The hair and the muscles are intricate; they were the slowest, most expensive parts to carve. Damaging them costs months of recovery work. The plain back of the statue is smoother; you can plane a millimetre off without anyone noticing. In a network, the Fisher information of each weight is the "intricacy" — it tells you how much pretraining loss you would pay per unit of weight change.

Second, the sculptor has options. They can carve the ribbon directly into the chest (full fine-tune, expensive in lost capability). They can lay a thin metal plate over the chest and engrave the ribbon into the plate (LoRA — the base is untouched). They can carve slowly and check the rest of the statue after every cut (low learning rate plus eval gates). They can also keep a mirror of the original statue beside them and stop whenever the differences grow too large (KL anchor). Each option trades how visible the ribbon is against how intact the rest of the statue remains.

The geometry in one sentence. Pretraining converges to a wide basin of parameter space; SFT's gradient field points out of that basin in a direction unrelated to its width. Mitigation is about how far you let the model walk before something else pulls it back.

The Math: Why SFT Pulls Weights Off the Base Minimum

Let θ\theta be the model parameters and θ\theta^* the base-model checkpoint (the converged pretraining optimum). Pretraining gave us a loss LPT(θ)\mathcal{L}_{\text{PT}}(\theta) for which LPT(θ)0\nabla \mathcal{L}_{\text{PT}}(\theta^*) \approx 0. Around θ\theta^* we can write the second-order expansion

LPT(θ)LPT(θ)+12(θθ)F(θθ),\mathcal{L}_{\text{PT}}(\theta) \approx \mathcal{L}_{\text{PT}}(\theta^*) + \tfrac{1}{2}(\theta - \theta^*)^\top F (\theta - \theta^*), where FF is the Fisher information matrix — exactly the curvature of the pretraining loss around the converged checkpoint. Forgetting is now a closed-form quantity: any displacement Δθ=θθ\Delta\theta = \theta - \theta^* costs us 12ΔθFΔθ\tfrac{1}{2}\Delta\theta^\top F\, \Delta\theta in pretraining loss.

SFT minimises a different objective — cross-entropy on instruction data — LSFT(θ)\mathcal{L}_{\text{SFT}}(\theta). Its gradient at θ\theta^* is generally nonzero, and the resulting trajectory accumulates a displacement Δθ\Delta\theta. The total objective the five-lever mitigation stack actually optimises is

Ltotal(θ)=(1r)LSFT(θ)+rLPT(θ)+βKL ⁣(πθπθ)+λθθF2.\mathcal{L}_{\text{total}}(\theta) = (1 - r)\,\mathcal{L}_{\text{SFT}}(\theta) + r\,\mathcal{L}_{\text{PT}}(\theta) + \beta\,\mathrm{KL}\!\big(\pi_\theta \,\big\|\, \pi_{\theta^*}\big) + \lambda \lVert \theta - \theta^* \rVert_F^2.

Each term is one lever. rr is the replay ratio: a fraction of every gradient minibatch is drawn from the pretraining distribution, so part of the gradient already points back toward θ\theta^*. β\beta is the KL anchor: it penalises distributional drift of the model's output policy from the base. The KL itself expands to second order as 12ΔθFΔθ\tfrac{1}{2}\Delta\theta^\top F\,\Delta\theta — it is the same Fisher quadratic as the PT loss, which is why a small KL coefficient can be a remarkably good proxy for "don't forget". The λ\lambda term is L2-SP / EWC weight regularisation — explicit anchoring in parameter space rather than function space.

Two additional levers act outside the loss. LoRA restricts Δθ\Delta\theta to a low-rank subspace Δθ=AB\Delta\theta = AB with ARd×r,BRr×d,rdA \in \mathbb{R}^{d \times r},\, B \in \mathbb{R}^{r \times d},\, r \ll d. This is a hard constraint, not a penalty: the model is mechanically incapable of large parameter drift, because most directions in weight space are not reachable. The learning-rate budget caps the Frobenius norm of Δθ\Delta\theta by setting ΔθFηT\lVert \Delta\theta \rVert_F \le \eta \cdot T \cdot \lVert \nabla \rVert for TT training steps — small η\eta means small total drift, regardless of what the gradient wants.

A useful reframing. Catastrophic forgetting is a sample-efficiency problem dressed up as a stability problem. You have ~10⁻⁵ as many SFT tokens as PT tokens and you are letting them rewrite weights as freely as the trillion-token PT run did. The five levers are five ways of saying: don't let a tiny dataset overwrite a giant one.

The Mitigation Stack: Five Levers

In production, the five techniques below are stacked, not picked. Each one is cheap, each one is partial, and the combination is what reliably keeps forgetting under control.

LeverMechanismTypical settingWhat it costs youWhat it does not fix
Replay (data mixing)Mix pretraining tokens into SFT batches10–20 % replay rateextra forward/backward cost, longer wall-clockformat-level drift (chat template still gets learned the same)
KL anchor (β)Penalise KL(π‖π_base) per tokenβ ≈ 0.01–0.1extra forward pass through frozen base; over-regularises if β too largeweights are still moving; only the output distribution is anchored
LoRA / PEFTFreeze base; train low-rank adaptersrank 8–64, α/r ≈ 2small expressivity ceiling; slight inference cost without mergecan still forget within the low-rank channel if LR and steps are wrong
Learning-rate budgetSmall LR + few epochs; cosine decaylr ≈ 1–5e-6 (full FT), 1–5e-5 (LoRA)longer time to reach target instruction qualityif you do not also clip gradients, one bad batch still spikes drift
Capability evals as a gateBlock release on > X-point drop on MMLU/HumanEval/etc.1–3 point tolerance per benchmarkeval compute; release latencyevals miss what they don't measure (rare languages, niche capabilities)

Treat the table as a checklist, not a menu. The single biggest mistake in SFT pipelines is to pick one lever ("we're using LoRA so we're safe") and skip the rest. LoRA without small LR still forgets within its low-rank subspace; replay without LR control still forgets on every non-replay step; KL anchoring without capability evals can hide drift on rare skills the KL term does not see. The stack works because the levers are partially independent — when one fails (a bad replay batch, a KL miscalibration), the others still pull.

Manual Numerical Walkthrough

We will run the two-task toy model from the Python section by hand and watch the forgetting bill change as we turn levers on. Parameters: θR2\theta \in \mathbb{R}^2, base optimum θ=(1,0)\theta^* = (1, 0), SFT optimum θSFT=(1,2.5)\theta_{\text{SFT}} = (1, 2.5), Fisher diagonal F=diag(4,0.25)F = \text{diag}(4, 0.25). The two tasks agree on dim 0 (both want θ0=1\theta_0 = 1) and disagree on dim 1 by 2.5 units. Both losses are quadratic, so we can solve the four-regime sweep in closed form.

Click to expand: solve the four regimes by hand

Step 1 — write down the gradient. Combining the three terms of the mitigation objective gives, per coordinate,

iLtotal=Fi[(1r)(θiθSFT,i)+r(θiθi)+β(θiθi)].\nabla_i \mathcal{L}_{\text{total}} = F_i\big[(1-r)(\theta_i - \theta_{\text{SFT},i}) + r(\theta_i - \theta^*_i) + \beta(\theta_i - \theta^*_i)\big].

Step 2 — solve for the stationary point. Setting i=0\nabla_i = 0 and solving for θi\theta_i gives the fixed point of SGD,

θifixed=(1r)θSFT,i+(r+β)θi1+β.\theta_i^{\text{fixed}} = \frac{(1-r)\,\theta_{\text{SFT},i} + (r + \beta)\,\theta^*_i}{1 + \beta}.

Notice the Fisher FiF_i dropped out — the equilibrium is independent of curvature. Curvature only sets how fast the trajectory gets there, not where it ends up. This is why a tiny KL coefficient or a small replay rate can have a big effect even on weights with large Fisher: those weights converge faster, but to a closer point.

Step 3 — plug in the four regimes for dim 1. Dim 0 is uninteresting: both optima are at 1, so the fixed point is exactly 1 in every regime. All the action is on dim 1, where θ1=0,θSFT,1=2.5\theta^*_1 = 0, \theta_{\text{SFT},1} = 2.5:

Regimerβθ₁ at fixed pointΔL_PT = ½·F₁·θ₁²% forgetting
Naive SFT002.5000.781100 %
+ 10 % replay0.1002.2500.63381 %
+ KL β = 0.500.51.6670.34744 %
Stack: replay + KL0.100.51.5000.28136 %

Step 4 — read the numbers. Naive SFT lands at the pure SFT optimum and pays 0.781 of pretraining loss — that is our 100% baseline of forgetting. 10 % replay alone reduces forgetting by 19 %. KL on its own reduces it by 56 %. Stacking both reduces it by 64 %. The two levers are not additive — the KL and replay terms partially compete for the same dimension-1 movement, but they compose constructively.

Step 5 — check the SFT cost. What did we give up? The instruction optimum was θSFT,1=2.5\theta_{\text{SFT},1} = 2.5; we landed at 1.50. The SFT loss at the stacked point is 120.25(1.52.5)2=0.125\tfrac{1}{2} \cdot 0.25 \cdot (1.5 - 2.5)^2 = 0.125. So we cut PT-loss damage from 0.781 to 0.281 (a 0.500 reduction) at a cost of 0.125 in SFT loss — a 4× favourable trade. The instruction task only mildly cares about dim 1 (it has Fisher 0.25, ten times smaller than dim 0), which is exactly why mitigation buys so much here: the SFT optimum is loose along dim 1 and we can give up most of it cheaply. This generalises: the lever-stack is most effective in directions where the SFT task is loose and the PT task is tight — and those are also the directions where forgetting is most expensive, so the cost/benefit ratio works in your favour exactly when it matters.

The lesson. Forgetting can be quantified, traded against SFT performance, and reduced to a closed form on the toy model. The same machinery, scaled to billions of parameters and with the Fisher diagonal estimated empirically (or approximated by KL on the output distribution), drives every production SFT recipe.

Visualizing Capability Trajectories

The chart below simulates what happens to four canonical capabilities during SFT under different lever settings. The green curve is instruction-following — what we are trying to teach. The other three are pretraining capabilities the model is supposed to keep: general knowledge (MMLU), math reasoning (GSM8K), and code (HumanEval). The starting values are the post-pretraining scores of a typical 7B model; the shapes are calibrated to match the magnitudes reported in the SFT mitigation literature (LoRA paper, InstructGPT appendix, Lima paper).

Loading capability chart…

Three slider sweeps to try:

  • Start with all controls at zero, then drag α toward 1. You are watching a naive full fine-tune. Instruction-following climbs from 5 % to 95 %; HumanEval falls from 41 % to 8 %, GSM8K from 48 % to 12 %. This is what a paper means by "catastrophic forgetting" — the orange and purple lines simply collapse.
  • Now drag replay r up to 0.2. The orange and purple curves snap back toward their baselines; instruction-following keeps almost the same shape because we only spent 20 % of our gradient budget on replay. This single lever recovers most of the drop and is the single largest line item in any production SFT recipe.
  • Set α down to 0.15 (LoRA-rank regime), r = 0.1, β = 0.05. This is the modern frontier-LM SFT stack. Every capability stays within 2–3 points of its baseline, instruction-following plateaus at ~78 %, and the parameter footprint of the run is ~0.5 % of the base model. This is the picture a paper from 2024–2025 shows when it claims "no measurable forgetting".

Note that the chart is a smoothed model — real curves are noisier and have benchmark- specific shocks (a single bad SFT batch can drop HumanEval 5 points before recovering). But the directions of all three sweeps match what real teams see when they ablate replay, adapters, and KL anchoring.

Plain Python: A Two-Task Toy Model

Before invoking PyTorch and a billion-parameter LM, we reproduce the math closed form with 50 lines of NumPy. The two-task quadratic model is the entire forgetting story, and watching the trajectory under the four (r, β) settings is the cleanest possible illustration of how levers compose.

🐍sft_forgetting_toy.py
5A two-task universe

Real SFT involves billions of parameters and millions of examples; we strip that to two dimensions and two tasks. The base model lives at w_pt = (1, 0). Instruction-following lives at w_sft = (1, 2.5). They agree on dimension 0 (call it 'language modelling') and disagree on dimension 1 ('chat formatting'). This is the simplest setup that makes forgetting visible.

9Per-dimension curvature

F is the Fisher-information diagonal — how steeply the pretraining loss curves up around each weight. F[0] = 4.0 means dimension 0 is a deep, narrow valley (small moves cost a lot of PT loss). F[1] = 0.25 means dimension 1 is shallow. EWC, L2-SP, and Bayesian distillation all use this kind of curvature to decide which weights are 'safe to move'.

EXECUTION STATE
F = [4.0, 0.25]
13Two quadratic loss bowls

loss_pt and loss_sft are weighted L2 distances from the two optima, with weights given by F. They are surrogates for cross-entropy on the two tasks. The key feature: they are not the same function — minimising one pulls you out of the other.

17The mixed objective

Three forces act on w during SFT. (1) g_sft pulls toward the SFT optimum — the gradient of the new task. (2) g_pt is replay: a fraction r of each minibatch is sampled from the pretraining distribution, so its gradient is also there. (3) g_kl is the KL anchor — it penalises moves away from the base. With diagonal Fisher, KL becomes the same quadratic in (w - w_pt) weighted by F.

EXECUTION STATE
r = replay ratio (0..1)
beta = KL anchor strength
25Initial weights = base model

SFT starts from a checkpoint, not from random init. w begins at w_pt — the model already knows how to language-model. Every gradient step from now on is moving away from that initial knowledge in some direction. The question is which direction, how far, and how reversibly.

28Plain SGD, no fancy optimiser

We use vanilla gradient descent with a small learning rate so the trajectory is easy to read. In production this would be AdamW with cosine decay; the qualitative picture (drift toward w_sft, slowed by replay and β) is the same — Adam just gets there faster on some dimensions and slower on others depending on per-parameter variance.

33Sweep four regimes

(0.0, 0.0) is naive full SFT — no replay, no anchor. (0.1, 0.0) adds 10% replay tokens. (0.0, 0.5) uses no replay but a moderate KL penalty. (0.1, 0.5) stacks both. Compare L_pt at the end — that number is exactly catastrophic forgetting in numerical form.

35What the print actually shows

Typical output: naive SFT → w ≈ (1.0, 2.49), L_pt ≈ 0.78. The model nailed the SFT task and paid a 0.78 PT-loss bill. Add 10% replay → w ≈ (1.0, 2.27), L_pt ≈ 0.64 — replay holds dimension 1 back. Add β = 0.5 with no replay → w ≈ (1.0, 1.67), L_pt ≈ 0.35. β with replay → w ≈ (1.0, 1.51), L_pt ≈ 0.28. Anchoring is strong on this toy because the diagonal-Fisher KL is exactly aligned with the cost we're measuring.

28 lines without explanation
1import numpy as np
2
3# A 2-dim linear model with two "tasks":
4#   - PT: the pretraining task (we already converged on it).
5#   - SFT: the new instruction-tuning task.
6# Each task has its own quadratic loss; the minima are 2.5 units apart.
7
8w_pt   = np.array([1.0, 0.0])     # base-model parameters (PT optimum)
9w_sft  = np.array([1.0, 2.5])     # SFT optimum (different in dim 1)
10
11# Per-dim curvature (Fisher diagonal). Big curvature = "this weight matters".
12F = np.array([4.0, 0.25])
13
14# Quadratic losses centered at each optimum:
15def loss_pt (w): return 0.5 * np.sum(F * (w - w_pt ) ** 2)
16def loss_sft(w): return 0.5 * np.sum(F * (w - w_sft) ** 2)
17
18# Mixed objective: (1 - r) on SFT, r on replay of pretraining.
19# β anchors us back toward w_pt — this is the EWC / KL penalty.
20def grad(w, r, beta):
21    g_sft = F * (w - w_sft)
22    g_pt  = F * (w - w_pt )
23    g_kl  = F * (w - w_pt )                # diagonal-Fisher KL gradient
24    return (1 - r) * g_sft + r * g_pt + beta * g_kl
25
26# Walk SGD from the base-model init.
27def run(r, beta, lr=0.05, steps=400):
28    w = w_pt.copy()
29    for _ in range(steps):
30        w = w - lr * grad(w, r, beta)
31    return w, loss_pt(w), loss_sft(w)
32
33for r, beta in [(0.0, 0.0), (0.1, 0.0), (0.0, 0.5), (0.1, 0.5)]:
34    w, L_pt, L_sft = run(r, beta)
35    print(f"r={r:.2f}  β={beta:.2f}  →  w={w.round(3)}   "
36          f"L_pt={L_pt:.3f}   L_sft={L_sft:.3f}")

The fixed points the script lands at match the closed-form table from the numerical walkthrough to three decimal places (you can verify by setting rr and β\beta to the same values and reading θ1\theta_1). The print at the bottom is the smallest possible ablation table for an SFT mitigation experiment: change one knob, observe both losses, keep the Pareto-best.

PyTorch: SFT with Replay, LoRA, and KL Anchor

The toy model collapses into one line per lever; a real SFT loop is one function per lever. Below is the minimum loop that combines all three: LoRA from peft, replay via a second DataLoader, and an on-the-fly forward-KL anchor against a frozen base. The scheduler and clip lines are exactly as you would copy them into a Megatron-LM or trl SFTTrainer subclass.

🐍sft_with_mitigations.py
5One reusable SFT step

The signature is what production loops look like in trl, NeMo, and Megatron-LM SFT trainers: a batch of instruction examples, an optional replay batch, and a KL coefficient. The three mitigations — replay, KL anchor, LoRA — compose, and each can be turned off by setting its hyperparameter to zero or None.

7Standard causal-LM loss

Hugging Face's labels=input_ids trick: the model shifts labels by one inside and computes mean cross-entropy on next-token prediction over the whole sequence (masked to the assistant turn in practice via -100 labels on the prompt). out.loss is a scalar — the SFT objective.

EXECUTION STATE
out.loss = scalar — mean CE on the assistant tokens
11Replay — the cheapest forgetting fix

We take a random batch of plain pretraining text from the same distribution the base model saw and add its loss with weight 0.5. The model is forced to keep predicting next tokens on Wikipedia, code, and books while it learns chat formatting. Empirically a 10–20 % replay rate is enough to recover most of the MMLU / HumanEval drop without slowing SFT convergence.

16The KL anchor, computed on-the-fly

We run the *frozen* base model on the same batch (no grad, no memory cost beyond the activations) and compare token-by-token output distributions. The KL term pulls log_p_now back toward log_p_base wherever they diverge. β = 0.05 is the classical InstructGPT setting; β = 0 disables the anchor; β > 0.5 over-regularises and the model stops learning the task.

EXECUTION STATE
log_p_base = (B, T, V) — frozen base logits
kl = scalar — per-token forward KL averaged
25LoRA, the architectural anchor

LoRA freezes the entire base model and inserts low-rank adapters A, B into the attention projections, so the effective weight is W₀ + αAB with rank 16. Only A and B receive gradients. Because W₀ never changes, every base-model capability that lives in W₀ is mathematically preserved — forgetting can only happen through the new low-rank channel, which is too small to overwrite a 70B-parameter base.

EXECUTION STATE
r = 16 — rank
trainable params = ≈ 0.5 % of full model
31Smaller learning rate than pretraining

2e-5 is roughly 1/10 of typical pretraining LR. SFT is short (thousands of steps, not trillions of tokens) and starts from a converged checkpoint, so big steps are pure forgetting risk with no convergence benefit. Together with weight decay 0.01, this is the standard 'don't break the base' starting point.

38Interleaved replay — one out of every five steps

Rather than mixing replay into every batch, we do a pure replay batch every 5th step. This is operationally simpler (you can load two DataLoaders independently) and produces ~20 % replay tokens overall. Both schemes work; per-batch mixing has lower variance, interleaved is easier to instrument.

41Gradient clipping is the safety net

Even with all three mitigations, one bad batch can spike a gradient and drag the model into a region it cannot recover from. clip_grad_norm_(..., 1.0) bounds the per-step parameter drift. It does not fix forgetting — it caps how much you can forget per step, which is enough for the rest of the stack to do its job.

39 lines without explanation
1import torch, torch.nn.functional as F
2from torch.utils.data import DataLoader
3from peft import LoraConfig, get_peft_model        # Hugging Face PEFT
4
5def sft_step(model, base_model, batch, *, replay_batch=None, kl_beta=0.0):
6    """One SFT update with optional replay and KL anchor."""
7    # --- 1. Loss on the instruction batch -----------------------
8    out  = model(**batch, labels=batch["input_ids"])
9    loss = out.loss
10
11    # --- 2. Mix in a replay batch of pretraining text -----------
12    if replay_batch is not None:
13        out_r = model(**replay_batch, labels=replay_batch["input_ids"])
14        loss  = 0.5 * loss + 0.5 * out_r.loss        # equal weight here
15
16    # --- 3. KL anchor to frozen base model ----------------------
17    if kl_beta > 0.0:
18        with torch.no_grad():
19            log_p_base = F.log_softmax(base_model(**batch).logits, dim=-1)
20        log_p_now  = F.log_softmax(out.logits, dim=-1)
21        kl = (log_p_now.exp() * (log_p_now - log_p_base)).sum(-1).mean()
22        loss = loss + kl_beta * kl
23
24    return loss
25
26# --- Wrap the base model with LoRA on every attention projection -
27lora_cfg = LoraConfig(
28    r=16, lora_alpha=32, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
29    lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
30)
31peft_model = get_peft_model(base_model, lora_cfg)
32peft_model.print_trainable_parameters()              # ~0.5 % of params
33
34opt = torch.optim.AdamW(peft_model.parameters(), lr=2e-5, weight_decay=0.01)
35
36# --- Training loop with interleaved replay ----------------------
37sft_iter    = iter(DataLoader(sft_data,    batch_size=8, shuffle=True))
38replay_iter = iter(DataLoader(replay_data, batch_size=8, shuffle=True))
39
40for step in range(NUM_STEPS):
41    batch        = next(sft_iter)
42    replay_batch = next(replay_iter) if step % 5 == 0 else None   # 20 % replay
43    loss = sft_step(peft_model, base_model, batch,
44                    replay_batch=replay_batch, kl_beta=0.05)
45    loss.backward()
46    torch.nn.utils.clip_grad_norm_(peft_model.parameters(), 1.0)
47    opt.step(); opt.zero_grad()
What scales and what doesn't. The toy model and the production loop differ in size, optimiser, and tokenizer — but they share the structure: one term per lever, summed into the loss, with one knob per lever in the hyperparameter sweep. Every complication on top of this (mixed precision, FSDP sharding, sequence packing, masked loss on assistant turns only) is orthogonal to forgetting. Get the five-lever skeleton right first; bolt on the systems engineering second.

What Changes at 671B Parameters

Three things change uncomfortably when you scale the recipe from a 7B model to a frontier 671B-parameter MoE.

Quantity7B SFT671B-MoE SFTWhy it matters
SFT dataset size~50k–500k examples1–5M curated examplesmore lever-tuning is justified because the eval surface is larger
Active params per token7B~37B (MoE routing)forgetting is concentrated in routed experts, not shared weights
Replay tokens needed5–10B tokens50–200B tokensyou cannot afford to redo replay if you set the rate wrong
Full FT costtractable on 8 GPUsthousands of GPU-hoursLoRA / PEFT becomes operationally mandatory, not just attractive
KL forward passfree (same GPU)another sharded model in memorymany teams approximate KL with cached base-model logits

1. MoE makes forgetting localise — and that helps

In a dense model, every weight contributes to every token; an SFT batch updates everything. In an MoE model like DeepSeek-V3, only the experts routed to by each token receive gradient signal. SFT data has its own routing distribution (instruction-style prompts route to different experts than encyclopaedia text), so forgetting concentrates in the instruction-routed experts and barely touches the rest. This is good news — it means the majority of the model is mechanically frozen by the routing pattern, not by an explicit adapter. It is also bad news: if your SFT data is narrow, the experts it does touch can be very thoroughly overwritten and you may not see the damage until production traffic routes a different topic into them.

2. Replay datasets must match pretraining mix proportionally

A 7B model that was 80 % English / 15 % code / 5 % math at pretraining cannot safely replay 100 % English tokens during SFT — the code and math channels will still drift. Production replay datasets are explicitly mixed to match the original PT proportions, usually at the same coarse-domain granularity used during pretraining. This is why teams keep a frozen snapshot of the original PT data mixer for the lifetime of a model: every downstream SFT or RLHF stage uses it for replay.

3. The KL anchor needs a frozen base in memory

A naive KL implementation runs the frozen base on every batch — doubling GPU memory and nearly doubling FLOPs. Frontier teams cut this cost with two tricks: cached base logits (pre-compute and store top-k logits per training token; ~50 GB for a few billion tokens is much cheaper than running the base each step), and top-k KL approximations (compute KL only over the top-128 tokens of the base distribution, which captures > 99.9 % of the mass for a well-calibrated LM). The savings are large enough that production SFT loops invariably use one or both.

Engineering Reality: How Teams Actually Catch It

The single most expensive way to discover catastrophic forgetting is in production. The second most expensive is during the final pre-release eval. Mature SFT pipelines push the detection earlier and earlier in the loop. The pattern below is what frontier teams converge on after a few incidents.

  • Per-step capability probes. Every 100–500 steps, run a 200-example subset of MMLU / HumanEval / GSM8K through the current checkpoint. Plot the curve next to the training loss. If any capability curve dips by more than 1.5 % per 1000 steps, page the on-call SFT engineer. This catches misconfigurations (replay loader serving the wrong distribution, LoRA rank zero by accident) within an hour rather than at end of run.
  • Held-out PT-loss canary. Keep a 1M-token sample of pretraining data that is never used for training. Compute its cross-entropy under the SFT checkpoint every N steps. A rising curve is forgetting in raw form, before any benchmark translates it into a percentage. This is the single most sensitive forgetting metric and the cheapest to compute.
  • Capability evals as a release gate. Block any SFT checkpoint that drops more than 2 points on MMLU, 3 on HumanEval, 3 on GSM8K, or 1 on a multilingual suite, relative to the base model. Even if the instruction quality is excellent. The implementation cost of this gate is small; the cost of skipping it is occasionally a public regression.
  • Adversarial forgetting probes. Maintain a small set of rare-capability prompts (Hungarian poetry, obscure programming languages, niche academic terminology) that public benchmarks do not cover. Catastrophic forgetting on these is what users notice first because they are the long tail of capability.

Done well, the eval scaffolding doubles the wall-clock cost of an SFT run and reduces the incidence of full-run rollbacks by an order of magnitude. The math from the start of this section makes the trade obvious: a 2× SFT-cost increase is a tiny line item next to the cost of a 670B-parameter pretraining run whose capabilities you accidentally erased in post-training.

The single sentence to remember. Catastrophic forgetting is gradient descent on a new objective freely overwriting the converged minimum of an old objective; every mitigation in this section is a way to constrain that overwriting — through data, through loss terms, through parameter subspaces, through learning-rate budgets, and through eval gates that block release when the constraints fail.
Loading comments...