The Real Problem: A Base Model Has No Format
After pretraining, a 70-billion-parameter base model has absorbed roughly fifteen trillion tokens of text. It knows physics, knows Python, can finish a sonnet. What it cannot do is answer a question. Give it the prompt “What is 2+2?” and it is just as likely to continue with “is a common arithmetic question asked of children in kindergarten across…” as to say “4”. The base model is a next-token oracle; it has no concept of a conversation, no notion of stopping when a turn is complete, no idea that “assistant” is its role.
Supervised fine-tuning solves exactly this gap. Section 13.1 told us what SFT does conceptually. Section 13.2 built the dataset. Section 13.3 turned every example into a chat-templated token sequence with role markers. This section is where the rubber meets the road: we write the actual training configuration — the loss, the optimiser, the schedule, the batch construction — that turns a base model into a chat model in a few thousand gradient steps.
The challenge is that the loss looks exactly like pretraining cross-entropy at first glance, but every detail has been re-engineered for a different goal. Pretraining wants to predict every token equally well across an internet-shaped distribution. SFT wants to predict only the assistant's response across a carefully curated handful of demonstrations, without destroying any of the knowledge that took ten million GPU-hours to install. Get the loss masking wrong and the model learns to generate the user's message back to you. Get the learning rate wrong and you erase weeks of pretraining in the first thousand steps. Get the packing wrong and you spend half your training budget on padding cells. The configuration is not optional decoration — it is the algorithm.
Intuition: SFT Is Continued Pretraining With a Sharp Loss
Hold the picture of pretraining in mind: gigantic shuffled token stream, every token contributing to the loss, learning rate hovering around , the optimiser carving slow rivers across the loss surface for weeks. Now contrast SFT: a few tens of thousands of carefully formatted chat sequences, most tokens masked out of the loss, learning rate an order of magnitude smaller at , a run that finishes overnight. Same architecture. Same loss function. Same optimiser. Different everything else.
The most useful analogy is a surgeon coming back to a patient after a long surgery. The surgery (pretraining) installed the organ. The follow-up (SFT) sutures and bandages. You use the same instruments, but you use them gently, in a small region, and you take care not to re-open anything that has already healed. A surgeon who scrubs in for SFT with a pretraining-sized learning rate is going to tear the sutures back out. This is not a metaphor — it is exactly what catastrophic forgetting (the topic of § 13.5) looks like in the loss curve.
The four design principles of SFT configuration
Loss Masking: The Mathematical Heart of SFT
The pretraining loss is the per-token cross-entropy averaged over every position in the sequence. Writing it out, for a sequence of length with input tokens :
Every term in the sum is real: the model is asked to predict every next token, every position contributes equally to the gradient. For SFT we change exactly one thing — we introduce a binary mask that selects which positions count:
Read this carefully. The numerator only sums over positions where — by convention these are the positions whose label is an assistant-generated token. The denominator divides by the number of kept positions, not by . That denominator choice matters: if you divided by , longer prompts with shorter responses would have artificially smaller per-token loss, and the gradient signal per example would shrink with prompt length. Dividing by makes every example contribute its average response loss to the batch, regardless of prompt length.
In code the mask is encoded indirectly. PyTorch's F.cross_entropy has an ignore_index argument: any label equal to ignore_index (conventionally -100) is treated as and contributes zero to both the numerator and the denominator. So an SFT data collator does one job per example: copy input_ids into labels, then overwrite every non-assistant position with -100. The trainer never knows the mask is there.
Click the three masking modes above. The first (“train on all”) is the naive approach a tutorial uses on day one — it produces a model that completes both halves of a chat. The second (“completion-only”) is the production default, used by every modern open-weights chat model. The third (“completion + EOT”) is a small but important refinement: include the end-of-turn marker so the model also learns when to stop. Without the EOT mark, the model keeps going past its turn and emits user-role tokens as if it were a play with no curtain.
Manual Numerical Walkthrough: One Masked Sequence
Take a single chat-templated sequence with eight tokens. Suppose the model gives the per-token cross-entropies in the second row below. We will compute the pretraining loss, the naive SFT loss, and the masked SFT loss by hand.
| position t | role | label | per-token CE | mask m_t |
|---|---|---|---|---|
| 0 | control (<|im_start|>) | control | 0.10 | 0 |
| 1 | user | user text | 2.80 | 0 |
| 2 | user | user text | 1.40 | 0 |
| 3 | control (<|im_end|>) | control | 0.20 | 0 |
| 4 | control (<|im_start|>) | control | 0.10 | 0 |
| 5 | assistant | asr text | 1.20 | 1 |
| 6 | assistant | asr text | 0.90 | 1 |
| 7 | control (<|im_end|>) | control | 0.40 | 1 (EOT) |
Pretraining loss (everything counted, divide by 8):
This is misleadingly low. Of those eight cross-entropies, four (0.10, 0.20, 0.10, 0.40 — the control tokens) are nearly free — the model already produces them with high confidence because they appear millions of times in pretraining. The numerator is dominated by what looks like easy reading.
Naive SFT loss (same as pretraining for this example): . The model also gets credit for predicting the user's tokens, which trains it to generate user messages — exactly the wrong behaviour.
Completion-only masked SFT loss (mask sums to 3):
Compare 0.833 to the 0.888 of the naive loss: the masked loss is higher per kept token. That is the expected sign — once you remove the free wins (the control tokens) the model has to earn every bit of its loss reduction on the assistant content. A real SFT run reports the completion-only loss in its logs. If your training loss looks ‘too good’ (below 0.5 after a few steps), check first that your dashboard is computing the masked version, not the full version.
The off-by-one trap
Learning Rate, Warmup, and the Cosine Schedule
The learning rate is the single dial that most strongly controls whether SFT improves or destroys your model. Too high and pretrained knowledge erodes in the first hundred steps — a phenomenon called catastrophic forgetting, addressed in § 13.5. Too low and the model never adapts to the chat format at all. The standard recipe is a linear warmup followed by a cosine decay, parameterised by four numbers: peak LR, warmup ratio, minimum LR fraction, and total steps.
Mathematically, with total steps and warmup steps :
For :
For , with :
The cosine term starts at when progress is 0 and ends at when progress is 1. So the LR slides smoothly from down to — typically a 10× drop. The cosine shape (as opposed to linear) keeps the LR near peak for longer at the start of decay and decays slowly at the end, which empirically gives lower final loss than linear decay at the same peak.
Drag the knobs. Three things to notice. First, increasing the warmup ratio past 5% gives diminishing returns — modern open-weights recipes have converged on 3%. Second, the choice of minimum LR fraction matters more than people expect: dropping the floor from 10% to 1% noticeably hurts final eval scores on Llama-class fine-tunes because the last quarter of training stops contributing real learning. Third, the curve does not care whether your ‘total steps’ is 1 000 or 10 000 — the shape is invariant under rescaling, which is why this schedule generalises across model sizes.
| Knob | Typical value | Why this value |
|---|---|---|
| peak LR | 1e-5 to 5e-5 | An order of magnitude smaller than pretraining (3e-4). Bigger models need smaller LR. |
| warmup ratio | 3% | Long enough for Adam moments to stabilise; short enough not to waste budget. |
| min LR / peak | 10% | Keeps the tail of training contributing; 0% wastes the last quarter of steps. |
| epochs | 1 to 3 | More than 3 epochs on a fixed SFT set usually overfits the demonstrations. |
| effective batch size | 64 to 512 examples | Big enough for stable gradient direction; small enough to fit memory. |
Manual Numerical Walkthrough: A Three-Step Schedule
Take steps, warmup steps, , and . We compute the learning rate at three representative steps: 0 (start), 1 (mid-warmup), and 5 (mid-decay).
Step 0 (warmup): . We start at half of peak — not at zero. A tiny but real gradient step at step 0 is intentional, it lets Adam start populating its moments. Note that some implementations start at 0; both conventions exist.
Step 1 (last warmup step): . We reach peak.
Step 5 (mid-decay): progress is . The cosine term is . Final LR: .
Step 9 (last step): progress is . Cosine term: . Final LR: . Even at the end we are above the 10% floor times peak — close to it, but not zero. The model is still learning, just very gently.
Sequence Packing and Effective Batch Size
Chat data is wildly variable in length. A “hello” + “hi” exchange is twenty tokens; a code-review conversation is two thousand. If you pad every example to the longest sequence in your batch you spend the majority of your GPU's cycles multiplying zeros by weights, which generates exactly zero learning signal but pays the full FLOPs bill.
The fix is example packing: concatenate short examples end-to-end into a single fixed-length row, separated by end-of-sequence markers. Each row of a packed batch contains multiple training examples; the attention mask is block-diagonal so no example attends across the boundary. Done properly, packing recovers 30–80% of the wasted compute, depending on the length distribution.
Toggle between the two modes. With max length 64 and the example lengths shown, padding wastes about a third of every row. Packing compresses the same examples into half as many rows, with the leftover space at the end of the final row being the only padding. The throughput readout is the speed-up over the naive padded version.
Two production details. One: the attention mask is no longer a 1-D vector of (1 = real, 0 = pad). It is a 2-D block-diagonal matrix where each block is the lower-triangular causal mask within one packed example. FlashAttention 2 natively supports a cu_seqlens argument that encodes this efficiently. Two: the effective batch size in ‘examples’ is no longer the leading dimension of input_ids — a single row might contain six examples. SFT trainers track ‘examples per step’ separately from ‘rows per step’.
Rule of thumb: if your padding ratio (pad tokens / total tokens) is above 20% in any batch, packing will pay for itself in the first hour of integration work. Most production SFT trainers (TRL, axolotl, llama-recipes) ship packing as a single flag.
Optimizer Choice: AdamW Defaults That Just Work
AdamW remains the dominant SFT optimiser in 2026, for the same reason it has dominated transformer training since 2018: it converges in fewer steps than plain SGD on the kinds of high-dimensional, ill-conditioned losses that transformer FFN matrices produce, and the decoupled weight decay (the “W”) gives a clean knob for regularisation that does not interfere with the LR schedule.
The update rule, for a parameter with gradient at step :
First moment:
Second moment:
Bias-corrected: ,
Decoupled update:
Two things to call out. First, instead of the textbook — the Llama team established this convention because the smaller makes the second-moment estimate react faster to changes in gradient scale, which empirically gives lower final loss on transformer fine-tunes. Almost every open-weights chat model trained since Llama-2 uses . Second, the weight-decay term is applied as a separate shrinkage of — not folded into the gradient. This is what ‘decoupled’ means and is the only difference between AdamW and the original Adam. Biases and norm parameters are excluded from weight decay; shrinking a LayerNorm gain toward zero is a quick way to break a model.
Memory cost is non-trivial: AdamW stores two extra fp32 tensors ( and ) per parameter, so optimiser state is bytes (4 bytes each × 2 tensors). For a 70B model that's 560 GB of optimiser state alone — more than any single H100's memory. This is the reason ZeRO/FSDP (§ 11) shards optimiser state across devices.
Regularisation: Weight Decay, Dropout, and NEFTune
SFT runs are short and use small datasets, both of which raise the risk of overfitting to the exact phrasing of the demonstrations. Three regularisers help.
- Weight decay (). The decoupled term above. Applied to all weight matrices but not to biases or LayerNorm parameters. The standard implementation builds two parameter groups: one with decay = 0.1 and one with decay = 0.0.
- Dropout. Many SFT recipes set dropout to zero — the model is already heavily regularised by being undertrained on a small dataset, and adding dropout makes the loss noisier without improving generalisation. If you do enable dropout, 0.05 is the typical value, applied only to the attention output and FFN output. Llama-3 and Qwen-2 both ship with dropout disabled for SFT.
- NEFTune (Noisy Embedding Fine-Tuning). Add uniform noise to the embedding output at training time only, with a typical value. Empirically boosts MT-Bench scores by 1–2 points for a one-line code change. Costs zero inference; pure training-time regulariser. NEFTune is one of the few cheap wins the SFT community has converged on.
Plain Python: Loss-Masked Cross-Entropy from Scratch
We re-implement the exact loss computation an SFT trainer runs each step, in pure NumPy, on a single five-token sequence. The model is a toy linear projection so we can read every intermediate. Replace the forward function with a real transformer and this code is byte-for-byte what a production SFT script does — the loss does not care how the logits got there.
PyTorch: A Production-Shaped SFT Training Step
Now the same calculation, but with an HF causal LM behind the forward pass, real-batch shapes, gradient accumulation, gradient clipping, an LR scheduler, and a logger. This is the inner loop of every production SFT trainer. The body is shorter than the data collator that feeds it.
From a single step to a working training run
Three things turn this skeleton into a real run:
- Mixed precision. Wrap the forward pass in
torch.autocast(device_type="cuda", dtype=torch.bfloat16). The optimiser state stays in fp32 (necessary for numerical stability of Adam's second moment), but activations and the forward computation use bf16. Halves the memory cost of the forward pass for free on H100/A100. - Distributed wrapper. The model is wrapped in FSDP (or DeepSpeed ZeRO-3), which shards parameters, gradients, and optimiser state across GPUs. No code changes to the training step itself — FSDP intercepts the all-reduce inside
.backward(). - Eval hook. Every steps, run the model on a held-out eval set (a few hundred chat examples, or a slice of MT-Bench, or the validation split of the SFT dataset). Track eval loss and at least one downstream metric (e.g. average MT-Bench Likert score from a judge model). Best checkpoint is selected on the downstream metric, not the eval loss — the two often disagree.
At Massive Scale: What Changes for a 405B SFT Run
Everything above scales to 405B with three changes.
Optimiser state sharding becomes mandatory, not optional
At 405B parameters, AdamW's fp32 first and second moments are of state. Even with parameters in bf16 (810 GB) plus gradients (810 GB), the total is ~5 TB before you have allocated a single activation. ZeRO-1 (shard optimiser state) is enough for 70B; ZeRO-3 (shard optimiser state + gradients + parameters) is required for anything past 100B. FSDP is the PyTorch-native version and is the default for new open-weights releases.
Effective batch size in tokens, not examples
At the frontier, ‘batch size’ is reported as tokens per step, not examples per step, because packing makes the example-per-step count meaningless. A typical 405B SFT configures a batch of about tokens per step (4M tokens). At a context of 8 192 that's ~500 packed rows of 8 192 tokens each, spread across 1024 GPUs at ~4 effective rows per device. Total SFT compute for an instruction set of 10M examples × 1k tokens = 10B tokens is about FLOPs — about 1% of the pretraining compute. SFT is cheap by design.
The schedule has to interact with the warm-restart strategy
At frontier scale, SFT is rarely a single run. A common pattern is SFT v1 → reward model → DPO → SFT v2 (regen on improved completions). Each SFT pass uses its own cosine schedule starting from the previous checkpoint, not from the base model. Choosing the peak LR for SFT v2 is delicate: too high and you erase the DPO gains, too low and you do not absorb the new completions. Frontier labs typically use the SFT v1 peak for SFT v2.
Engineering Reality: The Knobs That Actually Move the Needle
After running enough SFT experiments, every team converges on the same short list of mistakes and the same short list of high-leverage fixes.
- The loss mask is wrong. Symptom: model produces plausible user turns when given an empty prompt. Cause: the data collator masked
input_idspositions instead oflabelspositions, or forgot the off-by-one shift. Fix: print one example's tokens with their corresponding labels side-by-side before you launch. - LR is set for the wrong model size. Symptom: training loss looks fine but downstream benchmarks crater (HellaSwag drops 10 points). Cause: pretraining-grade LR used for SFT, eroding pretrained knowledge. Fix: scale LR inversely with model size — for 7B, for 70B, for 400B+.
- Gradient norm spikes and the run NaNs. Symptom: a flat training loss for 50 steps then loss explodes to inf. Cause: a single batch with anomalously high gradient bypassed your clipping because
max_normwas set too high. Fix:max_norm=1.0is the universal default; never raise this without a reason. - EOT not in the loss. Symptom: model generates plausible responses then keeps going past the end of the assistant turn into a fake user turn. Cause: the end-of-turn token
<|im_end|>was masked out. Fix: include the EOT in the unmasked region (the third mode in the loss-mask visualiser above). - Padding eats the batch. Symptom: GPU utilisation stuck around 30% and training is mysteriously slow. Cause: long tail of example lengths in the dataset, padding everything to the max. Fix: enable example packing — every modern trainer has a flag.
- Weight decay applied to LayerNorm. Symptom: small but steady accuracy degradation across all evals. Cause: weight decay applied to every parameter via
model.parameters(). Fix: build two parameter groups, one withweight_decay=0.1for weight matrices and one withweight_decay=0.0for biases and norm parameters. - Scheduler stepped wrong number of times. Symptom: your logged LR does not match the schedule curve you intended. Cause:
scheduler.step()called per microbatch instead of per accumulation cycle, or the opposite. Fix: scheduler advances once per optimiser step, never per microbatch. - Saved best-loss checkpoint underperforms. Symptom: checkpoint at minimum eval loss is worse on MT-Bench than the checkpoint two epochs later. Cause: eval loss measures the wrong thing for chat models. Fix: select on downstream metrics (judge score, win rate against a reference) rather than loss.
The mental model that unifies this section: SFT is a continued-pretraining run with three changes — the loss masks all non-assistant tokens, the LR is an order of magnitude smaller, and the data is curated demonstrations rather than the open web. Every other knob (optimiser, schedule, packing, decay) is the same machinery you used for pretraining, dialled to its ‘gentle and precise’ setting. Get the masking right and the LR right, and a 70B chat model is a few thousand training steps away.