The Real Problem: Why the Naive Port Diverges
The previous section built the case for FP8: half the memory bandwidth of BF16, twice the Tensor Core throughput on Hopper, and a clear path to training a frontier-scale model with proportionally less hardware. The arithmetic is irresistible. So why does the naive port — pick an FP8 format, cast every weight, activation, and gradient into it, re-run the training loop — reliably diverge inside a few hundred steps?
Three teams have published autopsies. NVIDIA's TransformerEngine paper (Micikevicius et al., 2022) shows the loss leaving the BF16 reference curve after step ~200 on GPT-3 scale runs. Meta's OPT-IML team observed the same. DeepSeek-V3's technical report devotes an entire section — the one we will unpack mathematically across the rest of this chapter — to the specific failure modes they had to fix before FP8 training of a 671B- parameter MoE became stable.
The pattern across all three reports is the same. FP8 training is not BF16 training with a smaller dtype. It is BF16 training with a number format that has fundamentally less dynamic range and fundamentally less precision per binade, and every kernel in the stack has to be re-engineered to compensate. The naive port silently violates four separate numerical assumptions BF16 kernels were designed around. Each violation is recoverable in isolation. Together, they make the loss diverge.
The thesis of this section: there are exactly four numerical failure modes that make naive FP8 training unstable. By the end of this section you should be able to name them, draw the distribution that triggers each one, and explain why each one needs a different mitigation. The next four sections of the chapter (3.3–3.6) are simply the mitigations.
Intuition: A Ruler That Is Too Short and Too Coarse
Imagine you have to measure every distance in your house with a ruler. A BF16 ruler has 7 mantissa bits of resolution — about 128 marks per power-of-two interval — and 8 exponent bits, so it can measure anything from sub-atomic to astronomical without changing rulers. An FP8 E4M3 ruler has 3 mantissa bits (8 marks per binade) and 4 exponent bits (a range of just 2⁻⁶ to 448). It is a ruler that is both shorter and coarser than BF16.
Two consequences fall out immediately. One: if the thing you want to measure is bigger than 448, your ruler clips. You either saturate (the value you store back is 448) or you get a NaN. In either case the gradient signal at that location is destroyed. Two: if the thing is smaller than 2⁻⁶ ≈ 0.016, your ruler underflows. You write back literal zero, and any subsequent multiply involving that location is also zero. The gradient is not just noisy — it has been deleted.
Crucially, both clipping and underflow are not zero-mean. They are systematic. Repeated across millions of elements every step, they accumulate into a directional bias in the gradient that no optimiser is designed to tolerate. AdamW has no immune response to “5% of my gradients are systematically pushing me the wrong way every step.” The loss spike that eventually shows up at step 200 is not a flake — it is the integral of three hundred thousand small biases.
The picture above makes the “short, coarse ruler” intuition concrete. Slide the orange marker. Notice that for x between roughly 2⁻⁵ and 2⁸, E4M3 has a representable tick somewhere within ~6% of any input. Move the marker below 2⁻⁶ and it falls into the grey underflow band. Move it above 448 and it saturates. E5M2 has the same problem on the precision axis (only 4 ticks per binade) but far more range. Both formats are essentially BF16 with two-thirds of the bits chopped off, and the chop is non-uniform.
The Mathematics of FP8 Quantisation
A floating-point number with exponent bits and mantissa bits represents a positive normal value as , where the biased exponent , the mantissa fraction for integer , and the bias . The all-ones exponent is reserved for NaN; the all-zeros exponent encodes denormals (which Hopper's FP8 path flushes to zero for throughput).
Three derived quantities matter for everything that follows:
- Range . Anything outside saturates or underflows.
- Worst-case relative quantisation error for inputs in range. With (E4M3) this is 1/16 = 6.25%. With (E5M2) it is 12.5%. BF16's gives 0.39%.
- Number of representable normal positive values . E4M3: 14 × 8 = 112. E5M2: 30 × 4 = 120. BF16: 254 × 128 = 32 512.
The cast with scale is the operation that does all of the actual work in an FP8 kernel. Given an input and a scale , we define:
where is the set of representable FP8 values plus a saturation point at . The choice of determines whether the input lands inside the dynamic range or outside it, and how much of the input's information survives the round-trip.
The standard recipe is : pick the scale that makes the largest-magnitude element land exactly at the top of the FP8 range. Two observations about this choice.
The quantisation error of the round-trip is bounded by for in range, and by itself for any that underflows. The first bound is the precision floor; the second bound is the dynamic-range cliff. Failure modes 1 and 2 live in the gap between “in range” and “below floor”.
Failure 1: Dynamic-Range Collapse
Activations and gradients in a deep transformer span 5–7 orders of magnitude. A typical LLaMA-2 70B forward pass has post-LayerNorm activations in , attention scores after softmax in , and MLP intermediate activations that include both. Gradients are even wider — the backward pass through softmax stretches the gradient distribution by several decades.
E4M3 covers a range of — 4.5 decades. E5M2 covers — 9 decades. BF16 covers ~78 decades. A single tensor of post-LayerNorm activations does not fit into E4M3's 4.5 decades. Some elements are above amax, some are below min-normal, and you can only pick a scale that satisfies one constraint at a time.
If you pick the scale to fit the largest elements, the smallest get crushed to zero (failure mode 1). If you pick it to preserve the smallest, the largest saturate to (which is the same magnitude of damage, applied to the values the loss is most sensitive to). The only escape is to stop using one scale per tensor.
Failure 2: One Outlier Kills the Tensor
Even when the bulk of a tensor lives comfortably inside FP8's 4.5 decades, a single large element can poison the whole cast. This is the “outlier feature” phenomenon documented by Dettmers et al.'s work and reproduced for every trained transformer above ~6.7B parameters: a few feature channels develop magnitudes 100–1000× larger than the bulk. They are not bugs — the model uses them for attention sinks, residual-stream gain control, and other legitimate functions. They are simply incompatible with per-tensor scaling.
Slide the outlier in the visualisation below. With per-tensor scaling, growing one element to 100–1000 immediately drives the scale up to . The bulk values, divided by this , fall below E4M3's minimum normal of 2⁻⁶ and round to zero. You can watch the small values collapse to red dots one by one as the slider moves. Then switch the mode to per-tile and observe that the outlier is now confined to a single 8-wide block; the other three blocks are unaffected.
The relative L2 error metric on the right makes the difference quantitative. With per-tensor scaling and an outlier of magnitude 1000, the non-outlier values incur something like 80–99% relative L2 error. With per-tile scaling the same input has ~10% error — an order of magnitude better, with no extra hardware cost (the per-tile amax is a free side-effect of the matmul epilogue).
There is a subtle point hiding in the “non-outlier” framing. Per-tile scaling does not fix the outlier's tile — the block containing the outlier still has its small siblings crushed. What it does is contain the damage. In a (1, 128) tile, at most 127 small values get crushed per outlier; in a per-tensor scheme, the entire tensor of millions of values is at risk. The damage scales like the tile size, not like the tensor size.
Failure 3: Forward and Backward Want Different Formats
Two FP8 variants exist on Hopper, and the choice between them is not cosmetic. E4M3 has 4 exponent bits and 3 mantissa bits — less range, more precision. E5M2 has 5/2 — more range, less precision. The two formats encode the same number of values; they trade one resource for the other.
Forward activations and weights are typically well-behaved in a trained transformer: outliers exist, but the bulk distribution is narrow. Precision matters more than range — you want to resolve the difference between “attention sink at 0.5” and “attention sink at 0.55”. E4M3 is the right format for these tensors.
Backward gradients are not well-behaved. The chain rule passes through softmax (which compresses the gradient distribution into [0, 1]), through LayerNorm (which can amplify it by a factor proportional to the layer width), and through GeLU/SiLU (which is smooth but heavy-tailed under the gradient). The result is a gradient distribution with much wider dynamic range than the activations — you can have and in the same tensor. Range matters more than precision; E5M2 is the right format.
| Quantity | Format | Range | Per-binade ticks | Reason |
|---|---|---|---|---|
| Forward activations | E4M3 | 0.016 … 448 | 8 | Narrow distribution, precision-bound |
| Weights | E4M3 | 0.016 … 448 | 8 | Same as activations |
| Backward gradients | E5M2 | 6.1e-5 … 57 344 | 4 | Heavy tails, range-bound |
| Optimizer state | FP32 | — | — | AdamW second moment is variance; precision matters |
| Master weights | FP32 | — | — | Accumulated updates need full precision |
A naive FP8 implementation that uses one format for everything will either lose precision on forward (using E5M2 for both) or lose range on backward (using E4M3 for both). Both choices destabilise training within a few hundred steps. DeepSeek-V3 and NVIDIA TransformerEngine both use E4M3 for forward and E5M2 for backward by default.
Failure 4: The Accumulator Drifts
Even if you fix dynamic range, outliers, and format asymmetry, one bug remains. A matrix multiply is a reduction: each element of the output is a dot product of length (the shared inner dimension). For a 7B-class model ; for the 70B+ class it can be 16 384 or more. That reduction has to happen somewhere — and Hopper's WGMMA Tensor Core instruction does it inside a partial-sum register that holds only bits of mantissa before promotion.
Each addition of an FP8×FP8 product into that 14-bit register introduces a small rounding error. If the errors were unbiased, they would cancel and the cumulative error would grow like — tolerable. But they are biased: round-to-nearest in a non-symmetric grid systematically favours certain magnitudes, and the bias compounds. The cumulative error then grows like — catastrophic at K = 4096.
The red curve in the chart above is the running sum of an FP8 dot product done entirely inside a 14-bit accumulator. The green curve is the FP64 reference. By the red curve is several percent off. The dashed blue curve is DeepSeek- V3's mitigation: empty the 14-bit accumulator into an FP32 register every terms. The blue curve tracks the green curve to within a fraction of a percent for the full reduction.
The cost of the promotion is roughly FP32 adds per output element — about 0.8% of the matmul's FLOPs for M = 128 and K = 16 384. The accuracy gain is two orders of magnitude in worst-case accumulation error. It is one of the cleanest cost/benefit trades in the FP8 training stack.
Manual Numerical Walkthrough
Let us cast a single small tensor by hand and watch each failure mode in action. Five values, one outlier, both per-tensor and per-tile scaling, all numbers to four decimals.
Click to expand: quantising five values to E4M3 with two scaling strategies
Step 1 — the tensor. Five values, one outlier at position 2:
i 0 1 2 3 4 x_i 0.40 -0.10 220.00 0.05 -0.30
Step 2 — per-tensor scale. so . Divide each element by :
i 0 1 2 3 4 x_i / s 0.815 -0.204 448.0 0.102 -0.611
Step 3 — round to the E4M3 grid. The E4M3 grid near these magnitudes contains the values . Each scaled value rounds to its nearest neighbour:
i 0 1 2 3 4 x_i / s 0.815 -0.204 448.0 0.102 -0.611 Q(...) 0.8125 -0.2031 448.0 0.1016 -0.6250 Q(x)*s 0.3990 -0.0997 220.0 0.0499 -0.3070 abs error 0.0010 0.0003 0.0 0.0001 0.0070
Step 4 — check for underflow. The smallest scaled magnitude is 0.102, which is well above the minimum normal 2⁻⁶ ≈ 0.0156, so none underflow. Good. But notice that incurred 2% relative error — the precision loss is small but not zero, and on a tensor of millions of elements those 2% errors do not cancel.
Step 5 — now make the outlier bigger. Replace 220 with 4400. The new scale is . Each non-outlier element divided by :
i 0 1 2 3 4 x_i / s 0.0407 -0.0102 448.0 0.0051 -0.0305
Step 6 — underflow strikes. is below the E4M3 minimum normal 2⁻⁶ ≈ 0.0156, so . The reconstructed value is . Element 3 has been deleted. No gradient through position 3, no learning signal, no recovery. And at position 1 is just barely above the threshold, where the next-nearest E4M3 value is 0.0176 — a relative quantisation error of 72%.
Step 7 — switch to per-tile. Tile of size 3: and . The first tile has and scale 9.821 — the outlier's tile is just as bad as before. But the second tile has and scale . Divide the second tile by its scale:
i 3 4 x_i / s 74.65 -448.0 Q(...) 72.0 -448.0 Q(x) * s 0.0482 -0.30 abs error 0.0018 0.0
Step 8 — tally. Per-tensor with outlier 4400: one element deleted, one element 72% wrong, two elements small. Per-tile: zero elements deleted across the non-outlier tile, max relative error around 4%. Same data, same hardware, same FP8 grid — only the scaling strategy is different. Per-tile costs one extra scale per 128 values, roughly 0.05% of the activation memory. The accuracy difference is several orders of magnitude.
Step 9 — the lesson generalises. Two of the four failure modes (range collapse and outliers) are fixed by per-tile scaling alone. The remaining two (format asymmetry and accumulator drift) need separate mitigations, but they do not interact with the scaling choice. Section 3.3 walks through the per-tile scaling recipe end-to-end; section 3.4 derives the FP32 accumulator promotion schedule.
Plain Python: Simulating the Four Failure Modes
The following NumPy script implements the toy E4M3 quantiser, the per-tensor vs per-tile cast comparison, and the 14-bit accumulator drift simulation. It runs in under a second on a laptop and produces numbers you can compare against the visualisations above.
A representative run prints something like:
per-tensor: scale=0.5580, crushed=3/128, rel_err=0.043 per-tile: crushed=0/128, rel_err=0.011 exact = 18.4172 naive = 18.9143 (2.70% off) promoted= 18.4286 (0.06% off)
Three numbers to internalise. The per-tile cast has roughly 4× better L2 error than per-tensor for this input — and the ratio grows fast with the outlier magnitude. The 14-bit accumulator drifts by 2.7% over a 1024-term reduction; the FP32-promoted accumulator stays inside 0.1%. These ratios are exactly what DeepSeek-V3 reports in its training-stability ablations.
PyTorch: What torch.float8 Actually Does
PyTorch core exposes E4M3 and E5M2 as storage dtypes, and torch._scaled_mm is the lowest-level FP8 matmul entry point. The full per-tile cast lives in higher-level libraries (TransformerEngine, torchao's float8 module) but the building blocks are part of the standard distribution. The next listing implements a per-row, per-128-tile activation cast in 30 lines of torch.
Two things this code is not. It is not a complete training loop — we have only shown the forward cast; the backward path needs E5M2 with its own scaling. And it is not a fused kernel: the amax reduction, divide, and cast are three separate CUDA launches here, whereas TransformerEngine fuses all three into one. For understanding what the cast does numerically, the unfused version is clearer; for running it in production, you want the fused one.
At Massive Scale: Loss Spikes Within a Few Hundred Steps
Each of the four failure modes is small in isolation. A single tile with one underflowed element loses one gradient signal — the optimiser absorbs it. A single dot product accumulator drifting by 2% is, on average, a 2% smaller update step. Why does naive FP8 actually diverge?
The answer is that the failures compound across three axes. Across the depth axis — every layer's forward and backward casts re-quantise. Across the width axis — every matmul of size drifts. Across the time axis — every step adds its bias to the optimiser's first and second moment estimates. For DeepSeek-V3 (61 layers, hidden 7168, 16k step pretraining), the naive FP8 path accumulates roughly biased matmul reductions before the first checkpoint — and AdamW's second-moment estimate, which divides into the update, cannot tolerate that much bias.
Empirically, the loss curve of a naive FP8 run looks like a healthy BF16 curve for ~150–300 steps, then begins to diverge super-linearly. Restarting from a recent checkpoint does not help — the bias is structural, not stochastic. Lowering the learning rate delays divergence but does not prevent it. The fix has to address the four numerical failures directly, which is what the rest of this chapter is about.
Engineering Reality and Gotchas
A handful of practical issues come up in every FP8 implementation regardless of which mitigation stack you use. None of them are deep, but missing one wastes a week of GPU time:
1. amax computation is the cast's critical path
The per-tile amax is a reduction over the tile axis. For a (1, 128) tile of an activation tensor with 65k tokens, that is 65k reductions of length 128. Hopper Tensor Cores cannot help with this — it is a non-matmul operation. The right place to put the amax computation is inside the previous matmul's epilogue, where the output is still in registers. TransformerEngine and DeepSeek's kernels both do this; a naive implementation that does the amax in a separate kernel pays a 5–15% throughput tax.
2. NaN propagation kills more runs than overflow
E4M3's saturation behaviour (clip to ±448) is gentle. E5M2 does not saturate — it has IEEE-style Inf and NaN. A single gradient that overflows E5M2 produces a NaN, which propagates through every subsequent matmul touching that element, and through AdamW's state. The standard guard is a per-step gradient norm clip before the FP8 cast; an additional NaN-detection layer that rolls back to the previous optimizer state if any gradient is non-finite is cheap insurance.
3. Scale calibration is not stationary
The amax of a tensor changes across training. Early in training, outliers are mild — the “outlier features” phenomenon emerges around step 1k for 7B models, step 5k for 70B+. A scale calibrated on the first 100 steps is wrong by step 10k. The standard solution is “delayed scaling”: the scale used at step is computed from the amax history of steps for some window , plus a small safety margin. This is a hyperparameter you have to tune.
4. Communication still happens in BF16
FP8 gives you 2× compute and 2× bandwidth on the GPU. But cross-GPU collectives (all-reduce in DP, all-gather in TP) typically stay in BF16 or FP32, because the cumulative numerical error of an FP8 all-reduce across 256+ ranks is much worse than the intra-GPU error this chapter has been worrying about. DeepSeek-V3 explicitly keeps NCCL traffic in BF16 even though the on-GPU compute is FP8.
5. Determinism is harder
BF16 training can be made bit-exact deterministic with care (fixed random seeds, deterministic kernels, no atomic accumulation). FP8 training generally cannot, because the amax history depends on scheduling, the partial-sum order matters for the 14-bit accumulator, and the cast kernels themselves frequently use atomic operations. Debugging an FP8 run that diverges differently across replays is a serious time sink.