Chapter 13
20 min read
Section 75 of 117

SFT Training Configuration

Supervised Fine-Tuning (SFT)

The Real Problem: A Base Model Has No Format

After pretraining, a 70-billion-parameter base model has absorbed roughly fifteen trillion tokens of text. It knows physics, knows Python, can finish a sonnet. What it cannot do is answer a question. Give it the prompt “What is 2+2?” and it is just as likely to continue with “is a common arithmetic question asked of children in kindergarten across…” as to say “4”. The base model is a next-token oracle; it has no concept of a conversation, no notion of stopping when a turn is complete, no idea that “assistant” is its role.

Supervised fine-tuning solves exactly this gap. Section 13.1 told us what SFT does conceptually. Section 13.2 built the dataset. Section 13.3 turned every example into a chat-templated token sequence with role markers. This section is where the rubber meets the road: we write the actual training configuration — the loss, the optimiser, the schedule, the batch construction — that turns a base model into a chat model in a few thousand gradient steps.

The challenge is that the loss looks exactly like pretraining cross-entropy at first glance, but every detail has been re-engineered for a different goal. Pretraining wants to predict every token equally well across an internet-shaped distribution. SFT wants to predict only the assistant's response across a carefully curated handful of demonstrations, without destroying any of the knowledge that took ten million GPU-hours to install. Get the loss masking wrong and the model learns to generate the user's message back to you. Get the learning rate wrong and you erase weeks of pretraining in the first thousand steps. Get the packing wrong and you spend half your training budget on padding cells. The configuration is not optional decoration — it is the algorithm.

Intuition: SFT Is Continued Pretraining With a Sharp Loss

Hold the picture of pretraining in mind: gigantic shuffled token stream, every token contributing to the loss, learning rate hovering around 3×1043 \times 10^{-4}, the optimiser carving slow rivers across the loss surface for weeks. Now contrast SFT: a few tens of thousands of carefully formatted chat sequences, most tokens masked out of the loss, learning rate an order of magnitude smaller at 2×1052 \times 10^{-5}, a run that finishes overnight. Same architecture. Same loss function. Same optimiser. Different everything else.

The most useful analogy is a surgeon coming back to a patient after a long surgery. The surgery (pretraining) installed the organ. The follow-up (SFT) sutures and bandages. You use the same instruments, but you use them gently, in a small region, and you take care not to re-open anything that has already healed. A surgeon who scrubs in for SFT with a pretraining-sized learning rate is going to tear the sutures back out. This is not a metaphor — it is exactly what catastrophic forgetting (the topic of § 13.5) looks like in the loss curve.

The four design principles of SFT configuration

Every choice in this section flows from four ideas. One: only the assistant's tokens should produce gradient. Two: the optimiser should take small, smooth steps that do not erase pretrained knowledge. Three: the batch should be packed tightly so the GPU never sees padding. Four: the schedule should ramp gently up to a modest peak and decay smoothly to a floor. Every line of the code in this section is one of these four ideas wearing a hat.

Loss Masking: The Mathematical Heart of SFT

The pretraining loss is the per-token cross-entropy averaged over every position in the sequence. Writing it out, for a sequence of length SS with input tokens x0,,xS1x_0, \ldots, x_{S-1}:

Lpretrain=1S1t=0S2logpθ(xt+1xt)L_{\text{pretrain}} = -\frac{1}{S-1} \sum_{t=0}^{S-2} \log p_\theta(x_{t+1} \mid x_{\leq t})

Every term in the sum is real: the model is asked to predict every next token, every position contributes equally to the gradient. For SFT we change exactly one thing — we introduce a binary mask mt{0,1}m_t \in \{0, 1\} that selects which positions count:

LSFT=1tmtt=0S2mtlogpθ(xt+1xt)L_{\text{SFT}} = -\frac{1}{\sum_{t} m_t} \sum_{t=0}^{S-2} m_t \cdot \log p_\theta(x_{t+1} \mid x_{\leq t})

Read this carefully. The numerator only sums over positions where mt=1m_t = 1 — by convention these are the positions whose label is an assistant-generated token. The denominator divides by the number of kept positions, not by SS. That denominator choice matters: if you divided by SS, longer prompts with shorter responses would have artificially smaller per-token loss, and the gradient signal per example would shrink with prompt length. Dividing by tmt\sum_t m_t makes every example contribute its average response loss to the batch, regardless of prompt length.

In code the mask is encoded indirectly. PyTorch's F.cross_entropy has an ignore_index argument: any label equal to ignore_index (conventionally -100) is treated as mt=0m_t = 0 and contributes zero to both the numerator and the denominator. So an SFT data collator does one job per example: copy input_ids into labels, then overwrite every non-assistant position with -100. The trainer never knows the mask is there.

Loading loss-mask visualiser…

Click the three masking modes above. The first (“train on all”) is the naive approach a tutorial uses on day one — it produces a model that completes both halves of a chat. The second (“completion-only”) is the production default, used by every modern open-weights chat model. The third (“completion + EOT”) is a small but important refinement: include the end-of-turn marker so the model also learns when to stop. Without the EOT mark, the model keeps going past its turn and emits user-role tokens as if it were a play with no curtain.

Manual Numerical Walkthrough: One Masked Sequence

Take a single chat-templated sequence with eight tokens. Suppose the model gives the per-token cross-entropies in the second row below. We will compute the pretraining loss, the naive SFT loss, and the masked SFT loss by hand.

position trolelabelper-token CEmask m_t
0control (<|im_start|>)control0.100
1useruser text2.800
2useruser text1.400
3control (<|im_end|>)control0.200
4control (<|im_start|>)control0.100
5assistantasr text1.201
6assistantasr text0.901
7control (<|im_end|>)control0.401 (EOT)

Pretraining loss (everything counted, divide by 8): Lpretrain=0.10+2.80+1.40+0.20+0.10+1.20+0.90+0.408=7.108=0.888L_{\text{pretrain}} = \frac{0.10 + 2.80 + 1.40 + 0.20 + 0.10 + 1.20 + 0.90 + 0.40}{8} = \frac{7.10}{8} = 0.888

This is misleadingly low. Of those eight cross-entropies, four (0.10, 0.20, 0.10, 0.40 — the control tokens) are nearly free — the model already produces them with high confidence because they appear millions of times in pretraining. The numerator is dominated by what looks like easy reading.

Naive SFT loss (same as pretraining for this example): L=0.888L = 0.888. The model also gets credit for predicting the user's tokens, which trains it to generate user messages — exactly the wrong behaviour.

Completion-only masked SFT loss (mask sums to 3): LSFT=1.20+0.90+0.403=2.503=0.833L_{\text{SFT}} = \frac{1.20 + 0.90 + 0.40}{3} = \frac{2.50}{3} = 0.833

Compare 0.833 to the 0.888 of the naive loss: the masked loss is higher per kept token. That is the expected sign — once you remove the free wins (the control tokens) the model has to earn every bit of its loss reduction on the assistant content. A real SFT run reports the completion-only loss in its logs. If your training loss looks ‘too good’ (below 0.5 after a few steps), check first that your dashboard is computing the masked version, not the full version.

The off-by-one trap

The labels are shifted relative to input_ids: position tt of input_ids is predicted from positions 00 through t1t-1. So the mask at position tt applies to the label at position tt, which lives in input_ids at position tt (the model predicts xtx_t from x<tx_{<t}). Most data-collator bugs come from masking the input position instead of the label position, or vice versa. Verify with one printout: for every position where labels ≠ −100, the actual token at that position of input_ids should be an assistant token.

Learning Rate, Warmup, and the Cosine Schedule

The learning rate is the single dial that most strongly controls whether SFT improves or destroys your model. Too high and pretrained knowledge erodes in the first hundred steps — a phenomenon called catastrophic forgetting, addressed in § 13.5. Too low and the model never adapts to the chat format at all. The standard recipe is a linear warmup followed by a cosine decay, parameterised by four numbers: peak LR, warmup ratio, minimum LR fraction, and total steps.

Mathematically, with total steps TT and warmup steps TwT_w:

For t<Twt < T_w: ηt=ηpeakt+1Tw\eta_t = \eta_{\text{peak}} \cdot \frac{t + 1}{T_w}

For tTwt \geq T_w, with progress=(tTw)/(TTw)\text{progress} = (t - T_w) / (T - T_w): ηt=ηpeak[ηmin+(1ηmin)12(1+cos(πprogress))]\eta_t = \eta_{\text{peak}} \cdot \big[\eta_{\min} + (1 - \eta_{\min}) \cdot \tfrac{1}{2} (1 + \cos(\pi \cdot \text{progress}))\big]

The cosine term 12(1+cos(πprogress))\tfrac{1}{2}(1 + \cos(\pi \cdot \text{progress})) starts at 11 when progress is 0 and ends at 00 when progress is 1. So the LR slides smoothly from ηpeak\eta_{\text{peak}} down to ηminηpeak\eta_{\min} \cdot \eta_{\text{peak}} — typically a 10× drop. The cosine shape (as opposed to linear) keeps the LR near peak for longer at the start of decay and decays slowly at the end, which empirically gives lower final loss than linear decay at the same peak.

Loading LR-schedule visualiser…

Drag the knobs. Three things to notice. First, increasing the warmup ratio past 5% gives diminishing returns — modern open-weights recipes have converged on 3%. Second, the choice of minimum LR fraction matters more than people expect: dropping the floor from 10% to 1% noticeably hurts final eval scores on Llama-class fine-tunes because the last quarter of training stops contributing real learning. Third, the curve does not care whether your ‘total steps’ is 1 000 or 10 000 — the shape is invariant under rescaling, which is why this schedule generalises across model sizes.

KnobTypical valueWhy this value
peak LR1e-5 to 5e-5An order of magnitude smaller than pretraining (3e-4). Bigger models need smaller LR.
warmup ratio3%Long enough for Adam moments to stabilise; short enough not to waste budget.
min LR / peak10%Keeps the tail of training contributing; 0% wastes the last quarter of steps.
epochs1 to 3More than 3 epochs on a fixed SFT set usually overfits the demonstrations.
effective batch size64 to 512 examplesBig enough for stable gradient direction; small enough to fit memory.

Manual Numerical Walkthrough: A Three-Step Schedule

Take T=10T = 10 steps, Tw=2T_w = 2 warmup steps, ηpeak=2×105\eta_{\text{peak}} = 2 \times 10^{-5}, and ηmin=0.1\eta_{\min} = 0.1. We compute the learning rate at three representative steps: 0 (start), 1 (mid-warmup), and 5 (mid-decay).

Step 0 (warmup): η0=2×1050+12=1.0×105\eta_0 = 2 \times 10^{-5} \cdot \frac{0 + 1}{2} = 1.0 \times 10^{-5}. We start at half of peak — not at zero. A tiny but real gradient step at step 0 is intentional, it lets Adam start populating its moments. Note that some implementations start at 0; both conventions exist.

Step 1 (last warmup step): η1=2×1051+12=2.0×105\eta_1 = 2 \times 10^{-5} \cdot \frac{1 + 1}{2} = 2.0 \times 10^{-5}. We reach peak.

Step 5 (mid-decay): progress is (52)/(102)=3/8=0.375(5 - 2) / (10 - 2) = 3/8 = 0.375. The cosine term is 12(1+cos(π0.375))=12(1+cos(1.178))=12(1+0.383)=0.691\tfrac{1}{2}(1 + \cos(\pi \cdot 0.375)) = \tfrac{1}{2}(1 + \cos(1.178)) = \tfrac{1}{2}(1 + 0.383) = 0.691. Final LR: η5=2×105(0.1+0.90.691)=2×1050.722=1.44×105\eta_5 = 2 \times 10^{-5} \cdot (0.1 + 0.9 \cdot 0.691) = 2 \times 10^{-5} \cdot 0.722 = 1.44 \times 10^{-5}.

Step 9 (last step): progress is (92)/(102)=7/8=0.875(9 - 2) / (10 - 2) = 7/8 = 0.875. Cosine term: 12(1+cos(π0.875))=0.0381\tfrac{1}{2}(1 + \cos(\pi \cdot 0.875)) = 0.0381. Final LR: 2×105(0.1+0.90.0381)=2×1050.134=2.69×1062 \times 10^{-5} \cdot (0.1 + 0.9 \cdot 0.0381) = 2 \times 10^{-5} \cdot 0.134 = 2.69 \times 10^{-6}. Even at the end we are above the 10% floor times peak — close to it, but not zero. The model is still learning, just very gently.

Sequence Packing and Effective Batch Size

Chat data is wildly variable in length. A “hello” + “hi” exchange is twenty tokens; a code-review conversation is two thousand. If you pad every example to the longest sequence in your batch you spend the majority of your GPU's cycles multiplying zeros by weights, which generates exactly zero learning signal but pays the full FLOPs bill.

The fix is example packing: concatenate short examples end-to-end into a single fixed-length row, separated by end-of-sequence markers. Each row of a packed batch contains multiple training examples; the attention mask is block-diagonal so no example attends across the boundary. Done properly, packing recovers 30–80% of the wasted compute, depending on the length distribution.

Loading packing visualiser…

Toggle between the two modes. With max length 64 and the example lengths shown, padding wastes about a third of every row. Packing compresses the same examples into half as many rows, with the leftover space at the end of the final row being the only padding. The throughput readout is the speed-up over the naive padded version.

Two production details. One: the attention mask is no longer a 1-D vector of (1 = real, 0 = pad). It is a 2-D block-diagonal matrix where each block is the lower-triangular causal mask within one packed example. FlashAttention 2 natively supports a cu_seqlens argument that encodes this efficiently. Two: the effective batch size in ‘examples’ is no longer the leading dimension of input_ids — a single row might contain six examples. SFT trainers track ‘examples per step’ separately from ‘rows per step’.

Rule of thumb: if your padding ratio (pad tokens / total tokens) is above 20% in any batch, packing will pay for itself in the first hour of integration work. Most production SFT trainers (TRL, axolotl, llama-recipes) ship packing as a single flag.

Optimizer Choice: AdamW Defaults That Just Work

AdamW remains the dominant SFT optimiser in 2026, for the same reason it has dominated transformer training since 2018: it converges in fewer steps than plain SGD on the kinds of high-dimensional, ill-conditioned losses that transformer FFN matrices produce, and the decoupled weight decay (the “W”) gives a clean knob for regularisation that does not interfere with the LR schedule.

The update rule, for a parameter θ\theta with gradient gtg_t at step tt:

First moment: mt=β1mt1+(1β1)gtm_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t

Second moment: vt=β2vt1+(1β2)gt2v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2

Bias-corrected: m^t=mt/(1β1t)\hat{m}_t = m_t / (1 - \beta_1^t), v^t=vt/(1β2t)\hat{v}_t = v_t / (1 - \beta_2^t)

Decoupled update: θt=θt1ηtm^t/(v^t+ε)ηtλθt1\theta_t = \theta_{t-1} - \eta_t \cdot \hat{m}_t / (\sqrt{\hat{v}_t} + \varepsilon) - \eta_t \cdot \lambda \cdot \theta_{t-1}

Two things to call out. First, β2=0.95\beta_2 = 0.95 instead of the textbook 0.9990.999 — the Llama team established this convention because the smaller β2\beta_2 makes the second-moment estimate react faster to changes in gradient scale, which empirically gives lower final loss on transformer fine-tunes. Almost every open-weights chat model trained since Llama-2 uses β2=0.95\beta_2 = 0.95. Second, the weight-decay term λ=0.1\lambda = 0.1 is applied as a separate shrinkage of θt1\theta_{t-1} — not folded into the gradient. This is what ‘decoupled’ means and is the only difference between AdamW and the original Adam. Biases and norm parameters are excluded from weight decay; shrinking a LayerNorm gain toward zero is a quick way to break a model.

Memory cost is non-trivial: AdamW stores two extra fp32 tensors (mtm_t and vtv_t) per parameter, so optimiser state is 8Nparams8 \cdot N_{\text{params}} bytes (4 bytes each × 2 tensors). For a 70B model that's 560 GB of optimiser state alone — more than any single H100's memory. This is the reason ZeRO/FSDP (§ 11) shards optimiser state across devices.

Regularisation: Weight Decay, Dropout, and NEFTune

SFT runs are short and use small datasets, both of which raise the risk of overfitting to the exact phrasing of the demonstrations. Three regularisers help.

  1. Weight decay (λ=0.1\lambda = 0.1). The decoupled term above. Applied to all weight matrices but not to biases or LayerNorm parameters. The standard implementation builds two parameter groups: one with decay = 0.1 and one with decay = 0.0.
  2. Dropout. Many SFT recipes set dropout to zero — the model is already heavily regularised by being undertrained on a small dataset, and adding dropout makes the loss noisier without improving generalisation. If you do enable dropout, 0.05 is the typical value, applied only to the attention output and FFN output. Llama-3 and Qwen-2 both ship with dropout disabled for SFT.
  3. NEFTune (Noisy Embedding Fine-Tuning). Add uniform noise U[α/Ld,+α/Ld]U[-\alpha/\sqrt{Ld}, +\alpha/\sqrt{Ld}] to the embedding output at training time only, with α=5\alpha = 5 a typical value. Empirically boosts MT-Bench scores by 1–2 points for a one-line code change. Costs zero inference; pure training-time regulariser. NEFTune is one of the few cheap wins the SFT community has converged on.

Plain Python: Loss-Masked Cross-Entropy from Scratch

We re-implement the exact loss computation an SFT trainer runs each step, in pure NumPy, on a single five-token sequence. The model is a toy linear projection so we can read every intermediate. Replace the forward function with a real transformer and this code is byte-for-byte what a production SFT script does — the loss does not care how the logits got there.

SFT loss — masked cross-entropy in pure NumPy
🐍sft_loss_plain.py
17The tokenised chat sequence

Five tokens representing a complete chat turn after the template is applied: a beginning-of-sequence marker, a user-role marker, the literal word 'hi' from the user, an assistant-role marker, and the literal 'hi' from the assistant. Every real SFT script produces a sequence in this shape — the only difference is length (thousands of tokens, not five) and vocabulary size (50k+, not 4).

EXAMPLE
input_ids = [0, 1, 3, 2, 3]  # <bos> <usr> 'hi' <asr> 'hi'
19The labels — same shape as input_ids, but with prompt tokens masked

Labels are a copy of input_ids with the positions we DO NOT want to train on replaced by -100. PyTorch's CrossEntropyLoss treats -100 as 'ignore'. Only the final '3' (the assistant's 'hi') has a real label — that's the one token the SFT gradient will push the model toward. The prompt tokens stay as -100 so they contribute zero gradient.

EXAMPLE
labels = [-100, -100, -100, -100, 3]  # only train on the assistant's 'hi'
28Toy model parameters: embedding + LM head

Two weight matrices. W_embed turns a token id into a hidden-state vector (vocab_size, hidden). W_lm projects that hidden state back to vocab-sized logits (hidden, vocab_size). Both are initialised with small Gaussians — the standard convention. In a real model the same shapes hold, but there are N transformer blocks between them, and N is anywhere from 12 (small) to 96+ (frontier). Crucially, replacing the middle with transformer blocks does not change the loss code at all.

33Forward — the simplest possible LM

📚 Look up each input id in the embedding table to get its hidden vector (h has shape (S, HIDDEN) = (5, 6)). Multiply by W_lm to get logits (S, VOCAB) = (5, 4). One logit per token per vocabulary item. This is the same end of the pipeline a real model produces — what varies between a toy and a frontier model is what creates h, not what happens after.

EXAMPLE
input_ids shape (5,) → h shape (5, 6) → logits shape (5, 4)
43Causal shift — position t predicts token t+1

📚 The convention in every causal LM trainer: at position t, the model's output should be a prediction of input_ids[t+1]. So we drop the last logit row (there is no token t+1 for it) and drop the first label (there is no position predicting it). After this shift, shift_logits[i] is the prediction for shift_labels[i]. Off-by-one errors here are the most common SFT bug — labels look right but the model is being asked to predict the wrong token.

EXAMPLE
logits[:-1, :] shape (4, 4); labels[1:] shape (4,). Each row of shift_logits predicts the matching row of shift_labels.
51Numerically stable log-softmax

📚 We subtract the max per row before exponentiating so no entry blows up to inf. This is the standard log-sum-exp trick. The result log_probs[i, v] is log P(token v | context up to position i). At bf16 (the dtype real SFT uses), forgetting this trick is enough to introduce silent NaNs on long sequences.

EXAMPLE
logits row [2.1, 3.0, 1.5, 0.0] → log_probs row [-1.41, -0.51, -2.01, -3.51]
56Build the per-position cross-entropy

valid is the boolean mask of positions whose loss we want — exactly the positions where labels != -100. We gather the model's log-prob at the true label and negate it (cross-entropy is -log p). For masked positions we substitute label 0 in the gather (any safe index works) because the next line will zero them out anyway. The trick of gathering and then masking is faster than indexing the unmasked positions out — it keeps every tensor a fixed shape so the JIT and the autograd graph stay simple.

EXAMPLE
per_token = [1.40, 0.90, 2.10, 0.30] before mask; only the last entry is real
60Apply the mask — masked positions contribute zero

np.where replaces the per-token CE with 0 wherever valid is False. This is what makes SFT a 'completion-only' objective: the prompt's gradient signal is exactly zero, only the response gradient flows through. Mistakes here cause the model to learn to GENERATE the user message back, not respond to it — a classic and embarrassing fine-tune failure mode.

EXAMPLE
per_token after mask = [0, 0, 0, 0.30]
64Average over only the kept positions

n_valid is the count of real labels (4 if no mask, 1 here). We divide by n_valid, NOT by the sequence length. Dividing by sequence length is a classic bug that makes longer sequences contribute less per-token signal and biases the loss toward short examples. Real SFT trainers (TRL, axolotl, llama-recipes) all divide by the number of valid label tokens.

EXAMPLE
per_token.sum() = 0.30, n_valid = 1, loss = 0.30
76One full training step in five lines

Forward, shift, masked-CE — three function calls and you have a scalar loss. A real SFT trainer adds: loss.backward(), optimizer.step(), optimizer.zero_grad(), lr_scheduler.step(), and a gradient-accumulation guard. None of those touch this code. That's why a 30-line training loop runs on the same skeleton from a 7B model to a 405B model.

EXAMPLE
loss = 0.30 (single sequence). Real batch: mean over 32+ sequences.
77 lines without explanation
1"""
2Loss-masked cross-entropy for SFT — pure NumPy, no autograd.
3
4We re-implement the exact computation a SFT trainer performs each step:
5  1. forward: logits = model(input_ids)
6  2. shift: predict token t+1 from tokens [0..t]
7  3. mask: ignore positions where label == -100 (prompt / padding)
8  4. average cross-entropy over the kept positions
9
10The 'model' is a single linear layer over a 4-token vocabulary so we can
11walk every number by hand. Replace 'forward' with a real transformer and
12this code is byte-for-byte what a SFT script does.
13"""
14
15import numpy as np
16
17# ---------------------------------------------------------------------------
18# 1. A toy chat-formatted sequence after tokenisation.
19#    Vocab: 0=<bos>, 1=<usr>, 2=<asr>, 3='hi'
20#    Story: user says "hi", assistant says "hi" back.
21# ---------------------------------------------------------------------------
22
23input_ids = np.array([0, 1, 3, 2, 3])                # length 5
24# Label = -100 for every token whose loss we DO NOT want to count.
25# Here we only learn from the assistant turn (the final '3').
26labels    = np.array([-100, -100, -100, -100, 3])    # mask everything but pos 4
27
28VOCAB_SIZE = 4
29HIDDEN     = 6
30
31# ---------------------------------------------------------------------------
32# 2. A trivial 'model': embedding lookup + linear LM head.
33#    Real models replace the middle with N transformer blocks; the
34#    shapes and the loss are unchanged.
35# ---------------------------------------------------------------------------
36
37rng = np.random.default_rng(0)
38W_embed = rng.standard_normal((VOCAB_SIZE, HIDDEN)) * 0.1
39W_lm    = rng.standard_normal((HIDDEN, VOCAB_SIZE)) * 0.1
40
41def forward(ids: np.ndarray) -> np.ndarray:
42    h = W_embed[ids]              # (S, HIDDEN)
43    logits = h @ W_lm             # (S, VOCAB)
44    return logits
45
46# ---------------------------------------------------------------------------
47# 3. Shift labels: position t predicts token t+1.
48#    Standard causal LM convention. After the shift, position t of the
49#    'shifted' arrays talks about the prediction for token t+1.
50# ---------------------------------------------------------------------------
51
52def shift(logits: np.ndarray, labels: np.ndarray):
53    shift_logits = logits[:-1, :]        # drop last: nothing to predict for it
54    shift_labels = labels[1:]            # drop first: nothing predicts pos 0
55    return shift_logits, shift_labels
56
57# ---------------------------------------------------------------------------
58# 4. Loss-masked cross-entropy.
59# ---------------------------------------------------------------------------
60
61def masked_ce(logits: np.ndarray, labels: np.ndarray, ignore: int = -100) -> float:
62    # Numerically stable log-softmax
63    z = logits - logits.max(axis=-1, keepdims=True)
64    log_sum_exp = np.log(np.exp(z).sum(axis=-1))
65    log_probs = z - log_sum_exp[:, None]      # (S, VOCAB)
66
67    # Gather log-prob of the true label at each position
68    valid = labels != ignore                  # (S,) bool mask
69    per_token = -log_probs[np.arange(len(labels)), np.where(valid, labels, 0)]
70
71    # Apply mask: positions with label == -100 contribute zero
72    per_token = np.where(valid, per_token, 0.0)
73
74    # MEAN over only the kept positions (this is the canonical SFT objective)
75    n_valid = int(valid.sum())
76    if n_valid == 0:
77        return 0.0
78    return float(per_token.sum() / n_valid)
79
80# ---------------------------------------------------------------------------
81# 5. End-to-end one step.
82# ---------------------------------------------------------------------------
83
84logits = forward(input_ids)                       # (5, 4)
85sl, sllabels = shift(logits, labels)              # (4, 4) and (4,)
86loss = masked_ce(sl, sllabels)
87print(f"sequence loss = {loss:.4f}")              # single scalar per sequence

PyTorch: A Production-Shaped SFT Training Step

Now the same calculation, but with an HF causal LM behind the forward pass, real-batch shapes, gradient accumulation, gradient clipping, an LR scheduler, and a logger. This is the inner loop of every production SFT trainer. The body is shorter than the data collator that feeds it.

SFT training step — AdamW + cosine + grad clip + accumulation
🐍sft_train_step.py
16AdamW with Llama-class betas — not the PyTorch default

📚 AdamW: the standard transformer optimizer. lr=2e-5 is the canonical Llama-class SFT peak; for 70B+ models you drop to 1e-5, for 7B you can sometimes push to 5e-5. betas=(0.9, 0.95) instead of the PyTorch default (0.9, 0.999) — the Llama team found 0.95 gives more responsive second-moment estimates and slightly better fine-tune loss; almost every open-weight chat model uses this. eps=1e-8 is the standard; do not change without a reason. weight_decay=0.1 is decoupled (the W in AdamW): applied as a separate parameter shrinkage step, not folded into the gradient.

EXAMPLE
AdamW(lr=2e-5, betas=(0.9, 0.95), eps=1e-8, weight_decay=0.1)
273% warmup — the modern default for short SFT runs

Warmup ramps the LR from 0 to peak over the first 3% of steps. Why: at step 0 the model's outputs are 'wrong but stable' on the SFT distribution. A full-peak gradient step lands far from anything the model was pretrained on and Adam's running statistics, initialised to zero, give a single noisy step way more influence than it deserves. Linear warmup gives the optimiser a few hundred steps to populate its moments before any aggressive step can happen.

EXAMPLE
NUM_STEPS=3000 → WARMUP=90 steps of linear ramp
33Linear warmup then cosine — the curve every SFT trainer uses

Two cases: inside warmup we return a linear fraction of peak ((step+1)/WARMUP); after warmup we follow a cosine that decays to 0.1 × peak by the final step. Cosine outperforms linear and step schedules for fine-tunes because it spends more time near peak (good for taking large stable steps) and tapers gracefully without a discontinuity (no eval-loss spike at the LR drop). The 10% floor (not 0) keeps the very last steps doing real learning rather than perturbations dominated by noise.

EXAMPLE
step 0 → 0.0; step 45 (halfway through warmup) → 0.5 × 2e-5 = 1e-5; step 1545 (halfway through decay) → ~1.55e-5
48Batched, padded inputs

📚 input_ids is shape (B, S) where B is the per-device batch and S is the per-example sequence length after padding (or packing — see § packing). bf16 is the native dtype for an H100; the values are integer indices into the vocabulary.

EXAMPLE
input_ids.shape = torch.Size([16, 2048])
49Labels with prompt tokens already masked

The dataset collator is responsible for putting -100 at every prompt and pad position. Get this wrong and the model silently trains on the wrong objective; a common bug is to forget to mask the system prompt or to forget to mask the user turn between assistant turns in a multi-turn example.

EXAMPLE
labels.shape = (16, 2048). Typically 30-50% are -100 in a chat fine-tune.
54Pass attention_mask, NOT a flag on input_ids

📚 attention_mask is shape (B, S) with 1 on real tokens and 0 on pad. The model uses it inside attention to zero out the attention weights for pad tokens. Without it, attending to pad positions is a real gradient pathway — the model would 'learn' the noise in your padded values. Disclaimer: this is for plain padded batches; the packing path uses a 4-D block-diagonal mask instead (see § packing).

57use_cache=False is critical during training

📚 By default HuggingFace causal LMs return key/value caches for the next decode step. That's useful for generation but wastes memory during training (we don't decode autoregressively here) and breaks gradient flow on a few internal kernels. Always set use_cache=False during training. Reflex error: forgetting this on a long-context SFT and OOMing at step 1.

EXAMPLE
out = model(input_ids=..., use_cache=False) → out.logits.shape = (B, S, V)
62Pull out the logits — bf16, three dimensions

logits has shape (B, S, V) where V is the vocab size (often 32k or 128k). At bf16 each entry is two bytes, so logits alone is ~B·S·V·2 bytes = 16·2048·128k·2 = 8.4 GB for a Llama-3 batch. This is the largest single tensor in the forward pass and the reason 'logits offloading' tricks exist.

EXAMPLE
shape (16, 2048, 128256); bytes = 8.4 GB at bf16
66Causal shift — flatten before cross-entropy

📚 We drop the last position from logits (no token follows it to predict) and the first position from labels (no preceding token predicts it). After this, shift_logits[b, t, :] is the prediction for shift_labels[b, t]. The .contiguous() is required because cross_entropy below uses the underlying storage layout for fast indexing.

EXAMPLE
shift_logits.shape = (16, 2047, 128256); shift_labels.shape = (16, 2047)
71Cross-entropy with ignore_index=-100

📚 F.cross_entropy: this single PyTorch call IS the SFT loss. It does the log-softmax, the gather of the true label's log-prob, and the masking in one fused CUDA kernel. ignore_index=-100 makes positions with label -100 contribute zero — the entire 'completion-only' objective comes from this argument. reduction='mean' divides by the number of kept tokens (not the total length), exactly as we did by hand in the plain-Python version.

EXAMPLE
loss = F.cross_entropy(logits.view(-1, V), labels.view(-1), ignore_index=-100)
80Gradient accumulation — divide-then-backward

📚 We divide loss by accum_steps BEFORE .backward(). This is the only correct way to accumulate: each microbatch contributes 1/N of the full gradient, and after N microbatches the accumulated .grad equals the average over N microbatches' gradients. Forgetting to divide is the canonical accum bug — the effective LR ends up N× too high and the loss explodes 50 steps in.

EXAMPLE
(loss / 4).backward()  # accum_steps = 4 microbatches per effective batch
93Gradient clipping — the most-watched stability metric

📚 torch.nn.utils.clip_grad_norm_ rescales all gradients so their global L2 norm is at most max_norm. max_norm=1.0 is the universal default; bigger values let occasional huge gradients (from a noisy batch or a rare token) destabilise the run. The returned 'gn' is the PRE-clip gradient norm — log this every step. A run where gn climbs over 5 for many consecutive steps is in trouble; a run where gn stays in [0.1, 2.0] is healthy.

EXAMPLE
gn = clip_grad_norm_(params, max_norm=1.0) → typical value 0.4 ± 0.2
95Step the schedule AFTER the optimizer

The order is: optimizer.step() → scheduler.step(). Reversing this is a quiet bug — you'll be using the previous step's LR for one extra update. set_to_none=True frees the .grad tensors rather than zero-filling them, saving a small amount of memory per step (matters at scale).

104Log lr, loss, and gradient norm together

These three numbers are the minimum-viable training dashboard. lr should match the schedule exactly (a mismatch means scheduler.step() is being called the wrong number of times). loss should trend down with noise inversely proportional to batch size. |g| should be bounded and stable. If any of the three breaks, you can debug from this one log line.

113Save every N steps, not every epoch

Modern SFT runs are short (a few thousand steps) so we save by step, not by epoch. Every checkpoint is a candidate for the final release — best-checkpoint selection happens later, by running each on an eval suite. Saving model.save_pretrained writes the HF-compatible files (config.json, model.safetensors, generation_config.json) so a downstream evaluator can load with from_pretrained.

94 lines without explanation
1"""
2PyTorch SFT training step — what TRL / axolotl / llama-recipes do per step.
3
4Drop a HuggingFace causal LM in for 'model', use the Llama-3 chat template
5in your dataset preprocessor, and this loop trains a real production
6fine-tune. Every line below matters; commented numbers are the values
7they hold for a typical 8B / batch-of-16 / 2048-seq-len configuration.
8"""
9
10import math
11import torch
12import torch.nn.functional as F
13from torch.optim import AdamW
14from torch.optim.lr_scheduler import LambdaLR
15
16# ---------------------------------------------------------------------------
17# 1. Optimizer — AdamW with the canonical Llama-class SFT defaults.
18# ---------------------------------------------------------------------------
19
20optimizer = AdamW(
21    model.parameters(),
22    lr=2e-5,                # peak LR. For a 70B+: 1e-5. For a 7B: 2e-5-5e-5.
23    betas=(0.9, 0.95),      # Llama-tradition (not Adam's default 0.999)
24    eps=1e-8,
25    weight_decay=0.1,       # decoupled from LR; ignored on biases + norms
26)
27
28# ---------------------------------------------------------------------------
29# 2. LR schedule — linear warmup then cosine decay to 10% of peak.
30# ---------------------------------------------------------------------------
31
32NUM_STEPS = 3000
33WARMUP    = int(0.03 * NUM_STEPS)   # 3% warmup is the modern default
34
35def lr_lambda(step: int) -> float:
36    if step < WARMUP:
37        return (step + 1) / WARMUP
38    progress = (step - WARMUP) / max(1, NUM_STEPS - WARMUP)
39    cos = 0.5 * (1 + math.cos(math.pi * min(1.0, progress)))
40    return 0.1 + 0.9 * cos              # cosine floor at 10% of peak
41
42scheduler = LambdaLR(optimizer, lr_lambda)
43
44# ---------------------------------------------------------------------------
45# 3. One training step.
46# ---------------------------------------------------------------------------
47
48def sft_step(batch, model, accum_steps: int = 1):
49    """A single SFT step. 'batch' is what the data collator returns."""
50    input_ids = batch["input_ids"]            # (B, S) e.g. (16, 2048)
51    labels    = batch["labels"]               # (B, S), prompt tokens = -100
52    attention = batch["attention_mask"]       # (B, S), 1 for real, 0 for pad
53
54    # Forward — the model returns shifted logits internally if we pass labels,
55    # but we ask for logits only so we can apply our own masking.
56    out = model(
57        input_ids=input_ids,
58        attention_mask=attention,
59        use_cache=False,                       # critical: KV cache is for eval
60    )
61    logits = out.logits                        # (B, S, V), bf16
62
63    # Causal shift: position t predicts token t+1.
64    shift_logits = logits[..., :-1, :].contiguous()    # (B, S-1, V)
65    shift_labels = labels[..., 1:].contiguous()        # (B, S-1)
66
67    # Cross-entropy with ignore_index=-100 — this IS the SFT loss.
68    loss = F.cross_entropy(
69        shift_logits.view(-1, shift_logits.size(-1)),
70        shift_labels.view(-1),
71        ignore_index=-100,
72        reduction="mean",                      # mean over kept tokens only
73    )
74
75    # Gradient accumulation: divide so the eventual .backward() sums to
76    # one effective batch's worth of gradient.
77    (loss / accum_steps).backward()
78    return loss.detach()
79
80# ---------------------------------------------------------------------------
81# 4. The outer loop — accumulate, clip, step, decay, log, save.
82# ---------------------------------------------------------------------------
83
84ACCUM_STEPS = 4
85GLOBAL_STEP = 0
86LOG_EVERY   = 10
87SAVE_EVERY  = 500
88
89for step, batch in enumerate(dataloader):
90    loss = sft_step(batch, model, accum_steps=ACCUM_STEPS)
91
92    if (step + 1) % ACCUM_STEPS == 0:
93        # Global grad norm — the single most-watched stability metric.
94        gn = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
95        optimizer.step()
96        scheduler.step()
97        optimizer.zero_grad(set_to_none=True)
98        GLOBAL_STEP += 1
99
100        if GLOBAL_STEP % LOG_EVERY == 0:
101            print(
102                f"step={GLOBAL_STEP:>5}  "
103                f"lr={scheduler.get_last_lr()[0]:.2e}  "
104                f"loss={loss.item():.3f}  "
105                f"|g|={gn.item():.2f}"
106            )
107
108        if GLOBAL_STEP % SAVE_EVERY == 0:
109            model.save_pretrained(f"ckpt/step-{GLOBAL_STEP}")

From a single step to a working training run

Three things turn this skeleton into a real run:

  1. Mixed precision. Wrap the forward pass in torch.autocast(device_type="cuda", dtype=torch.bfloat16). The optimiser state stays in fp32 (necessary for numerical stability of Adam's second moment), but activations and the forward computation use bf16. Halves the memory cost of the forward pass for free on H100/A100.
  2. Distributed wrapper. The model is wrapped in FSDP (or DeepSpeed ZeRO-3), which shards parameters, gradients, and optimiser state across GPUs. No code changes to the training step itself — FSDP intercepts the all-reduce inside .backward().
  3. Eval hook. Every NN steps, run the model on a held-out eval set (a few hundred chat examples, or a slice of MT-Bench, or the validation split of the SFT dataset). Track eval loss and at least one downstream metric (e.g. average MT-Bench Likert score from a judge model). Best checkpoint is selected on the downstream metric, not the eval loss — the two often disagree.

At Massive Scale: What Changes for a 405B SFT Run

Everything above scales to 405B with three changes.

Optimiser state sharding becomes mandatory, not optional

At 405B parameters, AdamW's fp32 first and second moments are 8×405×109=3.24 TB8 \times 405 \times 10^9 = 3.24 \text{ TB} of state. Even with parameters in bf16 (810 GB) plus gradients (810 GB), the total is ~5 TB before you have allocated a single activation. ZeRO-1 (shard optimiser state) is enough for 70B; ZeRO-3 (shard optimiser state + gradients + parameters) is required for anything past 100B. FSDP is the PyTorch-native version and is the default for new open-weights releases.

Effective batch size in tokens, not examples

At the frontier, ‘batch size’ is reported as tokens per step, not examples per step, because packing makes the example-per-step count meaningless. A typical 405B SFT configures a batch of about 4×1064 \times 10^6 tokens per step (4M tokens). At a context of 8 192 that's ~500 packed rows of 8 192 tokens each, spread across 1024 GPUs at ~4 effective rows per device. Total SFT compute for an instruction set of 10M examples × 1k tokens = 10B tokens is about 640510910102.4×10226 \cdot 405 \cdot 10^9 \cdot 10^{10} \approx 2.4 \times 10^{22} FLOPs — about 1% of the pretraining compute. SFT is cheap by design.

The schedule has to interact with the warm-restart strategy

At frontier scale, SFT is rarely a single run. A common pattern is SFT v1 → reward model → DPO → SFT v2 (regen on improved completions). Each SFT pass uses its own cosine schedule starting from the previous checkpoint, not from the base model. Choosing the peak LR for SFT v2 is delicate: too high and you erase the DPO gains, too low and you do not absorb the new completions. Frontier labs typically use 0.5×0.5 \times the SFT v1 peak for SFT v2.

Engineering Reality: The Knobs That Actually Move the Needle

After running enough SFT experiments, every team converges on the same short list of mistakes and the same short list of high-leverage fixes.

  • The loss mask is wrong. Symptom: model produces plausible user turns when given an empty prompt. Cause: the data collator masked input_ids positions instead of labels positions, or forgot the off-by-one shift. Fix: print one example's tokens with their corresponding labels side-by-side before you launch.
  • LR is set for the wrong model size. Symptom: training loss looks fine but downstream benchmarks crater (HellaSwag drops 10 points). Cause: pretraining-grade LR used for SFT, eroding pretrained knowledge. Fix: scale LR inversely with model size — 5×1055 \times 10^{-5} for 7B, 2×1052 \times 10^{-5} for 70B, 1×1051 \times 10^{-5} for 400B+.
  • Gradient norm spikes and the run NaNs. Symptom: a flat training loss for 50 steps then loss explodes to inf. Cause: a single batch with anomalously high gradient bypassed your clipping because max_norm was set too high. Fix: max_norm=1.0 is the universal default; never raise this without a reason.
  • EOT not in the loss. Symptom: model generates plausible responses then keeps going past the end of the assistant turn into a fake user turn. Cause: the end-of-turn token <|im_end|> was masked out. Fix: include the EOT in the unmasked region (the third mode in the loss-mask visualiser above).
  • Padding eats the batch. Symptom: GPU utilisation stuck around 30% and training is mysteriously slow. Cause: long tail of example lengths in the dataset, padding everything to the max. Fix: enable example packing — every modern trainer has a flag.
  • Weight decay applied to LayerNorm. Symptom: small but steady accuracy degradation across all evals. Cause: weight decay applied to every parameter via model.parameters(). Fix: build two parameter groups, one with weight_decay=0.1 for weight matrices and one with weight_decay=0.0 for biases and norm parameters.
  • Scheduler stepped wrong number of times. Symptom: your logged LR does not match the schedule curve you intended. Cause: scheduler.step() called per microbatch instead of per accumulation cycle, or the opposite. Fix: scheduler advances once per optimiser step, never per microbatch.
  • Saved best-loss checkpoint underperforms. Symptom: checkpoint at minimum eval loss is worse on MT-Bench than the checkpoint two epochs later. Cause: eval loss measures the wrong thing for chat models. Fix: select on downstream metrics (judge score, win rate against a reference) rather than loss.
The mental model that unifies this section: SFT is a continued-pretraining run with three changes — the loss masks all non-assistant tokens, the LR is an order of magnitude smaller, and the data is curated demonstrations rather than the open web. Every other knob (optimiser, schedule, packing, decay) is the same machinery you used for pretraining, dialled to its ‘gentle and precise’ setting. Get the masking right and the LR right, and a 70B chat model is a few thousand training steps away.
Loading comments...