Chapter 10
25 min read
Section 59 of 117

Implementing FP8 Training

FP8 Mixed-Precision Training

The Real Problem: From Idea to Kernel

The four sections before this one have walked through the idea of FP8 training: why naive FP8 fails, what fine-grained per-tile quantisation buys you, how high-precision accumulation rescues the inner product, and which tensors stay in BF16 or FP32 because their dynamic range is too wide. Every one of those decisions has the flavour of a clean mathematical observation: round-to-nearest is unbiased, tiles localise outliers, partial sums commute.

And yet the gap between those observations and a kernel that survives a fifteen-trillion-token pretraining run is huge. Real FP8 training has to answer questions the math never asks:

  • When are tensors cast? Every step? Every microbatch? Once per layer? Each choice has different memory and throughput consequences.
  • Where is the scale factor stored, and on which dimension is it tiled? Activations want row-tiles because the batch axis is dynamic; weights want square tiles because they are static across the step.
  • Which three GEMMs of a linear layer (forward, weight-gradient, input-gradient) get which FP8 format? They have different operands and different dynamic ranges.
  • What stays in FP32 forever — and why? Master weight, Adam moments, LayerNorm parameters, the loss scalar, and the GEMM accumulator are non-negotiable.
  • How does the optimizer update an FP32 master weight from FP8 gradients without re-introducing the very quantisation errors we worked so hard to avoid?

This section assembles every previous piece into a complete, working FP8 training step. It is the chapter's capstone — the place where the theory turns into a kernel.

The promise of this section: by the end, you will be able to look at a transformer block and name, for every tensor it touches, the dtype it lives in, the dtype the GEMM accumulates in, and the moment it gets cast. You will also have a 90-line numpy implementation that performs all three FP8 GEMMs end-to-end, and a PyTorch autograd Function you can drop into a model.

Intuition: One Linear Layer, Three GEMMs

Every nn.Linear in your transformer — the attention projections, the MLP up and down projections, the LM head — triggers three matrix multiplications across a full training step, not one. Most engineers learn the forward and forget the backward needs two more.

The three GEMMs, in plain English

  1. Forward. Given input activations XX of shape (B,K)(B, K) and weights WW of shape (K,N)(K, N), compute output Y=XWY = X \cdot W of shape (B,N)(B, N).
  2. Weight-gradient backward. Given the upstream gradient dYdY of shape (B,N)(B, N), compute dW=XdYdW = X^{\top} \cdot dY of shape (K,N)(K, N) — the gradient the optimizer needs.
  3. Input-gradient backward. Compute dX=dYWdX = dY \cdot W^{\top} of shape (B,K)(B, K) — the signal that flows to the previous layer.

Each GEMM has a different pair of operands. Each operand has its own dynamic range. So each GEMM has its own quantisation choice.

GEMMOperand 1Operand 2Why this format?
Forward Y = X·WX → E4M3W → E4M3Activations after LayerNorm + GeLU + residual have moderate range; weights similar. E4M3's tighter mantissa gives the precision we need.
Weight-grad dW = Xᵀ·dYXᵀ → E4M3dY → E5M2Gradient dY can swing six decades — early layers in early steps see dY ≈ 1e-6, late layers see dY ≈ 1e0. Only E5M2's 5-bit exponent contains that range.
Input-grad dX = dY·WᵀdY → E5M2Wᵀ → E4M3Same reason: the gradient half of the matmul lives in E5M2; the weight half lives in E4M3.
The pattern: activations and weights live in E4M3; gradients live in E5M2; all three GEMMs accumulate into FP32. The dtype of an operand is fixed by what tensor it is, not by which GEMM it appears in.

This is the single biggest piece of architecture-level knowledge in FP8 training. The rest of the section is about implementing it without losing accuracy.

The Mathematics of a Fine-Grained FP8 Step

Let us write the operations of one linear layer in equations so we can talk about them precisely. Take a single matrix multiplication Y=XWY = X \cdot W where XRB×KX \in \mathbb{R}^{B \times K} and WRK×NW \in \mathbb{R}^{K \times N}.

Step 1: Per-tile amax

Split WW into T×TT \times T tiles (DeepSeek V3 uses T=128T = 128). For each tile W(p,q)W^{(p, q)} compute its absolute maximum and divide by the format's largest representable value to get a scale:

sW(p,q)=maxi,j(p,q)WijFP8max,FP8max=448 for E4M3s^{(p,q)}_W = \frac{\max_{i, j \in (p,q)} |W_{ij}|}{\text{FP8}_{\max}}, \qquad \text{FP8}_{\max} = 448 \text{ for E4M3}

For XX we use 1×T1 \times T row-tiles because the batch dimension is dynamic: sX(b,q)=maxj(q)Xbj/448s^{(b,q)}_X = \max_{j \in (q)} |X_{bj}| / 448.

Step 2: Cast each tile

Define the round-to-nearest E4M3 operator QE4M3()Q_{\text{E4M3}}(\cdot) by: clip the magnitude to 448, identify the binade [2e,2e+1)[2^e, 2^{e+1}), round to the nearest of the eight mantissa steps with step size 2e32^{e-3}. The cast operation is W~(p,q)=QE4M3(W(p,q)/sW(p,q))\tilde{W}^{(p,q)} = Q_{\text{E4M3}}(W^{(p,q)} / s^{(p,q)}_W) — the values now live in the tight integer-like grid of E4M3. Re-multiplying by the scale gives the dequantised approximation W^(p,q)=sW(p,q)W~(p,q)\hat{W}^{(p,q)} = s^{(p,q)}_W \cdot \tilde{W}^{(p,q)}.

Step 3: The accumulated GEMM

The forward output is the FP32 accumulation of FP8 products, with the scales fused into the epilogue. For one output entry YbnY_{bn}:

Ybn=qsX(b,q)sW(q,n)k(q)X~bkW~knY_{bn} = \sum_{q} s^{(b,q)}_X \cdot s^{(q,n)}_W \sum_{k \in (q)} \tilde{X}_{bk} \cdot \tilde{W}_{kn}

The inner sum is over one tile of length TT; the outer sum is over tile indices qq. In hardware the inner sum lives in a BF16-precision accumulator (the tensor core's native output register) and the outer sum lives in an FP32 register. Every K=128K = 128 inner steps the BF16 accumulator is added into the FP32 outer and reset.

Why two-level accumulation? Adding N=4096N = 4096 small FP8 partial products into one BF16 accumulator builds up a relative error of order N2732N \cdot 2^{-7} \approx 32 — completely unusable. Adding the same partials into 32 BF16 chunks of length K=128K = 128 each, then summing those chunks in FP32, gives a relative error of order K271K \cdot 2^{-7} \approx 1 per chunk times 322234×10632 \cdot 2^{-23} \approx 4 \times 10^{-6} FP32 overhead — about a thousand times smaller.

Step 4: The backward GEMMs follow the same rule

For the weight gradient, replace E4M3 with E5M2 on the gradient operand:

dWkn=bqsX(b,q)sdY(b,n)X~bkdY~bndW_{kn} = \sum_{b} \sum_{q} s^{(b,q)}_X \cdot s^{(b,n)}_{dY} \cdot \tilde{X}_{bk} \cdot \tilde{dY}_{bn}

The structure is identical. Only the dtype of dYdY changes — E5M2 instead of E4M3 — because of dynamic range, not precision.

Step 5: The optimizer never sees FP8

After the backward, dWdW is cast back to BF16 and the FP32 master weight is updated by AdamW:

Wt+1=Wtηm^t/(v^t+ε)ηλWtW_{t+1} = W_t - \eta \cdot \hat{m}_t / (\sqrt{\hat{v}_t} + \varepsilon) - \eta \lambda W_t

where mt,vtm_t, v_t are the Adam moments in FP32, η\eta is the learning rate, and λ\lambda is weight decay. The next forward pass re-quantises Wt+1W_{t+1} into FP8. The optimizer state never touches the FP8 grid.

Interactive: How Tile Size Controls Error

The whole argument for fine-grained quantisation lives in one plot: per-tile error vs. per-tensor error. The widget below builds a 16×16 activation tensor with a tunable number of outliers, lets you pick the tile size, and shows three heatmaps side by side — the input, the dequantised approximation, and the absolute error.

Loading tile quantiser…

The key reading: turn outlier intensity up. With a per-tensor scale, one big spike forces a huge global scale, and every small bulk value rounds to zero. Drop tile size to 4×4 and the spike is confined to one tile; every other tile recovers its own well-fitted scale and the mean error collapses by an order of magnitude or more.

Manual Numerical Walkthrough

Now let us do every step by hand on a 2×4 example so the machinery is undeniable.

Click to expand: full FP8 GEMM on a 2×4 example

Setup. Take X=[0.400.100.300.05]X = \begin{bmatrix} 0.40 & 0.10 & -0.30 & 0.05 \end{bmatrix} (one row of activations, K=4) and the weight column W=[1.200.060.040.02]W = \begin{bmatrix} 1.20 \\ 0.06 \\ -0.04 \\ 0.02 \end{bmatrix}. The exact product is Y=XW=0.401.20+0.100.06+(0.30)(0.04)+0.050.02=0.4990Y = X \cdot W = 0.40 \cdot 1.20 + 0.10 \cdot 0.06 + (-0.30) \cdot (-0.04) + 0.05 \cdot 0.02 = 0.4990.

Step 1: per-tile amax. We use one tile of size 4 (the whole vector). For XX: amax = 0.40, scale sX=0.40/4488.93×104s_X = 0.40 / 448 \approx 8.93 \times 10^{-4}. For WW: amax = 1.20, scale sW=1.20/4482.68×103s_W = 1.20 / 448 \approx 2.68 \times 10^{-3}.

Step 2: cast. Divide each entry by its scale, round to the nearest E4M3 representable value, then look up the result.

For X: divided values are X/sX=(448.0, 112.0, 336.0, 56.0)X / s_X = (448.0,\ 112.0,\ -336.0,\ 56.0). Each one is already inside [−448, 448] and the rounding lands on representable mantissa steps for binade e=8e = 8 (step = 32) or e=6e = 6 (step = 8) — the casts are exact at this resolution. After the cast and re-multiply by sXs_X we get X^=(0.40, 0.100, 0.300, 0.050)\hat{X} = (0.40,\ 0.100,\ -0.300,\ 0.050) — no error at this scale.

For W: divided values are W/sW=(448.0, 22.4, 14.93, 7.47)W / s_W = (448.0,\ 22.4,\ -14.93,\ 7.47). The first lands exactly. The others land on the nearest E4M3 step: 22.422.022.4 \to 22.0, 14.9315.0-14.93 \to -15.0, 7.477.57.47 \to 7.5. After re-multiplying: W^=(1.20, 0.0589, 0.0402, 0.0201)\hat{W} = (1.20,\ 0.0589,\ -0.0402,\ 0.0201).

Step 3: accumulated GEMM. The four FP8 products are 0.401.20=0.48000.40 \cdot 1.20 = 0.4800, 0.1000.0589=0.005890.100 \cdot 0.0589 = 0.00589, 0.300(0.0402)=0.01206-0.300 \cdot (-0.0402) = 0.01206, 0.0500.0201=0.0010050.050 \cdot 0.0201 = 0.001005.

BF16 inner accumulator after each step (we round each running sum to 7 mantissa bits, which is BF16 precision):

  1. 0 + 0.4800 → 0.4800 (exact in BF16)
  2. 0.4800 + 0.00589 → 0.4858 (BF16 rounds 0.48589 to 0.4858)
  3. 0.4858 + 0.01206 → 0.4978
  4. 0.4978 + 0.001005 → 0.4988

Step 4: FP32 promotion. Promote the BF16 accumulator into an FP32 register: outer = 0.4988. The result is 0.4988 against a true value of 0.4990 — relative error of 4 × 10⁻⁴, well inside the noise floor of any real training step.

What just happened. Two operands quantised from FP32 to E4M3, four FP8 products accumulated through a BF16 register, one FP32 promotion at the end. Total quantisation loss: 0.0002 on a value of 0.5. That ratio survives at scale — DeepSeek V3's 671B model trained with this exact recipe matches BF16 baseline loss to within 0.25% across 14.8T tokens.

Interactive: The Three GEMMs of a Linear Layer

Use the buttons or hit Play to walk through the three GEMMs in order. Notice which operand changes dtype at each step and which stays the same.

Loading GEMM pipeline…

The single most useful lesson from this animation: the dtype of a tensor is decided by what role it plays (activation, weight, gradient), not by which GEMM it shows up in.XX stays E4M3 whether it is the left operand of forward or the right operand of weight-grad; dYdY stays E5M2 whether it is the right operand of weight-grad or the left operand of input-grad. This regularity is exactly what makes FP8 implementable as a small, bounded change to a working BF16 transformer.

Interactive: Why the Accumulator Must Be Promoted

The simulator below sums 4096 partial products three different ways — accumulating in FP8, in BF16, and using the DeepSeek BF16-inner + FP32-promote ladder with a tunable promotion period K. Drag the K slider and watch the green curve.

Loading accumulation simulator…

Three things to notice. First, the red FP8 accumulator diverges almost immediately — within a few hundred partials it has lost all meaning. Second, the amber BF16 accumulator drifts linearly: relative error grows with the number of partials at a rate of roughly N27N \cdot 2^{-7}. Third, the green promoted accumulator stays at machine epsilon up to about K=256K = 256, then starts to track the amber curve once the inner BF16 register starts to overflow its own mantissa. K = 128 is the sweet spot on Hopper because that is the largest value that keeps the inner register within one mantissa decade of any single partial.

Plain Python: An End-to-End FP8 GEMM

Before reaching for PyTorch, it is worth seeing every byte of FP8 quantisation in numpy. The 90 lines below implement the round-to-nearest cast into E4M3, per-tile scale factors, and the two-level BF16-inner / FP32-outer accumulator. Everything generalises one-for-one to TILE = 128 and PROMOTE_EVERY = 128 on real hardware.

Plain numpy FP8 GEMM
🐍fp8_gemm.py
11Why these three constants and nothing else

E4M3 is defined by exactly three numbers: the largest finite magnitude it can represent (448 in the OCP spec), the smallest normal magnitude (2 to the minus 6), and the mantissa width (3 bits = 8 distinct fractional steps inside each binade). Everything else about the format — bias, subnormal handling, NaN encoding — falls out of those three once we agree on round-to-nearest.

14quantize_e4m3 in four readable steps

The cast is: take the sign, clip the magnitude to the representable max, find which binade [2^e, 2^(e+1)) the value lives in, then round to the nearest mantissa step inside that binade. The step size is 2^(e - 3) because we have 3 mantissa bits. This is exactly what a Hopper Tensor Core does in hardware; we are just spelling it out in numpy.

EXECUTION STATE
FP8_MAX = 448.0
step at e=1 = 2^(1-3) = 0.25
step at e=8 = 2^(8-3) = 32
23TILE = 4 — small enough to walk through by hand

Real systems use TILE = 128 for weights and 1×128 for activations. We use 4 here so each (TILE × TILE) block fits in one printed matrix and you can re-derive the scale factor mentally. Everything in this code generalises one-for-one to TILE = 128.

25per_tile_scale: one scalar per block

For each (TILE × TILE) block we compute its absolute maximum, then divide by FP8_MAX. That ratio is exactly the smallest scale that pushes the block's amax onto the format's largest representable value — full use of FP8's dynamic range, no waste, no clipping.

EXECUTION STATE
block.amax = 0.183 (example)
scale = 0.183 / 448 ≈ 4.08e-4
34cast_with_scales: divide, round, multiply back

This is the round-trip: scale a block down so its amax becomes 448, cast to E4M3, multiply back by the scale. The output dtype is still float64 in the numpy code — the FP8 lives only conceptually here — but the values are exactly representable in E4M3. In real hardware Xq stays in FP8 memory and the scale is multiplied back at the GEMM's accumulator output.

44PROMOTE_EVERY = 32 — the heart of the trick

Tensor cores compute many small partial products and add them into a running sum. If that sum lives in FP8, it drifts by O(1) within hundreds of steps. If it lives in BF16, it drifts by O(2^-7) — better, but still bad at K=4096. DeepSeek V3's insight: keep the inner sum in BF16 for K iterations, then promote into an FP32 outer register and zero the inner. K=128 is the sweet spot on Hopper; we use 32 here for the toy.

46fp8_gemm: the full forward computation

Per-tile scales for X and W, cast both to E4M3, then a textbook ijk-loop GEMM with the inner accumulator dance. The outer dtype is float64 to keep the reference clean; on hardware the outer register is FP32 and the inner is BF16.

EXECUTION STATE
X.shape =
(M, K) = (8, 64)
W.shape =
(K, N) = (64, 16)
Y.shape =
(M, N) = (8, 16)
55The two-level accumulator

inner is a float32 (we just call it BF16-equivalent for the toy). For each of the 64 partials in this dot product, we add into inner. Every 32 partials we move inner up into outer and zero it. Final answer = outer + inner. That single trick is what turns FP8 from a curiosity into a production training format.

71The acceptance test

Relative error of order 1e-3 is the empirical target — well below the noise of a real training step's gradient estimate (which is dominated by the mini-batch's stochasticity at 1e-1 to 1e-2). At this error level FP8 training matches BF16 training run-for-run on loss, while compute throughput roughly doubles.

EXECUTION STATE
expected rel err = ≈ 1e-3
70 lines without explanation
1"""
2End-to-end FP8 GEMM in 90 lines of numpy.
3
4Implements the three pieces DeepSeek V3 needs to make FP8 training stable:
5  1. round-to-nearest cast into E4M3 (sign / exponent / mantissa simulation),
6  2. per-tile amax scaling factors,
7  3. high-precision accumulation: BF16 inner sum, FP32 promotion every K partials.
8"""
9import numpy as np
10
11# -- (1) E4M3 cast simulation --------------------------------------------------
12FP8_MAX        = 448.0        # largest representable |x| in E4M3 (OCP convention)
13FP8_MIN_NORMAL = 2 ** -6      # smallest normal magnitude
14MANTISSA_BITS  = 3
15
16def quantize_e4m3(x: np.ndarray) -> np.ndarray:
17    sign = np.sign(x)
18    ax   = np.minimum(np.abs(x), FP8_MAX)
19    safe = ax >= FP8_MIN_NORMAL
20    e    = np.floor(np.log2(np.where(safe, ax, 1.0)))
21    step = 2.0 ** (e - MANTISSA_BITS)
22    return sign * np.where(safe, np.round(ax / step) * step, 0.0)
23
24# -- (2) Per-tile amax scaling -------------------------------------------------
25TILE = 4
26
27def per_tile_scale(M_: np.ndarray) -> np.ndarray:
28    rows, cols = M_.shape
29    s = np.zeros((rows // TILE, cols // TILE))
30    for ti in range(rows // TILE):
31        for tj in range(cols // TILE):
32            block = M_[ti*TILE:(ti+1)*TILE, tj*TILE:(tj+1)*TILE]
33            amax  = np.abs(block).max()
34            s[ti, tj] = amax / FP8_MAX if amax > 0 else 1.0
35    return s
36
37def cast_with_scales(M_: np.ndarray, s: np.ndarray) -> np.ndarray:
38    Q = np.zeros_like(M_)
39    for ti in range(s.shape[0]):
40        for tj in range(s.shape[1]):
41            scale = s[ti, tj]
42            block = M_[ti*TILE:(ti+1)*TILE, tj*TILE:(tj+1)*TILE] / scale
43            Q[ti*TILE:(ti+1)*TILE, tj*TILE:(tj+1)*TILE] = quantize_e4m3(block) * scale
44    return Q
45
46# -- (3) High-precision accumulation ------------------------------------------
47PROMOTE_EVERY = 32   # match DeepSeek V3's K = 128 at scale; smaller here for clarity
48
49def fp8_gemm(X: np.ndarray, W: np.ndarray) -> np.ndarray:
50    sX = per_tile_scale(X)
51    sW = per_tile_scale(W)
52    Xq = cast_with_scales(X, sX)
53    Wq = cast_with_scales(W, sW)
54
55    M, K = X.shape
56    _, N = W.shape
57    Y = np.zeros((M, N), dtype=np.float64)   # the FP32 accumulator (here doubled for safety)
58
59    for i in range(M):
60        for j in range(N):
61            outer = 0.0                       # FP32-promoted running sum
62            inner = np.float32(0.0)           # BF16-precision inner accumulator
63            for k in range(K):
64                inner = np.float32(inner + Xq[i, k] * Wq[k, j])
65                if (k + 1) % PROMOTE_EVERY == 0:
66                    outer = float(outer + inner)
67                    inner = np.float32(0.0)
68            Y[i, j] = float(outer + inner)
69    return Y
70
71# -- Sanity check: error must be O(2^-mantissa), not O(1) ---------------------
72rng = np.random.default_rng(0)
73X = rng.standard_normal((8, 64)).astype(np.float32) * 0.5
74W = rng.standard_normal((64, 16)).astype(np.float32) * 0.1
75
76Y_true = X @ W
77Y_fp8  = fp8_gemm(X, W)
78rel = np.linalg.norm(Y_true - Y_fp8) / np.linalg.norm(Y_true)
79print(f"relative error = {rel:.4e}")    # ~1e-3 — well below training noise

Run this on any laptop and the printed relative error will be order 10310^{-3}. That number is the entire case for FP8 training: three orders of magnitude below the noise of a real mini-batch's gradient estimate, so the model never sees the quantisation error during optimisation, but two GEMMs of E4M3 throughput later, you have done the work in half the time of BF16.

PyTorch: A Custom Fp8Linear with Autograd

The numpy code is faithful to the math but won't run on a GPU. In PyTorch we wrap the same operations in a custom torch.autograd.Function so it slots into any model. Two things to watch for: (1) we keep the FP32 master weight as the canonical parameter and quantise on-the-fly, and (2) the backward GEMMs use E5M2 for the gradient operand.

PyTorch Fp8Linear with custom autograd
🐍fp8_linear.py
13Two formats, two purposes

E4M3 has 8 mantissa steps per binade, max ≈ 448 — enough precision for forward activations and weights once we apply per-tile scales. E5M2 has only 4 mantissa steps but goes up to ≈ 57344 — the extra exponent bit (a factor of 32 in dynamic range) is what keeps gradients from being clipped during the wild swings of early training.

EXECUTION STATE
FP8_MAX e4m3 = 448.0
FP8_MAX e5m2 = 57344.0
TILE = 128
16cast_fp8: identical to the numpy version, GPU-resident

Same binade-and-step trick, but every operation is a torch primitive so it runs on the GPU and participates in autograd if you want it to. We do not register a custom gradient — when this function is used inside Fp8Linear's forward, autograd is suppressed by the outer Function class.

24tile_amax_scale via reshape, not a Python loop

The reshape splits the last two dims into (num_tiles, tile, num_tiles, tile), then amax over dims -3 and -1 collapses each tile to one number. This is the trick that turns an O(M × K / TILE^2) Python loop into a single fused kernel.

EXECUTION STATE
t.shape =
(out=128, in=256)
tiles =
(128/128, 256/128) = (1, 2)
scale.shape =
(1, 2)
30quantize_tiled: broadcast the scale grid back to full shape

repeat_interleave inflates a (tilesM, tilesN) grid back into the original (M, N). We divide, cast, then multiply back. In the real Hopper kernel the scale is fused into the GEMM's epilogue so the casted tensor never needs to be materialised at full precision in memory.

37Fp8Linear.forward: where autograd is hand-rolled

We subclass torch.autograd.Function so we can specify the backward exactly. The reason: PyTorch's default autograd would back-prop through every torch.round in cast_fp8, which is mathematically wrong — round has a zero derivative almost everywhere. We treat the cast as a straight-through estimator.

EXECUTION STATE
x.shape =
(batch, in_features)
w.shape =
(out_features, in_features)
y.shape =
(batch, out_features)
38Cast X and W independently

Each operand has its own tile grid of scales. The activations X are scaled per (1 × TILE) row-tile in real systems (because the batch dimension is dynamic), while W is scaled per (TILE × TILE) block. We use (TILE × TILE) for both here to keep the code uniform.

40F.linear is the actual GEMM

On Hopper, this single line is implemented by torch._scaled_mm, which takes Xq, Wq, plus the (1×K) and (K×N) scale tensors and emits an FP32 accumulator GEMM with the scales fused into the epilogue. On CPU we fall back to BF16-precision F.linear — close enough to demonstrate the math.

45Fp8Linear.backward: two GEMMs, one operand each in E5M2

dY is the only tensor that needs E5M2's dynamic range — gradients are the wild-swinging quantity. X and W stay in their E4M3 quantisations from the forward (saved in ctx). Then dX = dY · W and dW = dY^T · X are computed as plain matmuls with the FP8 operands.

EXECUTION STATE
dy.shape =
(batch, out_features)
dx.shape =
(batch, in_features)
dw.shape =
(out_features, in_features)
47dW carries E5M2 (gradient) × E4M3 (activation)

This is the asymmetric GEMM that nailed FP8 training. We promote dY into the format with the right dynamic range, keep X in its forward E4M3 form (saved in ctx — no recompute needed), and let the tensor core accumulate in FP32.

57Fp8LinearModule: a drop-in nn.Module wrapper

The Module is what you actually use in your transformer code. Replace every nn.Linear in your attention and MLP blocks with Fp8LinearModule and you have an FP8-trained transformer. Notice the master weight is FP32 (set on line 67) — that is the canonical copy the optimizer updates; the FP8 cast happens inside forward / backward every step.

67The master-weight discipline

Optimizer state (Adam moments, master weight, parameter) all live in FP32. The FP8 representation is regenerated from the FP32 master at every forward and backward pass. This is non-negotiable: without an FP32 master, the rounding errors from many cast cycles compound and training diverges within ~100 steps.

68AdamW on the FP32 master

Optimizer hyperparameters stay exactly the same as BF16 training — learning rate, weight decay, betas, epsilon. FP8 is a memory and throughput optimisation on the GEMMs, not a re-tuning of the loss landscape. This is the single most important property if you are migrating a working BF16 recipe.

76Smoke test: training-loss must go down

On a real Hopper-class machine, the loss in this loop should descend monotonically and match a vanilla nn.Linear baseline within 0.5%. If your FP8 implementation is buggy, you usually see the loss either NaN out within ten steps (scale mis-applied) or plateau visibly above the BF16 baseline (accumulator promotion missing).

73 lines without explanation
1"""
2Fp8Linear — a drop-in nn.Linear that does its three GEMMs in FP8.
3
4Storage:  FP32 master weight                    (optimizer sees this)
5Forward:  X (BF16 → E4M3) @ W (FP32 → E4M3)     (FP32 accumulator)
6Backward: dW = X^T (E4M3) @ dY (E5M2)           (FP32 accumulator)
7          dX =      dY (E5M2) @ W (E4M3)        (FP32 accumulator)
8
9On Hopper this maps directly to torch._scaled_mm(...).  Here we expose
10the math so the autograd Function is readable.
11"""
12import torch
13import torch.nn as nn
14import torch.nn.functional as F
15from torch.autograd import Function
16
17FP8_MAX  = {"e4m3": 448.0, "e5m2": 57344.0}
18MANT     = {"e4m3": 3,     "e5m2": 2}
19TILE     = 128
20
21def cast_fp8(x: torch.Tensor, fmt: str) -> torch.Tensor:
22    mx   = FP8_MAX[fmt]
23    sign = torch.sign(x)
24    ax   = x.abs().clamp_max(mx)
25    e    = torch.floor(torch.log2(ax.clamp_min(1e-30)))
26    step = torch.pow(2.0, e - MANT[fmt])
27    return sign * torch.round(ax / step) * step
28
29def tile_amax_scale(t: torch.Tensor, tile: int, fmt: str) -> torch.Tensor:
30    """One scale per (tile, tile) block on the last two dims."""
31    *batch, M, K = t.shape
32    t2  = t.reshape(*batch, M // tile, tile, K // tile, tile)
33    amx = t2.abs().amax(dim=(-3, -1))
34    return amx.clamp_min(1e-12) / FP8_MAX[fmt]
35
36def quantize_tiled(t: torch.Tensor, fmt: str, tile: int) -> tuple[torch.Tensor, torch.Tensor]:
37    s = tile_amax_scale(t, tile, fmt)
38    s_full = s.repeat_interleave(tile, dim=-2).repeat_interleave(tile, dim=-1)
39    q = cast_fp8(t / s_full, fmt) * s_full
40    return q, s
41
42class Fp8Linear(Function):
43    @staticmethod
44    def forward(ctx, x: torch.Tensor, w: torch.Tensor, b: torch.Tensor | None):
45        xq, _ = quantize_tiled(x.float(), "e4m3", TILE)
46        wq, _ = quantize_tiled(w.float(),  "e4m3", TILE)
47        y     = F.linear(xq, wq, b)                 # FP32 accumulator under the hood
48        ctx.save_for_backward(xq, wq)
49        return y.to(x.dtype)
50
51    @staticmethod
52    def backward(ctx, dy: torch.Tensor):
53        xq, wq    = ctx.saved_tensors
54        dyq, _    = quantize_tiled(dy.float(), "e5m2", TILE)
55        dx        = dyq @ wq                                  # E5M2 · E4M3
56        dw        = dyq.transpose(-1, -2) @ xq                # E5M2 · E4M3
57        db        = dy.sum(dim=tuple(range(dy.ndim - 1)))
58        return dx.to(dy.dtype), dw.to(dy.dtype), db
59
60class Fp8LinearModule(nn.Module):
61    def __init__(self, in_features: int, out_features: int) -> None:
62        super().__init__()
63        self.weight = nn.Parameter(torch.empty(out_features, in_features))
64        self.bias   = nn.Parameter(torch.zeros(out_features))
65        nn.init.kaiming_uniform_(self.weight, a=5 ** 0.5)
66
67    def forward(self, x: torch.Tensor) -> torch.Tensor:
68        return Fp8Linear.apply(x, self.weight, self.bias)
69
70# -- Tiny training-step smoke test --------------------------------------------
71torch.manual_seed(0)
72layer = Fp8LinearModule(256, 128).to(torch.bfloat16)
73layer.weight.data = layer.weight.data.float()        # master weights stay FP32
74opt   = torch.optim.AdamW(layer.parameters(), lr=1e-3)
75
76x   = torch.randn(32, 256, dtype=torch.bfloat16)
77tgt = torch.randn(32, 128, dtype=torch.bfloat16)
78
79for step in range(20):
80    y    = layer(x)
81    loss = F.mse_loss(y.float(), tgt.float())
82    opt.zero_grad()
83    loss.backward()
84    opt.step()
85    if step % 5 == 0:
86        print(f"step {step:3d}  loss = {loss.item():.4f}")

On a Hopper-class GPU each F.linear call above is implemented by torch._scaled_mm, which takes the FP8 operands plus their scale grids and fuses everything into one tensor-core dispatch. NVIDIA's TransformerEngine library and the open-source torchao package wrap exactly this autograd Function with a few production niceties: amax-history tracking, delayed scaling for activations, and a fallback path to BF16 when an FP8 GEMM produces NaNs.

Migration recipe. To FP8-train an existing BF16 transformer: (1) replace every nn.Linear in attention and MLP with the module above, (2) keep LayerNorm, nn.Embedding, and lm_head in BF16, (3) keep the optimizer (master weight, moments) in FP32, (4) leave all hyperparameters unchanged. The result trains within 0.5% of the BF16 baseline at ~1.4× higher throughput.

At Massive Scale: Where the 1.4× Speed-Up Comes From

At pre-training scale (671B parameters, 14.8T tokens, eight weeks of wall-clock on roughly 2048 H800s for DeepSeek V3), the bottleneck of one training step is the GEMM time inside attention and MLP layers. Memory-bound operators — softmax, LayerNorm, residual adds — are unaffected by the activation dtype because they are already saturating HBM bandwidth. Communication kernels (all-reduce for data-parallel gradients, all-to-all for expert routing in MoE) are also untouched. What changes is the tensor-core throughput on the three GEMMs of every linear layer.

Tensor / stateStored dtypePer-parameter bytesWhy
Master weightFP324 BOptimizer truth; never quantised.
FP8 forward weightE4M31 BRegenerated every forward; never persisted.
Adam first moment mFP324 BSensitive to small magnitudes.
Adam second moment vFP324 BSame.
Gradient buffer dWBF162 BAfter FP8 backward, cast back to BF16.
Activation checkpointBF16 or E4M31–2 BRecomputation lets us throw most away.

Memory per parameter for the optimizer state therefore stays at 12 B (FP32 master + FP32 m + FP32 v), same as BF16 training. The memory savings of FP8 come from two places: (1) the persistent weight memory drops from 2 B (BF16) to 1 B (FP8) when we keep only the cast version for inference-style recompute, and (2) the activation checkpoint memory drops by 2× when we store activations in E4M3 instead of BF16. The throughput gain comes entirely from the tensor cores: the GEMM time on H800 in E4M3 is roughly 1.85× the BF16 time per FLOP for compute-bound shapes. Net training-step speed-up is around 1.4× because not all of a transformer's step is GEMM.

What stays BF16 forever (and why)

  • Embeddings and LM head. Vocab matrices are fat-tailed: a handful of tokens have enormous gradient magnitudes and most tokens have nothing. Per-tile scaling does not help because tiles do not align with token rarity.
  • LayerNorm parameters. Tiny tensors, all of them, and the mean/variance computations need full BF16 precision to avoid bias.
  • Attention softmax. The exponent in softmax dominates the dtype choice; FP8 cannot represent the dynamic range of pre-softmax logits cleanly.
  • Residual stream. The skip connection sums many contributions; a cast-and-uncast hop per layer would compound error catastrophically.

These exceptions are not a flaw of FP8 — they are evidence that the format is being used exactly where it pays. The 95% of parameters that live in attention and MLP linears get a 1.85× FLOPs boost; the 5% that demand more precision stay where they belong.

Engineering Reality: The Pitfalls That Will Bite You

FP8 training is harder than the math suggests, and the failures are mostly not bugs in your kernel — they are integration issues. Here is what production teams (DeepSeek, NVIDIA, Cohere, Mistral) have publicly reported tripping over.

1. Amax history and delayed scaling

Computing per-tile amax on every forward pass is correct but expensive: it adds a small reduction kernel per layer. Production kernels instead maintain an exponential moving average of amax over the last K steps (NVIDIA uses K = 16) and use the EMA as the scale. This is called delayed scaling. It works because activation distributions are smooth across consecutive steps once training is past warmup, and it removes the per-step amax kernel from the critical path. The gotcha: during warmup or a learning-rate restart, amax can shift sharply and the EMA lags; most production implementations therefore use just-in-time scaling for the first ~1000 steps and switch to delayed scaling once the training has stabilised.

2. The amax → scale → cast sequencing

Subtle bug: if you compute amax on the pre-scaled tensor but cast the post-scaled tensor (or vice versa), you double-scale and the forward output is off by exactly s2s^2. Loss will look fine for one step because the gradient backs out the error, then NaN within ten. The fix is to be religious about which copy of the tensor any given kernel sees, and to add unit tests on a single-layer forward that compare against a BF16 reference within tolerance.

3. NaN sentries on every GEMM output

E5M2's mantissa is only two bits — a single bit-flip from cosmic ray or a bad GEMM kernel produces NaN. Production trainers check the L2 norm of every layer's output every step and roll back to the last checkpoint if it overflows. DeepSeek reports they hit ~3 GPU-induced NaNs per week across 2048 GPUs; the cost of the per-step norm check is below 0.1% of step time and is non-negotiable.

4. The gradient-scaler dance you do not have to do

FP16 mixed-precision training required a global loss scaler to keep gradients in range. FP8 does not — the per-tile amax scales already adapt to the local magnitude of every tile, so a single global scaler would be both unnecessary and harmful. Teams porting from torch.cuda.amp often forget to disable the loss scaler and see training diverge; the fix is one line.

5. Checkpoint format is BF16, not FP8

Always serialise the FP32 master weights (and Adam state) to BF16 in your checkpoints, never to FP8. The reason: FP8 has only eight mantissa steps per binade; saving a checkpoint and reloading it would re-round every parameter, drifting the model away from the optimizer's next intended step. Downstream inference can re-cast to FP8 on load, but training resumes from BF16.

6. The first wrong place to look

When FP8 training diverges, the first suspect is almost always the wrong one: people blame the cast kernel. In practice the offender is nearly always (a) an unscaled LayerNorm, (b) a forgotten loss scaler, or (c) the embedding-table dtype having silently been cast to FP8 by an autocast wrapper. Spend ten minutes auditing those three before re-reading the cast code.

The two-sentence summary. FP8 training works because per-tile scales localise outlier ranges into manageable binades, and a two-level accumulator (BF16 inner + FP32 outer every 128 partials) absorbs the rounding noise. Everything in this section — the three GEMMs, the dtype assignments, the master-weight discipline, the engineering pitfalls — is in the service of getting those two properties from the page into a kernel that survives fifteen trillion tokens of pretraining.

With this implementation in hand, the rest of the book's chapters on distributed training, long-context extension, and post-training can assume FP8 as a given. From here on, every time you see a transformer, picture three GEMMs per linear layer, each cast just-in-time, each accumulated through two precision levels, each contributing its 1.85× to the eight-week pretraining budget. That is how giants are forged at the bit level.

Loading comments...