Chapter 19
28 min read
Section 61 of 65

Mixed Precision Training

Modern Training Techniques

Why Precision Became a Bottleneck

For two decades, deep-learning math was synonymous with a single number format: IEEE 754 single-precision, also called FP32. Every weight, every activation, every gradient was a 32-bit real number. It was simple, numerically safe, and — for AlexNet-sized networks on 2012 hardware — entirely practical.

Then the models grew. GPT-3 has 175 billion parameters. In FP32 the weights alone are 175×109×4 B=700 GB175 \times 10^9 \times 4 \text{ B} = 700 \text{ GB}, four A100-80GB GPUs full — just to store the model. During training you also need gradients (another 700 GB) and the Adam optimizer state (two FP32 buffers per parameter, another 1.4 TB). We are now past 2.8 TB before any activation has been computed.

At the same time NVIDIA's tensor cores, starting with Volta in 2017, began to offer enormous speed-ups for narrower number formats: roughly 8× for FP16 over FP32 on V100, and 16× on A100. The question was no longer whether to drop precision but how. Mixed precision training is the discipline that answers it without blowing up the numerics.

The central idea in one sentence: compute in FP16/BF16, accumulate and store master weights in FP32, and use loss scaling to keep small gradients from underflowing. Everything else in this section is a justification for — or a consequence of — that one sentence.

Anatomy of a Floating-Point Number

An IEEE-754 float is three fields packed into a bit string: a sign bit ss, an exponent field ee, and a mantissa field mm. The value it encodes is x=(1)s(1+m/2k)2ebx = (-1)^{s} \cdot (1 + m / 2^{k}) \cdot 2^{e - b}, where kk is the number of mantissa bits and bb is the exponent bias. The exponent controls the range (how big or small a number can be), and the mantissa controls the precision (how finely numbers are spaced). The three formats that matter for neural networks split that 32-bit budget differently:

FormatSignExponentMantissaRange (±)ULP @ 1.0Bytes
FP321823~3.4×10³⁸1.2×10⁻⁷4
FP161510~6.5×10⁴9.8×10⁻⁴2
BF16187~3.4×10³⁸7.8×10⁻³2
FP8 (E4M3)143±4480.1251

The key insight is that FP16 and BF16 are not the same 16-bit format — they chose opposite sides of the range-vs-precision trade. FP16 kept FP32's precision philosophy and paid the price in range. BF16 kept FP32's range (same 8 exponent bits) and paid the price in precision.

Interactive: The Same Number in Four Formats

Type any number and see exactly which bits it occupies in each format, what value those bits actually decode back to, and where the encoding overflows or underflows. Use the preset buttons to jump to values that highlight specific regimes — for instance, 1e-8 is a typical tiny gradient that underflows FP16 but survives BF16.

Loading bit-level float visualizer…
Try 1.0 first and note that FP16 rounds π to 3.140625 while BF16 rounds it to 3.125 — BF16 literally has three fewer mantissa bits. Then try 1e-8: FP16 and FP8 both underflow to zero while BF16 keeps the value. The yellow/blue/pink coloring shows sign/exponent/mantissa regions.

The Great 16-Bit Debate: FP16 vs BF16

Before 2018 the default was FP16. Volta and Pascal GPUs had FP16 tensor cores and no native BF16. Mixed precision recipes were built around FP16 — including all the loss-scaling machinery we'll meet shortly.

Starting with Ampere (2020, A100), BF16 became a first-class citizen. And a quiet revolution happened: almost every large-model training run switched from FP16 to BF16. GPT-3, PaLM, Chinchilla, LLaMA, and Gemini were all trained primarily in BF16.

Why? Because gradients are the hardest tensor to keep numerically healthy, and gradients care more about range than precision:

  • BF16's range is identical to FP32's. The exponent field is 8 bits in both. A BF16 cast of an FP32 gradient can never underflow or overflow — it can only get rounded.
  • FP16's mantissa is precise but its range is tiny. A gradient below ~6×10⁻⁵ becomes subnormal and below ~6×10⁻⁸ vanishes to zero. A gradient above 65504 explodes to ±∞.
  • BF16 precision is coarser, but optimizers like Adam effectively apply a running average that washes out per-step quantization noise. You lose a few bits of mantissa and nothing else.

Interactive: Representable Ranges

Loading precision range chart…
Rule of thumb: on Ampere-class hardware or newer, default to BF16. Reach for FP16 only when you are targeting older hardware (V100, T4) or have a specific case where FP16's extra precision demonstrably helps — and when you do, budget time for loss-scale tuning.

The Three Numerical Crises of Half-Precision Training

Naive half-precision training — "just cast everything to FP16" — fails in three well-catalogued ways. Each of the fixes we'll meet in the next section exists because of one of these failure modes.

Crisis 1 — Gradient Underflow

FP16 cannot represent positive values smaller than 2245.96×1082^{-24} \approx 5.96 \times 10^{-8} (its smallest positive subnormal). Any gradient below that is quantized to exactly zero. The corresponding weight receives no update at all, as if the layer were frozen.

The Micikevicius et al. (2018) Mixed Precision Training paper (ICLR) documents that for many deep networks a substantial fraction of gradient values — often a majority, depending on the architecture and training phase — falls below FP16's smallest representable positive value (2245.96×1082^{-24} \approx 5.96 \times 10^{-8}). Their Figure 2 plots the gradient histogram for an SSD-ResNet detector and shows that without scaling the bulk of the distribution is in FP16's underflow region.

In FP16-only training, gradient underflow is silent. You see a loss that plateaus or drifts upward, but no error, no warning. Symptom: weights in the lower layers stop changing while the upper layers keep moving.

Crisis 2 — Activation Overflow

The other end of the range cliff. FP16's maximum is 65504. A self-attention block computing QKTQ K^{T} over long sequences can produce pre-softmax scores in the thousands; squaring large activations in an MSE loss can exceed 65000. Once a value becomes ±\pm\infty, it poisons every downstream tensor (anything × ∞ = ∞; ∞ − ∞ = NaN).

This is why autocast's FP32-only list includes softmax, reductions, and most losses. Operations whose output can grow quadratically — matmul reductions, squares — stay in FP32 precisely to avoid overflow.

Crisis 3 — The Disappearing Weight Update

Suppose we fixed both underflow and overflow. A third, subtler issue remains. Consider a weight W=1.0W = 1.0 and an Adam update ΔW=ηm^=5×105\Delta W = \eta \cdot \hat{m} = 5 \times 10^{-5}.

The FP16 "gridpoints" near 1.0 are spaced by 2109.77×1042^{-10} \approx 9.77 \times 10^{-4}. Any update smaller than half of that (≈ 5×10⁻⁴) rounds to the same bin as W itself. So storing W+ΔWW + \Delta W in FP16 gives back exactly W — the update was computed correctly but then thrown away during the write-back.

This is why the master weight must be FP32. It is not a memory optimization gone too far — it is a correctness requirement. The tiny updates that would vanish in FP16 accumulate faithfully in FP32.

To see this concretely, drag the "center x" slider and adjust the simulated update magnitude. Watch the green update tick turn red when it falls below half an ULP of FP16.

Loading mantissa ULP explorer…

The Mixed-Precision Recipe

The complete recipe — first published in Micikevicius et al. 2018, Mixed Precision Training — has four ingredients. Each one directly addresses one of the three crises above.

  1. FP32 master weights. The optimizer keeps a canonical FP32 copy of every parameter. Fixes Crisis 3.
  2. FP16/BF16 forward & backward. At each step, the master weights are cast to the narrow dtype and used for compute. Cuts compute time and activation memory roughly in half.
  3. Loss scaling. The loss is multiplied by a scalar SS before backward. By the chain rule, every gradient is scaled by SS. Fixes Crisis 1. (Not needed with BF16.)
  4. FP32 "safe list". Operations that are known to overflow or lose precision (softmax, layer norm, losses, reductions) are kept in FP32. Fixes Crisis 2.

Loss Scaling — Shifting the Histogram

Loss scaling is the most distinctive trick of FP16 training, and it is surprisingly simple. Let LL be the loss. Instead of backpropagating LL, we backpropagate L=SLL' = S \cdot L. By linearity of differentiation, every gradient becomes Lθ=SLθ\frac{\partial L'}{\partial \theta} = S \cdot \frac{\partial L}{\partial \theta} — the gradient histogram is translated to the right on a log axis by log2S\log_2 S bits. Tiny gradients that would have fallen off FP16's lower cliff are now safely in the center of the band. After backprop, we divide each gradient by SS to get the true gradient back for the optimizer step.

In practice we never pick SS by hand. The dynamic loss scaler starts at a hopeful value (typically 2¹⁵) and adjusts automatically: whenever a gradient overflows to ±\pm\infty, it halves SS and skips the step; after many clean steps in a row, it doubles SS to stay aggressive.

Watch the feedback loop evolve step by step in the simulator below. Step through a training run and observe SS doubling after stretches of clean steps and halving the moment an overflow is injected.

Loading dynamic loss-scale simulator…

Interactive: Loss Scaling in Action

Drag the slider. At S=1S = 1 most of the synthetic gradient distribution falls below FP16's minimum and gets crushed to zero (red on the left). Push SS up to 2¹⁵ and essentially the whole distribution lands in the representable band. Push it too far and overflow kicks in on the right. This is exactly the feedback loop PyTorch's GradScaler runs on every step.

Loading loss-scaling demo…

FP32 Master Weights

The optimizer owns an FP32 copy of every parameter. At the top of each step, it materializes an FP16 copy for the forward and backward pass. After backward, gradients are unscaled and applied to the master copy, not to the FP16 copy. The FP16 copy is then re-materialized for the next step and discarded.

Memory-wise this sounds expensive: we are carrying an extra FP32 tensor. In practice it is dominated by the Adam state, which was already FP32 and already 6× larger than the weights themselves (two moments × 4 B/param = 8 B/param). Mixed precision is mostly about activations and bandwidth, not parameters.

The Autocast Cast List

Rather than leave the user to tag every op, PyTorch ships a curated cast list. A non-exhaustive but representative split:

Kept in FP32 (safe list)Cast to FP16/BF16 (fast list)
softmax, log_softmaxmatmul, mm, bmm, addmm
layer_norm, batch_norm, group_normconv1d, conv2d, conv3d
NLLLoss, CrossEntropyLoss, MSELosslinear / F.linear
log, exp, pow, reciprocal, rsqrtLSTM / GRU cells
sum, prod, cumsum (reductions)scaled_dot_product_attention

The pattern is consistent: ops whose output variance scales with input size (matmul reductions, convolutions) run narrow because their hardware kernels do the accumulation in wider precision internally. Ops that compute transcendentals or take log/exp of small quantities stay wide because a single FP16 intermediate would overflow or underflow.


From Scratch: Mixed Precision in NumPy

NumPy has no autograd, no tensor cores, and no GPU — but it has the exact same IEEE-754 FP16 we are trying to understand. That is all we need to reproduce every crisis and every fix by hand. Click any line to see what the computer is holding after that line.

Mixed Precision — Hand-Simulated in NumPy
🐍mixed_precision_demo.py
1import numpy as np

NumPy is our microscope for IEEE-754 arithmetic. It exposes native FP32, FP16 and FP64 dtypes, and we can quantize a value by casting it through the narrower type — exactly the effect a real GPU tensor core would have when it stores an intermediate in FP16.

EXECUTION STATE
numpy = Gives us np.float32, np.float16, np.finfo (format limits), and round-to-nearest-even casts. We don't need autograd here — we're just studying the number system.
3def to_fp16(x) → np.float32

A helper that quantizes x through FP16 and returns the result as FP32. We quantize (lose precision) but then return in FP32 so the rest of the arithmetic is exact — this is exactly what hardware does when an FP16 tensor is loaded into an FP32 accumulator.

EXECUTION STATE
⬇ input: x = Any number or array of numbers in FP32 (or Python float). Will be rounded to the nearest FP16 value that exists.
📚 np.float16(x) = Casts x to IEEE half-precision. Under the hood: round-to-nearest-even on mantissa, clamp to [±65504]. Anything smaller than the smallest FP16 subnormal (≈5.96e-8) becomes 0; anything larger than 65504 becomes ±inf.
📚 .astype(np.float32) = Widens the already-rounded value back to FP32. No new information is added — this just lets us add/multiply precisely from here on.
⬆ returns = An np.float32 holding the value that FP16 can actually represent closest to x. Think of this as 'x after one round-trip through an FP16 register'.
4Docstring: "Quantize x through FP16…"

Reminds the reader that the function is a round-trip. It mirrors what autocast does when it downgrades an activation from FP32 to FP16 for matmul and then writes it back into memory.

5return np.float16(x).astype(np.float32)

Two casts chained. The first performs the lossy rounding; the second converts the bit pattern to FP32 for downstream arithmetic.

EXECUTION STATE
→ example = to_fp16(3.14159265) → 3.140625 (FP16 has 10 mantissa bits → ULP≈0.0039)
→ another example = to_fp16(1e-8) → 0.0 (below min subnormal — rounded to zero)
7# ---------- CRISIS 1 — underflow ----------

A section header. The first of three failure modes we will reproduce: gradients that are too small to be represented in FP16 quietly become zero.

8grad_fp32 = np.float32(1.0e-8)

Pretend the backward pass produced a gradient of magnitude 1×10⁻⁸. This is realistic for deep networks late in training or in layers close to the input — a common regime where FP16 training silently dies.

EXECUTION STATE
grad_fp32 = 1e-08 (well-represented in FP32: its smallest normal is ≈1.18e-38)
→ why this value? = The smallest positive subnormal of FP16 is ≈5.96e-8. Any gradient below that rounds to zero. We pick 1e-8 to sit just under that cliff.
9grad_fp16 = to_fp16(grad_fp32)

Round-trip the gradient through FP16. Since 1e-8 < 5.96e-8 (FP16's smallest representable positive value), it rounds to exactly 0. This is the canonical underflow problem that motivates loss scaling.

EXECUTION STATE
⬇ arg: grad_fp32 = 1e-08
⬆ grad_fp16 = 0.0 ← the gradient DISAPPEARED
→ consequence = Autograd multiplies this zero into every upstream operation. The corresponding weights get exactly zero update — the layer stops learning.
10# grad_fp32 = 1e-08 grad_fp16 = 0.0 <-- disappeared

Comment showing the punch-line. Whenever you see NaN training losses or a layer's weights that never budge, underflow is suspect #1.

12# ---------- FIX 1 — loss scaling ----------

Section header. The fix is simple: multiply the loss (and therefore every gradient) by a large constant S before backprop, then divide the gradients back by S before the optimizer step. Shifts the whole gradient histogram into FP16's sweet spot.

13S = np.float32(2**15) # loss scale = 32768

S is the loss scale. We use a power of two so multiplying by S is lossless in floating point (it only shifts the exponent — no mantissa bits change). 2¹⁵ = 32768 is the default initial scale PyTorch's GradScaler uses.

EXECUTION STATE
📚 np.float32(...) = Wraps the Python int 32768 in a NumPy FP32 scalar. Using FP32 here keeps S precise even when combined with FP16 operands.
⬆ S = 32768.0
→ why a power of two? = Multiplying by 2ⁿ moves the exponent field by n and leaves the mantissa untouched. The scaled gradient is numerically equal to the original grad scaled — no extra rounding error introduced by the multiplication itself.
→ why 2¹⁵? = Gradients in a transformer near convergence typically sit around 1e-5…1e-7. Multiplying by 2¹⁵ pushes that band to 3e-2…3e-3 — comfortably in the center of FP16's dynamic range.
14scaled = to_fp16(grad_fp32 * S) # ≈ 3.28e-4, now in FP16 range

Multiply the FP32 gradient by 32768 THEN quantize to FP16. The product 1e-8 × 32768 = 3.28e-4 lives squarely inside the FP16 normal range, so the cast is virtually lossless.

EXECUTION STATE
grad_fp32 * S = 1e-08 × 32768 = 3.2768e-4 (FP32 value before the cast)
⬆ scaled = 3.275871e-04 (nearest FP16 value — tiny rounding, big survival)
→ checkpoint = The gradient information is preserved. Without scaling, step 9 turned it into 0; with scaling, it became a healthy 3.28e-4.
15recovered = scaled / S # unscale in FP32

Divide by S to recover the true gradient magnitude. The division happens in FP32, so we don't introduce FP16 rounding here. The result is almost equal to grad_fp32 — the tiny discrepancy is the FP16 mantissa rounding error, not a vanished gradient.

EXECUTION STATE
/ (division operator) = Element-wise scalar division in FP32. FP32 has 23 mantissa bits → relative error ~1e-7.
⬆ recovered = 9.997166e-09 (vs. original 1.00e-08) — relative error ≈ 3e-4, introduced entirely by the single FP16 step at line 14.
→ intuition = Loss scaling trades a few bits of mantissa rounding (acceptable) for keeping the gradient off the zero-floor (essential). The unscale-in-FP32 step is what avoids a second round-trip loss.
17# ---------- CRISIS 2 — precision loss ----------

Section header for the second failure mode. Even if you successfully recover a gradient, adding a tiny lr·grad to a weight of order ~1 in FP16 can still fail — not by overflow, but by being rounded away.

18W = np.float32(1.0)

A representative weight value. Many late-training transformer weights have magnitudes in the 0.1–1 range.

EXECUTION STATE
W = 1.0
→ FP16 ULP at 1.0 = 2⁻¹⁰ = 0.0009766. The gap between 1.0 and the next representable FP16 value is about 1e-3. Any update smaller than half of that (≈5e-4) rounds to 1.0 exactly.
19lr = np.float32(1e-3)

Learning rate. 1e-3 is typical for Adam-style optimizers in transformers.

EXECUTION STATE
lr = 0.001
20grad = np.float32(5.0e-2)

A moderately sized gradient — comfortably within FP16's representable range. This scenario is the tricky one: the grad itself is fine, the problem is what happens when we multiply by lr and add to W.

EXECUTION STATE
grad = 0.05
21update = lr * grad # = 5e-5

The actual change we want to apply to W. 0.001 × 0.05 = 5e-5 — well below FP16's ULP at 1.0 (≈1e-3). This is the silent killer.

EXECUTION STATE
update = 5e-5
→ math sanity check = 5e-5 < 5e-4 = ½·ULP_FP16(1.0), so rounding to FP16 ON TOP of 1.0 will drop it.
23W_after_fp32 = W + update # -> 1.00004995

Do the update in FP32. The sum is 1.00005 — the tiny change is fully preserved. FP32's ULP at 1.0 is 1.19e-7, so 5e-5 sits ~420 ULPs away from 1.0 — lots of headroom.

EXECUTION STATE
W + update = 1.00004995
24W_after_fp16 = np.float16(W + update)

Do the same sum and cast the result to FP16. The FP32 sum is 1.00005, but quantizing to FP16 snaps it back to the nearest representable value — which is exactly 1.0. The update vanishes.

EXECUTION STATE
np.float16(1.00004995) = 1.0 ← swallowed
→ why? = FP16's representable values near 1.0 are {…, 0.99902, 1.0, 1.00098, …}. 1.00004995 is closer to 1.0 than to 1.00098, so round-to-nearest-even picks 1.0.
25# cast to FP16 rounds back to exactly 1.0 -- the update was swallowed

The take-away. Even though we computed the correct update, storing W in FP16 destroyed it during the write-back. We would need to keep W itself in FP32 to escape this.

27# ---------- FIX 2 — FP32 master weight + FP16 compute copy ----------

Section header for the fix: hold the weights in FP32 as the permanent source of truth, and only materialize an FP16 copy at the top of each forward pass.

28W_master = np.float32(1.0) # source of truth lives in FP32

The master weight. The optimizer only ever updates this tensor; FP16 tensors are strictly derived copies that can be discarded after each step.

EXECUTION STATE
W_master = 1.0
29for step in range(3):

Three training steps. We'll see the master weight move visibly while an FP16-only weight would not.

LOOP TRACE · 3 iterations
step = 0
W_fp16 = to_fp16(W_master) = to_fp16(1.0) = 1.0
W_master -= lr*grad = 1.0 - 1e-3 × 0.05 = 0.99995
step = 1
W_fp16 = to_fp16(W_master) = to_fp16(0.99995) = 0.99951 (FP16 rounds to nearest ≈1/1024 gridpoint)
W_master -= lr*grad = 0.99995 - 5e-5 = 0.99990
step = 2
W_fp16 = to_fp16(W_master) = to_fp16(0.99990) = 0.99951
W_master -= lr*grad = 0.99990 - 5e-5 = 0.99985
30W_fp16 = to_fp16(W_master) # forward/backward use this copy

At the top of each step we cast the master weight to FP16. That FP16 copy is what participates in matmuls on the tensor cores. It's OK that this copy is quantized — we only ever READ from it, never update it.

EXECUTION STATE
→ key insight = The FP16 copy is disposable. If we lose precision casting W_master→W_fp16, the loss shows up as a slightly noisy forward pass (which the optimizer absorbs). But the master is never corrupted.
31W_master = W_master - lr * grad # update done in FP32 (precise)

The optimizer update runs in FP32. The FP32 ULP at 1.0 is ~1e-7, so updates of size 5e-5 are preserved cleanly. Over many steps the tiny changes accumulate into something large.

EXECUTION STATE
lr * grad = 1e-3 × 0.05 = 5e-5 (done in FP32 — lossless)
→ after 3 steps = 1.0 → 0.99995 → 0.99990 → 0.99985 (real movement)
33# W_master after 3 steps: 0.99984998

The master weight has moved by ~1.5e-4 across three steps. This is exactly the behavior you want.

34# Without master (FP16-only): 1.00000000 <-- no learning at all

For contrast: if we stored W itself in FP16 and did W = np.float16(W - lr*grad) every step, the update would be swallowed every time and W would stay pinned at 1.0 forever. The master-weight pattern is not an optimization — it is a correctness requirement for FP16 training.

7 lines without explanation
1import numpy as np
2
3def to_fp16(x):
4    """Quantize x through FP16 and return as FP32 for arithmetic."""
5    return np.float16(x).astype(np.float32)
6
7# ---------- CRISIS 1 — underflow ----------
8grad_fp32 = np.float32(1.0e-8)
9grad_fp16 = to_fp16(grad_fp32)
10# grad_fp32 = 1e-08      grad_fp16 = 0.0   <-- disappeared
11
12# ---------- FIX 1 — loss scaling ----------
13S = np.float32(2**15)                # loss scale = 32768
14scaled    = to_fp16(grad_fp32 * S)   # ≈ 3.28e-4, now in FP16 range
15recovered = scaled / S               # unscale in FP32
16
17# ---------- CRISIS 2 — precision loss ----------
18W    = np.float32(1.0)
19lr   = np.float32(1e-3)
20grad = np.float32(5.0e-2)
21update = lr * grad                   # = 5e-5
22
23W_after_fp32 = W + update            # -> 1.00004995
24W_after_fp16 = np.float16(W + update)
25# cast to FP16 rounds back to exactly 1.0 -- the update was swallowed
26
27# ---------- FIX 2 — FP32 master weight + FP16 compute copy ----------
28W_master = np.float32(1.0)           # source of truth lives in FP32
29for step in range(3):
30    W_fp16   = to_fp16(W_master)     # forward/backward use this copy
31    W_master = W_master - lr * grad  # update done in FP32 (precise)
32
33# W_master after 3 steps: 0.99984998
34# Without master (FP16-only): 1.00000000  <-- no learning at all

Three numerical facts that drop out of the simulation:

  • Underflow is total, not graceful. 1e-8 cast to FP16 is not "a tiny positive number" — it is exactly 0.0. Whatever downstream computation depended on it is dead.
  • Loss scaling is (almost) lossless. Multiplying by 2¹⁵ and dividing back recovered the gradient within ~3×10⁻⁴ relative error. That is one FP16 rounding, not cumulative decay.
  • Master-weight-free FP16 training is strictly broken. Our last print line shows that an FP16-only weight does not move at all over three steps. Full stop, this is not a tuning issue — the number system simply cannot hold the update.

Dynamic Loss Scaling — A From-Scratch Simulator

The GradScaler we relied on in the previous example is not magic — it is roughly 40 lines of Python. Below is a faithful port of PyTorch's growth/backoff algorithm in pure NumPy, followed by a 20-step simulation with one deliberate overflow at step 10 so you can watch SS double, halve, and recover in real numbers.

Dynamic Loss Scaler — NumPy Port of GradScaler.update()
🐍dynamic_loss_scaler.py
1import numpy as np

The only dependency. We reuse NumPy's FP32 arithmetic and its np.isfinite / np.any / np.all reductions — exactly the primitives PyTorch's C++ GradScaler is built on, but in plain Python.

EXECUTION STATE
numpy = Provides np.float32, np.random.randn, np.isfinite (checks each element is neither inf nor NaN), np.any, np.all, np.abs.
3class DynamicLossScaler:

Our stand-in for torch.amp.GradScaler. It holds a single mutable number S (the loss scale) plus a counter, and exposes one method — step(raw_grad) — that decides whether the optimizer should run this iteration.

4Docstring: "Minimal port of PyTorch's GradScaler.update() semantics."

Advertises scope: we implement only the growth/backoff logic, not the full wrap-around-loss-and-call-backward plumbing. That part is identical to the PyTorch example you just read.

6def __init__(self, init_scale=2**15, growth_factor=2.0, backoff_factor=0.5, growth_interval=2000):

Constructor with the same defaults torch.amp.GradScaler uses. Every argument has a precise role:

EXECUTION STATE
⬇ arg: init_scale=2**15 = Starting S = 32768. Powers of two keep multiplication lossless (only the exponent field moves). Sits in the middle of FP16's normal range.
⬇ arg: growth_factor=2.0 = When things are calm, S doubles. Keeping the factor at 2.0 (a power of two) means the new S is still exactly representable — no drift.
⬇ arg: backoff_factor=0.5 = On overflow, S is halved immediately. Also a power of two for the same reason.
⬇ arg: growth_interval=2000 = How many consecutive clean steps must pass before we try growing S. The real default is 2000; we use 5 in the simulation below so growth events are visible in 20 steps.
8self.S = float(init_scale)

Store S as a Python float (which is FP64 under the hood — doesn't matter, we only hold one scalar). Using float() unwraps any NumPy/torch wrappers the caller might have passed.

EXECUTION STATE
self.S (initial) = 32768.0
9self.growth_factor = growth_factor

Save the growth multiplier. Used inside step() when clean_steps hits the growth interval.

10self.backoff_factor = backoff_factor

Save the backoff multiplier. Applied immediately on any overflow.

11self.growth_interval = growth_interval

Save the growth interval so step() can compare clean_steps against it.

12self.clean_steps = 0

Counter: how many consecutive non-overflowing steps we have seen. Resets to 0 on any overflow OR on a successful growth event.

EXECUTION STATE
self.clean_steps = 0
14def step(self, raw_grad):

The one method that does real work. Given a raw (unscaled) gradient tensor, return either the recovered gradient (if clean) or None (if overflow was detected — caller should skip the optimizer step).

EXECUTION STATE
⬇ input: raw_grad = An np.ndarray of FP32 gradients, as would come out of backward() before any scaling. In the simulation these are ~1e-4 magnitude.
⬆ returns = np.ndarray (FP32) recovered grad — OR — None (caller skips optimizer.step()).
15# Would-be scaled gradient in FP16 range

Comment: this line is where the histogram shift happens. Multiplying by S is what lifts tiny grads above FP16's underflow cliff.

16scaled = raw_grad * self.S

Element-wise multiply. If raw_grad has magnitude ~1e-4 and S = 32768, then scaled has magnitude ~3.3 — well inside FP16's normal band.

EXECUTION STATE
raw_grad * self.S (step 0 sample) = randn() × 1e-4 × 32768 ≈ 3.3 × randn()
⬆ scaled (step 10, after ×1e6 spike) = randn() × 100 × 32768 ≈ 3.3e6 × randn() — will trip the >65504 check below.
17# Detect overflow as PyTorch does — any non-finite element

Comment explaining the next line. PyTorch's C++ scaler runs an "amp_inf_scan" kernel that ORs an inf/NaN flag across all gradient elements.

18overflow = not np.all(np.isfinite(scaled))

np.isfinite returns a boolean array (True where the value is neither inf nor NaN). np.all collapses it. If ANY element is non-finite, np.all is False, and overflow becomes True.

EXECUTION STATE
📚 np.isfinite(x) = Element-wise: True for a normal/subnormal/zero finite number, False for ±inf or NaN. Works on any NumPy dtype.
📚 np.all(bool_array) = True iff every element is True. Short-circuits to O(n) but doesn't stop early — walks the full array.
⬆ overflow (clean step) = False
19# Also check FP16's representable max

Comment: on a real GPU the overflow is detected because the value actually became inf when cast to FP16. Here we are simulating, so we check the boundary 65504 by hand.

20if np.any(np.abs(scaled) > 65504):

Belt-and-suspenders overflow check for the simulation. np.abs gives magnitudes; the comparison yields a boolean array; np.any is True if even one element exceeds 65504 (FP16's max finite value).

EXECUTION STATE
📚 np.any(bool_array) = True if at least one True. Complement of np.all.
→ trigger (step 10) = scaled values reach ~3.3e6 after the ×1e6 spike → condition True → overflow becomes True.
21overflow = True

Only reached when the boundary check fires. Latches overflow regardless of what the isfinite check said.

23if overflow:

Branch on the overflow flag. This is where the growth/backoff asymmetry lives.

24self.S *= self.backoff_factor # halve on overflow

Immediate halving. Example: S = 32768 × 0.5 = 16384 after one overflow. If more overflows follow, S keeps halving — this is the mechanism that hunts for a safe scale automatically.

EXECUTION STATE
→ step 10 effect = S: 65536.0 → 32768.0 (halved)
25self.clean_steps = 0

Reset the clean-step counter so we require another full growth_interval of clean steps before attempting to grow again.

26return None # signal: skip optimizer step

Returning None is our in-band "skip" signal. PyTorch's GradScaler does exactly this: scaler.step(optimizer) becomes a no-op, the parameters stay put, and scaler.update() then halves S.

27else:

Clean-step branch. No inf/NaN and no >65504 elements — the step is safe to apply.

28self.clean_steps += 1

Increment the consecutive-clean-step counter.

EXECUTION STATE
→ after step 0 = clean_steps = 1
→ after step 4 = clean_steps = 5 (hits growth_interval)
29if self.clean_steps >= self.growth_interval:

Time to try doubling S. In the simulation growth_interval = 5, so this fires every 5 clean steps. In production it's 2000 — you'll almost never see this fire mid-epoch once training is stable.

30self.S *= self.growth_factor

Double S. If no overflows follow for another growth_interval steps, we'll double again — S exponentially hunts the largest safe scale.

EXECUTION STATE
→ step 4 effect = S: 32768.0 → 65536.0 (doubled)
31self.clean_steps = 0

Reset the counter so the NEXT growth event requires another full interval. Without this reset S would double on every single step after the first success.

32return scaled / self.S # recover true gradient in FP32

Unscale the gradient so the optimizer sees the mathematically correct value. The division runs in FP32 (we never stored scaled in FP16 in this simulation), so no extra rounding is introduced by the unscale step.

EXECUTION STATE
⬆ returned grad (step 0) = scaled / 32768.0 ≈ raw_grad (within FP32 rounding)
34# Simulate: 20 steps, one deliberate overflow at step 10

Section header for the driver loop.

35np.random.seed(0)

Deterministic gradient stream so the numbers in this trace are reproducible. Seeding before the loop — not inside — means all 20 calls to randn draw from the same stream.

36scaler = DynamicLossScaler(init_scale=2**15, growth_interval=5)

Instantiate with a short growth interval (5 instead of the real default 2000) so we can watch S double several times in only 20 steps.

EXECUTION STATE
scaler.S = 32768.0
scaler.growth_interval = 5
scaler.clean_steps = 0
37for step in range(20):

Run 20 steps. Step 10 is the deliberately-bad one. Watch S double at 4, 9, 19 (every 5 clean steps) and halve at 10.

LOOP TRACE · 5 iterations
step = 0 (first clean step)
g = randn(4) × 1e-4 ≈ [1.76e-4, 4.0e-5, 9.8e-5, 2.2e-4]
scaled = g × 32768 ≈ [5.77, 1.31, 3.22, 7.35] (all < 65504)
overflow = False
scaler.S (after) = 32768.0 (unchanged)
clean_steps = 1
return = scaled / 32768 ≈ g (ok)
step = 4 (fifth clean step → triggers growth)
clean_steps before ++ = 4
clean_steps after ++ = 5
5 >= growth_interval(5) = True
scaler.S = 32768.0 × 2.0 = 65536.0 (doubled!)
clean_steps reset = 0
step = 9 (fifth clean step AGAIN → doubles again)
scaler.S before = 65536.0
scaler.S after = 131072.0 (= 2¹⁷)
clean_steps = 0
step = 10 (OVERFLOW — g is multiplied by 1e6)
g (after ×1e6) = randn(4) × 1e2 ≈ [O(100)]
scaled = g × 131072 = ~1.3e7 (far above 65504)
np.any(|scaled| > 65504) = True
overflow = True
scaler.S = 131072.0 × 0.5 = 65536.0 (halved!)
clean_steps = 0
return = None (caller skips optimizer.step)
step = 15 (fifth clean step AFTER the overflow → grow back)
scaler.S before = 65536.0
scaler.S after = 131072.0 (growing again)
38g = np.random.randn(4).astype(np.float32) * 1e-4

Draw 4 standard-normal samples, cast to FP32 to match a realistic training tensor, and scale to ~1e-4 magnitude — the order of typical late-training Adam gradients.

EXECUTION STATE
📚 np.random.randn(4) = Shape (4,) tensor of N(0,1) samples. Uses the seeded stream from line 35, so each step draws different values deterministically.
📚 .astype(np.float32) = NumPy's randn defaults to FP64; we narrow to FP32 so the arithmetic mirrors a real GPU tensor.
39if step == 10:

Set up the deliberate overflow.

40g *= 1e6 # force overflow

Scale the gradient by a million. Original magnitude 1e-4 × 1e6 = 100 → scaled (×131072 because S just doubled twice) = ~1.3e7, which trips the >65504 branch. This simulates a real training pathology: a bad batch, a diverging layer, or a numerical instability in a custom op.

41out = scaler.step(g)

Feed the gradient through the scaler. For clean steps this returns the unscaled gradient; for step 10 it returns None.

EXECUTION STATE
→ out (clean step) = np.ndarray ≈ g
→ out (step 10) = None ← skip optimizer
42action = "skip" if out is None else "ok"

A tiny inline-if for printing. Matches PyTorch's mental model: either the optimizer.step went through (ok) or it was skipped (skip).

43print(f"step {step:02d} S={scaler.S:>10.0f} clean={scaler.clean_steps} {action}")

Format string with :02d zero-padding, :>10.0f right-aligned no-decimals for S. Output walks the S trajectory: starts at 32768, doubles at steps 4 and 9 to 131072, halves at step 10 back to 65536, then climbs again.

EXECUTION STATE
⬆ expected output (first 12 lines) = step 00 S= 32768 clean=1 ok step 01 S= 32768 clean=2 ok step 02 S= 32768 clean=3 ok step 03 S= 32768 clean=4 ok step 04 S= 65536 clean=0 ok ← first double step 05 S= 65536 clean=1 ok step 06 S= 65536 clean=2 ok step 07 S= 65536 clean=3 ok step 08 S= 65536 clean=4 ok step 09 S= 131072 clean=0 ok ← second double step 10 S= 65536 clean=0 skip ← overflow, halved step 11 S= 65536 clean=1 ok
6 lines without explanation
1import numpy as np
2
3class DynamicLossScaler:
4    """Minimal port of PyTorch's GradScaler.update() semantics."""
5
6    def __init__(self, init_scale=2**15, growth_factor=2.0,
7                 backoff_factor=0.5, growth_interval=2000):
8        self.S = float(init_scale)
9        self.growth_factor   = growth_factor
10        self.backoff_factor  = backoff_factor
11        self.growth_interval = growth_interval
12        self.clean_steps     = 0
13
14    def step(self, raw_grad):
15        # Would-be scaled gradient in FP16 range
16        scaled = raw_grad * self.S
17        # Detect overflow as PyTorch does — any non-finite element
18        overflow = not np.all(np.isfinite(scaled))
19        # Also check FP16's representable max
20        if np.any(np.abs(scaled) > 65504):
21            overflow = True
22
23        if overflow:
24            self.S *= self.backoff_factor   # halve on overflow
25            self.clean_steps = 0
26            return None                     # signal: skip optimizer step
27        else:
28            self.clean_steps += 1
29            if self.clean_steps >= self.growth_interval:
30                self.S *= self.growth_factor
31                self.clean_steps = 0
32            return scaled / self.S          # recover true gradient in FP32
33
34# Simulate: 20 steps, one deliberate overflow at step 10
35np.random.seed(0)
36scaler = DynamicLossScaler(init_scale=2**15, growth_interval=5)
37for step in range(20):
38    g = np.random.randn(4).astype(np.float32) * 1e-4
39    if step == 10:
40        g *= 1e6                             # force overflow
41    out = scaler.step(g)
42    action = "skip" if out is None else "ok"
43    print(f"step {step:02d}  S={scaler.S:>10.0f}  clean={scaler.clean_steps}  {action}")

Now the PyTorch counterpart. The only new pieces are the import path and scaler.get_scale() for introspection — every argument of the constructor maps directly onto a field of our NumPy class, and scaler.update() is the exact method that implements the growth/ backoff logic we just traced by hand.

PyTorch GradScaler — The Production Counterpart
🐍pytorch_gradscaler.py
1import torch

Core PyTorch. We need it so torch.amp.GradScaler has a live device to pin to.

2from torch.amp import GradScaler

The unified (PyTorch ≥ 2.4) import path for the dynamic loss scaler. In legacy code you may still see `torch.cuda.amp.GradScaler` — same class, same algorithm.

EXECUTION STATE
📚 GradScaler = The C++/Python class that implements EXACTLY the growth/backoff logic you just traced by hand. Internally holds a scale tensor (FP32, on device) and a growth tracker.
4scaler = GradScaler(device='cuda', init_scale=2.**15,

Construct the scaler. Each argument maps 1:1 to a field on our NumPy DynamicLossScaler.

EXECUTION STATE
⬇ arg: device='cuda' = Pins the internal scale tensor to the CUDA device so no host↔device transfer is needed when it multiplies the loss.
⬇ arg: init_scale=2.**15 = Starting S = 32768.0. Note the `.` — `2.**15` is a Python float so the scale is FP64 internally before it is moved to the device as FP32.
5growth_factor=2.0, backoff_factor=0.5,

Identical to our NumPy constants. PyTorch enforces these as floats; they become attributes of the scaler and are read inside its step/update methods.

EXECUTION STATE
⬇ arg: growth_factor=2.0 = Same doubling behavior. Non-power-of-two values would work but would slowly drift S's mantissa — best left at 2.0.
⬇ arg: backoff_factor=0.5 = Same halving behavior. Applied by scaler.update() when the previous step was skipped.
6growth_interval=2000)

Production default. Means you need 2000 consecutive clean steps to trigger a doubling. For a typical large-model run S stabilizes within the first few epochs.

EXECUTION STATE
⬇ arg: growth_interval=2000 = PyTorch's tuned default. Larger values make S conservative (fewer overflow-and-backoff cycles) but adapt slowly to regime changes.
→ vs NumPy sim = Our simulation used 5 so growth was visible in 20 steps. The real default of 2000 would mean the scale never doubles in a short demo — visually boring, but correct for production.
8print("scale at start:", scaler.get_scale())

scaler.get_scale() returns the current S as a plain Python float (synchronizes with the device tensor). Exactly analogous to reading self.S in our NumPy port.

EXECUTION STATE
📚 scaler.get_scale() = Reads the device-resident scale tensor and returns its scalar value. Handy for logging; don't call it every step in hot loops (forces a CUDA sync).
⬆ expected print = scale at start: 32768.0
9# The skeleton of a step — see main AMP block above for the full loop:

Cross-reference to the earlier CodeExplanation. The full training step is scaler.scale(loss).backward() → scaler.step(optimizer) → scaler.update().

10# with autocast(...): loss = ...

Inside the autocast region the forward pass runs in FP16/BF16. The loss is returned in FP32 because losses are on the safe list.

11# scaler.scale(loss).backward()

scaler.scale(loss) returns loss × S (lossless because S is a power of two). .backward() propagates the scaled scalar through the graph, so every .grad tensor lands in memory with the same multiplicative factor S baked in.

12# scaler.step(optimizer)

Internally: (1) divide every .grad by S in FP32, (2) scan for inf/NaN, (3) if clean → call optimizer.step(); if dirty → skip. This is where our NumPy step(raw_grad) returning None maps directly to a skipped optimizer.step.

13# scaler.update()

THE method that implements the growth/backoff logic we just traced by hand. If step() saw inf/NaN, update() halves S. Otherwise it increments a counter and doubles S once the counter reaches growth_interval. Calling it once per training step is mandatory.

EXECUTION STATE
📚 scaler.update() = Exact C++ analog of the if/else in our NumPy step(): on overflow S *= backoff_factor; else clean_steps += 1 and maybe S *= growth_factor.
14# print("scale after step:", scaler.get_scale())

Logging pattern. In production you'd log scaler.get_scale() to TensorBoard/WandB to watch S stabilize over the first thousand steps of training.

2 lines without explanation
1import torch
2from torch.amp import GradScaler
3
4scaler = GradScaler(device='cuda', init_scale=2.**15,
5                    growth_factor=2.0, backoff_factor=0.5,
6                    growth_interval=2000)
7
8print("scale at start:", scaler.get_scale())
9# The skeleton of a step — see main AMP block above for the full loop:
10# with autocast(...): loss = ...
11# scaler.scale(loss).backward()
12# scaler.step(optimizer)
13# scaler.update()
14# print("scale after step:", scaler.get_scale())

The Same Thing in PyTorch: autocast + GradScaler

Production code doesn't do the casts by hand — PyTorch's AMP (automatic mixed precision) module wires everything up for us. The eight-line pattern below is the canonical training step you'll see in NanoGPT, LLaMA trainers, Megatron-LM, and most modern open-source transformer code. Every step in the skeleton maps directly to something we did by hand above.

PyTorch AMP — The Canonical Training Step
🐍pytorch_amp_step.py
1import torch

Core PyTorch: tensors, autograd, the CUDA stream scheduler, and the `.cuda()` device helpers we use on the very next few lines.

EXECUTION STATE
torch = The root module. We'll use torch.float16, torch.float32, torch.randn (creates tensors), and torch.optim (optimizers).
2import torch.nn as nn

torch.nn holds the Module classes. We only need nn.Linear (a weight matrix + bias) and nn.MSELoss (mean-squared-error) here.

3from torch.amp import autocast, GradScaler

The two entry points to PyTorch's Automatic Mixed Precision. `autocast` is a context manager that replaces eligible ops with their FP16 (or BF16) kernels on the fly. `GradScaler` tracks and adjusts the loss scale between training steps.

EXECUTION STATE
📚 autocast = Context manager. Inside `with autocast(...):`, PyTorch consults a cast-list per op: matmul/conv → FP16, softmax/layernorm → FP32, etc. Ops outside the block are untouched.
📚 GradScaler = Stateful helper that (1) multiplies the loss by a scale S before backward, (2) unscales gradients before optimizer.step, (3) checks for inf/NaN and skips the step if overflow happened, (4) increases S when things are calm and halves it when overflow occurs.
→ import path note = `torch.amp` is the unified API (PyTorch ≥ 2.4). Older code imported from `torch.cuda.amp` — same objects, legacy path still works.
5# ---------- Model, data, optimizer (all FP32 by default) ----------

Setup section. The important fact is that the MODEL WEIGHTS are FP32 the whole time. Autocast does not change stored parameters — it only changes the dtype of intermediate computations.

6model = nn.Linear(512, 512).cuda()

A single linear layer for demo. Constructor creates two parameters: `weight` of shape (512, 512) and `bias` of shape (512). `.cuda()` moves them to the default GPU.

EXECUTION STATE
📚 nn.Linear(in, out) = Weight W ∈ R^{out×in}, bias b ∈ R^{out}. Forward: y = x W^T + b.
→ weight dtype = torch.float32 — default for nn.Linear. Stays FP32 for the entire training run.
→ parameter count = 512 × 512 + 512 = 262,656 params × 4 B = 1.05 MB weights
7optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

Adam keeps two FP32 buffers per parameter (first and second moments — m and v). These are the 'master weights' from our NumPy demo, living inside the optimizer.

EXECUTION STATE
📚 Adam(...) = Adaptive moment estimation. Maintains m (EMA of grads) and v (EMA of grad²). Update: W ← W − lr · m̂ / (√v̂ + ε).
→ memory cost = Adam adds 8 bytes per parameter (2 × FP32). That's why the 'optimizer states' bar dominates training memory — more than weights themselves.
8loss_fn = nn.MSELoss()

Mean squared error. Autocast's cast-list marks most loss functions as FP32-only (softmax, log, exp) — losses are a common source of overflow if computed in FP16.

10# GradScaler — dynamic loss scaling controller

The heart of FP16 training. Without it, underflowing gradients would train your model to a stalemate.

11scaler = GradScaler(device='cuda', init_scale=2.**15)

Construct the scaler. It starts at S = 32768 and will automatically grow up to 2²⁴ ≈ 16M (if no overflow occurs for `growth_interval` = 2000 steps), or shrink by 0.5× whenever any grad becomes inf/NaN.

EXECUTION STATE
⬇ arg: device='cuda' = Tells the scaler which device's tensors it will operate on. Since PyTorch 2.4 this is required; in older code it was just `torch.cuda.amp.GradScaler()`.
⬇ arg: init_scale=2.**15 = Starting value for S. 32768 sits in the middle of FP16's ~10⁻⁴…10⁴ normal band. The scaler will tune this automatically.
→ hidden defaults = growth_factor=2.0, backoff_factor=0.5, growth_interval=2000. Every 2000 clean steps S doubles; any inf/NaN halves S and skips the optimizer step.
13# Synthetic batch

Create random inputs and targets so this snippet is self-contained.

14x = torch.randn(64, 512, device='cuda')

A batch of 64 examples, each 512-dimensional. `device='cuda'` keeps allocation on-GPU.

EXECUTION STATE
📚 torch.randn(...) = Samples from a standard normal N(0,1). Returns a tensor with the requested shape on the requested device.
⬆ x.dtype = torch.float32 (default)
⬆ x.shape = torch.Size([64, 512])
15y = torch.randn(64, 512, device='cuda')

Targets. Same shape and dtype as x.

17# ---------- One training step with mixed precision ----------

Everything above was setup. The next eight lines are the entire mixed-precision training step. You'll see this exact skeleton in production codebases like NanoGPT, LLaMA trainers, and Megatron-LM.

18optimizer.zero_grad()

Reset .grad to zero on every parameter. PyTorch accumulates gradients, so without this the next backward() would add to whatever was there from the previous step.

EXECUTION STATE
📚 optimizer.zero_grad() = Iterates param groups and calls .grad.zero_() (or sets to None if set_to_none=True, which is faster). Must happen BEFORE the backward that produces new grads.
20with autocast(device_type='cuda', dtype=torch.float16):

Open the autocast region. Every eligible op inside this block will be dispatched to its FP16 kernel. The block does two things under the hood: (1) wrap tensor arguments in an implicit .to(FP16) cast when needed, (2) register an op-type → target-dtype mapping for the duration of the block.

EXECUTION STATE
⬇ arg: device_type='cuda' = The device whose autocast rules to use. 'cuda' uses CUDA's cast-list; 'cpu' uses CPU's (which defaults to BF16, not FP16). 'xpu' and 'hpu' exist for other accelerators.
⬇ arg: dtype=torch.float16 = Target narrow dtype. On Ampere+ this can be torch.bfloat16 — same autocast machinery, different cast table and no loss scaling needed.
→ autocast cast-list (partial) = → FP16: matmul, mm, conv1d/2d/3d, linear, bmm, addmm, lstm → FP32: softmax, log_softmax, layer_norm, batch_norm, l1/l2/nll/mse losses, log, exp, pow, sum (in reductions)
21out = model(x) # matmul runs in FP16 on tensor cores

The forward pass. `model(x)` is `x @ W^T + b`, but because we're inside autocast, the matmul is dispatched to the FP16 tensor-core kernel. Inputs get implicit downcasts, the multiply runs at FP16, and the accumulator inside the kernel is FP32 to preserve precision.

EXECUTION STATE
→ x dtype inside autocast = Still torch.float32 externally, but when fed into `F.linear`, autocast downcasts the matmul inputs to FP16 on the fly.
→ out dtype = torch.float16 — the matmul output comes back at the narrow dtype. This is the expected behavior.
→ tensor-core throughput = On A100: FP16/BF16 matmul delivers up to 312 TFLOPS vs 19.5 TFLOPS for FP32 — a 16× hardware speed-up.
22loss = loss_fn(out, y) # MSELoss output stays FP32 (safe list)

Even though `out` is FP16, MSELoss is on autocast's FP32-only list. The losses are computed by upcasting inputs and returning an FP32 scalar. This is crucial: a loss in FP16 might overflow, and backprop needs a numerically precise scalar to start from.

EXECUTION STATE
→ loss.dtype = torch.float32
→ why? = MSELoss internally does `(out - y)**2`.mean(). Squaring FP16 can easily exceed 65504 (e.g. out − y = 300 → squared = 90000). Autocast's safety net sends it to FP32.
24scaler.scale(loss).backward() # multiply loss by S, then backprop

Left to right: `scaler.scale(loss)` returns `loss * S` (still FP32, because scale is cheap). `.backward()` runs reverse-mode autodiff from this scaled scalar. By the chain rule, every gradient in the graph is also multiplied by S — which is exactly the histogram shift we need.

EXECUTION STATE
📚 scaler.scale(loss) = Returns a new tensor `loss * S`. `S` is whatever the scaler's current scale is (32768 initially). Multiplication by a power of two is exact, so no rounding error.
📚 .backward() = Reverse mode autodiff. Walks the graph from this scalar. Because we scaled the root, every ∂loss/∂θ along the way is also scaled by S.
→ what becomes of the grads? = After this call, `p.grad` for each parameter p holds S × (true gradient). These are SCALED grads — you must NOT run optimizer.step on them directly.
25scaler.step(optimizer) # unscale grads, skip step if inf/NaN

Does three jobs in one call. (1) Unscale every .grad by dividing by S (in FP32). (2) Check for inf/NaN in the unscaled grads. (3) If clean, call optimizer.step(); if not, skip the step entirely — because a step with garbage grads would poison the master weights.

EXECUTION STATE
📚 scaler.step(opt) = Internally: unscale → amp_inf_scan → opt.step() OR skip. Importantly, if skipped, the next line (scaler.update) will halve S so the next step runs with a smaller scale.
→ anatomy of the skip = If any grad = ±inf, the step is a no-op; parameters stay put. You'll see the loss value not change for one step — completely expected, no cause for alarm.
26scaler.update() # adjust S for next step (x2 or ÷2)

Dynamic loss scaling. If the last step had no inf/NaN, a success counter increments; once `growth_interval` successful steps pass, S *= 2. If the last step had inf/NaN, S *= 0.5 immediately. This feedback loop means you don't need to tune the scale by hand — it finds its level.

EXECUTION STATE
→ steady state = In a typical transformer training run, S hovers between 2¹³ and 2¹⁷ — large enough to rescue gradients, small enough to avoid overflow.
→ debugging tip = If `scaler.get_scale()` keeps shrinking to 1 (or 0), something upstream is producing inf/NaN even without scaling — usually a div-by-zero in a custom layer, not a mixed-precision issue.
28print(loss.dtype, out.dtype, model.weight.dtype)

Sanity check. Confirms the dtype choreography: loss is FP32 (safe), out is FP16 (fast), weight is FP32 (precise master copy). If any of these is surprising, your autocast region is wider or narrower than you think.

EXECUTION STATE
⬆ expected output = torch.float32 torch.float16 torch.float32
29# torch.float32 torch.float16 torch.float32

The three dtypes summarize the entire mixed-precision contract: fast narrow compute (FP16), precise accumulators and losses (FP32), and permanent master weights (FP32).

7 lines without explanation
1import torch
2import torch.nn as nn
3from torch.amp import autocast, GradScaler
4
5# ---------- Model, data, optimizer (all FP32 by default) ----------
6model     = nn.Linear(512, 512).cuda()
7optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
8loss_fn   = nn.MSELoss()
9
10# GradScaler — dynamic loss scaling controller
11scaler = GradScaler(device='cuda', init_scale=2.**15)
12
13# Synthetic batch
14x = torch.randn(64, 512, device='cuda')
15y = torch.randn(64, 512, device='cuda')
16
17# ---------- One training step with mixed precision ----------
18optimizer.zero_grad()
19
20with autocast(device_type='cuda', dtype=torch.float16):
21    out  = model(x)          # matmul runs in FP16 on tensor cores
22    loss = loss_fn(out, y)   # MSELoss output stays FP32 (safe list)
23
24scaler.scale(loss).backward()    # multiply loss by S, then backprop
25scaler.step(optimizer)           # unscale grads, skip step if inf/NaN
26scaler.update()                  # adjust S for next step (x2 or ÷2)
27
28print(loss.dtype, out.dtype, model.weight.dtype)
29# torch.float32   torch.float16   torch.float32
Swap dtype=torch.float16 for dtype=torch.bfloat16, and on Ampere or newer you can usually drop the GradScaler entirely — BF16's matching FP32 range makes loss scaling unnecessary. That three-line simplification is why BF16 has become the default.

Interactive: Training Memory Footprint

How much memory do you actually save? The honest answer is "less than you think for weights, a lot for activations." Pick a model size and an activation budget and compare the four training recipes side-by-side. You can see why a 7B-parameter model fits in an 80-GB GPU in BF16 but not in FP32.

Loading memory footprint comparator…

Connections to Modern Transformer Systems

Every major efficiency trick in modern transformer architectures is, directly or indirectly, a mixed-precision story. Here is how the concept ripples outward.

Flash Attention — Mixed Precision at the Kernel Level

Flash Attention (Dao et al., 2022) was the watershed optimization for long-context transformers. Its headline trick — tiling the attention computation to keep intermediates in on-chip SRAM instead of writing the full N×NN \times N attention matrix to HBM — is a memory story. Its quieter but equally important trick is mixed precision inside the kernel.

The inputs Q,K,VQ, K, V arrive in FP16 or BF16. Inside the kernel, the matmul QKTQ K^{T} runs on tensor cores at that narrow dtype, but the running max and the softmax denominator are maintained in FP32 accumulators. This is Crisis 2 avoidance at the hardware level: if the softmax denominator were stored in FP16, a single sequence with pre-softmax scores above log(65504)11\log(65504) \approx 11 would overflow. Keeping the denominator in FP32 is why Flash Attention gets FP32-quality results at BF16 throughput.

Multi-Head Attention — Where Softmax Stays FP32

The canonical multi-head attention formula is headh=softmax ⁣(QhKhTdk)Vh\text{head}_h = \text{softmax}\!\left(\frac{Q_h K_h^{T}}{\sqrt{d_k}}\right) V_h. Look at where each dtype lives in a BF16 training setup:

  • Qh,Kh,VhQ_h, K_h, V_h — BF16 projections of the input (matmul runs on tensor cores).
  • QhKhTQ_h K_h^{T} — BF16 matmul with FP32 accumulator internally, result written back to BF16.
  • dk\sqrt{d_k} — an FP32 constant; division may be promoted.
  • softmaxalways FP32. The log-sum-exp trick plus the exp/sum would underflow or overflow in BF16's coarse 7-bit mantissa. This is a direct consequence of Crisis 2.
  • softmax()Vh\text{softmax}(\cdot) \cdot V_h — back to BF16 matmul.

That softmax-in-FP32 choice is exactly what PyTorch's autocast cast-list enforces for you (the LayerNorm that precedes the QKV projections — Section 1 — is what gives us the unit-variance input that makes this softmax-in-FP32 economical). You never wrote it; the safe list did.

Positional Encodings — A Quiet FP32 Holdout

Sinusoidal positional encodings are computed once at startup and added to the token embeddings as PE(p,2i)=sin(p/100002i/d)PE(p, 2i) = \sin(p / 10000^{2i/d}) and PE(p,2i+1)=cos(p/100002i/d)PE(p, 2i+1) = \cos(p / 10000^{2i/d}). The problematic piece is 100002i/d10000^{2i/d}: for d=12288d = 12288 (GPT-3's hidden size) and ii near the top, this exponent reaches 10000213.310000 \approx 2^{13.3}. In FP16 that's close to the max representable value; in BF16 it is fine but the result rounds to ~1-bit precision. Production code computes PEPE in FP32 and then casts to the activation dtype, so the precision loss only happens once, at the end.

Rotary (RoPE) and ALiBi encodings have the same discipline: compute the rotation/bias in FP32 and apply it to BF16 activations. The common pattern is "construct in the widest precision, apply in the narrowest".

KV-Cache — Inference Memory is a Precision Problem

Autoregressive inference caches the keys and values from every past token so each new token only needs one new attention step. For a model with LL layers and HH heads of size dkd_k, the KV-cache for a sequence of length NN stores 2LHdkNbytes/elem2 \cdot L \cdot H \cdot d_k \cdot N \cdot \text{bytes/elem}. For LLaMA-70B with a 4K context that is roughly 80 GB in FP16 — larger than the weights themselves. Serving stacks therefore work very hard on KV-cache precision:

  • FP16/BF16 is the starting point. Halves the cache compared to FP32.
  • FP8 KV-cache is supported by TensorRT-LLM and vLLM on H100. Stores K and V in E4M3 or E5M2 with per-tensor scales. Cuts memory and bandwidth in half again, typically with < 0.5 point loss on downstream benchmarks.
  • INT8 KV-cache (vLLM, SmoothQuant) goes further using asymmetric per-channel quantization — more complex but a common choice for cost-constrained serving.

The tradeoff is identical to training's: narrower storage means more tokens per GPU second, at the cost of small numerical drift that teams measure with perplexity and task-level evals.

FP8 and the Scaling Frontier

Hopper (H100, 2022) introduced FP8 tensor cores and two formats: E4M3\text{E4M3} (4 exponent, 3 mantissa, max ≈ 448) for forward activations, and E5M2\text{E5M2} (5 exponent, 2 mantissa, max ≈ 57344) for gradients, where range matters most. NVIDIA's Transformer Engine library automates the casts and manages per-tensor scale factors the way GradScaler manages a single global scale.

With only 3 mantissa bits in E4M3 the ULP at 1.0 is 0.125 — an enormous precision hit. FP8 therefore demands per-tensor scaling: each tensor gets its own floating-point multiplier, chosen so that the tensor's actual dynamic range lands in the center of FP8's band. The scales themselves are FP32 and computed from recent tensor statistics.

For example, if an activation tensor's observed amax (max absolute value) is 6.4 and we are targeting FP8 E4M3 (max representable ≈ 448), the per-tensor scale factor is s=448/6.4=70s = 448 / 6.4 = 70. We store sxs \cdot x in FP8 and remember ss (as FP32) alongside the tensor; at read-time we divide by ss. NVIDIA's Transformer Engine updates ss from the observed amax on every forward pass — an online EMA very similar to BatchNorm's running statistics, but applied per-tensor per-step. The cost is one extra FP32 scale per tensor (4 bytes) and one division per read — trivial compared to the memory and bandwidth saved by storing the bulk tensor in 1-byte FP8.

Hardware peak TFLOPS ratios (e.g., 2× FP8 over BF16 on H100 per NVIDIA's Hopper white paper) do not translate 1:1 to end-to-end training throughput — real-world gains are typically 1.5–2× vs BF16, limited by memory bandwidth, data loading, and non-matmul layers.

The upshot: every generation of the mixed-precision story is the same pattern at finer granularity. FP32 training used one precision everywhere. FP16 AMP introduced one global scale (the GradScaler). FP8 training introduces per-tensor scales. The logical endpoint — per-group scales like those used by MXFP8 and fine-grained INT4 quantization in inference — is already shipping.


Tradeoffs at a Glance

RecipeSpeed vs FP32Memory vs FP32Numerical gotchasWhen to pick it
FP32 baselineNone — numerics are trivial.Research prototypes on small models, or when debugging a divergence.
FP16 AMP + GradScaler~2–2.5×~0.6–0.7×Underflow if loss scale too low; overflow if too high. GradScaler tunes it.Older hardware (V100, T4). Legacy codebases.
BF16 AMP~2–2.5×~0.6–0.7×Coarser mantissa — slightly noisier convergence. No loss scaling needed.Modern default for large-model training (A100, H100, TPU).
FP8 (E4M3/E5M2)~1.5–2× vs BF16 in practice (hardware peak ratio larger)~0.5×Requires per-tensor scaling; narrow safety margin.Very large models where memory and throughput dominate.
FP4 / INT4 (inference only)~8× vs FP16 on H1000.25×Group-wise quantization; accuracy loss on sensitive layers.Serving frontier — weight quantization, not training.

Summary

Mixed precision training is the reason a 7-billion-parameter model fits on one GPU and a 175-billion-parameter model fits on a single cluster. But it is not a magic switch — it is a careful division of labor between number formats.

  • Compute narrow, store wide. FP16/BF16 on the tensor cores for speed, FP32 master weights and optimizer state for correctness.
  • Use loss scaling when range is tight. FP16's 5-bit exponent can't cover typical gradient magnitudes — shift the histogram with SS. Skip loss scaling with BF16 because its 8-bit exponent already covers FP32's range.
  • Keep softmax, norms, and losses in FP32. These ops overflow or lose precision in half formats. The autocast cast-list encodes the hard-won knowledge of which ops are safe where.
  • Every modern transformer optimization is a precision story. Flash Attention accumulates softmax in FP32. Multi-head attention keeps its softmax in FP32. KV-cache quantization is precision-traded-for-memory. FP8/Transformer Engine replaces one global scale with per-tensor scales.
If FP32 was the era of "just use floats," the modern era is the era of "use the narrowest representation that survives the math of your op, and not a bit narrower." Mastering mixed precision is mastering exactly that decision on every tensor in your network.
Loading comments...