Chapter 10
15 min read
Section 58 of 117

What Stays in BF16 and FP32

FP8 Mixed-Precision Training

The Real Problem: FP8 Cannot Be Everywhere

Sections 10.3 and 10.4 sold you on FP8: fine-grained block scaling recovers most of the dynamic range, high-precision accumulators stop the running sum from collapsing, and an H100 tensor core in FP8 runs nearly twice as fast as in BF16. So why not push the whole training stack to FP8? Why does DeepSeek-V3's actual precision map still keep entire layers in BF16 and an irreducible core in FP32?

The honest answer: FP8 has 3\sim 3 bits of mantissa and a hard ceiling at ±448\pm 448. That is enough for the big matrix multiplies, where per-block scaling stretches the range and where the result is immediately dequantized into a wider format. It is not enough for three families of operations that appear in every training step:

  1. Long-horizon accumulators. Master weights drift by 107\sim 10^{-7} per step against a weight of order 11. Over millions of steps these updates have to actually land. In FP8 the grain near 1.0 is roughly 5×1035 \times 10^{-3} — every update rounds to zero and training stalls.
  2. Wide-range reductions. RMSNorm computes a mean of squares; softmax exponentiates logits that span tens of units. Both routinely produce intermediates that overflow ±448\pm 448 or collapse small probabilities to zero.
  3. Second-moment statistics. Adam's vv tracks g2g^2. For a typical late-training gradient g3×104g \sim 3 \times 10^{-4}, the second moment v108v \sim 10^{-8} — three orders of magnitude below FP8's smallest distinguishable nonzero value.

Putting any of these in FP8 silently corrupts training: no NaN, no explosion, just a loss curve that mysteriously plateaus 0.05 nats above where it should. The job of this section is to make every one of those failures visible, then map out exactly which tensors in a transformer block stay in BF16 versus FP32 versus FP8 — and why.

Intuition: Cheap Ruler, Calipers, Master Blueprint

Imagine a workshop that builds a precise instrument. Most of the work — sawing boards, drilling rough holes, sanding panels — uses a millimetre ruler. It is fast and cheap and the precision is good enough because the next step always re-references the parts against the blueprint. That is FP8.

A few moments demand calipers: matching the bearing diameter to the shaft, reading a balance after many cycles. Errors here cascade — the bearing either fits or seizes the machine. That is BF16: keep the wide range, take the tighter precision, accept the 2× cost for the operations that compound across thousands of layers.

And on the shelf, in a locked drawer, sits the master blueprint — the single canonical drawing every part is referenced against. You do not redraw the blueprint with the millimetre ruler. You make all your sawing and drilling with cheap tools, then you walk back to the drawer, update the blueprint with a fine pen, and the next batch of cuts is taken from that updated reference. That is FP32: master weights, optimizer moments, gamma scales — the source of truth, never the working copy.

The mental rule. If a tensor is read many times to produce one number (residual stream accumulating updates over depth, optimizer state averaged over steps, normalization reducing across d_model), it deserves higher precision. If a tensor is read once and immediately written back (an FP8 GEMM input, a quantized weight tile), low precision is fine.

The Mathematics of Why Some Tensors Refuse FP8

Quantization grain near a value

FP8 E4M3 has 4 exponent bits and 3 mantissa bits. Around a value of magnitude x|x| the spacing between representable numbers is approximately ϵfp8(x)23x=0.125x\epsilon_{fp8}(x) \approx 2^{-3} \cdot |x| = 0.125 \cdot |x|. With per-block scaling, that grain becomes ϵ23amaxblock/N\epsilon \approx 2^{-3} \cdot \text{amax}_{\text{block}} / N where NN is the number of binades resolved. Concretely: near a weight of 1.01.0 the FP8 grain is roughly 5×1035 \times 10^{-3}; near 10410^{-4} it is roughly 10510^{-5}.

Compare against BF16, which has 8 mantissa bits and a grain of ϵbf16(x)28x4×103x\epsilon_{bf16}(x) \approx 2^{-8} \cdot |x| \approx 4 \times 10^{-3} \cdot |x|, and FP32 with ϵfp32(x)223x1.2×107x\epsilon_{fp32}(x) \approx 2^{-23} \cdot |x| \approx 1.2 \times 10^{-7} \cdot |x|. These three numbers are the entire story.

Update-survival inequality

An Adam update of size Δ\Delta against a master weight ww survives the round-trip through a quantized format with grain ϵ(w)\epsilon(w) only if:

Δ12ϵ(w)|\Delta| \geq \tfrac{1}{2}\, \epsilon(w)

For w1w \sim 1 and Δ107\Delta \sim 10^{-7}, this fails by four orders of magnitude in FP8, by two orders of magnitude in BF16, and is satisfied with five orders of headroom in FP32. This is the formal reason master weights live in FP32.

Softmax saturation bound

Softmax over logits ziz_i computes pi=exp(zizmax)/jexp(zjzmax)p_i = \exp(z_i - z_{max}) / \sum_j \exp(z_j - z_{max}). The shifted maximum is exp(0)=1\exp(0) = 1; the smallest tail probability is roughly exp((zmaxzmin))\exp(-(z_{max} - z_{min})). For attention logits with realistic gap zmaxzmin15z_{max} - z_{min} \approx 15, the tail probability is 3×107\sim 3 \times 10^{-7} — well below any FP8 grain, so the tail collapses to zero under FP8 storage. Even BF16 distorts it; FP32 carries it cleanly. That is why every modern stack does the softmax reduction in FP32 even when both inputs and outputs are in BF16.

Second-moment underflow

Adam's v=β2v+(1β2)g2v = \beta_2 v + (1-\beta_2) g^2. For late-training gradients g3×104g \sim 3 \times 10^{-4}, g2107g^2 \sim 10^{-7}. The smallest positive FP8 value near zero is on the order of 292×1032^{-9} \approx 2 \times 10^{-3} per block, so vv rounds to 00. The Adam step ηm/(v+ϵ)\eta \cdot m / (\sqrt{v} + \epsilon) then divides by ϵ\epsilon and blows up. FP32 storage of mm and vv is non-negotiable.

Manual Numerical Walkthrough

Let us trace one Adam step end-to-end in three precisions and see each failure mode appear.

Click to expand: one Adam step in FP32, BF16, and FP8 — side by side

Step 1 — the inputs. A single parameter, mid-training.

w        = 0.731 421       (current master weight)
g        = 3.2e-4           (gradient from this minibatch)
m_prev   = 1.1e-4           (Adam first moment)
v_prev   = 8.5e-8           (Adam second moment)
beta1    = 0.9    beta2 = 0.95
lr       = 3.0e-4    eps  = 1.0e-8

Step 2 — Adam update math (FP32 reference).

m  = 0.9  * 1.1e-4 + 0.1  * 3.2e-4 = 1.31e-4
v  = 0.95 * 8.5e-8 + 0.05 * (3.2e-4)^2
   = 8.075e-8 + 5.12e-9 = 8.587e-8
step = lr * m / (sqrt(v) + eps)
     = 3e-4 * 1.31e-4 / (sqrt(8.587e-8) + 1e-8)
     = 3e-4 * 1.31e-4 / (2.93e-4 + 1e-8)
     = 1.341e-4
w_new (FP32) = 0.731 421 - 1.341e-4 = 0.731 287

The actual weight motion is 1.34×1041.34 \times 10^{-4}. That is the quantity each precision must preserve.

Step 3 — same computation if v stored in FP8. The block grain for v near 1e-7 is roughly 2e-3 (FP8 cannot distinguish below this). v rounds to 0.

v_fp8        = 0  (underflow)
step_fp8     = 3e-4 * 1.31e-4 / (sqrt(0) + 1e-8)
             = 3e-4 * 1.31e-4 / 1e-8
             = 3.93         <- five orders of magnitude too large

One bad step. The parameter would jump from 0.730.73 to 3.2-3.2. The next forward pass would either NaN, or produce a loss spike of dozens of nats, or — worst case — silently poison every downstream activation. This is why v must be FP32.

Step 4 — same computation if w stored in FP8. FP8 grain near w=0.73w = 0.73 is roughly 0.73/80.090.73 / 8 \approx 0.09 with global scaling, or roughly 5×1035 \times 10^{-3} with per-block scaling and amax ≈ 1.

w_quantized      = round(0.731421 / 5e-3) * 5e-3 = 0.730
step (correct)   = 1.341e-4
w - step         = 0.730 - 1.341e-4 = 0.729 866
re-quantize      = round(0.729 866 / 5e-3) * 5e-3 = 0.730    <- no move

Every step rounds back to the same quantized weight. The weight cannot move. Multiply this over 10,000 steps and the FP8 master copy is frozen in place while the FP32 truth has drifted by 1.3\sim 1.3. That is the second non-negotiable: master weights must be FP32.

Step 5 — what BF16 does. BF16 grain near 0.73 is 280.733×1032^{-8} \cdot 0.73 \approx 3 \times 10^{-3} — still too coarse to resolve a 1.3×1041.3 \times 10^{-4} update reliably, though one in 20 updates lands a representable bin. BF16 master weights almost work, which is why a few research papers tried them — but accuracy degrades by 0.05–0.15 nats over a full pretraining run. Nobody at frontier scale ships BF16 master weights anymore.

Step 6 — what BF16 stores happily. Activations of magnitude 1\sim 1, gradients of magnitude 103\sim 10^{-3}, attention scores of magnitude 10\sim 10: all fit cleanly in BF16's range and its mantissa resolution is fine for tensors that are read once and replaced. This is why BF16 is the universal default for activations, gradients, and the residual stream.

Step 7 — what the walkthrough teaches. The precision boundary is not arbitrary. It is set by three hard mechanical bounds: (a) the update-survival inequality for master weights, (b) the second-moment underflow bound for Adam's v, and (c) the softmax-tail bound for attention probabilities. Everything else can — and should — be in BF16 for the highway and FP8 in the GEMM kernels.

Visualizing the Precision Map of a Block

The diagram below is the precision map of one transformer block as DeepSeek-V3 actually ships it. Click any tensor to see why it lives where it lives. Use the precision chips to fade everything except FP8, BF16, or FP32 and see at a glance which paths each precision owns.

Loading precision map…

Three things to read out of the map. First, FP8 is a kernel decision, not a tensor-type decision: only the big GEMMs (QKV projection, output projection) execute in FP8, and the outputs are dequantized to BF16 immediately. Every tensor your Python code sees lives in BF16 or FP32; FP8 is a transient inside a fused kernel. Second, the backward and optimizer side is uniformly FP32 — there is no point arguing about it, nobody at scale runs FP8 master weights or FP8 Adam state. Third, softmax is the only operation in the forward path that escapes to FP32; even with BF16 inputs and BF16 outputs, the exp+sum reduction is FP32 in the middle. That single cast is the difference between a stable long-context model and one whose attention silently truncates.

Plain Python: Simulating FP8 to See It Fail

Below is the diagnostic script every mixed-precision engineer should be able to write in their sleep. It quantizes a few canonical tensors to FP8 with per-block scaling, then measures the damage against an FP32 reference. Three tests, three failure modes, three reasons that FP32 paths exist.

🐍fp8_failure_modes.py
4FP8 E4M3 in one constant: 448

E4M3 has 4 exponent bits and 3 mantissa bits. Its largest representable value is 448. That cap matters: any tensor whose magnitude exceeds 448 saturates, which silently destroys gradients and softmax tails. In real hardware (H100), per-block scaling stretches the range so that the BLOCK abs-max maps to 448 — we simulate that in to_fp8_blockwise below.

EXECUTION STATE
FP8_MAX = 448.0
MANTISSA_LEVELS = 8 (= 2^3)
7Block-wise scaling: the precision trick that saves FP8

Instead of quantizing a whole tensor to one global scale, we pick a per-block scale so the block's abs-max maps to 448. That gives FP8 a usable dynamic range on activations that vary across heads, channels, and tokens. This is the per-128-element scaling DeepSeek-V3 uses (Section 10.3) and the reason FP8 GEMMs are accurate enough for the big matmuls.

13Why we model a coarse step size

Three mantissa bits give 8 quantization levels per power-of-two binade. We approximate this with a uniform step of about 1.75 near amax — close to what FP8 actually does. This step is the source of every problem below: it is too coarse for Adam updates, too coarse for softmax probabilities, too coarse for variance estimates.

EXECUTION STATE
step (near amax) = ≈ 1.75 per ULP
23Test 1 — naive FP8 master weights drift to zero update

We run 10k Adam-sized updates against the same weight. The FP32 copy honestly accumulates each 1e-4 step and ends up at 0.0. The FP8 master copy quantizes after every step. Because 1e-4 is far smaller than the FP8 grain of ~5e-3 near w=1, every update rounds to zero and the FP8 weight never moves. After 10k steps it sits at 1.0. This is the core reason every modern training stack keeps the master weights in FP32 — even DeepSeek-V3 which is otherwise aggressive about FP8.

EXECUTION STATE
delta = 1e-4
FP8 grain near 1.0 = ≈ 5e-3
37Test 2 — softmax in FP8 destroys the probability tail

Logits of [2.1, 8.7, 11.3, -3.4, 6.0] are typical for a mid-trained transformer. The FP32 softmax preserves a long tail of small probabilities (the -3.4 logit becomes ≈ 2e-7 — still meaningful for gradient flow). The FP8 softmax quantizes those small probabilities to zero, breaks the gradient through low-attention paths, and silently kills long-range attention. The L1 error printed here is the measurable damage.

EXECUTION STATE
logits = [2.1, 8.7, 11.3, -3.4, 6.0]
46Test 3 — Adam's v underflows below FP8's grain

Adam tracks v = running average of g². For a late-training gradient of g ≈ 3e-4, v ≈ 9e-8. FP8 cannot represent anything smaller than its grain near zero, so v_fp8 collapses to 0. The Adam update divides by sqrt(v + eps): when v is forced to 0, the update step becomes g / sqrt(eps), which is several orders of magnitude too large. One bad step is enough to spike the loss or NaN the run. This is the second non-negotiable FP32 reservation: Adam's m and v.

EXECUTION STATE
g = 3e-4
v = g² = 9e-8
46 lines without explanation
1import numpy as np
2
3# Simulate FP8 E4M3: ~3 mantissa bits, range about [-448, 448].
4# We approximate by clipping to the range, then quantizing to a coarse grid
5# whose step size scales with the value's magnitude (block-wise scaling).
6FP8_MAX = 448.0
7MANTISSA_LEVELS = 8  # 2^3 = 8 distinguishable steps inside each binade
8
9def to_fp8_blockwise(x):
10    """Per-block dynamic scaling: scale the whole tensor so abs-max -> FP8_MAX."""
11    x = np.asarray(x, dtype=np.float64)
12    amax = max(np.abs(x).max(), 1e-12)
13    scale = FP8_MAX / amax
14    # quantize to MANTISSA_LEVELS per binade (coarse, on purpose)
15    y = x * scale
16    # signed clip
17    y = np.clip(y, -FP8_MAX, FP8_MAX)
18    # round to ~3 mantissa bits resolution
19    step = FP8_MAX / (MANTISSA_LEVELS * 32)   # ~ 1.75 per ULP near amax
20    y = np.round(y / step) * step
21    return y / scale                          # dequantize for downstream math
22
23# ------ Test 1: master weights in FP8 vs FP32 over many tiny Adam updates ------
24np.random.seed(0)
25w_fp32 = 1.0                                  # canonical weight
26w_fp8  = 1.0                                  # naive FP8 master weight
27delta  = 1e-4                                 # representative Adam update
28for step in range(10_000):
29    w_fp32 = w_fp32 - delta
30    # FP8 master: re-quantize after each update
31    w_fp8  = float(to_fp8_blockwise(np.array([w_fp8 - delta]))[0])
32
33print(f"FP32 master after 10k steps: {w_fp32:.6f}")
34print(f"FP8  master after 10k steps: {w_fp8:.6f}")
35print(f"FP8 drift vs FP32 truth:     {abs(w_fp8 - w_fp32):.6f}")
36
37# ------ Test 2: softmax in FP8 vs FP32 on realistic attention logits ------
38logits = np.array([2.1, 8.7, 11.3, -3.4, 6.0])
39def softmax(x): e = np.exp(x - x.max()); return e / e.sum()
40
41p_fp32 = softmax(logits)
42p_fp8  = softmax(to_fp8_blockwise(logits))
43print("\nFP32 softmax:", np.round(p_fp32, 5))
44print("FP8  softmax:", np.round(p_fp8, 5))
45print("L1 error:    ", np.abs(p_fp32 - p_fp8).sum().round(5))
46
47# ------ Test 3: Adam second moment v ~ g^2 under FP8 ------
48g = 3e-4                                       # realistic late-training gradient
49v_fp32 = g * g                                 # 9e-8
50v_fp8  = float(to_fp8_blockwise(np.array([v_fp32]))[0])
51print(f"\nv (FP32): {v_fp32:.3e}")
52print(f"v (FP8) : {v_fp8:.3e}    <- collapses to 0; Adam step explodes via 1/sqrt(v+eps)")

The output of running this script on any modern laptop:

FP32 master after 10k steps: 0.000000
FP8  master after 10k steps: 1.000000
FP8 drift vs FP32 truth:     1.000000

FP32 softmax: [0.00010 0.07585 0.92388 0.00000 0.00018]
FP8  softmax: [0.00000 0.07585 0.92415 0.00000 0.00000]
L1 error:    0.00056

v (FP32): 9.000e-08
v (FP8) : 0.000e+00    <- collapses to 0; Adam step explodes via 1/sqrt(v+eps)

Read each line as a verdict: the FP8 master never moved, the FP8 softmax dropped the lowest-probability token entirely, and the FP8 second moment vaporized. Three small failures that compound into an unusable training run.

Sanity-check yourself. Re-run the master-weight loop with delta = 1e-2 (a hundred times larger). The FP8 master finally tracks the FP32 reference. The threshold for FP8 master weights to work is roughly Δ>5×103\Delta > 5 \times 10^{-3} per step — three orders of magnitude larger than any realistic Adam update at scale. Confirm the breakdown bound for yourself; it is the cleanest way to internalize why FP32 master weights are mandatory.

PyTorch: The Mixed-Precision Recipe DeepSeek Ships

Below is the production transformer block with the precision map baked in. Three patterns to study: (a) RMSNorm with a forced FP32 reduction, (b) FP8-eligible matmuls scoped by autocast, (c) softmax with an explicit FP32 cast in the middle. The optimizer sees only the FP32 master copy and updates it in FP32.

🐍block_mixed_precision.py
5RMSNorm in FP32: a 6-line discipline that prevents silent NaNs

The norm computes mean(x²) over d_model. In BF16 that mean is fine; in FP8 it collapses because squaring 0.05 gives 0.0025 — below the FP8 grain near zero — and the variance estimate becomes garbage. The fix is to upcast inside forward, do the reduction in FP32, and cast back. gamma stays in FP32 because it is a learned scale that the FP32 optimizer updates. This pattern is universal across modern stacks (LLaMA, Mistral, DeepSeek).

25qkv and proj are FP8 GEMM targets — the only FP8 ops in the block

These two nn.Linear layers are where FP8 earns its keep. Hopper tensor cores execute the GEMM with FP8 inputs/weights and a higher-precision accumulator; the output is dequantized back to BF16 immediately. Everything else in the block — the norm, the softmax, the residual, the optimizer state — stays in BF16 or FP32.

34autocast scopes the FP8 region

torch.autocast tags a region of the graph as eligible for low-precision execution. The framework picks FP8 for the matmul kernel (when available) and BF16 for the output. Outside this block we are deliberately running in BF16 or FP32 — autocast is the lever that selectively enables low precision only where the math is safe.

40softmax(.float()) is the single most important line in the file

We cast scores up to FP32 before softmax and cast the probabilities back to BF16 afterwards. Exponentiating attention logits in BF16 is borderline (the largest exp can be 1e9 — beyond BF16's mantissa precision near that magnitude); doing it in FP8 is catastrophic. Forcing FP32 here costs a single transient cast per attention head per step and prevents probability-tail collapse.

EXECUTION STATE
scores dtype = BF16
probs dtype after softmax = BF16 (FP32 intermediate)
47Residual addition is BF16 — same as the rest of the highway

The residual stream is the spine of the network. Adding a freshly-computed attention output to it in BF16 is safe (the magnitudes match), but doing it in FP8 would compound quantization noise over 60+ layers. Modern LLMs universally keep the residual in BF16 — including DeepSeek-V3, despite its aggressive FP8 use elsewhere.

52AdamW operates only on the FP32 master parameters

By construction, the optimizer sees the FP32 master copy of every weight. Its internal m and v buffers are FP32. When step() runs, it computes the update against FP32 state and writes back into the FP32 master. A separate pass casts the master into the BF16 shadow used for the forward — this is the source-of-truth/cache pattern that every mixed-precision trainer implements.

56 lines without explanation
1import torch
2import torch.nn as nn
3from torch.optim import AdamW
4
5# ------------- Custom mixed-precision RMSNorm (FP32 reduction) -------------
6class RMSNormFP32(nn.Module):
7    """RMSNorm where the reduction is forced to FP32 even if inputs are BF16."""
8    def __init__(self, d, eps=1e-6):
9        super().__init__()
10        self.gamma = nn.Parameter(torch.ones(d))   # gamma stays in FP32
11        self.eps = eps
12
13    def forward(self, x):                          # x: (B, S, d) BF16
14        orig_dtype = x.dtype
15        x32 = x.float()                            # cast up for the reduction
16        rms = x32.pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
17        out = (x32 * rms) * self.gamma             # gamma is FP32
18        return out.to(orig_dtype)                  # cast back to BF16
19
20# ------------- A transformer block whose precision map mirrors DeepSeek-V3 ----
21class Block(nn.Module):
22    def __init__(self, d, n_heads):
23        super().__init__()
24        self.norm1 = RMSNormFP32(d)
25        self.qkv   = nn.Linear(d, 3 * d, bias=False)    # FP8 GEMM target
26        self.proj  = nn.Linear(d, d, bias=False)        # FP8 GEMM target
27        self.n_heads = n_heads
28        self.d_h    = d // n_heads
29
30    def forward(self, x):                                # x BF16
31        h = self.norm1(x)                                # FP32 inside, BF16 out
32        # ---- FP8 GEMM, BF16 output (autocast handles cast + dequant) ----
33        with torch.autocast(device_type="cuda", dtype=torch.bfloat16,
34                            enabled=True):
35            qkv = self.qkv(h)                            # FP8 weight, BF16 act
36        B, S, _ = qkv.shape
37        q, k, v = qkv.view(B, S, 3, self.n_heads, self.d_h).unbind(2)
38        # ---- Attention: BF16 scores, FP32 softmax ----
39        scores = torch.einsum("bshd,bthd->bhst", q, k) / (self.d_h ** 0.5)
40        probs  = scores.float().softmax(-1).to(scores.dtype)   # safe softmax
41        attn   = torch.einsum("bhst,bthd->bshd", probs, v).reshape(B, S, -1)
42        # ---- FP8 output projection ----
43        with torch.autocast(device_type="cuda", dtype=torch.bfloat16,
44                            enabled=True):
45            out = self.proj(attn)
46        return x + out                                   # BF16 residual
47
48# ------------- Optimizer: master weights + m + v all live in FP32 -----------
49def build_optimizer(model, lr=3e-4):
50    # AdamW maintains FP32 internal state by default; what matters is that
51    # the .param tensor it sees is the FP32 master copy. We cast the model
52    # parameters to FP32 here and keep BF16 shadow copies for the forward.
53    fp32_params = [p for p in model.parameters() if p.dtype == torch.float32]
54    return AdamW(fp32_params, lr=lr, betas=(0.9, 0.95), weight_decay=0.1)
55
56# Training step (sketch):
57#   x_bf16 = embed(tokens).bfloat16()
58#   y      = model(x_bf16)                # FP8 inside big matmuls, BF16 acts
59#   loss   = cross_entropy(y.float(), targets)   # FP32 cross-entropy
60#   loss.backward()                       # BF16 grads, FP32 reductions
61#   optim.step()                          # FP32 master weights updated
62#   sync_master_to_bf16(model)            # cast master -> BF16 shadow

Two structural details worth a second look. First, autocast is not a magic wand: it tags a region as eligible for low precision, but the framework still picks the actual kernel based on what is available and safe. On H100 with FP8 support enabled, the qkv and proj GEMMs run in FP8; on A100 the same code runs in BF16. The precision map is portable. Second, the cost of the FP32 cast inside softmax is negligible — softmax is memory-bound, not compute-bound, and a single extra cast costs microseconds against the milliseconds of the surrounding matmuls. There is no engineering reason to economize on it.

One line that catches 90% of mixed-precision bugs. Add assert not torch.isnan(loss).any() immediately after the loss computation and pin the failing batch to a file. NaN at training time almost always traces back to one of three things in this section's precision map: a missing FP32 cast in softmax, an FP8 input that breached ±448\pm 448, or an Adam step against a quantized v. The assertion fires before the gradient propagates and the postmortem is trivial. Frontier labs run this assertion on every step.

At Massive Scale: Why FP32 Master Weights Cost Terabytes (and Are Worth It)

Snap the precision map onto DeepSeek-V3's 671B parameter budget and the costs become concrete:

Tensor classPrecisionBytes / paramTotal at 671B params
Forward weights (FP8 GEMM input)FP81671 GB
BF16 shadow weights (for autocast)BF1621.34 TB
FP32 master weightsFP3242.68 TB
Adam m + v (FP32 each)FP3285.37 TB
Gradients (BF16)BF1621.34 TB
Total optimizer-state cost1711.4 TB

That 11.4 TB of optimizer state is per training step and per replica if you do not shard it. The only reason this is tractable is ZeRO-1 / FSDP, which splits the FP32 master and Adam state across data-parallel ranks (Chapter 11.7). With 10241024 H100s in the cluster, the per-GPU share of FP32 optimizer state is around 11 GB11 \text{ GB} — fits comfortably in the 80 GB HBM with room left for activations and the FP8 forward weights.

Two strategic consequences fall out of this arithmetic. First, FP8 is the leverage that makes everything else possible: dropping the forward weight bytes from 2 to 1 frees the budget for the FP32 master and optimizer state we are not willing to give up. Without FP8 GEMMs, DeepSeek-V3 could not fit on the cluster. Second, the FP32 reservations are not waste — they are the survival kit. The two failure modes from Section 1 (master-weight underflow and Adam v underflow) are exactly what FP32 prevents. Cutting them to BF16 to save bytes is the most common rookie optimization, and the most expensive one when the run silently underperforms.

Where the bytes go in practice

  1. FP8 forward weights live on every GPU that runs a copy of the tensor-parallel shard. They are read once per forward step and discarded.
  2. BF16 shadow weights are the working copy that PyTorch's autocast actually reads. They are refreshed from the FP32 master after each optimizer step.
  3. FP32 master weights and Adam m, v are sharded across data-parallel ranks (ZeRO-1). On a step boundary, all-gather brings the master weight tile back to every rank that needs it, the optimizer step runs, and the shard is written back. This is the most communication-heavy step in the training loop and the one DeepSeek-V3 spent the most engineering on (see Chapter 11.5 DualPipe).

Engineering Reality: The DeepSeek-V3 Precision Map

Here is the precision map as DeepSeek-V3 actually ships it, layer by layer. Every entry is a deliberate engineering choice backed by a failure mode like the ones we traced above.

ComponentPrecisionWhy
Token embedding & output headBF16Embeddings are sparse and gradient-sensitive (a token seen once must still produce a meaningful gradient). FP8 quantization on the embedding matrix loses tail tokens. The output head ties into cross-entropy which itself runs in FP32.
RMSNorm reductionFP32 (in), BF16 (out)Mean-of-squares overflows or underflows in FP8 and is borderline in BF16. The reduction is forced to FP32 and the rescaled activation is cast back.
RMSNorm gamma (learned scale)FP32 master, BF16 shadowSame logic as every other learned parameter: master in FP32 so the optimizer step actually lands, shadow in BF16 for the forward pass.
QKV / output / FFN GEMMsFP8 input, FP8 weight, BF16 outputThe whole point of FP8. Per-block (1×128 act, 128×128 weight) scaling preserves dynamic range; the FP32 accumulator (Section 10.4) prevents catastrophic accumulation error.
Attention scores (QKᵀ)BF16Logits routinely span ±30 in long-context models; the softmax that follows needs the full range. FP8 saturates and the post-softmax tail collapses.
SoftmaxFP32 reduction, BF16 outputexp(z) where z spans 15-30 produces a probability tail at 1e-7 to 1e-13. Only FP32 preserves it; BF16 distorts it; FP8 destroys it.
Residual streamBF16Compounded over 60+ layers, an FP8 residual would accumulate quantization noise on every block. BF16 grain (≈ 4e-3 × magnitude) is small enough relative to typical residual magnitudes (≈ 1) that depth-wise drift is negligible.
GradientsBF16 (per-rank), FP8 (cross-node compression)Per-rank gradient is BF16 — the standard. For cross-node all-reduce, DeepSeek-V3 compresses to FP8 with per-block scaling, recovers the BF16 average on the receive side. This is an aggressive choice and Section 11.6 covers the analysis.
Master weightsFP32Update-survival inequality — every line of math points here. Sharded across DP ranks via ZeRO-1.
Adam m, vFP32Second-moment underflow bound — every line of math points here too. Sharded across DP ranks via ZeRO-1.
Cross-entropy lossFP32Loss values are summed across the batch and then divided; numerical drift in the reduction shows up directly in the gradient norm. The whole loss path is FP32 — it is one scalar per step, the cost is zero.
Gradient clippingFP32Computes the global gradient L2 norm by summing squared values across every parameter shard. Mixed-precision implementations of this step are a known source of subtle bugs; FP32 is the safe default and the cost is negligible.
The one sentence to remember. FP8 lives inside the big matmuls; BF16 lives on the activation highway; FP32 lives in the master copies and the small reductions that compound over time. Crossing those lines costs accuracy with no real compensation in speed — the FP8 GEMM is already saturating the tensor cores, so demoting the residual or the softmax buys nothing and breaks the loss curve.

DeepSeek-V3's precision map looks intricate because every cell records a separate engineering battle: did this tensor survive FP8 in the ablations, or did the run diverge by 200B tokens in? The answer is rarely "FP8 was fine" and rarely "FP32 was needed." The answer almost always is this exact tensor at this exact place needed BF16. The whole chapter builds toward this one observation: mixed precision is not a format choice. It is a precision-by-precision audit of every tensor flow in the model. Section 10.6 puts that audit into runnable code — the full training loop with the precision map wired in.

Where we go from here. Section 10.5 mapped the precision boundaries for one transformer block. Section 10.6 implements the complete mixed-precision training step that fits on a real H100 cluster, including the FP32 master/BF16 shadow synchronization, the autocast scopes, and the gradient handling that lets DeepSeek-V3 reach its target loss on 14.8 T tokens without ever crossing the FP8 lines we just drew.
Loading comments...