Chapter 13
25 min read
Section 73 of 117

SFT Data Collection

Supervised Fine-Tuning (SFT)

The Real Problem: A Fluent Model With No Manners

After trillions of tokens of pretraining, the base model has read more text than any human ever will. It can complete a sentence in any of forty languages, summarise a Wikipedia paragraph in the style of Hemingway, finish a Python function it has never seen, and recite the first hundred digits of pi. What it cannot do is answer a question.

Drop a base model into a chat box and type “What is the capital of France?”. A typical un-aligned Llama-2-7B base will produce one of three responses: it will continue the question (“What is the capital of France? What is the capital of Germany?...”), it will hallucinate a multiple-choice quiz around it, or it will dutifully complete the most common web-scraped continuation, which on a typical pretraining mix happens to be the start of an IELTS practice exercise. Paris appears nowhere. The model is not broken; it is doing exactly what we trained it to do — predict the next token under the empirical distribution of the web.

The web is not full of helpful answers. It is full of forum posts, listicles, half-finished tutorials, question stems without answers, and arguments. A base model trained on the web knows facts but not format. The mechanism that turns “a library of completion patterns” into “a useful assistant” is supervised fine-tuning, and the lever that turns SFT from a vague hope into a reliable engineering process is the data.

The thesis of this section. SFT is not where a model learns facts — that ship sailed when pretraining finished. SFT is where a model learns a conditional distribution over response shapes: given this kind of question, produce a response that looks like that. And in the modern era — post-LIMA, post-Tülu, post-DeepSeek-R1 — the quality of those few thousand demonstration responses dominates every other knob in the SFT pipeline. Quantity is cheap. Curation is the moat.

Why is this a problem worth a whole section? Because the instinct most engineers walk in with is wrong. People assume SFT scales like pretraining — more data, longer training, lower loss, better model. The empirical record from every frontier lab over the last three years is the exact opposite:

  • LIMA (Meta, 2023): a 65B model fine-tuned on 1,0001{,}000 carefully hand-curated examples reached the same human-preference win rate as the same base model fine-tuned on 52,00052{,}000 filtered Common Crawl examples. Quality > quantity by a 50× factor.
  • Tülu 2 (Allen AI, 2023): systematic ablations showed that adding low-quality data to a clean SFT mix decreases downstream eval scores by 2–6 points. More data is worse than no data when the floor is low.
  • DeepSeek-R1 (2025): the post-RL cold-start SFT step used roughly 800,000800{,}000 rejection-sampled examples — but only after the team threw away ~95% of the candidate pool. The keep rate, not the raw volume, is what trained the model.

Three independent teams, three different scales, one conclusion: SFT is a curation problem disguised as a training problem. The mathematics of the next-token loss is genuinely trivial — we will derive it in two equations. Everything difficult about SFT lives upstream of the train loop, in the data pipeline.

Intuition: A Style Demo, Not a Knowledge Injection

Here is the analogy that makes SFT click. Imagine a brilliant but socially-awkward librarian. They have read every book in the building and can recite passages from any of them on request. But when you walk in and say “Could you help me write a polite email declining a wedding invitation?”, they freeze. Not because they don't know what a polite email looks like — they have read thousands of them. Because they don't know that your question is a request for one. They might quote the etiquette section of a 1950s manners book at you. They might describe the historical evolution of the declined-RSVP genre. They might, in their confusion, just complete your sentence and ask “why?”.

SFT is the moment you hand this librarian a stack of three thousand index cards. On the front of each card: a question someone might ask. On the back: a model answer in the voice and shape you want. You do not teach them new facts — they already know everything on the back of every card. You teach them which face goes with which face. After enough cards, the next time someone walks in and says “polite email declining a wedding invitation”, the librarian instinctively flips to the back of the card, not the front.

The geometric picture. Pretraining built a massive low-perplexity manifold over all of human text. SFT does not move that manifold; it carves a small response sub-manifold inside it, and teaches the model that whenever the input has chat-template shape, the output should live on that sub-manifold. The sub-manifold is defined entirely by your demonstration examples — which is why three thousand carefully chosen examples can beat fifty thousand sloppy ones. The sloppy ones carve a sub-manifold that includes “polite email” and also “rude email”, “passive-aggressive email”, and “legal disclaimer accidentally formatted as an email”. The model has no way of preferring one over another.

This is also why the SFT data pipeline looks nothing like a pretraining data pipeline. Pretraining wants volume and breadth: every domain, every register, every language, hundreds of billions of tokens. SFT wants precision and consistency: every example needs to be on the response sub-manifold you want, written in the voice you want, formatted the way you want. A single bad example — a hallucinated fact, a refusal where there shouldn't be one, a sycophantic preamble — tells the model that your sub-manifold includes that point too. Multiply by a few hundred such examples and the sub-manifold gets so blurry that the model defaults back to its base behaviour for any prompt that does not sit exactly on a demonstrated point.

The practical consequence is uncomfortable for anyone who has spent a career scaling data pipelines: SFT data quality is a manual problem, all the way down. Every successful frontier lab has at some point in their post-training pipeline a group of expert annotators rewriting model outputs by hand. Synthetic data from a stronger model helps. Rejection sampling helps. But the floor is set by the worst example in the corpus, and lowering that floor requires human eyes on every candidate.

The Math: Cross-Entropy on the Assistant Turn Only

Strip away the curation drama and the SFT loss is the same next-token cross-entropy that drove pretraining. The only difference is which positions contribute to the average. Let x=(x1,x2,,xT)x = (x_1, x_2, \ldots, x_T) be the flat token sequence after applying the chat template — system prompt, then user turn, then assistant turn, optionally more turns. Let m{0,1}Tm \in \{0, 1\}^T be the loss mask: mt=1m_t = 1 if token xtx_t belongs to an assistant turn, and mt=0m_t = 0 otherwise.

The SFT loss for this single example is LSFT(θ;x,m)=1tmtt=1T1mt+1logpθ(xt+1xt)\mathcal{L}_{\text{SFT}}(\theta; x, m) = -\frac{1}{\sum_t m_t}\sum_{t=1}^{T-1} m_{t+1}\, \log p_\theta(x_{t+1} \mid x_{\leq t}). The right-hand sum is the standard next-token cross-entropy. The mt+1m_{t+1} factor turns off the loss everywhere the target token is not an assistant token — i.e. we do not penalise the model for being wrong about the user's next word, only about its own. The normaliser tmt\sum_t m_t divides by the count of gradient-bearing positions so that the loss is comparable across examples with different assistant-token densities.

The classic SFT bug. If you forget the mask and average over all tokens, you train the model to autocomplete the user question. The loss looks normal, perplexity goes down, training is stable — and the resulting model continues prompts instead of answering them. This bug has shipped to internal demos at multiple frontier labs. The unit test is one line: assert mask.sum()<0.6T\text{assert mask.sum()} < 0.6 \cdot T for any chat dataset (assistant tokens are typically 20–40% of the sequence; anything above 60% means the mask is wrong).

Three quantities derived from this loss matter for the data side of SFT. First, the effective token count Neff=eDtmt(e)N_{\text{eff}} = \sum_{e \in \mathcal{D}} \sum_t m_t^{(e)} — the total number of supervised positions across the corpus. This, not the raw example count, is what your one-epoch compute budget should be sized against. A corpus of 100,000100{,}000 single-turn chat examples with average 150-token assistant responses gives Neff1.5×107N_{\text{eff}} \approx 1.5 \times 10^7 — fifteen million gradient-bearing tokens, or about 0.0001% of a typical pretraining run.

Second, the per-category effective contribution Nc=eDctmt(e)N_c = \sum_{e \in \mathcal{D}_c} \sum_t m_t^{(e)}. Long-reasoning examples contribute proportionally more than single-line refusals. A naive 50/50 example-count split between “math” and “safety” might be a 90/10 split in NcN_c terms, because a math chain-of-thought is 5–10× longer than a calibrated refusal. Always report mix proportions in effective tokens, not example counts.

Third, the quality-weighted effective count Q=eqetmt(e)Q = \sum_e q_e \sum_t m_t^{(e)} where qe[0,1]q_e \in [0, 1] is a per-example quality score (from a learned quality classifier, a human rating, or a rejection-sampling rank). The LIMA result, in these terms, says that QQ dominates eval performance once it exceeds a small threshold; pushing NeffN_{\text{eff}} up at the cost of average qeq_e is a net loss.

QuantitySymbolTypical SFT valueWhat it controls
Examples|𝒟|1K – 1Mlabour cost
Effective tokensN_eff0.5M – 500Mcompute / one-epoch step count
Per-category fractionN_c / N_effanydownstream skill distribution
Quality-weighted countQcontext-dependenteval-score uplift (LIMA)

Manual Numerical Walkthrough

Click to expand: building the loss mask on a 12-token toy chat and computing one SFT step

Take a tiny chat-template example with one user turn and one assistant turn. Use a stylised tokeniser where each word maps to one token and we have role tags [USR] and [AST] plus an end-of-turn marker [EOT]. The conversation is:

User: “What is two plus three?”
Assistant: “Five.”

Step 1: tokenise + emit the mask.

ttokenrolem_t
1[USR]tag0
2Whatuser0
3isuser0
4twouser0
5plususer0
6threeuser0
7?user0
8[EOT]user-EOT0
9[AST]tag0
10Fiveassistant1
11.assistant1
12[EOT]assistant-EOT1

Notice exactly which positions get mt=1m_t = 1: the assistant content tokens (10, 11) and the assistant-EOT (12), but not the [AST] role tag (9). The model should learn to produce the answer and the stop signal, but it should never be supervised on the role tag — the tokeniser places that automatically.

Step 2: count the gradient-bearing positions. tmt=3\sum_t m_t = 3. Out of 12 tokens, only 3 contribute to the loss — a 25% supervision density. This is squarely in the “healthy” band; a multi-turn reasoning example with 600-token chains-of-thought and 30-token user turns would hit 95% supervision density, and a single-line refusal example would hit 10%.

Step 3: compute the per-token NLL. Pretend the model assigns the following probabilities to the true next-token at each supervised position (after a single SFT step from the base model):

predict at ttarget tokenp(target)−log p
9 → 10Five0.043.22
10 → 11.0.550.60
11 → 12[EOT]0.281.27

Step 4: average over supervised tokens. L=(3.22+0.60+1.27)/3=5.09/31.70\mathcal{L} = (3.22 + 0.60 + 1.27) / 3 = 5.09 / 3 \approx 1.70. That is the SFT loss for this example.

Step 5: contrast with the buggy version. If you forgot the mask and averaged over all 12 tokens with a base-model-like probability of 0.1 for each non-assistant token (typical for “the next word in a random web-scraped sequence”), you would get an additional 9 contributions of ~2.3 each and the final “loss” would be roughly (5.09+92.3)/122.15(5.09 + 9 \cdot 2.3) / 12 \approx 2.15 — higher in absolute terms but dominated by noise from the unsupervised positions, and the gradient signal pointed in the wrong direction. The model would “learn” to predict the user's next word better, which is precisely the wrong behaviour.

Step 6: the corpus view. Multiply this by ten thousand examples. The healthy SFT corpus has Neff10,000×3(avg multiplier50)=1.5M tokensN_{\text{eff}} \approx 10{,}000 \times 3 \cdot (\text{avg multiplier} \approx 50) = 1.5\,\text{M tokens} of gradient — 4 minutes of compute on an 8×H100 node at BF16. Compute-cheap. Curation-expensive.

Interactive: Designing the Mix

The math says it; the widget shows it. Slide the budget, slide the curation quality, slide the per-category mix. Three things change at once:

  • The stacked share bar shows what proportion of your example budget goes to each category. The numbers on each segment are the absolute example counts.
  • The response-length histogram stacks the assistant-turn-length distributions of each category. Math and code skew right; refusals huddle on the left. This is where most of your NeffN_{\text{eff}} actually lives.
  • The metrics panel shows the gradient token count, a diversity score, and an estimated eval uplift. The uplift curve is the empirical S-shape from LIMA / Tülu / R1 — quality dominates past a few thousand examples.
Loading SFT data-mix designer…

Three exercises to try with the widget. (a) Set the budget to 1K and crank quality to 100%. Note the eval uplift. Now set the budget to 250K and drop quality to 30%. The gradient-token count is 250× bigger, but the predicted uplift is roughly the same — the curated 1K beats the sloppy 250K. This is the LIMA result rendered in pixels. (b) Push the math+code share to 80% and the chat share to 5%. The histogram leans hard to the right. Models trained on this mix are great at reasoning and terrible at small talk; this is exactly the failure mode that pushed OpenAI to publish the InstructGPT mix proportions in 2022. (c) Crank safety+refusals to 50%. The histogram piles up on the left (refusals are short) and the diversity score plummets. The resulting model refuses everything — the “over-refusal tax” that haunted Claude 2 and Llama-2-Chat at release.

Plain Python: Mask, Dedupe, Decontaminate

Before any GPU sees an example, the SFT pipeline does three things in CPU-land: it builds the assistant-only loss mask, it removes near-duplicate prompts, and it filters out any example that overlaps an eval benchmark. The code below implements all three from scratch, with no PyTorch and no framework magic — every line is the literal algorithm the frontier labs publish in their data-prep appendices.

From-scratch SFT data preparer
🐍sft_prep.py
28Turn — the atomic unit of SFT data

Every modern chat dataset is a list of (role, content) turns. The role is one of 'system', 'user', or 'assistant'. The whole supervised-fine-tuning game hinges on one rule: gradient flows ONLY when the role is 'assistant'. Everything else — the system prompt that frames the run, the user question that prompts it, the role tags themselves — is conditioning context. The model reads it, but is not graded on producing it.

EXECUTION STATE
role = 'system' | 'user' | 'assistant'
content = raw UTF-8 string, no template formatting yet
33tokenize_and_mask — the most consequential 12 lines in SFT

We walk the turns in order, encoding each one and emitting a parallel loss_mask. The mask is 0 on every system and user token, 0 on the role-tag tokens, and 1 on every assistant token (including the end-of-turn marker — the model must learn when to stop). If you flip this and train on the user turn too, you teach the model to autocomplete questions instead of answering them; this is the classic 'SFT bug' that has wasted weeks of compute at more than one lab.

EXECUTION STATE
ids = flat list of int token IDs across all turns
mask = parallel list of 0/1; sum(mask) = # of gradient-bearing tokens
43EOT on an assistant turn is part of the answer

A subtle point that costs perplexity if you get it wrong. The end-of-turn (EOT) token that closes an assistant turn IS supervised. The model needs to learn 'when to stop talking' just as much as it needs to learn the words in between. Mask EOT to 0 and you get a model that rambles past the answer at inference time — a failure that looks like 'verbosity' but is really just a one-bit data-prep bug.

EXECUTION STATE
mask[-1] = 1 on assistant EOT — supervises stop-token prediction
50shingles — the input to every dedup pipeline

k-character shingles (k=5 is the standard for English text) give us a bag-of-substrings representation that is robust to minor edits. Two paraphrases of the same prompt overlap heavily in shingle space even if every word is reordered. Lowercase + whitespace normalisation strips the easy variations the data-augmentation step would otherwise hide behind.

EXECUTION STATE
k = 5 characters per shingle (industry default for English)
return = set of strings, |set| ≈ len(text) - k + 1 before dedup
55MinHash — Jaccard similarity in 64 integers

MinHash maps a set to a fixed-length signature such that the fraction of equal positions between two signatures is an unbiased estimate of the Jaccard similarity of the original sets. We pay a one-time O(|shingles| · num_perms) hash cost per example, then any pair of examples can be compared in O(num_perms). For SFT corpora of millions of examples, this turns a quadratic problem into a linear-scan-plus-LSH problem.

EXECUTION STATE
num_perms = 64 (good enough; Tülu uses 128, R1 uses 256)
sigs[p] = the min over all shingles of the p-th hash — the 'min-hash'
71jaccard_estimate — what the signatures buy you

Fraction of equal signature positions ≈ Jaccard(A, B) = |A∩B| / |A∪B|. Threshold 0.85 is what LLaMA-2's data team published as 'aggressive but not destructive' for instruction data. Below 0.7 you start keeping near-paraphrases; above 0.95 you start letting through templated variants like 'Translate hello to French' vs 'Translate hello to Spanish'.

EXECUTION STATE
threshold = 0.85 (industry default for instruction data)
77dedupe — greedy single-pass with kept signatures

A clean two-pass algorithm: signature first, compare second. For each new example we compare against every already-kept example. For large corpora this is replaced by Locality-Sensitive Hashing (LSH) with banded buckets, but the semantics are the same: a single representative survives, the rest are dropped. The order matters — earlier examples win — so reputable pipelines pre-sort by a quality score before deduping.

EXECUTION STATE
keep = indices of survivors; |keep| typically 30–70% of |examples|
87ngrams — the eval-set decontamination unit

A 13-token contiguous word n-gram is the standard 'fingerprint' for benchmark contamination, published by Meta in the LLaMA-2 paper and adopted by every frontier lab since. Why 13? It is short enough to catch genuine leakage of test questions and long enough that natural overlap (common phrases like 'on the other hand') is extremely rare. We tokenise on word boundaries, not subword tokens — contamination should not be evaded by a different tokeniser.

EXECUTION STATE
n = 13 (LLaMA-2 / GPT-4 / DeepSeek convention)
93is_contaminated — fast set-intersection check

Python set intersection is O(min(|A|, |B|)). With ~10⁶ eval n-grams pre-indexed once, checking a single example is microseconds. The shocking statistic from the LLaMA-2 paper: even after this filter, ~4% of their carefully curated SFT data still tripped on a 13-gram match. Without the filter, that number would have polluted every public benchmark they reported.

99prepare_sft_corpus — the three filters in their canonical order

Decontamination first (cheap, catches the most damaging issue), dedup second (expensive but order-insensitive within categories), tokenise+mask last (only on data that has already passed both filters — no wasted work). Every frontier lab's pipeline has this same skeleton plus more: quality scoring, length filtering, PII scrubbing, toxicity filtering, and language-ID checks. The skeleton is the same.

110Dedupe only on the user prompt, not the full example

A subtle but important choice. We hash the USER question, not the full turn list. Why? Two different instruction-following examples might have very different (good!) assistant answers but the SAME user question — keep both is a waste because the prompt distribution is what we're trying to diversify. Conversely, two examples with the same user question and different assistant answers are usually rejection-sampling siblings, and we want exactly one per prompt.

123 lines without explanation
1"""
2A from-scratch SFT data preparer. Three jobs:
3
4  1. Tokenise a chat-template conversation into a flat token sequence AND
5     a parallel 'loss_mask' that is 1 only on the assistant-turn tokens.
6     This is the single most consequential design choice in SFT — the
7     model should learn to PRODUCE the assistant turn, not to PARROT the
8     user prompt.
9
10  2. Near-duplicate filter with MinHash + LSH. Two examples that share
11     more than a Jaccard similarity of, say, 0.85 are almost certainly
12     templated variants of the same prompt. Keeping both is dead weight
13     at best and a memorisation hazard at worst.
14
15  3. Decontamination against eval benchmarks. If your SFT set leaks the
16     test set of MMLU, GSM8K, or HumanEval, your headline numbers are
17     meaningless. We check 13-gram overlap, the same n-gram size every
18     frontier lab now publishes against.
19
20Nothing here imports torch. The whole pipeline runs on the CPU in the
21data-prep stage, hours before any GPU sees an example.
22"""
23from __future__ import annotations
24import hashlib, random, re
25from dataclasses import dataclass
26from typing import Iterable
27
28# ─── 1. Assistant-only loss masking ─────────────────────────────────────
29@dataclass
30class Turn:
31    role: str          # 'system', 'user', or 'assistant'
32    content: str
33
34def tokenize_and_mask(
35    turns: list[Turn],
36    tokenizer,           # any HF-style tokenizer with .encode(str) -> list[int]
37    role_tokens: dict[str, int],   # special tokens that mark each role boundary
38    eot_token: int,                # end-of-turn marker
39) -> tuple[list[int], list[int]]:
40    """Return (token_ids, loss_mask) of equal length.
41    loss_mask[i] = 1 ⇔ token i belongs to an assistant turn (gradient flows).
42    All other tokens (system prompt, user turn, role markers) get 0."""
43    ids:  list[int] = []
44    mask: list[int] = []
45    for turn in turns:
46        ids.append(role_tokens[turn.role])  # role tag — NEVER gets gradient
47        mask.append(0)
48        body = tokenizer.encode(turn.content, add_special_tokens=False)
49        ids.extend(body)
50        # gradient ONLY on assistant content tokens
51        mask.extend([1 if turn.role == "assistant" else 0] * len(body))
52        ids.append(eot_token)
53        # EOT after an assistant turn: model should learn to stop here.
54        mask.append(1 if turn.role == "assistant" else 0)
55    return ids, mask
56
57# ─── 2. MinHash near-duplicate filter ───────────────────────────────────
58def shingles(text: str, k: int = 5) -> set[str]:
59    """k-character shingles, lowercased, whitespace-collapsed."""
60    t = re.sub(r"\s+", " ", text.strip().lower())
61    return {t[i : i + k] for i in range(max(0, len(t) - k + 1))}
62
63def minhash(shingle_set: set[str], num_perms: int = 64, seed: int = 0) -> list[int]:
64    """Tiny pure-Python MinHash. Each permutation is just a salted SHA-1."""
65    sigs = [(1 << 63) - 1] * num_perms
66    for sh in shingle_set:
67        h0 = hashlib.sha1(sh.encode("utf-8")).digest()
68        for p in range(num_perms):
69            # Cheap re-hash per permutation: salt with the perm index.
70            h = int.from_bytes(
71                hashlib.sha1(p.to_bytes(2, "big") + h0).digest()[:8],
72                "big",
73            )
74            if h < sigs[p]:
75                sigs[p] = h
76    return sigs
77
78def jaccard_estimate(sig_a: list[int], sig_b: list[int]) -> float:
79    """Estimated Jaccard similarity ∈ [0, 1] from two MinHash signatures."""
80    eq = sum(1 for a, b in zip(sig_a, sig_b) if a == b)
81    return eq / len(sig_a)
82
83def dedupe(examples: list[str], threshold: float = 0.85) -> list[int]:
84    """Return indices to KEEP after near-duplicate removal."""
85    sigs = [minhash(shingles(x)) for x in examples]
86    keep, kept_sigs = [], []
87    for i, s in enumerate(sigs):
88        if any(jaccard_estimate(s, k) >= threshold for k in kept_sigs):
89            continue
90        keep.append(i)
91        kept_sigs.append(s)
92    return keep
93
94# ─── 3. 13-gram decontamination ─────────────────────────────────────────
95def ngrams(text: str, n: int = 13) -> set[tuple[str, ...]]:
96    toks = re.findall(r"\w+", text.lower())
97    return {tuple(toks[i : i + n]) for i in range(max(0, len(toks) - n + 1))}
98
99def is_contaminated(example: str, eval_ngrams: set[tuple[str, ...]]) -> bool:
100    """True if ANY 13-token contiguous span from the example appears in
101    the evaluation set. The frontier-lab standard."""
102    return bool(ngrams(example) & eval_ngrams)
103
104# ─── 4. Putting it all together ─────────────────────────────────────────
105def prepare_sft_corpus(
106    raw: Iterable[dict],          # each dict has 'turns' and 'meta'
107    tokenizer,
108    role_tokens, eot_token,
109    eval_texts: list[str],
110) -> list[dict]:
111    # Build the eval-set 13-gram index ONCE.
112    eval_grams: set[tuple[str, ...]] = set()
113    for t in eval_texts:
114        eval_grams |= ngrams(t)
115
116    examples = list(raw)
117    # 1. drop anything that touches an eval n-gram.
118    examples = [
119        e for e in examples
120        if not is_contaminated(" ".join(t["content"] for t in e["turns"]),
121                               eval_grams)
122    ]
123    # 2. dedupe on the user-turn prompt only.
124    user_prompts = [next(t["content"] for t in e["turns"] if t["role"] == "user")
125                    for e in examples]
126    keep = dedupe(user_prompts, threshold=0.85)
127    examples = [examples[i] for i in keep]
128    # 3. tokenise + mask.
129    out = []
130    for e in examples:
131        turns = [Turn(t["role"], t["content"]) for t in e["turns"]]
132        ids, mask = tokenize_and_mask(turns, tokenizer, role_tokens, eot_token)
133        out.append({"input_ids": ids, "loss_mask": mask, "meta": e["meta"]})
134    return out

The whole file is ~100 lines and runs on a laptop. For a million-example SFT corpus the dedup step is the only non-trivial cost: O(N · num_perms) hashing plus O(N²) pair comparisons. Production pipelines (datatrove, dolma, nemotron) replace the inner loop with banded LSH — split the signature into bands, hash each band into a bucket, and only compare examples that collide in at least one bucket. That turns the quadratic step into a near-linear pass at the cost of a bit of recall.

The single best line in this file. Line 43, the EOT-on-assistant masking. Half the SFT bugs we have seen in the wild traced to teams that masked the assistant content tokens but forgot to mask the EOT, or vice versa. The model either rambled past the answer or stopped before completing it. One byte of mask, days of debugging.

PyTorch: A Production-Style SFT Dataset

Once the corpus is filtered, deduped, and decontaminated, the PyTorch side is the easy half. We need three things: an IterableDataset that streams JSONL lines and emits packed windows, a masked cross-entropy loss that averages only over assistant tokens, and a tokeniser that knows how to apply the model's chat template. The 90 lines below cover all three.

PyTorch SFT dataset with packing and assistant-only masking
🐍sft_dataset.py
23An IterableDataset, not a map-style one

SFT corpora are heavy (tens to hundreds of GB after tokenisation) and almost always exceed worker RAM if loaded eagerly. We extend IterableDataset so the train loop streams JSONL lines on demand and can shard cleanly across DataLoader workers. The map-style alternative would need a full random-access index — fine for small SFT runs but a footgun the moment you cross 100K examples.

25Constructor inputs — only four numbers that matter

max_seq_len is the single hyperparameter every SFT run lives and dies by. Too small (≤1024) and you truncate every reasoning example; too large (≥8192) and you waste compute on padding. The frontier-lab consensus is 4096 for chat-heavy mixes and 8192 once you add long math/code reasoning. pad_id matters because some tokenisers (Llama-3) ship without a pad token and you must register one before this dataset is safe to use.

EXECUTION STATE
max_seq_len = 4096 (chat), 8192 (reasoning), 16384 (R1-class)
pad_id = tokeniser-specific; for Llama-3 it must be REGISTERED first
34_encode_one — applying the chat template, turn by turn

This is the per-example workhorse. We re-render the chat template after each turn and diff against the previous render to figure out which tokens are 'new'. Those new tokens belong to exactly one turn, so we can label them as assistant (mask=1) or non-assistant (mask=0). Doing this turn-by-turn is the only template-agnostic way to get accurate assistant-only masking — Llama-3, Qwen-3, and DeepSeek-V3 chat templates differ in ways that any 'split by string' shortcut will eventually break on.

41apply_chat_template — the standard HF entrypoint

Every modern HF tokeniser ships with a chat_template Jinja string. apply_chat_template walks it deterministically: render the system message, the first user turn, the first assistant turn, etc., adding all the special tokens (<|im_start|>, <|im_end|>, etc.) along the way. By calling it with progressively longer prefixes of ex['turns'], we get a strictly-increasing token sequence — the difference between consecutive renders is what was added by the latest turn.

EXECUTION STATE
prefix = growing list of token IDs after each chat-template render
new_tokens = slice of prefix that was added by THIS turn
50Truncation by simple slicing

We truncate the encoded sequence at max_seq_len. This is intentional — we want the packer downstream to see ragged-but-bounded chunks, not a long tail of giant examples that would dominate the loss. The frontier-lab heuristic is: if your truncation discards >5% of your assistant tokens, your max_seq_len is too small; raise it before you blame the model.

53Worker sharding by line number modulo worker count

A clean round-robin shard: worker w reads lines where lineno % num_workers == w. No coordination needed, no overlap, no missed lines. The trade-off is that adjacent lines on disk end up on different workers, which can defeat OS-level read-ahead — but for JSONL on SSD this rarely matters in practice. NEVER shard by random sampling without coordination; you will silently duplicate examples and skew the loss.

EXECUTION STATE
wid = this worker's index (0 to num_workers-1)
wn = total workers; sharding modulus
65Packing into rolling L-token windows

The single biggest throughput optimisation in modern SFT. Without packing, every example becomes a separate sequence padded to L; with chat-heavy mixes (median response ~150 tokens) this wastes 90%+ of every batch. With packing we concatenate examples in the order they arrive and yield a new window every time the buffer crosses L. The downside is that one window may contain pieces of two unrelated conversations — but the loss_mask handles that cleanly because the gradient only flows on assistant content.

LOOP TRACE · 2 iterations
after example 1 (ids=120, mask=80)
len(buf_ids) = 120
yielded? = no (buf < L=4096)
after example 30 (ids ≈ 4100)
len(buf_ids) = ≈ 4100
yielded? = YES — emit first 4096, keep ≈4 tokens for next window
77Packing the tail — do not throw the last window away

When the corpus ends with a partial window we pad and yield it. Throwing it away biases the corpus toward longer examples (because long examples are over-represented in completed-window memberships). The pad tokens get loss_mask=0 so they contribute zero gradient. This 'one extra padded window per epoch' is a free fix for a sneaky bias that has shipped to production at more than one lab.

84_pack — yield a triple, not a pair

We yield input_ids, labels (== input_ids for next-token teacher forcing), and loss_mask. The train loop will shift these to produce the standard next-token loss; the loss_mask is what makes this SFT and not pretraining. Many open-source SFT trainers (axolotl, trl) accept exactly this triple, which is why we ship it in this shape.

EXECUTION STATE
input_ids = shape [L], dtype long
labels = clone of input_ids (next-token teacher forcing)
loss_mask = shape [L], dtype float32, 1.0 on assistant tokens
93masked_cross_entropy — the actual SFT loss

Standard causal-LM cross-entropy with one twist: we average over MASKED tokens, not over all tokens. If you average over all tokens the loss is artificially small (because 80%+ of tokens have mask=0 and contribute zero to the sum but inflate the denominator). Averaging over masked tokens makes the loss comparable across batches of different assistant-density, which is essential for any LR schedule that responds to absolute loss values.

100The 'shift by one' for next-token prediction

Causal language modelling predicts token t+1 from tokens [0..t]. So we slice logits to [:, :-1, :] (drop the last position — we have no target for it) and labels to [:, 1:] (drop the first — there is no input that predicted it). The loss_mask shifts the same way: a position is supervised iff the TARGET token is an assistant token.

EXECUTION STATE
logits[:, :-1, :] = shape [B, T-1, V] — predictions for positions 0..T-2
labels[:, 1:] = shape [B, T-1] — true next-tokens for those positions
mask[:, 1:] = shape [B, T-1] — 1.0 iff the TARGET is assistant
110Normalise by the number of supervised tokens

denom = mask.sum() is the count of gradient-bearing positions in this batch. clamp(min=1.0) guards against an all-zero mask (a batch with only system+user turns — should not happen in a healthy corpus, but defensive code matters at scale). The returned loss is mean-per-supervised-token cross-entropy, exactly the quantity papers report.

EXECUTION STATE
denom = sum of mask; typically 20-40% of B*T for chat data
return = mean nll over assistant tokens — comparable across batches
108 lines without explanation
1"""
2A PyTorch IterableDataset for SFT. Three production-grade behaviours:
3
4  - Packs multiple short examples into one fixed-length sequence so we
5    do not waste GPU FLOPs on padding tokens — at small-batch SFT, the
6    pad waste is the single biggest source of throughput loss.
7  - Streams tokenisation lazily so the worker fleet can keep up with
8    a multi-TB corpus without loading it into RAM.
9  - Yields BOTH input_ids and a loss_mask, so the train loop applies
10    cross-entropy only on assistant tokens.
11
12Compatible with Hugging Face tokenisers (LLaMA, Mistral, Qwen,
13DeepSeek, Phi all work) and any chat template that can be applied
14turn-by-turn.
15"""
16import json
17from typing import Iterator
18import torch
19from torch.utils.data import IterableDataset
20
21class SftPackedDataset(IterableDataset):
22    def __init__(
23        self,
24        jsonl_path: str,
25        tokenizer,                # HF AutoTokenizer with chat_template set
26        max_seq_len: int = 4096,
27        pad_id: int | None = None,
28        seed: int = 0,
29    ):
30        super().__init__()
31        self.path = jsonl_path
32        self.tok = tokenizer
33        self.L = max_seq_len
34        self.pad_id = pad_id if pad_id is not None else tokenizer.pad_token_id
35        self.seed = seed
36
37    def _encode_one(self, ex: dict) -> tuple[list[int], list[int]]:
38        """Run the chat template on a single example and emit the
39        parallel loss mask. We re-apply the template per turn so we know
40        EXACTLY which slice corresponds to assistant content — the
41        tokenizer.apply_chat_template(..., return_assistant_tokens_mask=True)
42        path also works but breaks on some tokenisers, so we do it by hand."""
43        ids: list[int] = []
44        mask: list[int] = []
45        # Render with NO assistant turns first; everything so far is context.
46        for i, turn in enumerate(ex["turns"]):
47            prefix = self.tok.apply_chat_template(
48                ex["turns"][: i + 1],
49                tokenize=True,
50                add_generation_prompt=False,
51            )
52            new_tokens = prefix[len(ids):]
53            ids.extend(new_tokens)
54            mask.extend(
55                [1] * len(new_tokens) if turn["role"] == "assistant"
56                else [0] * len(new_tokens)
57            )
58        return ids[: self.L], mask[: self.L]
59
60    def __iter__(self) -> Iterator[dict]:
61        """Pack examples into rolling windows of length L. We yield a
62        new window every time the running buffer reaches L tokens. The
63        last partial window is padded and yielded too — losing it
64        biases the loss toward longer examples."""
65        worker = torch.utils.data.get_worker_info()
66        wid = worker.id if worker else 0
67        wn  = worker.num_workers if worker else 1
68        rng = torch.Generator().manual_seed(self.seed + wid)
69
70        buf_ids:  list[int] = []
71        buf_mask: list[int] = []
72        with open(self.path) as f:
73            for lineno, line in enumerate(f):
74                if lineno % wn != wid:    # round-robin shard across workers
75                    continue
76                ex = json.loads(line)
77                ids, mask = self._encode_one(ex)
78                buf_ids.extend(ids)
79                buf_mask.extend(mask)
80                while len(buf_ids) >= self.L:
81                    yield self._pack(buf_ids[: self.L], buf_mask[: self.L])
82                    buf_ids  = buf_ids[self.L:]
83                    buf_mask = buf_mask[self.L:]
84            if buf_ids:                    # tail end — pad & emit
85                pad = self.L - len(buf_ids)
86                yield self._pack(buf_ids + [self.pad_id] * pad,
87                                 buf_mask + [0] * pad)
88
89    def _pack(self, ids: list[int], mask: list[int]) -> dict:
90        x = torch.tensor(ids,  dtype=torch.long)
91        m = torch.tensor(mask, dtype=torch.float32)
92        return {
93            "input_ids":   x,
94            "labels":      x.clone(),   # teacher-forced next-token target
95            "loss_mask":   m,
96        }
97
98
99def masked_cross_entropy(logits: torch.Tensor,
100                         labels: torch.Tensor,
101                         loss_mask: torch.Tensor) -> torch.Tensor:
102    """Standard next-token loss, but averaged over loss_mask=1 tokens only.
103    Shapes:
104      logits     [B, T, V]
105      labels     [B, T]
106      loss_mask  [B, T]   1.0 on supervised tokens, 0.0 elsewhere
107    """
108    # Standard 'shift by one' for causal LM teacher forcing.
109    logits = logits[:, :-1, :].contiguous()
110    labels = labels[:,  1:].contiguous()
111    mask   = loss_mask[:, 1:].contiguous()
112
113    log_probs = torch.nn.functional.log_softmax(
114        logits.float(), dim=-1
115    )
116    nll = -log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)  # [B, T-1]
117
118    # Sum over masked tokens, normalise by the number of masked tokens.
119    denom = mask.sum().clamp(min=1.0)
120    return (nll * mask).sum() / denom

Two engineering details that this code gets right and most first-attempts get wrong. Packing: concatenating multiple short examples into one fixed-length window. Without packing, every example becomes its own padded sequence, and the median chat example wastes 70–90% of every batch on pad tokens. With packing, GPU utilisation on an 8×H100 SFT run jumps from ~40% to ~85% on the same corpus. Tail emission: yielding the final partial window rather than discarding it. Discarding biases the corpus toward longer examples and is the second-most common “mysterious loss drift after epoch 1” bug.

Packing has a sharp edge. Two packed examples sharing one window means the cross-attention from example B can technically attend back to example A's tokens through the causal mask. The loss mask blocks the gradient, but the model still computes attention across the boundary, which is statistically a tiny contamination. Modern trainers (DeepSpeed-Chat, axolotl, Llama-Factory) ship with a document-attention mask that zeros out cross-document attention exactly. Use it. The throughput cost is negligible (one extra elementwise multiply per attention layer) and it eliminates the boundary effect entirely.

At Massive Scale: How the Frontier Labs Actually Do It

The from-scratch pipeline above is the minimum viable SFT prep. At frontier scale — Llama, Claude, GPT-4, DeepSeek-V3, Gemini, Qwen-3 — the same skeleton holds, but several production-grade layers are added on top. We will walk through each one with concrete numbers from public papers and post-mortems.

The four data sources, ranked by quality bar

  1. Hand-written by domain experts. 1K–10K examples per skill. The most expensive ($5–25 per example) and highest-quality source. Used for safety-critical skills (refusal calibration, instruction following, system-prompt adherence) and for “seed” examples that anchor a synthetic pipeline. LIMA's original 1K was this. Anthropic's “constitution authors” produce this. OpenAI's contractor fleet produces this at scale.
  2. Distilled from a stronger model. A few thousand to a few million examples generated by a frontier model (often the previous generation of the same family) and filtered with a smaller reward or quality model. Tülu 3 uses this for ~40% of its mix. DeepSeek-R1's cold-start SFT is exclusively this, drawn from R1-Zero outputs. Per-example cost is dollars of compute, not labour — a 100× cost reduction over hand-writing.
  3. Rejection-sampled from the model itself. Generate KK candidate answers per prompt, score with a reward model or rule-based checker, keep only the top-1. DeepSeek-R1 used K=64K = 64 for reasoning, keeping ~5% of candidates. Llama-3 used K=30K = 30 on instruction-following prompts. The keep rate, not the candidate count, is the headline quality knob.
  4. Open instruction datasets. ShareGPT, UltraChat, OpenOrca, NoRobots, WildChat. Cheap (free) and large (1M+ examples) but quality is the lowest of the four tiers and the contamination risk is the highest. Frontier labs use these as a base layer or drop them entirely; small-team SFT runs lean on them heavily.

The canonical mix proportion (effective tokens, not examples)

CategoryLlama-3 SFTTülu 3 SFTPhi-3 SFTDeepSeek-R1 SFT
General chat / IF~50%~30%~25%~20%
Math / reasoning~15%~25%~30%~45%
Code~14%~20%~25%~20%
Multilingual~8%~5%~5%~5%
Tool use~5%~10%~5%~5%
Safety / refusal~8%~10%~10%~5%

Two things to notice. First, the mix shifts dramatically with the model's intended workload: R1 is a reasoning model and its mix is reasoning-heavy; Llama-3 is a general-purpose assistant and its mix is chat-heavy. Second, the safety fraction is in single-digit percent across the board. The instinct “just add more safety data” is wrong — past ~10% safety share, the model starts refusing benign prompts and the over-refusal tax kicks in. Anthropic's Claude 3.5 sat closer to 6%6\%; OpenAI's GPT-4 turbo post-2024 sits closer to 4%4\%.

The rejection-sampling pipeline (the modern data factory)

The most consequential shift in SFT data collection from 2023 to 2025 has been the rise of self-distillation via rejection sampling. The pipeline:

  1. Take a base model and a set of high-quality prompts (the prompt pool can be hand-written, scraped, or model-generated).
  2. Generate K[16,256]K \in [16, 256] candidate completions per prompt with a high sampling temperature.
  3. Score each candidate with a stack of filters: a reward model (learned), a rule-based checker (regex, execution, equality), an LLM-as-judge pass, sometimes a human spot-check on a sub-sample.
  4. Keep the top-1 candidate per prompt — or drop the prompt entirely if no candidate clears the bar.
  5. The surviving (prompt, best-completion) pairs are the SFT corpus.

Why is this so effective? Because the model itself, sampled enough times, contains a higher-quality response than any single generation. Rejection sampling is just importance-weighted maximum-likelihood: we approximate the policy “sample, then take the argmax under the reward” with a finite-sample top-1, and SFT on the result. DeepSeek-R1 scaled this to 600,000\sim 600{,}000 reasoning prompts with K=64K = 64 candidates each, keeping the top-1 after a multi-stage filter — the surviving 600K examples are the SFT corpus that gave R1 its non-RL reasoning baseline. The compute cost was 64×64 \times the cost of single-shot inference; the labour cost was zero.

The economic shape of modern SFT. In 2022, every SFT example cost human-labour money (~$5–$25 each) and the constraint was annotator throughput. In 2025, the dominant cost line is candidate generation compute, which is roughly K×K \times the cost of normal inference. A 600K-example R1-style corpus at K=64K = 64 costs on the order of $50–100K of generation compute and zero annotator hours. This is why the keep rate, not the candidate count, is the lever: doubling KK doubles cost; tightening the keep filter from 10% to 5% halves corpus size with roughly constant cost.

Decontamination at scale

Every published frontier model since GPT-4 reports decontamination against a public eval suite. The standard protocol — 13-gram word overlap, per the LLaMA-2 paper — is now industry-canonical. Llama-3's data team reported that their SFT corpus, before filtering, had a 13-gram match against 4%\sim 4\% of the GSM8K test set, 6%\sim 6\% of MMLU, 11%\sim 11\% of HumanEval, and nontrivially against every benchmark they tested. After filtering, those numbers go to zero. The cost is dropping ~5–8% of the candidate corpus.

The frontier extension is cross-version decontamination: when you train v2 of your model, you decontaminate against not just the v2 eval set but also v1's eval set, so that version comparisons are not biased by “v2 saw the v1 test set during SFT”. This is what blew up several open-model leaderboard scores in late 2024 — teams trained on synthetic data from a model that had been evaluated on the same benchmark, and the leakage flowed through.

Engineering Reality: The Tax Items Nobody Talks About

The textbook version of SFT data collection ends at the rejection-sampling pipeline. The honest version has another half-dozen production concerns that every team ships and almost no paper publishes. They are not glamorous, but every one of them has at some point cost a team a week of training or a quarter of eval scores.

1. Length filtering, both ends

Drop assistant turns shorter than Lmin16L_{\min} \approx 16 tokens — these are almost always either refusals (collected separately) or low-effort responses that teach the model to be terse. Drop assistant turns longer than LmaxL_{\max} — typically the 99th percentile of your corpus, but capped at the model's context window minus the prompt budget. Without this, a single 32K-token reasoning example dominates an entire packed window and the gradient signal from short-form chat data drowns.

2. PII scrubbing

Run a named-entity recogniser pass for emails, phone numbers, social-security-like patterns, IP addresses, and credit-card-like patterns. Replace them with placeholders ([EMAIL], [PHONE], etc.). This is not optional: PII memorisation is a publishable bug, a regulatory risk (GDPR, CCPA), and a brand risk. The Microsoft Phi team published a notable case in 2024 where their pre-PII-scrubbing SFT corpus had memorisable email addresses that surfaced in chat-mode outputs.

3. Refusal calibration

The most counter-intuitive piece of the SFT data design. You need to teach the model when to refuse — but every refusal example you add also teaches it the shape of refusal, which leaks into prompts where the answer should be a thoughtful response, not a refusal. Modern pipelines therefore include a small fraction (1–3%) of borderline-but-helpful examples: prompts that look like they should be refused (medical, legal, sensitive-but-public-info) but where the demonstrated response is a calibrated, factual, contextually-appropriate answer. Without these, the over-refusal tax shows up within days of release.

4. Multi-turn balance

Single-turn SFT data is easy to collect; multi-turn data is hard. Most public datasets are single-turn or shallow multi-turn (one follow-up). The result, if you train only on single-turn data, is a model that resets context after every assistant response — it forgets the previous turn because it was never supervised on conditioning the next response on prior assistant turns. Tülu 3 and Llama-3 both report that pushing the multi-turn fraction above 25%\sim 25\% of effective tokens was necessary to make conversation feel coherent past turn 3.

5. Language-ID and code-language balance

A subtle but important step. Run a fastText langid model over every example and drop examples where the user turn and the assistant turn disagree on language by more than a small fraction (unless the example is explicitly a translation example). Without this, the corpus is sprinkled with cross-lingual artefacts (English user turn, French assistant turn) that confuse the model. The same trick applies to code: detect the programming language and balance the mix so Python does not eat 80% of the code budget.

6. The eval-feedback loop

Train on the corpus, evaluate against your held-out benchmark suite, identify the worst-performing categories, and route the next data-collection cycle toward those categories. The frontier labs run this loop weekly. The DeepSeek-V3 post-training report describes nine such cycles between the base model checkpoint and the final chat model. The Llama-3 post-training report describes six. SFT is not a one-shot pipeline; it is a closed-loop control system whose input is eval scores and whose output is data-collection priorities.

The takeaway. If you remember nothing else from this section, remember this: the SFT loss is trivial, and the corpus is everything. Mask the assistant turn only. Decontaminate against your eval set. Dedupe aggressively. Mix categories by effective tokens, not example counts. Filter by quality, not quantity. Reject- sample whenever you can. And every time eval scores regress, look at the data before you look at the model. Two times out of three, it is the data.

With the corpus prepared, the next section turns to the chat template itself — the deterministic Jinja string that converts (role, content) turns into a token sequence and the role tags. It looks like a formatting detail. It is, in fact, the single most overlooked source of train/inference skew in modern post-training.

Loading comments...