Chapter 9
20 min read
Section 51 of 117

Hyperparameter Scaling

Scaling Laws and Compute-Optimal Training

Chinchilla told us how to spend our compute budget on parameters versus tokens. It did not tell us a much harder operational question: what learning rate, batch size, warmup, and weight decay should we use at 671B parameters? Sweep them, you say. Run a few dozen short runs, pick the best, and ship. That answer works at 100M parameters. At frontier scale it costs more than the model itself.

The thesis of this section. Hyperparameters do not survive scale-up. The optimal learning rate at width 256 is wrong by an order of magnitude at width 4096. The batch size that converged beautifully at 100M parameters wastes 80% of your gradient signal at 100B. The fix is not to sweep harder — it is to choose a parametrization under which the small-model optimum is also the big-model optimum. That parametrization is called μP (Maximal Update Parametrization), and it is the quietly load-bearing trick behind every frontier training run since 2023.

The Real Problem: You Cannot Sweep at Frontier Scale

Here is the accounting nobody puts on the slide. A single training run of DeepSeek-V3 (671B parameters, 14.8T tokens) costs roughly $5.6M of GPU time and burns about 2.788×1062.788 \times 10^{6} H800-hours. The standard hyperparameter sweep at small scale — vary LR over six values, β₂ over three, weight decay over three, warmup over two — is 6 × 3 × 3 × 2 = 108 runs. If you sweep at full scale, you have just spent $605M on hyperparameter selection alone. Nobody does this. Nobody can.

So everyone falls back to one of three options, and the first two are traps.

StrategyWhat it looks likeWhy it fails at frontier scale
Sweep at full scaleRun 100+ full-size training runs and pick the best.Costs $100M–1B per generation. Even the largest labs do not have this budget.
Sweep at small scale, hope for transferTune at 1B parameters under standard parametrization (SP), reuse the LR at 671B.Under SP the optimal LR drifts with width. The 1B-tuned LR is 4–16× too high at 671B. The big run diverges in the first 500 steps or trains stably to a worse loss.
Tune small under μP, ship bigTune at 1B parameters under Maximal Update Parametrization, reuse the SAME hyperparameters at 671B.Works. Empirically validated by the DeepSeek-V3 paper, the Tensor Programs V paper, and the public Llama-3 technical report.
Why this is the most under-discussed bottleneck in scaling. The Chinchilla paper is famous for telling us the compute-optimal token count. It is silent on whether the hyperparameters used at 70M extrapolate to 70B. Without μP, the answer is no, and every paper that follows Chinchilla's budgeting prescription secretly assumes the hyperparameters were also tuned at scale. μP is what makes the Chinchilla budgeting actually executable.

Intuition: A Wind-Tunnel Model for Hyperparameters

Aircraft engineers do not test full-size jets in wind tunnels. They test scale models — 1/20th the size — and they use dimensionless numbers (Reynolds, Mach) to make sure the small-model behaviour is dynamically similar to the full-size behaviour. If the Reynolds number matches, the flow pattern matches, no matter the absolute scale.

μP does the same thing for neural networks. Width is the scale axis — the analogue of model size in the wind tunnel. The "dimensionless numbers" are the per-layer learning-rate scalings. The promise of μP is exactly the promise of dimensional analysis: if the rescaled hyperparameters are right at the small scale, they are right at every scale.

The intuition for why standard parametrization fails is gradient accumulation in a different guise. Consider one row of a weight matrix in a wide layer. When the model is nn units wide, that row participates in nn dot products on every forward pass. The gradient that flows back into that row is therefore a sum of nn error terms. A single SGD step of size η\eta changes the row by η\eta \cdot \sum — and that sum grows with nn. The natural fix is to shrink η\eta in proportion: that is exactly μP's prescription for hidden-layer learning rates.

The mental flip. In SP, you tune one global η\eta and the per-layer effective LR is whatever it happens to be at that width. In μP, you tune one global η\eta and the per-layer effective LR is explicitly rescaled by width so that the per-update change in each row stays the same across scales. The forward pass is unchanged. The optimizer is unchanged. Only the per-layer LR multiplier moves.

The Mathematics of μP and Scaling Rules

Let nn be the hidden width of a transformer and n0n_0 a fixed reference width (typically the proxy you can afford to sweep). Partition every parameter into one of three layer kinds:

  1. Input: token + positional embeddings, first projection from a width-independent input dimension. Fan-in is fixed (vocab size, embedding dim) — does not change with nn.
  2. Hidden: Q, K, V, attention output projection, MLP up- and down-projection. Both fan-in and fan-out scale with nn.
  3. Output: final projection to the vocabulary logits. Fan-in scales with nn; fan-out is fixed (vocab size).

Under μP with the AdamW optimizer, the per-layer learning-rate multiplier m(n)m_\ell(n) follows the table below. fan_in\sqrt{\text{fan\_in}} appears in every init row because that is the standard fan-in initialization — μP rescales the LR, not the init.

Layer kindLR multiplierForward multiplierInit std
Input111 / sqrt(fan_in)
Hiddenn_0 / n11 / sqrt(fan_in)
Outputn_0 / nn_0 / n1 / sqrt(fan_in)

The effective learning rate of layer \ell at width nn is therefore η(n)=ηm(n)\eta_\ell(n) = \eta \cdot m_\ell(n) where η\eta is a single global learning rate you tuned at the proxy. The key claim of μP — proven via the Tensor Programs framework in Yang & Hu (2022) — is that the optimum of η\eta is invariant in nn as nn \to \infty. In plain words: tune η\eta at n0n_0, use the same η\eta at every nn.

Two other hyperparameters need their own scaling rules. Batch size follows McCandlish et al. (2018): there is a critical batch size BcritB_{\mathrm{crit}} below which doubling the batch nearly halves training time, and above which doubling the batch barely helps. Empirically BcritLαB_{\mathrm{crit}} \propto L^{-\alpha} with α1\alpha \approx 1 — bigger models can absorb bigger batches before saturating. Weight decay scales as λ(n)=λ0\lambda(n) = \lambda_0 (no width rescaling under decoupled AdamW), and warmup steps are kept constant in absolute count, not as a fraction of total steps.

Putting these together gives the modern compute-optimal hyperparameter recipe: pick model width nn and training token count DD from Chinchilla; pick batch size near Bcrit(n)B_{\mathrm{crit}}(n); pick learning rate by sweeping a proxy at width n0n_0 with μP; pick weight decay and β₂ at the same proxy; keep warmup at a fixed absolute step count (typically 2000–4000). Five hyperparameter decisions, one of which (LR) is the only one that needs sweeping at all.

Manual Numerical Walkthrough

Walkthrough: scaling LR from width 256 to width 4096

Suppose you swept the LR at a proxy of width n0=256n_0 = 256 and found the optimum at η=3×103\eta = 3 \times 10^{-3}. You now want to train the production model at width n=4096n = 4096. Compute the per-layer effective LRs under μP and SP and compare.

LayerModeμP effective LR at n=4096SP effective LR at n=4096
tok_embinput3e-3 × 1 = 3e-33e-3
attn.q_projhidden3e-3 × (256/4096) = 1.88e-43e-3 (same as proxy — wrong)
mlp.up_projhidden1.88e-43e-3
lm_headoutput1.88e-4 + forward × 1/163e-3, no forward rescale

Now compute the magnitude of a single AdamW update to mlp.up_proj\texttt{mlp.up\_proj}. Suppose the gradient has per-element RMS of 0.50.5 at width 256. At width 4096 under SP, the same per-element RMS holds (fan-in init is preserved), but the gradient norm has grown by 4096/256=4×\sqrt{4096/256} = 4\times because the row participates in 16× more dot products. With AdamW the update magnitude is roughly ηsign()\eta \cdot \text{sign}(\nabla), so the SP update at width 4096 is 3×1033 \times 10^{-3} per element — four times bigger than what is stable at this width. Result: training diverges in the first few hundred steps, exactly the failure mode the DeepSeek-V3 paper attributes to bad LR scaling on early checkpoints.

Under μP, the per-element AdamW update at the same layer is 1.88×1041.88 \times 10^{-4} — sixteen times smaller. That is precisely the factor needed to keep the per-step parameter change comparable to width 256. The forward pass produces the same activation magnitudes (because init was already fan-in-correct), so the loss curve looks like a shifted copy of the width-256 curve — same shape, lower asymptote, same optimal hyperparameters.

Now the proxy savings. Width 256 has roughly (256/4096)2×(256/4096)^2 \times the parameters of width 4096, i.e. roughly 1/2561/256 as many. A full LR sweep at the proxy costs about 6/2562.3%6/256 \approx 2.3\% of one full-size training run. Under μP, you do exactly one sweep at the proxy and use the winner verbatim at full scale. Under SP, you would need to sweep at every scale you train — a 6× multiplier on every Chinchilla rung.

Visualizing LR Transfer

The visualizer below plots loss vs log10(η)\log_{10}(\eta) at three widths (256, 1024, 4096) under both parametrizations. Toggle between SP and μP and hover any LR value to read the loss at each width.

Loading LR-transfer visualizer…

The qualitative pattern reproduces the result in Figure 1 of the Tensor Programs V paper. Under SP, the three U-shaped curves drift to the left as width grows — the optimal LR at width 4096 is about 16× smaller than at width 256. Under μP, the three curves stack: the optimum sits at the same η\eta regardless of width. That stacking is the entire promise of μP, and it is what every frontier lab now exploits to keep hyperparameter search costs bounded as model scale grows.

Plain Python: A μP Coordinate-Check Simulator

Before you trust μP in a 671B training run, you have to verify it empirically on a tiny model. The standard diagnostic is the coordinate check: train the same network at several widths, log the per-layer activation magnitudes and gradient norms, and confirm that the products η\eta_\ell \cdot \nabla_\ell stay constant across widths. The script below does exactly that, in pure NumPy.

🐍mup_coord_check.py
12Why we need a coordinate check

Before you can claim μP transfer works, you have to verify the basic invariant: at every width, the magnitudes of activations, gradients, and weight updates stay O(1) across training steps. That is exactly what a coordinate check measures — sweep widths from small to large and confirm the per-layer statistics do not drift.

16Fan-in init: the SP baseline

Every entry of every weight matrix is drawn from a Gaussian with variance 1/d_in. Under this init the forward activations have unit variance regardless of width — that part of SP is already 'scale-stable'. The problem is the BACKWARD pass: a single SGD step of size η produces an update whose magnitude grows with width.

EXECUTION STATE
sigma (width=4096) = 1/sqrt(4096) = 0.0156
W shape =
(d_in, d_out)
27The single rule that turns SP into μP

The whole μP story compresses into this function. Input-layer LRs do not scale. Every hidden-to-hidden and hidden-to-output LR is multiplied by base_width / width. At width 256 the multiplier is 1.0 (μP collapses to your baseline). At width 4096 it is 256/4096 = 1/16, so the effective hidden LR is 16× smaller — exactly the amount that keeps updates O(1) instead of O(width).

EXECUTION STATE
lr_multiplier (width=4096, hidden) = 256/4096 = 0.0625
lr_multiplier (width=4096, input) = 1.0
37Forward pass: identical at every width

Two matmuls and a ReLU. Note that NOTHING in the forward pass is width-dependent — μP only changes how learning rates are scaled, not how the model computes. This is why a μP transformer is bit-for-bit a normal transformer at inference time.

EXECUTION STATE
h shape =
(batch=4, width)
y shape =
(batch=4, d_out=8)
50Per-layer effective learning rates

lr1 (input layer) uses the unscaled base_lr. lr2 (hidden→output) uses base_lr × base_width/width. The whole point of μP is that base_lr can now be tuned ONCE at width=256, and the same base_lr value works for width=4096 without re-sweeping.

EXECUTION STATE
lr1 (base_lr=1e-2, width=4096) = 1.00e-2
lr2 (base_lr=1e-2, width=4096, μP) = 6.25e-4
lr2 (base_lr=1e-2, width=4096, SP) = 1.00e-2 (too big!)
60Manual gradient: the part SP gets wrong

dW2 = hᵀ · dy. Under SP, the magnitude of dW2 scales with the number of hidden units summed over (i.e. with width). One SGD step then changes W2 by η · dW2, and that change explodes with width — so SP requires a smaller η at larger widths to stay stable. μP cancels exactly that scaling by shrinking the LR multiplier in lock-step.

EXECUTION STATE
dW2 shape =
(width, d_out)
||dW2|| scaling (SP) = ∝ sqrt(width)
||dW2 · lr|| scaling (μP) = ∝ 1 (constant)
68Two diagnostics that tell you μP is working

hidden_std should be roughly constant across widths under BOTH SP and μP — that is just fan-in init working. dW2_std × lr should be roughly constant across widths under μP but should grow with width under SP. If you run this script you will see exactly that pattern: SP's effective update magnitude blows up at width=4096, μP's stays put.

EXECUTION STATE
hidden_std (any width, both modes) = ≈ 0.6
||lr · dW2|| (SP, width 256 → 4096) = 0.003 → 0.012
||lr · dW2|| (μP, width 256 → 4096) = 0.003 → 0.003
85The signal you want to see in the print loop

Run the script. Under mode='sp' the final loss at width=4096 should be much WORSE than at width=256 with the same base_lr — the same LR diverged because updates were too big. Under mode='mup' the losses across widths should be within ~10% of each other. That convergence is the empirical proof of zero-shot LR transfer.

82 lines without explanation
1import numpy as np
2
3# Toy two-layer MLP whose only "scaling axis" is the hidden width.
4# We will train it at three widths and check whether activations,
5# logits, and gradient norms scale predictably under SP vs muP.
6#
7# This script reproduces the "coordinate check" that every muP
8# implementation has to pass before it can be trusted at scale.
9
10rng = np.random.default_rng(0)
11
12def init_weights(d_in, d_out, mode, base_width=256):
13    # Standard Parametrization (SP):
14    #   Every weight ~ N(0, 1 / d_in).
15    #   The activations have unit variance, but the LR that keeps
16    #   updates O(1) shrinks as width grows.
17    #
18    # Maximal Update Parametrization (muP):
19    #   Same forward init, but the LR of "hidden -> hidden" and
20    #   "hidden -> output" layers is rescaled by base_width / width.
21    #   That rescaling is what makes the updates O(1) at every width.
22    sigma = 1.0 / np.sqrt(d_in)
23    return rng.normal(0, sigma, size=(d_in, d_out))
24
25def lr_multiplier(layer_kind, width, mode, base_width=256):
26    if mode == "sp":
27        return 1.0
28    # muP rules (Yang et al. 2022, Table 3, AdamW row):
29    #   input layer (vocab/embed -> hidden):     LR scales as 1.0
30    #   hidden layer (hidden -> hidden):         LR scales as base/width
31    #   output layer (hidden -> logits):         LR scales as base/width
32    if layer_kind == "input":
33        return 1.0
34    return base_width / width
35
36def forward(x, W1, W2):
37    h = np.maximum(0.0, x @ W1)      # ReLU hidden activations
38    y = h @ W2                       # logits
39    return h, y
40
41def coord_check(width, mode, base_lr=1e-2, steps=10):
42    d_in, d_out = 32, 8
43    W1 = init_weights(d_in, width, mode)        # input layer
44    W2 = init_weights(width, d_out, mode)       # output layer
45
46    # Fixed batch so the run is deterministic
47    x = rng.standard_normal((4, d_in))
48    y_true = rng.standard_normal((4, d_out))
49
50    lr1 = base_lr * lr_multiplier("input", width, mode)
51    lr2 = base_lr * lr_multiplier("hidden", width, mode)
52
53    h_mags, dW2_mags = [], []
54    for _ in range(steps):
55        h, y = forward(x, W1, W2)
56        loss = ((y - y_true) ** 2).mean()
57        # Manual gradients (chain rule, MSE loss)
58        dy = 2 * (y - y_true) / y.size
59        dW2 = h.T @ dy
60        dh = dy @ W2.T
61        dh[h <= 0] = 0.0
62        dW1 = x.T @ dh
63        # SGD step
64        W1 -= lr1 * dW1
65        W2 -= lr2 * dW2
66        h_mags.append(np.std(h))
67        dW2_mags.append(np.std(dW2))
68
69    return {
70        "width": width,
71        "mode": mode,
72        "loss": float(loss),
73        "hidden_std": float(np.mean(h_mags)),
74        "dW2_std": float(np.mean(dW2_mags)),
75        "lr_input": lr1,
76        "lr_hidden": lr2,
77    }
78
79# Run the coord check across widths under both parametrizations.
80for mode in ("sp", "mup"):
81    print(f"\n--- mode = {mode} ---")
82    for w in (256, 1024, 4096):
83        r = coord_check(w, mode)
84        print(
85            f"width={r['width']:>5} "
86            f"hidden_std={r['hidden_std']:.3f} "
87            f"dW2_std={r['dW2_std']:.4f} "
88            f"lr_hidden={r['lr_hidden']:.2e} "
89            f"loss={r['loss']:.3f}"
90        )
What the coord check tells you. Under SP, the printed loss at width 4096 with base_lr = 1e-2 should be visibly worse than at width 256 — the same global LR was too large for the wider model and the optimizer overshot. Under μP, the losses across widths land within a few percent of each other. If your μP wiring is wrong (a layer misclassified as "input" instead of "hidden", say), the coord check fails LOUDLY: one width diverges and the others converge. That is exactly the signal you want to see in CI before any expensive run launches.

PyTorch: μP Wiring on a Real Transformer

The production pattern keeps μP entirely inside the layer definitions and the optimizer setup. The rest of the training stack — FSDP, gradient accumulation, bf16, activation checkpointing — is unchanged.

🐍mup_transformer.py
6Subclass nn.Linear, do not replace it

MuPLinear inherits the full nn.Linear forward path and parameter registration — autograd, FSDP sharding, mixed-precision casting, and quantization all still work without modification. We only add two attributes: width_mult (consumed by the optimizer) and forward_mult (consumed by the forward pass for the output projection).

EXECUTION STATE
width_mult (hidden, current=4096) = 256/4096 = 0.0625
forward_mult (output layer) = 16.0
13Three layer 'modes' to memorize

μP partitions every parameter into three buckets. input (vocab embedding, position embedding, first projection from input dim): width_mult = 1. hidden (every Linear whose BOTH ends are width-scaled — Q, K, V, attention output projection, MLP up/down projections, layernorm if width-dependent): width_mult = base/current. output (the final logit projection from hidden to vocab): width_mult = base/current AND a 1/width forward rescale on the logits. Get this wrong and transfer breaks silently — the loss curve still converges, just to a worse point.

20The 1/width logit rescale (output layer only)

Without forward_mult, the output logits scale with width because hidden has width components but each output unit gets a sum over all of them. The 1/width factor in the forward keeps logit magnitude O(1) at every width — which keeps the softmax temperature constant and lets the cross-entropy loss have the same numerical range across the LR sweep.

EXECUTION STATE
logit norm before rescale (width=4096) = ~16 × baseline
logit norm after rescale = ~1 × baseline
22Fan-in init is unchanged from SP

This is the most common confusion. μP DOES NOT change initialization — std = 1/sqrt(fan_in) is still the right choice. μP changes the LEARNING RATE per layer, not the init. (There is a separate variant called μP-init that also rescales the init, but it is not what the DeepSeek/GPT-4-style 'tune small, ship big' workflow uses.)

30Param groups: how μP plugs into a real optimizer

AdamW already supports per-group learning rates. We exploit that: every parameter gets its own group with lr = base_lr × width_mult. The result is one AdamW instance that internally applies a different LR per layer — no surgery on the optimizer, no custom step function. This is why μP integrates cleanly with FSDP, ZeRO, and bf16 stacks: AdamW is already where the per-group LRs live.

EXECUTION STATE
groups (target model) = 4 (W1, b1, W2, b2)
lr['fc2.weight'] = 3e-3 × 0.0625 = 1.88e-4
lr['fc1.weight'] (hidden mode) = 3e-3 × 0.0625 = 1.88e-4
50TinyMLP stands in for the real transformer

This is the structural pattern. Inside a transformer block, attention's Q/K/V/O projections and the MLP's up/down projections are all 'hidden' mode. The final lm_head Linear is 'output' mode. The token + position embedding lookup is 'input' mode. Bolt on those modes to every Linear in your block and you have a μP-parametrized transformer — no other change required.

60The 'proxy' and 'target' workflow

This is the line that saves the project $5M of GPU time. You build a small proxy (hidden=256) that costs <$100 to train for a few thousand steps. You sweep LR, weight decay, β₂ on that proxy. The winning hyperparameters are reused verbatim at hidden=4096 — μP guarantees they are still optimal. DeepSeek-V3, GPT-4, and Llama-3 all use this workflow.

EXECUTION STATE
proxy cost (rough) = $50–500
target cost (rough) = $1M–5M+
sweep cost saved = ~99%
66One base_lr, two models, correct per-layer LRs

base_lr = 3e-3 is the SINGLE number you tuned. AdamW sees it on both the proxy and the target, but the per-group multipliers fan it out to the right value for each layer at each scale. Print the param groups before kicking off the run — every config-file mistake in production μP comes from a layer being classified as the wrong mode, and the per-group LR print catches that in five seconds.

66 lines without explanation
1import math
2import torch
3import torch.nn as nn
4
5# A muP-aware linear layer. The only thing it adds over nn.Linear is a
6# 'width_mult' attribute the optimizer consults when building param groups.
7class MuPLinear(nn.Linear):
8    def __init__(self, in_features, out_features, mode="hidden",
9                 base_width=256, current_width=256, bias=True):
10        super().__init__(in_features, out_features, bias=bias)
11        self.mode = mode
12        if mode in ("hidden", "output"):
13            self.width_mult = base_width / current_width
14        else:
15            self.width_mult = 1.0
16        # muP also rescales the output layer's FORWARD by 1/width_mult so
17        # the logits stay O(1) at every width.
18        self.forward_mult = 1.0 / self.width_mult if mode == "output" else 1.0
19        # Standard fan-in init at every width.
20        nn.init.normal_(self.weight, mean=0.0, std=1.0 / math.sqrt(in_features))
21        if bias is not None:
22            nn.init.zeros_(self.bias)
23
24    def forward(self, x):
25        return super().forward(x) * self.forward_mult
26
27
28def build_param_groups(model, base_lr):
29    # Group parameters by their muP width multiplier so AdamW can use the
30    # right LR per layer in a single optimizer.
31    groups = []
32    for name, p in model.named_parameters():
33        if not p.requires_grad:
34            continue
35        # Walk up to the owning module to read its width_mult
36        mod = model.get_submodule(name.rsplit(".", 1)[0])
37        wm = getattr(mod, "width_mult", 1.0)
38        groups.append({
39            "params": [p],
40            "lr": base_lr * wm,
41            "name": name,
42            "width_mult": wm,
43        })
44    return groups
45
46
47# Two-layer MLP head, but the API is identical to a real transformer block:
48# swap MuPLinear for the projections inside attention and you have a
49# μP-parametrized transformer.
50class TinyMLP(nn.Module):
51    def __init__(self, d_in, hidden, d_out, base_width=256):
52        super().__init__()
53        self.fc1 = MuPLinear(d_in, hidden, mode="hidden",
54                             base_width=base_width, current_width=hidden)
55        self.fc2 = MuPLinear(hidden, d_out, mode="output",
56                             base_width=base_width, current_width=hidden)
57
58    def forward(self, x):
59        return self.fc2(torch.relu(self.fc1(x)))
60
61
62# Tune ONCE at the small "proxy" model. The base_lr that wins here is the
63# base_lr you ship at the big model.
64proxy = TinyMLP(d_in=512, hidden=256, d_out=128, base_width=256)
65target = TinyMLP(d_in=512, hidden=4096, d_out=128, base_width=256)
66
67base_lr = 3e-3   # found by sweeping the PROXY, not the target
68
69opt_proxy  = torch.optim.AdamW(build_param_groups(proxy,  base_lr))
70opt_target = torch.optim.AdamW(build_param_groups(target, base_lr))
71
72# Inspect the per-group LRs to convince yourself transfer is correct.
73for g in opt_target.param_groups:
74    print(f"{g['name']:>16s} width_mult={g['width_mult']:.4f} lr={g['lr']:.2e}")

Two architectural details worth marking. First, μP composes with optimizer state sharding (ZeRO-3, FSDP). Each shard sees the same per-group LRs because the LR is a scalar on each param group, not a per-parameter tensor. Second, μP does not interact with mixed-precision scaling — the gradient scaler operates on the gradient, not on the learning rate, so bf16 / fp8 training runs see μP as a no-op at the precision boundary.

What the production version adds

Real μP implementations (the open-source mup library and the in-house variants at DeepMind, Anthropic, and Meta) extend this skeleton in three ways:

  1. Attention scaling. The dot-product score temperature 1/dk1/\sqrt{d_k} becomes 1/dk1/d_k under μP (the so-called "μP attention scaling"). This keeps the pre-softmax logits O(1) at every head dimension.
  2. Embedding LR. Token embeddings get their own LR multiplier; some teams scale them by n0/n\sqrt{n_0/n} to keep the embedding dynamics aligned with the rest of the network.
  3. Reparam vs init μP. Two equivalent recipes exist: rescale the LR per layer (what we showed) or rescale the init per layer and use a single global LR. Reparametrization is easier to integrate with existing optimizer code and is the variant DeepSeek-V3 and Llama-3 use.

At Massive Scale: How DeepSeek, GPT-4, and Llama Tune

The reason μP matters is not pedagogy — it is the actual procedure every frontier lab now follows. Reconstructed from public technical reports and post-mortems:

Lab / ModelProxy sizeTarget sizeSweep costWhat they tune at the proxy
DeepSeek-V3 (2024)~1.4B active / ~10B total37B active / 671B total~0.5% of full runLR, β₂, weight decay, expert router temperature, MTP head loss weight
Llama-3 (2024)~8B405B~1% of full runLR, batch size schedule, warmup, weight decay
GPT-4 (2023, inferred)Reported as a 'small model with predictable scaling' (system card)~1.8T MoE (rumoured)Not disclosed; described as 'reliable extrapolation from ≤10000× smaller'Loss curve, LR, optimization stability flags

The DeepSeek-V3 paper is unusually forthcoming about the recipe. Section 5.2 reports that they ran a μP-style proxy sweep at a 1.4B active / 10B total MoE, picked the LR (4.2e-4) and β₂ (0.95) there, then trained the 671B run with those same values and no further tuning. They report a single LR-related restart in the entire 2.788M H800-hour run, attributable to a numerical stability issue in fp8, not to LR mis-scaling.

The bottleneck μP relaxes is wall-clock, not just dollars. Even if you had infinite money, a full-size LR sweep at 671B would take a month of calendar time per run. μP turns that month into an afternoon at the proxy. The next training run is gated on the previous one finishing — μP collapses the gate.

The batch-size half of the story

μP fixes LR transfer. It does not fix batch size. The McCandlish paper's critical batch size Bcrit=tr(H)σ2/(LHL)B_{\mathrm{crit}} = \mathrm{tr}(H) \cdot \sigma^2 / (\nabla L^\top H \nabla L) depends on the curvature of the loss and the gradient noise — both of which evolve during training and across scales. In practice every lab measures BcritB_{\mathrm{crit}} empirically by sweeping a few batch sizes early in training (the gradient-noise-scale estimator gives a cheap online estimate) and ramping the batch from a small starting value (4M tokens) to a large terminal value (60M+ tokens for DeepSeek-V3, ~16M for Llama-3). This batch-ramp schedule is independent of μP; the two recipes compose.

Engineering Reality and Failure Modes

Three failure modes account for nearly every μP-related incident that surfaces in public post-mortems.

  1. Layer misclassification. An embedding layer accidentally tagged as "hidden" gets a width-shrunk LR and learns to a worse representation. A hidden layer accidentally tagged as "input" gets an un-shrunk LR and the training run diverges in the first 1000 steps. Defence: print every parameter group's width_mult before the first step and gate the launch on a manual diff against the reference config. This is a five-line check that catches a $5M bug.
  2. Coord-check skipped. Teams sometimes assume μP "just works" because the library exports a MuPLinear. Then the attention scaling is missed (still using 1/dk1/\sqrt{d_k} instead of 1/dk1/d_k) and the transfer breaks silently — loss curves at small scale look fine, but the big run lands at a worse loss than expected. Defence: every new model architecture re-runs the coord check before its first scale-up.
  3. Proxy too small. μP is a large-width limit. Below width ~128 the asymptotic behaviour has not kicked in, and the LR optimum at the proxy will mis-extrapolate by 2–3×. Defence: pick the proxy width as the largest model that fits in your sweep budget, not the smallest. Most teams use proxies in the 128M–2B parameter range; below 100M the transfer gets noisy.

The good news: when μP works (which is most of the time), it is invisible. The training loop looks like a normal AdamW loop. The config file has an extra width_mult column. The optimizer prints a slightly longer param-group list. The reward is that you trained a 671B-parameter model with hyperparameters you tuned in an afternoon on a small proxy — and the post-mortem column titled "LR-related divergences" reads zero.

The big picture. Chinchilla told us how much compute to spend on parameters and tokens. μP told us how to spend it without re-sweeping hyperparameters at every scale. Together, they reduce the cost of training a new frontier model from "sample-efficiency-bound and sweep-bound" to just "sample-efficiency-bound" — which is the regime every post-2023 release operates in. The next section returns to the loss-curve side of the story: scaling laws for inference, where the compute-optimal question flips on its head.
Loading comments...