Chapter 10
20 min read
Section 55 of 117

Why Naive FP8 Training Fails

FP8 Mixed-Precision Training

The Real Problem: Why the Naive Port Diverges

The previous section built the case for FP8: half the memory bandwidth of BF16, twice the Tensor Core throughput on Hopper, and a clear path to training a frontier-scale model with proportionally less hardware. The arithmetic is irresistible. So why does the naive port — pick an FP8 format, cast every weight, activation, and gradient into it, re-run the training loop — reliably diverge inside a few hundred steps?

Three teams have published autopsies. NVIDIA's TransformerEngine paper (Micikevicius et al., 2022) shows the loss leaving the BF16 reference curve after step ~200 on GPT-3 scale runs. Meta's OPT-IML team observed the same. DeepSeek-V3's technical report devotes an entire section — the one we will unpack mathematically across the rest of this chapter — to the specific failure modes they had to fix before FP8 training of a 671B- parameter MoE became stable.

The pattern across all three reports is the same. FP8 training is not BF16 training with a smaller dtype. It is BF16 training with a number format that has fundamentally less dynamic range and fundamentally less precision per binade, and every kernel in the stack has to be re-engineered to compensate. The naive port silently violates four separate numerical assumptions BF16 kernels were designed around. Each violation is recoverable in isolation. Together, they make the loss diverge.

The thesis of this section: there are exactly four numerical failure modes that make naive FP8 training unstable. By the end of this section you should be able to name them, draw the distribution that triggers each one, and explain why each one needs a different mitigation. The next four sections of the chapter (3.3–3.6) are simply the mitigations.

Intuition: A Ruler That Is Too Short and Too Coarse

Imagine you have to measure every distance in your house with a ruler. A BF16 ruler has 7 mantissa bits of resolution — about 128 marks per power-of-two interval — and 8 exponent bits, so it can measure anything from sub-atomic to astronomical without changing rulers. An FP8 E4M3 ruler has 3 mantissa bits (8 marks per binade) and 4 exponent bits (a range of just 2 to 448). It is a ruler that is both shorter and coarser than BF16.

Two consequences fall out immediately. One: if the thing you want to measure is bigger than 448, your ruler clips. You either saturate (the value you store back is 448) or you get a NaN. In either case the gradient signal at that location is destroyed. Two: if the thing is smaller than 2 ≈ 0.016, your ruler underflows. You write back literal zero, and any subsequent multiply involving that location is also zero. The gradient is not just noisy — it has been deleted.

Crucially, both clipping and underflow are not zero-mean. They are systematic. Repeated across millions of elements every step, they accumulate into a directional bias in the gradient that no optimiser is designed to tolerate. AdamW has no immune response to “5% of my gradients are systematically pushing me the wrong way every step.” The loss spike that eventually shows up at step 200 is not a flake — it is the integral of three hundred thousand small biases.

Loading dynamic-range chart…

The picture above makes the “short, coarse ruler” intuition concrete. Slide the orange marker. Notice that for x between roughly 2 and 2, E4M3 has a representable tick somewhere within ~6% of any input. Move the marker below 2 and it falls into the grey underflow band. Move it above 448 and it saturates. E5M2 has the same problem on the precision axis (only 4 ticks per binade) but far more range. Both formats are essentially BF16 with two-thirds of the bits chopped off, and the chop is non-uniform.

The Mathematics of FP8 Quantisation

A floating-point number with ee exponent bits and mm mantissa bits represents a positive normal value as x=(1+f)2Ebx = (1 + f) \cdot 2^{E - b}, where the biased exponent E{1,,2e2}E \in \{1, \ldots, 2^e - 2\}, the mantissa fraction f=k/2mf = k / 2^m for integer k{0,,2m1}k \in \{0, \ldots, 2^m - 1\}, and the bias b=2e11b = 2^{e - 1} - 1. The all-ones exponent 2e12^e - 1 is reserved for NaN; the all-zeros exponent encodes denormals (which Hopper's FP8 path flushes to zero for throughput).

Three derived quantities matter for everything that follows:

  1. Range [xmin,xmax]=[21b,(22m)22e2b][x_{\min}, x_{\max}] = [2^{1 - b}, (2 - 2^{-m}) \cdot 2^{2^e - 2 - b}]. Anything outside saturates or underflows.
  2. Worst-case relative quantisation error ϵrel=2(m+1)\epsilon_{\text{rel}} = 2^{-(m + 1)} for inputs in range. With m=3m = 3 (E4M3) this is 1/16 = 6.25%. With m=2m = 2 (E5M2) it is 12.5%. BF16's m=7m = 7 gives 0.39%.
  3. Number of representable normal positive values (2e2)2m(2^e - 2) \cdot 2^m. E4M3: 14 × 8 = 112. E5M2: 30 × 4 = 120. BF16: 254 × 128 = 32 512.

The cast with scale is the operation that does all of the actual work in an FP8 kernel. Given an input xx and a scale s>0s > 0, we define:

x^=sQ ⁣(xs),Q(y)=argmingGyg\hat x = s \cdot Q\!\left( \frac{x}{s} \right), \quad Q(y) = \arg\min_{g \in \mathcal{G}} |y - g|

where G\mathcal{G} is the set of representable FP8 values plus a saturation point at ±xmax\pm x_{\max}. The choice of ss determines whether the input lands inside the dynamic range or outside it, and how much of the input's information survives the round-trip.

The standard recipe is s=amax(x)/xmaxs = \text{amax}(x) / x_{\max}: pick the scale that makes the largest-magnitude element land exactly at the top of the FP8 range. Two observations about this choice.

The amax recipe is optimal only when the magnitude distribution of xx is narrow. If amax is dominated by a handful of outliers, the scale chosen to fit them squashes the bulk of the distribution into the lowest few ticks of the grid — or below the minimum normal entirely. This is failure mode 2.

The quantisation error of the round-trip is bounded by x^ixiϵrelxi|\hat x_i - x_i| \leq \epsilon_{\text{rel}} \cdot |x_i|for xix_i in range, and by xi|x_i| itself for any xix_i that underflows. The first bound is the precision floor; the second bound is the dynamic-range cliff. Failure modes 1 and 2 live in the gap between “in range” and “below floor”.

Failure 1: Dynamic-Range Collapse

Activations and gradients in a deep transformer span 5–7 orders of magnitude. A typical LLaMA-2 70B forward pass has post-LayerNorm activations in [103,102][10^{-3}, 10^{2}], attention scores after softmax in [106,100][10^{-6}, 10^{0}], and MLP intermediate activations that include both. Gradients are even wider — the backward pass through softmax stretches the gradient distribution by several decades.

E4M3 covers a range of 448/2628672104.5448 / 2^{-6} \approx 28\,672 \approx 10^{4.5} — 4.5 decades. E5M2 covers 57344/2149.410857\,344 / 2^{-14} \approx 9.4 \cdot 10^{8} — 9 decades. BF16 covers ~78 decades. A single tensor of post-LayerNorm activations does not fit into E4M3's 4.5 decades. Some elements are above amax, some are below min-normal, and you can only pick a scale that satisfies one constraint at a time.

If you pick the scale to fit the largest elements, the smallest get crushed to zero (failure mode 1). If you pick it to preserve the smallest, the largest saturate to ±448\pm 448 (which is the same magnitude of damage, applied to the values the loss is most sensitive to). The only escape is to stop using one scale per tensor.

Mitigation preview: per-tile scaling (Section 3.3). Split the tensor into small blocks, pick a scale for each block, and each block only has to span its own much narrower local range. We prove this works numerically in the next visualiser.

Failure 2: One Outlier Kills the Tensor

Even when the bulk of a tensor lives comfortably inside FP8's 4.5 decades, a single large element can poison the whole cast. This is the “outlier feature” phenomenon documented by Dettmers et al.'s LLM.int8()\text{LLM.int8}() work and reproduced for every trained transformer above ~6.7B parameters: a few feature channels develop magnitudes 100–1000× larger than the bulk. They are not bugs — the model uses them for attention sinks, residual-stream gain control, and other legitimate functions. They are simply incompatible with per-tensor scaling.

Slide the outlier in the visualisation below. With per-tensor scaling, growing one element to 100–1000 immediately drives the scale up to s=outlier/448s = \text{outlier} / 448. The bulk values, divided by this ss, fall below E4M3's minimum normal of 2 and round to zero. You can watch the small values collapse to red dots one by one as the slider moves. Then switch the mode to per-tile and observe that the outlier is now confined to a single 8-wide block; the other three blocks are unaffected.

Loading outlier visualizer…

The relative L2 error metric on the right makes the difference quantitative. With per-tensor scaling and an outlier of magnitude 1000, the non-outlier values incur something like 80–99% relative L2 error. With per-tile scaling the same input has ~10% error — an order of magnitude better, with no extra hardware cost (the per-tile amax is a free side-effect of the matmul epilogue).

There is a subtle point hiding in the “non-outlier” framing. Per-tile scaling does not fix the outlier's tile — the block containing the outlier still has its small siblings crushed. What it does is contain the damage. In a (1, 128) tile, at most 127 small values get crushed per outlier; in a per-tensor scheme, the entire tensor of millions of values is at risk. The damage scales like the tile size, not like the tensor size.

Failure 3: Forward and Backward Want Different Formats

Two FP8 variants exist on Hopper, and the choice between them is not cosmetic. E4M3 has 4 exponent bits and 3 mantissa bits — less range, more precision. E5M2 has 5/2 — more range, less precision. The two formats encode the same number of values; they trade one resource for the other.

Forward activations and weights are typically well-behaved in a trained transformer: outliers exist, but the bulk distribution is narrow. Precision matters more than range — you want to resolve the difference between “attention sink at 0.5” and “attention sink at 0.55”. E4M3 is the right format for these tensors.

Backward gradients are not well-behaved. The chain rule passes through softmax (which compresses the gradient distribution into [0, 1]), through LayerNorm (which can amplify it by a factor proportional to the layer width), and through GeLU/SiLU (which is smooth but heavy-tailed under the gradient). The result is a gradient distribution with much wider dynamic range than the activations — you can have 10710^{-7} and 10110^{1} in the same tensor. Range matters more than precision; E5M2 is the right format.

QuantityFormatRangePer-binade ticksReason
Forward activationsE4M30.016 … 4488Narrow distribution, precision-bound
WeightsE4M30.016 … 4488Same as activations
Backward gradientsE5M26.1e-5 … 57 3444Heavy tails, range-bound
Optimizer stateFP32AdamW second moment is variance; precision matters
Master weightsFP32Accumulated updates need full precision

A naive FP8 implementation that uses one format for everything will either lose precision on forward (using E5M2 for both) or lose range on backward (using E4M3 for both). Both choices destabilise training within a few hundred steps. DeepSeek-V3 and NVIDIA TransformerEngine both use E4M3 for forward and E5M2 for backward by default.

Failure 4: The Accumulator Drifts

Even if you fix dynamic range, outliers, and format asymmetry, one bug remains. A matrix multiply is a reduction: each element of the output is a dot product of length KK (the shared inner dimension). For a 7B-class model K4096K \sim 4096; for the 70B+ class it can be 16 384 or more. That reduction has to happen somewhere — and Hopper's WGMMA Tensor Core instruction does it inside a partial-sum register that holds only 14\sim 14 bits of mantissa before promotion.

Each addition of an FP8×FP8 product into that 14-bit register introduces a small rounding error. If the errors were unbiased, they would cancel and the cumulative error would grow like O(K)O(\sqrt{K}) — tolerable. But they are biased: round-to-nearest in a non-symmetric grid systematically favours certain magnitudes, and the bias compounds. The cumulative error then grows like O(K)O(K) — catastrophic at K = 4096.

Loading accumulation chart…

The red curve in the chart above is the running sum of an FP8 dot product done entirely inside a 14-bit accumulator. The green curve is the FP64 reference. By K=1024K = 1024 the red curve is several percent off. The dashed blue curve is DeepSeek- V3's mitigation: empty the 14-bit accumulator into an FP32 register every M=128M = 128 terms. The blue curve tracks the green curve to within a fraction of a percent for the full reduction.

The cost of the promotion is roughly K/MK / M FP32 adds per output element — about 0.8% of the matmul's FLOPs for M = 128 and K = 16 384. The accuracy gain is two orders of magnitude in worst-case accumulation error. It is one of the cleanest cost/benefit trades in the FP8 training stack.

Manual Numerical Walkthrough

Let us cast a single small tensor by hand and watch each failure mode in action. Five values, one outlier, both per-tensor and per-tile scaling, all numbers to four decimals.

Click to expand: quantising five values to E4M3 with two scaling strategies

Step 1 — the tensor. Five values, one outlier at position 2:

i      0       1       2        3       4
x_i    0.40   -0.10   220.00   0.05   -0.30

Step 2 — per-tensor scale. amax(x)=220\text{amax}(x) = 220 so s=220/448=0.4911s = 220 / 448 = 0.4911. Divide each element by ss:

i           0       1       2       3       4
x_i / s     0.815  -0.204   448.0   0.102  -0.611

Step 3 — round to the E4M3 grid. The E4M3 grid near these magnitudes contains the values {0.0156,0.0176,,0.0625,0.0703,0.0781,0.0859,,0.75,0.875,1.0,1.125,}\{0.0156, 0.0176, \ldots, 0.0625, 0.0703, 0.0781, 0.0859, \ldots, 0.75, 0.875, 1.0, 1.125, \ldots\}. Each scaled value rounds to its nearest neighbour:

i           0          1          2        3          4
x_i / s     0.815     -0.204      448.0    0.102     -0.611
Q(...)      0.8125    -0.2031     448.0    0.1016    -0.6250
Q(x)*s      0.3990    -0.0997     220.0    0.0499    -0.3070
abs error   0.0010     0.0003     0.0      0.0001    0.0070

Step 4 — check for underflow. The smallest scaled magnitude is 0.102, which is well above the minimum normal 2 ≈ 0.0156, so none underflow. Good. But notice that x4=0.30x_4 = -0.30 incurred 2% relative error — the precision loss is small but not zero, and on a tensor of millions of elements those 2% errors do not cancel.

Step 5 — now make the outlier bigger. Replace 220 with 4400. The new scale is s=4400/448=9.821s = 4400 / 448 = 9.821. Each non-outlier element divided by ss:

i           0        1        2       3        4
x_i / s     0.0407  -0.0102   448.0   0.0051  -0.0305

Step 6 — underflow strikes. 0.0051|0.0051| is below the E4M3 minimum normal 2 ≈ 0.0156, so Q(0.0051)=0Q(0.0051) = 0. The reconstructed value is 0s=00 \cdot s = 0. Element 3 has been deleted. No gradient through position 3, no learning signal, no recovery. And 0.0102|0.0102| at position 1 is just barely above the threshold, where the next-nearest E4M3 value is 0.0176 — a relative quantisation error of 72%.

Step 7 — switch to per-tile. Tile of size 3: {x0,x1,x2}\{x_0, x_1, x_2\} and {x3,x4}\{x_3, x_4\}. The first tile has amax=4400\text{amax} = 4400 and scale 9.821 — the outlier's tile is just as bad as before. But the second tile has amax=0.30\text{amax} = 0.30 and scale s=0.30/448=6.7104s = 0.30 / 448 = 6.7 \cdot 10^{-4}. Divide the second tile by its scale:

i           3       4
x_i / s     74.65  -448.0
Q(...)      72.0   -448.0
Q(x) * s    0.0482 -0.30
abs error   0.0018  0.0

Step 8 — tally. Per-tensor with outlier 4400: one element deleted, one element 72% wrong, two elements small. Per-tile: zero elements deleted across the non-outlier tile, max relative error around 4%. Same data, same hardware, same FP8 grid — only the scaling strategy is different. Per-tile costs one extra scale per 128 values, roughly 0.05% of the activation memory. The accuracy difference is several orders of magnitude.

Step 9 — the lesson generalises. Two of the four failure modes (range collapse and outliers) are fixed by per-tile scaling alone. The remaining two (format asymmetry and accumulator drift) need separate mitigations, but they do not interact with the scaling choice. Section 3.3 walks through the per-tile scaling recipe end-to-end; section 3.4 derives the FP32 accumulator promotion schedule.

Plain Python: Simulating the Four Failure Modes

The following NumPy script implements the toy E4M3 quantiser, the per-tensor vs per-tile cast comparison, and the 14-bit accumulator drift simulation. It runs in under a second on a laptop and produces numbers you can compare against the visualisations above.

Simulating FP8 quantisation, outlier collapse, and accumulator drift
🐍python
4E4M3 numeric envelope

E4M3 has 4 exponent bits (bias 7) and 3 mantissa bits, so positive normals live in [2^-6, 2^8 * (1 + 7/8)] = [0.0156, 448]. We hardcode the two endpoints because every E4M3 cast must clamp against amax (saturation) and round-to-zero against the minimum normal (underflow).

EXECUTION STATE
E4M3_MAX = 448.0
E4M3_MIN_NORMAL = 2^-6 ≈ 0.0156
7Enumerate the representable values

There are exactly 14 × 8 = 112 positive normal E4M3 values (one exponent reserved for NaN). We pre-compute them into a sorted array so the quantiser is a single binary search. In real hardware the rounding is done by truncating mantissa bits; the array makes the algorithm easier to read but the result is identical.

EXECUTION STATE
len(E4M3_GRID) = 112
E4M3_GRID[:5] = [0.0156, 0.0176, 0.0195, 0.0215, 0.0234]
E4M3_GRID[-3:] = [384.0, 416.0, 448.0]
17The quantise step, vectorised

to_e4m3 is the workhorse. The sign is peeled off, the magnitude is clamped to amax (overflow → saturate, not NaN) and floored at min-normal (underflow → zero). searchsorted finds where the magnitude would slot into the sorted grid; we then pick whichever of the two neighbours is closer. This is round-to-nearest-even in spirit; real hardware uses banker's rounding to break ties without bias.

EXECUTION STATE
input x = [0.5, 250.0, 0.008, -0.31]
after clip+underflow = [0.5, 250.0, 0.0, 0.31]
to_e4m3 output = [0.5, 256.0, 0.0, -0.3125]
32The full cast: scale, quantise, unscale

Real FP8 training never stores raw tensors — it stores (q, s) pairs where q is FP8 and s is a small FP32 scaling factor. cast_with_scale models the full round-trip: divide by s, quantise to E4M3, multiply back. The size of s controls where the input lands inside the E4M3 grid; choosing s well is the entire game.

38Build the pathological tensor

We draw 128 N(0, 0.5) values — magnitudes roughly in [0, 1.5] — and inject one outlier of 250. This pattern (a handful of feature channels with magnitudes hundreds of times the bulk distribution) is empirically observed in trained transformers; see Dettmers et al. 2022 and Xiao et al. 2023.

EXECUTION STATE
x[63] = 250.0 (outlier)
np.abs(x[:63]).mean() = ~0.40
np.abs(x).max() = 250.0
41Per-tensor amax scaling chooses the outlier

The 'standard' recipe sets s = amax(x) / E4M3_MAX so the largest magnitude lands at the top of the FP8 range. With amax = 250 and E4M3_MAX = 448 we get s ≈ 0.558. That means every other value x_i must be divided by 0.558 before being rounded — but a typical value of 0.4 / 0.558 = 0.72 is still above min-normal, so the small values mostly survive. The damage shows up at the lower tail: a small value of 0.01 becomes 0.018, which rounds to either 0.0176 or 0.0195 — losing roughly 5% precision per element.

EXECUTION STATE
amax = 250.0
s_per_tensor = 250 / 448 ≈ 0.558
46Count the dead values and the L2 error

Two metrics: (1) how many non-outlier values rounded to literal zero, and (2) the relative L2 error across the whole tensor. For an outlier of 250 with 128 baseline values, you should see roughly 0–5 zeros and an L2 error of 1–4%. Now imagine the outlier is 1e4 instead — every small value gets divided by a scale of 22.3, lands at < 2^-6 after scaling, and rounds to zero. The tensor stops being a distribution and becomes a one-hot.

EXECUTION STATE
zeroed (this run) = 0-5 depending on draw
rel_err (this run) = 0.01-0.04
51Per-tile scaling: a separate scale every 16 values

Per-tile scaling treats the tensor as a stack of small blocks and gives each block its own s. The block containing the outlier still has s ≈ 0.558 and so still loses precision for its small siblings. But every other block sees s ≈ amax_block / 448 ≈ 0.003 — a scale 200× smaller. After scaling, a value of 0.4 becomes 0.4 / 0.003 ≈ 133, which is well inside the E4M3 grid; the small values are now resolved with mantissa-level precision. Only one block out of eight is poisoned by the outlier; the others stay healthy.

EXECUTION STATE
tile = 16
s_per_tile = [~0.003, 0.558, 0.003, 0.004, 0.003, 0.003, 0.003, 0.003]
62round_to_bits: simulate a fixed-width accumulator

A hardware accumulator has a fixed mantissa width; every addition is followed by a round-to-bits step. We model that by extracting the exponent (floor(log2(|v|))), computing the mantissa step count 2^bits, and rounding to the nearest representable mantissa. With bits=14 we are simulating Hopper's WGMMA partial sum; with bits=23 we are simulating an FP32 register.

71Three accumulators running in lock-step

exact stays in FP64 (effectively infinite precision). s_naive simulates what a real H100 does if you ask it to keep accumulating in FP8: every product is rounded to FP8 (round_to_bits(_, 3)), and the partial sum is rounded to 14 bits after every add. s_promo behaves identically until k+1 is divisible by 128 — then it dumps the partial sum into a high-precision FP32 register (round_to_bits(_, 23)) and resets the 14-bit accumulator to zero. That dump is DeepSeek-V3's high-precision accumulation trick.

EXECUTION STATE
K = 1024
promote every = 128 terms
partial sum mantissa width = 14 bits
FP32 register mantissa width = 23 bits
77The product is FP8 × FP8

Each multiplication takes two FP8 operands and produces a higher-precision product. Real Tensor Cores keep enough bits inside the multiplier to lose nothing here; the lossy step is the next line where we add the product into the 14-bit accumulator. We model the multiplier with one more round_to_bits(_, 3) for symmetry — it makes the simulator's behaviour line up with real Hopper kernels.

79The accumulation round is where precision dies

The biased-toward-up rounding of round_to_bits(_, 14) is small per step — about 1 part in 16 384 — but the errors do not cancel. The variance of the cumulative error grows like sqrt(K) for random rounding, and like K for biased rounding. After K = 1024 even unbiased rounding has accumulated an RMS error of ~32 ulps; with the biased rounding real hardware uses, the drift can be 1-3% of the magnitude.

82The promotion: empty the small accumulator into a big one

Every 128 terms, we take the current 14-bit partial sum, add it to a separate FP32 register, and reset the partial sum to zero. The 14-bit accumulator now starts each block from scratch, so it never has to represent a sum that has grown across thousands of terms. The FP32 register stores the long-running history at full precision. The cost: one extra add per 128 multiplies — well under 1% of the matmul's FLOPs.

87 lines without explanation
1import numpy as np
2
3# ----------------------------------------------------------------------
4# 1. A toy E4M3 quantiser. 4 exponent bits, 3 mantissa bits, max ~448.
5# ----------------------------------------------------------------------
6E4M3_MAX = 448.0
7E4M3_MIN_NORMAL = 2 ** -6  # ~ 0.0156
8
9# All positive normal E4M3 values, sorted ascending.
10def _enumerate_e4m3():
11    bias = 7
12    vals = []
13    for e in range(1, 15):        # reserve all-ones exponent for NaN
14        for m in range(8):        # 2^3 mantissa codes
15            vals.append(2 ** (e - bias) * (1 + m / 8))
16    return np.array(sorted(vals))
17
18E4M3_GRID = _enumerate_e4m3()
19
20def to_e4m3(x):
21    """Round x (any shape, any sign) to nearest E4M3, clamped at amax."""
22    sign = np.sign(x)
23    a = np.abs(x)
24    a = np.clip(a, 0.0, E4M3_MAX)
25    a = np.where(a < E4M3_MIN_NORMAL, 0.0, a)
26    # vectorised nearest-grid-point lookup
27    idx = np.searchsorted(E4M3_GRID, a)
28    idx = np.clip(idx, 1, len(E4M3_GRID) - 1)
29    lo = E4M3_GRID[idx - 1]
30    hi = E4M3_GRID[idx]
31    pick_hi = (hi - a) < (a - lo)
32    q = np.where(pick_hi, hi, lo)
33    return sign * q
34
35def cast_with_scale(x, scale):
36    """The standard FP8 cast: divide by scale, quantise, multiply back."""
37    return to_e4m3(x / scale) * scale
38
39# ----------------------------------------------------------------------
40# 2. Failure 1+2: one outlier + per-tensor amax scaling.
41# ----------------------------------------------------------------------
42rng = np.random.default_rng(0)
43x = rng.normal(0, 0.5, size=128).astype(np.float32)
44x[63] = 250.0                                # the outlier
45
46amax = np.abs(x).max()
47s_per_tensor = amax / E4M3_MAX
48x_hat_tensor = cast_with_scale(x, s_per_tensor)
49
50zeroed = int(np.sum((x_hat_tensor == 0) & (np.abs(x) > 1e-6)))
51err = np.linalg.norm(x_hat_tensor - x) / np.linalg.norm(x)
52print(f"per-tensor: scale={s_per_tensor:.4f}, "
53      f"crushed={zeroed}/128, rel_err={err:.3f}")
54
55# Per-tile scaling: 16-wide tiles, one scale each.
56tile = 16
57s_per_tile = np.array([np.abs(x[i:i+tile]).max() / E4M3_MAX
58                       for i in range(0, len(x), tile)])
59x_hat_tile = np.concatenate([
60    cast_with_scale(x[i:i+tile], s_per_tile[i // tile])
61    for i in range(0, len(x), tile)
62])
63zeroed_tile = int(np.sum((x_hat_tile == 0) & (np.abs(x) > 1e-6)))
64err_tile = np.linalg.norm(x_hat_tile - x) / np.linalg.norm(x)
65print(f"per-tile:   crushed={zeroed_tile}/128, rel_err={err_tile:.3f}")
66
67# ----------------------------------------------------------------------
68# 3. Failure 4: the accumulator. Round the partial sum to ~14 bits.
69# ----------------------------------------------------------------------
70def round_to_bits(v, bits):
71    if v == 0 or not np.isfinite(v):
72        return v
73    sign = np.sign(v)
74    av = abs(v)
75    e = np.floor(np.log2(av))
76    base = 2.0 ** e
77    steps = 2 ** bits
78    m = round((av / base - 1) * steps)
79    return float(sign * base * (1 + m / steps))
80
81K = 1024
82a = rng.normal(0, 1, K)
83b = rng.normal(0, 1, K)
84
85exact = float((a * b).sum())                 # FP64 reference
86s_naive = 0.0
87s_promo = 0.0
88s_promo_hi = 0.0
89for k in range(K):
90    p = round_to_bits(to_e4m3(a[k]) * to_e4m3(b[k]), 3)
91    s_naive = round_to_bits(s_naive + p, 14) # WGMMA partial sum
92    s_promo = round_to_bits(s_promo + p, 14)
93    if (k + 1) % 128 == 0:                   # DeepSeek fix
94        s_promo_hi = round_to_bits(s_promo_hi + s_promo, 23)
95        s_promo = 0.0
96
97print(f"exact   = {exact:.4f}")
98print(f"naive   = {s_naive:.4f}   ({100*abs(s_naive - exact)/abs(exact):.2f}% off)")
99print(f"promoted= {s_promo_hi + s_promo:.4f}   "
100      f"({100*abs(s_promo_hi + s_promo - exact)/abs(exact):.2f}% off)")

A representative run prints something like:

per-tensor: scale=0.5580, crushed=3/128, rel_err=0.043
per-tile:   crushed=0/128, rel_err=0.011
exact   = 18.4172
naive   = 18.9143   (2.70% off)
promoted= 18.4286   (0.06% off)

Three numbers to internalise. The per-tile cast has roughly 4× better L2 error than per-tensor for this input — and the ratio grows fast with the outlier magnitude. The 14-bit accumulator drifts by 2.7% over a 1024-term reduction; the FP32-promoted accumulator stays inside 0.1%. These ratios are exactly what DeepSeek-V3 reports in its training-stability ablations.

PyTorch: What torch.float8 Actually Does

PyTorch core exposes E4M3 and E5M2 as storage dtypes, and torch._scaled_mm is the lowest-level FP8 matmul entry point. The full per-tile cast lives in higher-level libraries (TransformerEngine, torchao's float8 module) but the building blocks are part of the standard distribution. The next listing implements a per-row, per-128-tile activation cast in 30 lines of torch.

Per-tile FP8 cast in PyTorch
🐍python
5float8_e4m3fn — the storage type

PyTorch exposes E4M3 as torch.float8_e4m3fn (the 'fn' suffix means finite, no-Inf). It is a storage dtype only — there is no torch.matmul kernel that takes two float8 tensors. The mainstream FP8 path is: store in FP8, dequantise into BF16/FP32 inside a fused kernel right before the matmul. TransformerEngine and torch._scaled_mm hide this dance.

EXECUTION STATE
torch.float8_e4m3fn.itemsize = 1 (byte)
torch.float8_e5m2.itemsize = 1 (byte) — used for grads
11Per-row, per-128-tile scaling shape

DeepSeek-V3 uses 1×128 tiles for activations and 128×128 tiles for weights. The activation tensor has shape (M, N) where N is the inner reduction dimension; we slice N into N/128 tiles per row. The scale tensor has shape (M, N/128) — small enough to live in FP32 without measurably increasing memory.

EXECUTION STATE
x.shape = (64, 1024)
tile = 128
scale.shape = (64, 8)
17Reshape exposes the tile axis

Reshaping (M, N) into (M, N//tile, tile) makes the per-tile amax a single reduce. PyTorch will not copy data here — the reshape is a view because we're just splitting a contiguous dim into two contiguous dims of the same total size.

EXECUTION STATE
x_tiles.shape = (64, 8, 128)
x_tiles.is_contiguous() = True
19amax in BF16 → scale in FP32

amax is computed in BF16 (the input dtype) for throughput. Dividing by 448 then promoting to FP32 keeps the scale at full precision — the scale itself does not need to be quantised. The clamp(min=1e-12) prevents divide-by-zero in the (very rare) case that an entire tile is exactly zero.

EXECUTION STATE
amax[0] = tensor([200.0, 0.84, 0.86, 0.82, 0.85, 0.83, 0.88, 0.81])
scale[0] = [0.446, 0.0019, 0.0019, 0.0018, ...]
24The cast is the only lossy step

x_scaled is FP32, mathematically in [-448, 448] (plus tiny excursions from BF16 rounding error in amax). Casting to float8_e4m3fn rounds each element to its nearest E4M3 representable and saturates anything > 448 to 448. The cast is implemented as a fused CUDA kernel in PyTorch core; there is no Python overhead per element.

31Construct the pathological tensor

Activations drawn from N(0, 0.3) are in the 'small' magnitude range that BF16 handles trivially. We jam one outlier of 200 into position (0, 511). In a real transformer this would be a feature channel that has learned to act as an attention sink or carry a magnitude signal; the magnitude is comparable to the outliers Dettmers reports for LLaMA-class models.

35Per-tile relative error: < 1%

For this input, the per-tile cast achieves a relative L2 error around 0.6%. The tile containing the outlier has scale 200/448 ≈ 0.446 and loses precision on its 127 small siblings; the seven other tiles see typical scales of 0.002 and represent their small values cleanly. The overall error is dominated by the one bad tile.

EXECUTION STATE
err (per-tile) = ~0.006
39The naive per-tensor cast for comparison

Per-tensor: one scale for the whole tensor. amax is the outlier's 200, so the scale is 0.446 for every element. The 8 191 small values now all get divided by 0.446 before being rounded — and most of them are an order of magnitude below the smallest E4M3 normal after that division, so they collapse to zero.

46Per-tensor relative error: typically 30-50× worse

The same x produces a per-tensor relative L2 error of roughly 0.2 — about 30-50× larger than per-tile. This is the headline reason naive FP8 fails. The ratio scales with the outlier magnitude and the tensor size; under realistic transformer activations it can blow out to >100× by mid-training.

EXECUTION STATE
err2 (per-tensor) = ~0.2
ratio = ~30-50x worse than per-tile
44 lines without explanation
1import torch
2from torch import nn
3
4# torch.float8_e4m3fn is the IEEE-style E4M3 with no-infinity, NaN-on-overflow
5# semantics (the variant Hopper uses for forward casts). E5M2 is also exposed.
6# Both are storage-only dtypes — you cannot run torch.matmul on them directly.
7
8# Per-tile cast that matches what DeepSpeed/TransformerEngine generate
9# under the hood. Block shape (1, 128) for activations.
10def cast_to_e4m3_per_tile(x_bf16: torch.Tensor, tile: int = 128) -> tuple:
11    """Cast (M, N) BF16 activations to FP8 E4M3 with per-row, per-128-tile scales.
12
13    Returns (q, scale) where q is float8_e4m3fn and scale is FP32 of shape
14    (M, N // tile). The reconstruction is q.to(torch.float32) * scale.repeat_interleave(tile, -1).
15    """
16    assert x_bf16.dtype == torch.bfloat16
17    M, N = x_bf16.shape
18    assert N % tile == 0, "Inner dim must be tile-multiple"
19
20    # Reshape to (M, N//tile, tile); take amax over the tile axis.
21    x_tiles = x_bf16.reshape(M, N // tile, tile)
22    amax = x_tiles.abs().amax(dim=-1)                # (M, N//tile)
23    E4M3_MAX = 448.0
24    scale = (amax.float() / E4M3_MAX).clamp(min=1e-12)
25
26    # Scaled inputs land in [-448, 448]. Cast & saturate.
27    x_scaled = x_tiles.float() / scale.unsqueeze(-1)
28    q = x_scaled.to(torch.float8_e4m3fn)             # rounds + saturates
29    q = q.reshape(M, N)
30    return q, scale
31
32# Round-trip and measure quality.
33torch.manual_seed(0)
34x = torch.randn(64, 1024, dtype=torch.bfloat16) * 0.3
35x[0, 511] = 200.0                                    # the outlier
36
37q, scale = cast_to_e4m3_per_tile(x, tile=128)
38x_hat = q.float() * scale.repeat_interleave(128, dim=-1)
39err = (x_hat - x.float()).norm() / x.float().norm()
40print(f"per-tile FP8 round-trip rel_err = {err.item():.4f}")
41
42# Compare against the naive per-tensor recipe.
43def cast_to_e4m3_per_tensor(x_bf16: torch.Tensor):
44    amax = x_bf16.float().abs().amax()
45    scale = (amax / 448.0).clamp(min=1e-12)
46    q = (x_bf16.float() / scale).to(torch.float8_e4m3fn)
47    return q, scale.reshape(1)
48
49q2, s2 = cast_to_e4m3_per_tensor(x)
50x_hat_2 = q2.float() * s2
51err2 = (x_hat_2 - x.float()).norm() / x.float().norm()
52print(f"per-tensor FP8 round-trip rel_err = {err2.item():.4f}")
53print(f"ratio (per-tensor / per-tile) = {err2.item() / err.item():.1f}x")

Two things this code is not. It is not a complete training loop — we have only shown the forward cast; the backward path needs E5M2 with its own scaling. And it is not a fused kernel: the amax reduction, divide, and cast are three separate CUDA launches here, whereas TransformerEngine fuses all three into one. For understanding what the cast does numerically, the unfused version is clearer; for running it in production, you want the fused one.

At Massive Scale: Loss Spikes Within a Few Hundred Steps

Each of the four failure modes is small in isolation. A single tile with one underflowed element loses one gradient signal — the optimiser absorbs it. A single dot product accumulator drifting by 2% is, on average, a 2% smaller update step. Why does naive FP8 actually diverge?

The answer is that the failures compound across three axes. Across the depth axis — every layer's forward and backward casts re-quantise. Across the width axis — every matmul of size K=4096+K = 4096+ drifts. Across the time axis — every step adds its bias to the optimiser's first and second moment estimates. For DeepSeek-V3 (61 layers, hidden 7168, 16k step pretraining), the naive FP8 path accumulates roughly 61416000410661 \cdot 4 \cdot 16\,000 \approx 4 \cdot 10^{6} biased matmul reductions before the first checkpoint — and AdamW's second-moment estimate, which divides into the update, cannot tolerate that much bias.

Empirically, the loss curve of a naive FP8 run looks like a healthy BF16 curve for ~150–300 steps, then begins to diverge super-linearly. Restarting from a recent checkpoint does not help — the bias is structural, not stochastic. Lowering the learning rate delays divergence but does not prevent it. The fix has to address the four numerical failures directly, which is what the rest of this chapter is about.

Do not let an FP8 training run reach divergence to confirm it is wrong. A diverged run is a 10–30% loss of capital. The leading indicators are: (1) gradient amax growing faster than BF16 reference, (2) optimizer second-moment ratio drifting away from a BF16 reference run, and (3) loss curve crossing the BF16 curve from below by more than 0.01 nats. Section 7.4 covers how to wire these telemetries into a kill-switch.

Engineering Reality and Gotchas

A handful of practical issues come up in every FP8 implementation regardless of which mitigation stack you use. None of them are deep, but missing one wastes a week of GPU time:

1. amax computation is the cast's critical path

The per-tile amax is a reduction over the tile axis. For a (1, 128) tile of an activation tensor with 65k tokens, that is 65k reductions of length 128. Hopper Tensor Cores cannot help with this — it is a non-matmul operation. The right place to put the amax computation is inside the previous matmul's epilogue, where the output is still in registers. TransformerEngine and DeepSeek's kernels both do this; a naive implementation that does the amax in a separate kernel pays a 5–15% throughput tax.

2. NaN propagation kills more runs than overflow

E4M3's saturation behaviour (clip to ±448) is gentle. E5M2 does not saturate — it has IEEE-style Inf and NaN. A single gradient that overflows E5M2 produces a NaN, which propagates through every subsequent matmul touching that element, and through AdamW's state. The standard guard is a per-step gradient norm clip before the FP8 cast; an additional NaN-detection layer that rolls back to the previous optimizer state if any gradient is non-finite is cheap insurance.

3. Scale calibration is not stationary

The amax of a tensor changes across training. Early in training, outliers are mild — the “outlier features” phenomenon emerges around step 1k for 7B models, step 5k for 70B+. A scale calibrated on the first 100 steps is wrong by step 10k. The standard solution is “delayed scaling”: the scale used at step tt is computed from the amax history of steps [tW,t1][t - W, t - 1] for some window WW, plus a small safety margin. This is a hyperparameter you have to tune.

4. Communication still happens in BF16

FP8 gives you 2× compute and 2× bandwidth on the GPU. But cross-GPU collectives (all-reduce in DP, all-gather in TP) typically stay in BF16 or FP32, because the cumulative numerical error of an FP8 all-reduce across 256+ ranks is much worse than the intra-GPU error this chapter has been worrying about. DeepSeek-V3 explicitly keeps NCCL traffic in BF16 even though the on-GPU compute is FP8.

5. Determinism is harder

BF16 training can be made bit-exact deterministic with care (fixed random seeds, deterministic kernels, no atomic accumulation). FP8 training generally cannot, because the amax history depends on scheduling, the partial-sum order matters for the 14-bit accumulator, and the cast kernels themselves frequently use atomic operations. Debugging an FP8 run that diverges differently across replays is a serious time sink.

What's next. Sections 3.3 and 3.4 give the full per-tile scaling and high-precision accumulation recipes, including the exact tile shapes DeepSeek-V3 uses (1×128 for activations, 128×128 for weights). Section 3.5 covers the BF16/FP32 components that stay outside the FP8 cast. Section 3.6 walks through a from-scratch FP8 training implementation in PyTorch that integrates all the mitigations.
Loading comments...