The Real Problem: From Idea to Kernel
The four sections before this one have walked through the idea of FP8 training: why naive FP8 fails, what fine-grained per-tile quantisation buys you, how high-precision accumulation rescues the inner product, and which tensors stay in BF16 or FP32 because their dynamic range is too wide. Every one of those decisions has the flavour of a clean mathematical observation: round-to-nearest is unbiased, tiles localise outliers, partial sums commute.
And yet the gap between those observations and a kernel that survives a fifteen-trillion-token pretraining run is huge. Real FP8 training has to answer questions the math never asks:
- When are tensors cast? Every step? Every microbatch? Once per layer? Each choice has different memory and throughput consequences.
- Where is the scale factor stored, and on which dimension is it tiled? Activations want row-tiles because the batch axis is dynamic; weights want square tiles because they are static across the step.
- Which three GEMMs of a linear layer (forward, weight-gradient, input-gradient) get which FP8 format? They have different operands and different dynamic ranges.
- What stays in FP32 forever — and why? Master weight, Adam moments, LayerNorm parameters, the loss scalar, and the GEMM accumulator are non-negotiable.
- How does the optimizer update an FP32 master weight from FP8 gradients without re-introducing the very quantisation errors we worked so hard to avoid?
This section assembles every previous piece into a complete, working FP8 training step. It is the chapter's capstone — the place where the theory turns into a kernel.
Intuition: One Linear Layer, Three GEMMs
Every nn.Linear in your transformer — the attention projections, the MLP up and down projections, the LM head — triggers three matrix multiplications across a full training step, not one. Most engineers learn the forward and forget the backward needs two more.
The three GEMMs, in plain English
- Forward. Given input activations of shape and weights of shape , compute output of shape .
- Weight-gradient backward. Given the upstream gradient of shape , compute of shape — the gradient the optimizer needs.
- Input-gradient backward. Compute of shape — the signal that flows to the previous layer.
Each GEMM has a different pair of operands. Each operand has its own dynamic range. So each GEMM has its own quantisation choice.
| GEMM | Operand 1 | Operand 2 | Why this format? |
|---|---|---|---|
| Forward Y = X·W | X → E4M3 | W → E4M3 | Activations after LayerNorm + GeLU + residual have moderate range; weights similar. E4M3's tighter mantissa gives the precision we need. |
| Weight-grad dW = Xᵀ·dY | Xᵀ → E4M3 | dY → E5M2 | Gradient dY can swing six decades — early layers in early steps see dY ≈ 1e-6, late layers see dY ≈ 1e0. Only E5M2's 5-bit exponent contains that range. |
| Input-grad dX = dY·Wᵀ | dY → E5M2 | Wᵀ → E4M3 | Same reason: the gradient half of the matmul lives in E5M2; the weight half lives in E4M3. |
The pattern: activations and weights live in E4M3; gradients live in E5M2; all three GEMMs accumulate into FP32. The dtype of an operand is fixed by what tensor it is, not by which GEMM it appears in.
This is the single biggest piece of architecture-level knowledge in FP8 training. The rest of the section is about implementing it without losing accuracy.
The Mathematics of a Fine-Grained FP8 Step
Let us write the operations of one linear layer in equations so we can talk about them precisely. Take a single matrix multiplication where and .
Step 1: Per-tile amax
Split into tiles (DeepSeek V3 uses ). For each tile compute its absolute maximum and divide by the format's largest representable value to get a scale:
For we use row-tiles because the batch dimension is dynamic: .
Step 2: Cast each tile
Define the round-to-nearest E4M3 operator by: clip the magnitude to 448, identify the binade , round to the nearest of the eight mantissa steps with step size . The cast operation is — the values now live in the tight integer-like grid of E4M3. Re-multiplying by the scale gives the dequantised approximation .
Step 3: The accumulated GEMM
The forward output is the FP32 accumulation of FP8 products, with the scales fused into the epilogue. For one output entry :
The inner sum is over one tile of length ; the outer sum is over tile indices . In hardware the inner sum lives in a BF16-precision accumulator (the tensor core's native output register) and the outer sum lives in an FP32 register. Every inner steps the BF16 accumulator is added into the FP32 outer and reset.
Step 4: The backward GEMMs follow the same rule
For the weight gradient, replace E4M3 with E5M2 on the gradient operand:
The structure is identical. Only the dtype of changes — E5M2 instead of E4M3 — because of dynamic range, not precision.
Step 5: The optimizer never sees FP8
After the backward, is cast back to BF16 and the FP32 master weight is updated by AdamW:
where are the Adam moments in FP32, is the learning rate, and is weight decay. The next forward pass re-quantises into FP8. The optimizer state never touches the FP8 grid.
Interactive: How Tile Size Controls Error
The whole argument for fine-grained quantisation lives in one plot: per-tile error vs. per-tensor error. The widget below builds a 16×16 activation tensor with a tunable number of outliers, lets you pick the tile size, and shows three heatmaps side by side — the input, the dequantised approximation, and the absolute error.
The key reading: turn outlier intensity up. With a per-tensor scale, one big spike forces a huge global scale, and every small bulk value rounds to zero. Drop tile size to 4×4 and the spike is confined to one tile; every other tile recovers its own well-fitted scale and the mean error collapses by an order of magnitude or more.
Manual Numerical Walkthrough
Now let us do every step by hand on a 2×4 example so the machinery is undeniable.
Click to expand: full FP8 GEMM on a 2×4 example
Setup. Take (one row of activations, K=4) and the weight column . The exact product is .
Step 1: per-tile amax. We use one tile of size 4 (the whole vector). For : amax = 0.40, scale . For : amax = 1.20, scale .
Step 2: cast. Divide each entry by its scale, round to the nearest E4M3 representable value, then look up the result.
For X: divided values are . Each one is already inside [−448, 448] and the rounding lands on representable mantissa steps for binade (step = 32) or (step = 8) — the casts are exact at this resolution. After the cast and re-multiply by we get — no error at this scale.
For W: divided values are . The first lands exactly. The others land on the nearest E4M3 step: , , . After re-multiplying: .
Step 3: accumulated GEMM. The four FP8 products are , , , .
BF16 inner accumulator after each step (we round each running sum to 7 mantissa bits, which is BF16 precision):
- 0 + 0.4800 → 0.4800 (exact in BF16)
- 0.4800 + 0.00589 → 0.4858 (BF16 rounds 0.48589 to 0.4858)
- 0.4858 + 0.01206 → 0.4978
- 0.4978 + 0.001005 → 0.4988
Step 4: FP32 promotion. Promote the BF16 accumulator into an FP32 register: outer = 0.4988. The result is 0.4988 against a true value of 0.4990 — relative error of 4 × 10⁻⁴, well inside the noise floor of any real training step.
What just happened. Two operands quantised from FP32 to E4M3, four FP8 products accumulated through a BF16 register, one FP32 promotion at the end. Total quantisation loss: 0.0002 on a value of 0.5. That ratio survives at scale — DeepSeek V3's 671B model trained with this exact recipe matches BF16 baseline loss to within 0.25% across 14.8T tokens.
Interactive: The Three GEMMs of a Linear Layer
Use the buttons or hit Play to walk through the three GEMMs in order. Notice which operand changes dtype at each step and which stays the same.
The single most useful lesson from this animation: the dtype of a tensor is decided by what role it plays (activation, weight, gradient), not by which GEMM it shows up in. stays E4M3 whether it is the left operand of forward or the right operand of weight-grad; stays E5M2 whether it is the right operand of weight-grad or the left operand of input-grad. This regularity is exactly what makes FP8 implementable as a small, bounded change to a working BF16 transformer.
Interactive: Why the Accumulator Must Be Promoted
The simulator below sums 4096 partial products three different ways — accumulating in FP8, in BF16, and using the DeepSeek BF16-inner + FP32-promote ladder with a tunable promotion period K. Drag the K slider and watch the green curve.
Three things to notice. First, the red FP8 accumulator diverges almost immediately — within a few hundred partials it has lost all meaning. Second, the amber BF16 accumulator drifts linearly: relative error grows with the number of partials at a rate of roughly . Third, the green promoted accumulator stays at machine epsilon up to about , then starts to track the amber curve once the inner BF16 register starts to overflow its own mantissa. K = 128 is the sweet spot on Hopper because that is the largest value that keeps the inner register within one mantissa decade of any single partial.
Plain Python: An End-to-End FP8 GEMM
Before reaching for PyTorch, it is worth seeing every byte of FP8 quantisation in numpy. The 90 lines below implement the round-to-nearest cast into E4M3, per-tile scale factors, and the two-level BF16-inner / FP32-outer accumulator. Everything generalises one-for-one to TILE = 128 and PROMOTE_EVERY = 128 on real hardware.
Run this on any laptop and the printed relative error will be order . That number is the entire case for FP8 training: three orders of magnitude below the noise of a real mini-batch's gradient estimate, so the model never sees the quantisation error during optimisation, but two GEMMs of E4M3 throughput later, you have done the work in half the time of BF16.
PyTorch: A Custom Fp8Linear with Autograd
The numpy code is faithful to the math but won't run on a GPU. In PyTorch we wrap the same operations in a custom torch.autograd.Function so it slots into any model. Two things to watch for: (1) we keep the FP32 master weight as the canonical parameter and quantise on-the-fly, and (2) the backward GEMMs use E5M2 for the gradient operand.
On a Hopper-class GPU each F.linear call above is implemented by torch._scaled_mm, which takes the FP8 operands plus their scale grids and fuses everything into one tensor-core dispatch. NVIDIA's TransformerEngine library and the open-source torchao package wrap exactly this autograd Function with a few production niceties: amax-history tracking, delayed scaling for activations, and a fallback path to BF16 when an FP8 GEMM produces NaNs.
nn.Linear in attention and MLP with the module above, (2) keep LayerNorm, nn.Embedding, and lm_head in BF16, (3) keep the optimizer (master weight, moments) in FP32, (4) leave all hyperparameters unchanged. The result trains within 0.5% of the BF16 baseline at ~1.4× higher throughput.At Massive Scale: Where the 1.4× Speed-Up Comes From
At pre-training scale (671B parameters, 14.8T tokens, eight weeks of wall-clock on roughly 2048 H800s for DeepSeek V3), the bottleneck of one training step is the GEMM time inside attention and MLP layers. Memory-bound operators — softmax, LayerNorm, residual adds — are unaffected by the activation dtype because they are already saturating HBM bandwidth. Communication kernels (all-reduce for data-parallel gradients, all-to-all for expert routing in MoE) are also untouched. What changes is the tensor-core throughput on the three GEMMs of every linear layer.
| Tensor / state | Stored dtype | Per-parameter bytes | Why |
|---|---|---|---|
| Master weight | FP32 | 4 B | Optimizer truth; never quantised. |
| FP8 forward weight | E4M3 | 1 B | Regenerated every forward; never persisted. |
| Adam first moment m | FP32 | 4 B | Sensitive to small magnitudes. |
| Adam second moment v | FP32 | 4 B | Same. |
| Gradient buffer dW | BF16 | 2 B | After FP8 backward, cast back to BF16. |
| Activation checkpoint | BF16 or E4M3 | 1–2 B | Recomputation lets us throw most away. |
Memory per parameter for the optimizer state therefore stays at 12 B (FP32 master + FP32 m + FP32 v), same as BF16 training. The memory savings of FP8 come from two places: (1) the persistent weight memory drops from 2 B (BF16) to 1 B (FP8) when we keep only the cast version for inference-style recompute, and (2) the activation checkpoint memory drops by 2× when we store activations in E4M3 instead of BF16. The throughput gain comes entirely from the tensor cores: the GEMM time on H800 in E4M3 is roughly 1.85× the BF16 time per FLOP for compute-bound shapes. Net training-step speed-up is around 1.4× because not all of a transformer's step is GEMM.
What stays BF16 forever (and why)
- Embeddings and LM head. Vocab matrices are fat-tailed: a handful of tokens have enormous gradient magnitudes and most tokens have nothing. Per-tile scaling does not help because tiles do not align with token rarity.
- LayerNorm parameters. Tiny tensors, all of them, and the mean/variance computations need full BF16 precision to avoid bias.
- Attention softmax. The exponent in softmax dominates the dtype choice; FP8 cannot represent the dynamic range of pre-softmax logits cleanly.
- Residual stream. The skip connection sums many contributions; a cast-and-uncast hop per layer would compound error catastrophically.
These exceptions are not a flaw of FP8 — they are evidence that the format is being used exactly where it pays. The 95% of parameters that live in attention and MLP linears get a 1.85× FLOPs boost; the 5% that demand more precision stay where they belong.
Engineering Reality: The Pitfalls That Will Bite You
FP8 training is harder than the math suggests, and the failures are mostly not bugs in your kernel — they are integration issues. Here is what production teams (DeepSeek, NVIDIA, Cohere, Mistral) have publicly reported tripping over.
1. Amax history and delayed scaling
Computing per-tile amax on every forward pass is correct but expensive: it adds a small reduction kernel per layer. Production kernels instead maintain an exponential moving average of amax over the last K steps (NVIDIA uses K = 16) and use the EMA as the scale. This is called delayed scaling. It works because activation distributions are smooth across consecutive steps once training is past warmup, and it removes the per-step amax kernel from the critical path. The gotcha: during warmup or a learning-rate restart, amax can shift sharply and the EMA lags; most production implementations therefore use just-in-time scaling for the first ~1000 steps and switch to delayed scaling once the training has stabilised.
2. The amax → scale → cast sequencing
Subtle bug: if you compute amax on the pre-scaled tensor but cast the post-scaled tensor (or vice versa), you double-scale and the forward output is off by exactly . Loss will look fine for one step because the gradient backs out the error, then NaN within ten. The fix is to be religious about which copy of the tensor any given kernel sees, and to add unit tests on a single-layer forward that compare against a BF16 reference within tolerance.
3. NaN sentries on every GEMM output
E5M2's mantissa is only two bits — a single bit-flip from cosmic ray or a bad GEMM kernel produces NaN. Production trainers check the L2 norm of every layer's output every step and roll back to the last checkpoint if it overflows. DeepSeek reports they hit ~3 GPU-induced NaNs per week across 2048 GPUs; the cost of the per-step norm check is below 0.1% of step time and is non-negotiable.
4. The gradient-scaler dance you do not have to do
FP16 mixed-precision training required a global loss scaler to keep gradients in range. FP8 does not — the per-tile amax scales already adapt to the local magnitude of every tile, so a single global scaler would be both unnecessary and harmful. Teams porting from torch.cuda.amp often forget to disable the loss scaler and see training diverge; the fix is one line.
5. Checkpoint format is BF16, not FP8
Always serialise the FP32 master weights (and Adam state) to BF16 in your checkpoints, never to FP8. The reason: FP8 has only eight mantissa steps per binade; saving a checkpoint and reloading it would re-round every parameter, drifting the model away from the optimizer's next intended step. Downstream inference can re-cast to FP8 on load, but training resumes from BF16.
6. The first wrong place to look
When FP8 training diverges, the first suspect is almost always the wrong one: people blame the cast kernel. In practice the offender is nearly always (a) an unscaled LayerNorm, (b) a forgotten loss scaler, or (c) the embedding-table dtype having silently been cast to FP8 by an autocast wrapper. Spend ten minutes auditing those three before re-reading the cast code.
The two-sentence summary. FP8 training works because per-tile scales localise outlier ranges into manageable binades, and a two-level accumulator (BF16 inner + FP32 outer every 128 partials) absorbs the rounding noise. Everything in this section — the three GEMMs, the dtype assignments, the master-weight discipline, the engineering pitfalls — is in the service of getting those two properties from the page into a kernel that survives fifteen trillion tokens of pretraining.
With this implementation in hand, the rest of the book's chapters on distributed training, long-context extension, and post-training can assume FP8 as a given. From here on, every time you see a transformer, picture three GEMMs per linear layer, each cast just-in-time, each accumulated through two precision levels, each contributing its 1.85× to the eight-week pretraining budget. That is how giants are forged at the bit level.