Why Precision Became a Bottleneck
For two decades, deep-learning math was synonymous with a single number format: IEEE 754 single-precision, also called FP32. Every weight, every activation, every gradient was a 32-bit real number. It was simple, numerically safe, and — for AlexNet-sized networks on 2012 hardware — entirely practical.
Then the models grew. GPT-3 has 175 billion parameters. In FP32 the weights alone are , four A100-80GB GPUs full — just to store the model. During training you also need gradients (another 700 GB) and the Adam optimizer state (two FP32 buffers per parameter, another 1.4 TB). We are now past 2.8 TB before any activation has been computed.
At the same time NVIDIA's tensor cores, starting with Volta in 2017, began to offer enormous speed-ups for narrower number formats: roughly 8× for FP16 over FP32 on V100, and 16× on A100. The question was no longer whether to drop precision but how. Mixed precision training is the discipline that answers it without blowing up the numerics.
Anatomy of a Floating-Point Number
An IEEE-754 float is three fields packed into a bit string: a sign bit , an exponent field , and a mantissa field . The value it encodes is , where is the number of mantissa bits and is the exponent bias. The exponent controls the range (how big or small a number can be), and the mantissa controls the precision (how finely numbers are spaced). The three formats that matter for neural networks split that 32-bit budget differently:
| Format | Sign | Exponent | Mantissa | Range (±) | ULP @ 1.0 | Bytes |
|---|---|---|---|---|---|---|
| FP32 | 1 | 8 | 23 | ~3.4×10³⁸ | 1.2×10⁻⁷ | 4 |
| FP16 | 1 | 5 | 10 | ~6.5×10⁴ | 9.8×10⁻⁴ | 2 |
| BF16 | 1 | 8 | 7 | ~3.4×10³⁸ | 7.8×10⁻³ | 2 |
| FP8 (E4M3) | 1 | 4 | 3 | ±448 | 0.125 | 1 |
The key insight is that FP16 and BF16 are not the same 16-bit format — they chose opposite sides of the range-vs-precision trade. FP16 kept FP32's precision philosophy and paid the price in range. BF16 kept FP32's range (same 8 exponent bits) and paid the price in precision.
Interactive: The Same Number in Four Formats
Type any number and see exactly which bits it occupies in each format, what value those bits actually decode back to, and where the encoding overflows or underflows. Use the preset buttons to jump to values that highlight specific regimes — for instance, 1e-8 is a typical tiny gradient that underflows FP16 but survives BF16.
1.0 first and note that FP16 rounds π to 3.140625 while BF16 rounds it to 3.125 — BF16 literally has three fewer mantissa bits. Then try 1e-8: FP16 and FP8 both underflow to zero while BF16 keeps the value. The yellow/blue/pink coloring shows sign/exponent/mantissa regions.The Great 16-Bit Debate: FP16 vs BF16
Before 2018 the default was FP16. Volta and Pascal GPUs had FP16 tensor cores and no native BF16. Mixed precision recipes were built around FP16 — including all the loss-scaling machinery we'll meet shortly.
Starting with Ampere (2020, A100), BF16 became a first-class citizen. And a quiet revolution happened: almost every large-model training run switched from FP16 to BF16. GPT-3, PaLM, Chinchilla, LLaMA, and Gemini were all trained primarily in BF16.
Why? Because gradients are the hardest tensor to keep numerically healthy, and gradients care more about range than precision:
- BF16's range is identical to FP32's. The exponent field is 8 bits in both. A BF16 cast of an FP32 gradient can never underflow or overflow — it can only get rounded.
- FP16's mantissa is precise but its range is tiny. A gradient below ~6×10⁻⁵ becomes subnormal and below ~6×10⁻⁸ vanishes to zero. A gradient above 65504 explodes to ±∞.
- BF16 precision is coarser, but optimizers like Adam effectively apply a running average that washes out per-step quantization noise. You lose a few bits of mantissa and nothing else.
Interactive: Representable Ranges
Rule of thumb: on Ampere-class hardware or newer, default to BF16. Reach for FP16 only when you are targeting older hardware (V100, T4) or have a specific case where FP16's extra precision demonstrably helps — and when you do, budget time for loss-scale tuning.
The Three Numerical Crises of Half-Precision Training
Naive half-precision training — "just cast everything to FP16" — fails in three well-catalogued ways. Each of the fixes we'll meet in the next section exists because of one of these failure modes.
Crisis 1 — Gradient Underflow
FP16 cannot represent positive values smaller than (its smallest positive subnormal). Any gradient below that is quantized to exactly zero. The corresponding weight receives no update at all, as if the layer were frozen.
The Micikevicius et al. (2018) Mixed Precision Training paper (ICLR) documents that for many deep networks a substantial fraction of gradient values — often a majority, depending on the architecture and training phase — falls below FP16's smallest representable positive value (). Their Figure 2 plots the gradient histogram for an SSD-ResNet detector and shows that without scaling the bulk of the distribution is in FP16's underflow region.
Crisis 2 — Activation Overflow
The other end of the range cliff. FP16's maximum is 65504. A self-attention block computing over long sequences can produce pre-softmax scores in the thousands; squaring large activations in an MSE loss can exceed 65000. Once a value becomes , it poisons every downstream tensor (anything × ∞ = ∞; ∞ − ∞ = NaN).
This is why autocast's FP32-only list includes softmax, reductions, and most losses. Operations whose output can grow quadratically — matmul reductions, squares — stay in FP32 precisely to avoid overflow.
Crisis 3 — The Disappearing Weight Update
Suppose we fixed both underflow and overflow. A third, subtler issue remains. Consider a weight and an Adam update .
The FP16 "gridpoints" near 1.0 are spaced by . Any update smaller than half of that (≈ 5×10⁻⁴) rounds to the same bin as W itself. So storing in FP16 gives back exactly W — the update was computed correctly but then thrown away during the write-back.
To see this concretely, drag the "center x" slider and adjust the simulated update magnitude. Watch the green update tick turn red when it falls below half an ULP of FP16.
The Mixed-Precision Recipe
The complete recipe — first published in Micikevicius et al. 2018, Mixed Precision Training — has four ingredients. Each one directly addresses one of the three crises above.
- FP32 master weights. The optimizer keeps a canonical FP32 copy of every parameter. Fixes Crisis 3.
- FP16/BF16 forward & backward. At each step, the master weights are cast to the narrow dtype and used for compute. Cuts compute time and activation memory roughly in half.
- Loss scaling. The loss is multiplied by a scalar before backward. By the chain rule, every gradient is scaled by . Fixes Crisis 1. (Not needed with BF16.)
- FP32 "safe list". Operations that are known to overflow or lose precision (softmax, layer norm, losses, reductions) are kept in FP32. Fixes Crisis 2.
Loss Scaling — Shifting the Histogram
Loss scaling is the most distinctive trick of FP16 training, and it is surprisingly simple. Let be the loss. Instead of backpropagating , we backpropagate . By linearity of differentiation, every gradient becomes — the gradient histogram is translated to the right on a log axis by bits. Tiny gradients that would have fallen off FP16's lower cliff are now safely in the center of the band. After backprop, we divide each gradient by to get the true gradient back for the optimizer step.
In practice we never pick by hand. The dynamic loss scaler starts at a hopeful value (typically 2¹⁵) and adjusts automatically: whenever a gradient overflows to , it halves and skips the step; after many clean steps in a row, it doubles to stay aggressive.
Watch the feedback loop evolve step by step in the simulator below. Step through a training run and observe doubling after stretches of clean steps and halving the moment an overflow is injected.
Interactive: Loss Scaling in Action
Drag the slider. At most of the synthetic gradient distribution falls below FP16's minimum and gets crushed to zero (red on the left). Push up to 2¹⁵ and essentially the whole distribution lands in the representable band. Push it too far and overflow kicks in on the right. This is exactly the feedback loop PyTorch's GradScaler runs on every step.
FP32 Master Weights
The optimizer owns an FP32 copy of every parameter. At the top of each step, it materializes an FP16 copy for the forward and backward pass. After backward, gradients are unscaled and applied to the master copy, not to the FP16 copy. The FP16 copy is then re-materialized for the next step and discarded.
Memory-wise this sounds expensive: we are carrying an extra FP32 tensor. In practice it is dominated by the Adam state, which was already FP32 and already 6× larger than the weights themselves (two moments × 4 B/param = 8 B/param). Mixed precision is mostly about activations and bandwidth, not parameters.
The Autocast Cast List
Rather than leave the user to tag every op, PyTorch ships a curated cast list. A non-exhaustive but representative split:
| Kept in FP32 (safe list) | Cast to FP16/BF16 (fast list) |
|---|---|
| softmax, log_softmax | matmul, mm, bmm, addmm |
| layer_norm, batch_norm, group_norm | conv1d, conv2d, conv3d |
| NLLLoss, CrossEntropyLoss, MSELoss | linear / F.linear |
| log, exp, pow, reciprocal, rsqrt | LSTM / GRU cells |
| sum, prod, cumsum (reductions) | scaled_dot_product_attention |
The pattern is consistent: ops whose output variance scales with input size (matmul reductions, convolutions) run narrow because their hardware kernels do the accumulation in wider precision internally. Ops that compute transcendentals or take log/exp of small quantities stay wide because a single FP16 intermediate would overflow or underflow.
From Scratch: Mixed Precision in NumPy
NumPy has no autograd, no tensor cores, and no GPU — but it has the exact same IEEE-754 FP16 we are trying to understand. That is all we need to reproduce every crisis and every fix by hand. Click any line to see what the computer is holding after that line.
Three numerical facts that drop out of the simulation:
- Underflow is total, not graceful. 1e-8 cast to FP16 is not "a tiny positive number" — it is exactly 0.0. Whatever downstream computation depended on it is dead.
- Loss scaling is (almost) lossless. Multiplying by 2¹⁵ and dividing back recovered the gradient within ~3×10⁻⁴ relative error. That is one FP16 rounding, not cumulative decay.
- Master-weight-free FP16 training is strictly broken. Our last print line shows that an FP16-only weight does not move at all over three steps. Full stop, this is not a tuning issue — the number system simply cannot hold the update.
Dynamic Loss Scaling — A From-Scratch Simulator
The GradScaler we relied on in the previous example is not magic — it is roughly 40 lines of Python. Below is a faithful port of PyTorch's growth/backoff algorithm in pure NumPy, followed by a 20-step simulation with one deliberate overflow at step 10 so you can watch double, halve, and recover in real numbers.
Now the PyTorch counterpart. The only new pieces are the import path and scaler.get_scale() for introspection — every argument of the constructor maps directly onto a field of our NumPy class, and scaler.update() is the exact method that implements the growth/ backoff logic we just traced by hand.
The Same Thing in PyTorch: autocast + GradScaler
Production code doesn't do the casts by hand — PyTorch's AMP (automatic mixed precision) module wires everything up for us. The eight-line pattern below is the canonical training step you'll see in NanoGPT, LLaMA trainers, Megatron-LM, and most modern open-source transformer code. Every step in the skeleton maps directly to something we did by hand above.
dtype=torch.float16 for dtype=torch.bfloat16, and on Ampere or newer you can usually drop the GradScaler entirely — BF16's matching FP32 range makes loss scaling unnecessary. That three-line simplification is why BF16 has become the default.Interactive: Training Memory Footprint
How much memory do you actually save? The honest answer is "less than you think for weights, a lot for activations." Pick a model size and an activation budget and compare the four training recipes side-by-side. You can see why a 7B-parameter model fits in an 80-GB GPU in BF16 but not in FP32.
Connections to Modern Transformer Systems
Every major efficiency trick in modern transformer architectures is, directly or indirectly, a mixed-precision story. Here is how the concept ripples outward.
Flash Attention — Mixed Precision at the Kernel Level
Flash Attention (Dao et al., 2022) was the watershed optimization for long-context transformers. Its headline trick — tiling the attention computation to keep intermediates in on-chip SRAM instead of writing the full attention matrix to HBM — is a memory story. Its quieter but equally important trick is mixed precision inside the kernel.
The inputs arrive in FP16 or BF16. Inside the kernel, the matmul runs on tensor cores at that narrow dtype, but the running max and the softmax denominator are maintained in FP32 accumulators. This is Crisis 2 avoidance at the hardware level: if the softmax denominator were stored in FP16, a single sequence with pre-softmax scores above would overflow. Keeping the denominator in FP32 is why Flash Attention gets FP32-quality results at BF16 throughput.
Multi-Head Attention — Where Softmax Stays FP32
The canonical multi-head attention formula is . Look at where each dtype lives in a BF16 training setup:
- — BF16 projections of the input (matmul runs on tensor cores).
- — BF16 matmul with FP32 accumulator internally, result written back to BF16.
- — an FP32 constant; division may be promoted.
- softmax — always FP32. The log-sum-exp trick plus the exp/sum would underflow or overflow in BF16's coarse 7-bit mantissa. This is a direct consequence of Crisis 2.
- — back to BF16 matmul.
That softmax-in-FP32 choice is exactly what PyTorch's autocast cast-list enforces for you (the LayerNorm that precedes the QKV projections — Section 1 — is what gives us the unit-variance input that makes this softmax-in-FP32 economical). You never wrote it; the safe list did.
Positional Encodings — A Quiet FP32 Holdout
Sinusoidal positional encodings are computed once at startup and added to the token embeddings as and . The problematic piece is : for (GPT-3's hidden size) and near the top, this exponent reaches . In FP16 that's close to the max representable value; in BF16 it is fine but the result rounds to ~1-bit precision. Production code computes in FP32 and then casts to the activation dtype, so the precision loss only happens once, at the end.
Rotary (RoPE) and ALiBi encodings have the same discipline: compute the rotation/bias in FP32 and apply it to BF16 activations. The common pattern is "construct in the widest precision, apply in the narrowest".
KV-Cache — Inference Memory is a Precision Problem
Autoregressive inference caches the keys and values from every past token so each new token only needs one new attention step. For a model with layers and heads of size , the KV-cache for a sequence of length stores . For LLaMA-70B with a 4K context that is roughly 80 GB in FP16 — larger than the weights themselves. Serving stacks therefore work very hard on KV-cache precision:
- FP16/BF16 is the starting point. Halves the cache compared to FP32.
- FP8 KV-cache is supported by TensorRT-LLM and vLLM on H100. Stores K and V in E4M3 or E5M2 with per-tensor scales. Cuts memory and bandwidth in half again, typically with < 0.5 point loss on downstream benchmarks.
- INT8 KV-cache (vLLM, SmoothQuant) goes further using asymmetric per-channel quantization — more complex but a common choice for cost-constrained serving.
The tradeoff is identical to training's: narrower storage means more tokens per GPU second, at the cost of small numerical drift that teams measure with perplexity and task-level evals.
FP8 and the Scaling Frontier
Hopper (H100, 2022) introduced FP8 tensor cores and two formats: (4 exponent, 3 mantissa, max ≈ 448) for forward activations, and (5 exponent, 2 mantissa, max ≈ 57344) for gradients, where range matters most. NVIDIA's Transformer Engine library automates the casts and manages per-tensor scale factors the way GradScaler manages a single global scale.
With only 3 mantissa bits in E4M3 the ULP at 1.0 is 0.125 — an enormous precision hit. FP8 therefore demands per-tensor scaling: each tensor gets its own floating-point multiplier, chosen so that the tensor's actual dynamic range lands in the center of FP8's band. The scales themselves are FP32 and computed from recent tensor statistics.
For example, if an activation tensor's observed amax (max absolute value) is 6.4 and we are targeting FP8 E4M3 (max representable ≈ 448), the per-tensor scale factor is . We store in FP8 and remember (as FP32) alongside the tensor; at read-time we divide by . NVIDIA's Transformer Engine updates from the observed amax on every forward pass — an online EMA very similar to BatchNorm's running statistics, but applied per-tensor per-step. The cost is one extra FP32 scale per tensor (4 bytes) and one division per read — trivial compared to the memory and bandwidth saved by storing the bulk tensor in 1-byte FP8.
The upshot: every generation of the mixed-precision story is the same pattern at finer granularity. FP32 training used one precision everywhere. FP16 AMP introduced one global scale (the GradScaler). FP8 training introduces per-tensor scales. The logical endpoint — per-group scales like those used by MXFP8 and fine-grained INT4 quantization in inference — is already shipping.
Tradeoffs at a Glance
| Recipe | Speed vs FP32 | Memory vs FP32 | Numerical gotchas | When to pick it |
|---|---|---|---|---|
| FP32 baseline | 1× | 1× | None — numerics are trivial. | Research prototypes on small models, or when debugging a divergence. |
| FP16 AMP + GradScaler | ~2–2.5× | ~0.6–0.7× | Underflow if loss scale too low; overflow if too high. GradScaler tunes it. | Older hardware (V100, T4). Legacy codebases. |
| BF16 AMP | ~2–2.5× | ~0.6–0.7× | Coarser mantissa — slightly noisier convergence. No loss scaling needed. | Modern default for large-model training (A100, H100, TPU). |
| FP8 (E4M3/E5M2) | ~1.5–2× vs BF16 in practice (hardware peak ratio larger) | ~0.5× | Requires per-tensor scaling; narrow safety margin. | Very large models where memory and throughput dominate. |
| FP4 / INT4 (inference only) | ~8× vs FP16 on H100 | 0.25× | Group-wise quantization; accuracy loss on sensitive layers. | Serving frontier — weight quantization, not training. |
Summary
Mixed precision training is the reason a 7-billion-parameter model fits on one GPU and a 175-billion-parameter model fits on a single cluster. But it is not a magic switch — it is a careful division of labor between number formats.
- Compute narrow, store wide. FP16/BF16 on the tensor cores for speed, FP32 master weights and optimizer state for correctness.
- Use loss scaling when range is tight. FP16's 5-bit exponent can't cover typical gradient magnitudes — shift the histogram with . Skip loss scaling with BF16 because its 8-bit exponent already covers FP32's range.
- Keep softmax, norms, and losses in FP32. These ops overflow or lose precision in half formats. The autocast cast-list encodes the hard-won knowledge of which ops are safe where.
- Every modern transformer optimization is a precision story. Flash Attention accumulates softmax in FP32. Multi-head attention keeps its softmax in FP32. KV-cache quantization is precision-traded-for-memory. FP8/Transformer Engine replaces one global scale with per-tensor scales.
If FP32 was the era of "just use floats," the modern era is the era of "use the narrowest representation that survives the math of your op, and not a bit narrower." Mastering mixed precision is mastering exactly that decision on every tensor in your network.