Chapter 11
28 min read
Section 66 of 117

Memory Optimisation: No Tensor Parallelism Required

Distributed Training: DualPipe and the Parallelism Stack

The Problem: Tensor Parallelism Was Supposed to Be Mandatory

Every other 600-billion-parameter recipe in the open record — Megatron-LM, GPT-NeoX, Llama-3 405B — reaches the same conclusion at the same point in the design: the model is too fat to fit on one GPU, so you slice each weight matrix across the GPUs of a single node and let an all_reduce glue the partial results back together. That technique is tensor parallelism (TP), and it is what makes a single transformer layer survive the H100's 80 GB budget.

TP is also expensive in a way that nothing else in the stack is. Each layer of forward and backward fires two all_reduce collectives in the critical path. They are blocking — no kernel inside the layer can start until every rank has its slice of the previous result. They eat NVLink bandwidth that DualPipe was supposed to give back to compute. And because they happen inside the layer, they break the symmetry that lets pipeline schedulers overlap forward and backward cleanly.

DeepSeek-V3 refused to pay that bill. The official report says it plainly: "In order to reduce the memory footprint during training, we employ the following techniques. […] Notably, DeepSeek-V3 does not use tensor parallelism during training." That decision is the load-bearing one for the entire parallelism stack — it is what lets DualPipe schedule cleanly, what lets the all-to-all collectives from expert parallelism share the network with nothing else, and what makes the cluster topology drawable on one page.

The bet. Three modest memory tricks — selective recomputation, CPU-resident EMA, and a shared embedding/output-head — collectively save enough HBM to keep the model below 80 GB per GPU without a single intra-layer collective. The savings are independent, additive, and cheap to implement. The cost is roughly 1% extra FLOPs from recomputation. The win is the rest of this chapter.

Intuition: Trade Three Cheap Tricks for One Expensive One

Tensor parallelism solves a memory problem with a communication tool. That is a category error when communication is already your scarce resource. The right question is: which parts of the per-GPU budget are easiest to shrink without talking to other GPUs?

Three answers fall out of the budget once you write it down:

  • Activations are full of cheap-to-recompute waste. The fattest activations in an MLA block — the RMSNorm output and the up-projection — are also the cheapest to redo. Don't save them; rebuild them on the backward pass.
  • The EMA isn't hot. The exponential moving average copy of the weights is only touched once per step, and the update is embarrassingly parallel with the next forward. It does not need to live in HBM; it can live on pinned host RAM and be updated asynchronously.
  • The embedding and the LM head are the same tensor. Tie their weights and put them on the same pipeline stage, and their VdV \cdot d matrix — about 1.8 GB in BF16 for V3 — costs you once instead of twice.
The three tricks are independent. You can switch any subset on, in any order, without breaking the others. In the visualizer below you can toggle each one and watch the per-GPU bar shrink piece by piece.

The Memory Bookkeeping: What Lives on a GPU

Before we apply any tricks, let us write down where the bytes actually go. For a model with NN total parameters running under PP pipeline stages and ZeRO-1 across a data-parallel degree DD, the per-GPU memory at one training step breaks into six buckets:

BucketSize (bytes)Why it is there
Weights (BF16)2N / POne pipeline stage holds N/P parameters in working precision.
Gradients (BF16)2N / PSame shape as weights; produced by the backward pass.
Optimizer state (FP32)12N / (P·D)AdamW master copy + two FP32 moments. Sharded by ZeRO-1 across the DP group.
EMA (FP32)4N / (P·D)Second FP32 copy of every weight, kept for evaluation and final-checkpoint smoothing.
Activations (BF16)≈ 12 · L · B · S · dSaved between forward and backward. Scales with sequence length and batch.
Embed + LM head (BF16)2 · V · d (or 4 · V · d untied)Vocabulary embedding and output projection — share or duplicate the same V × d matrix.

Four of these six are dictated by physics: you cannot avoid weights, gradients, the optimizer state, or activations. But the EMA and the embed/head are discretionary — the framework chose to put them on the GPU because that was the easiest place to put them. The tricks in this section take each discretionary line and ask: does this really need to be here?

Trick 1: Recompute RMSNorm and the MLA Up-Projection

Standard PyTorch autograd saves every intermediate it thinks the backward pass might need. For a transformer block that means at least four wide tensors per layer: the input xx, the RMSNorm output y=RMSNorm(x)y = \text{RMSNorm}(x), the up-projection result u=Wupyu = W_{\text{up}} y, and the block output aa. Three are width- dd; one — uu — is width-4d4d. In a 7168-wide model that single tensor is 4× the size of all the others combined.

Selective recomputation throws away yy and uu. On the backward pass, the framework redoes the RMSNorm and the up-projection — both of which are fast: RMSNorm is one fused kernel, the up-projection is a single GEMM. The boundary tensors xx and aa stay on the tape because the next block needs them and there is no shortcut to recompute attention cheaply.

The math behind the cost is the punchline. If the layer has forward FLOPs FF, the up-projection contributes roughly F/3F/3 and the RMSNorm contributes F/100F/100. Doubling those two terms (because we run them twice now) costs about 0.34F0.34 F extra for the recomputed block — but the backward of that same block costs 2F2F, so the extra is 0.34/(3F)11%0.34 / (3F) \approx 11\% of the block's end-to-end FLOPs. Amortised across the whole forward+ backward step of the network, that becomes the often-quoted 1%\sim 1\% overhead for activation checkpointing.

Why these two operations specifically? Because they have the largest memory-to-FLOP ratio in the block. RMSNorm produces a width-dd tensor withO(d)O(d) FLOPs — almost pure memory. The up-projection produces a width-4d4d tensor with O(d4d)=O(d2)O(d \cdot 4d) = O(d^2) FLOPs. Compare to attention: attention is O(S2d)O(S^2 \cdot d) FLOPs and its activations are smaller than its compute. So you save attention's output and you recompute everything that is fat-but-cheap.
Loading recomputation animation…

The animation makes the bookkeeping concrete. Watch the naive column accumulate four glowing tensors during forward, and the checkpoint column accumulate only two. When the backward arrives, the checkpoint column lights up amber where it has to redo work — first the up-projection (because the next backward op wants uu), and then the RMSNorm (because the up-projection itself wants yy). Both columns finish at the same final answer.

Trick 2: Stream the EMA to CPU Memory

Every large training run keeps an exponential moving average of the parameters — a second FP32 copy that updates with rule θtEMA=βθt1EMA+(1β)θt\theta^{\text{EMA}}_t = \beta\, \theta^{\text{EMA}}_{t-1} + (1 - \beta)\, \theta_t after every optimizer step. EMA weights are smoother than raw weights and produce better evaluation losses; you also typically ship the EMA as the final checkpoint.

At 671 B parameters, an FP32 EMA is 671×109×4=2.68 TB671 \times 10^9 \times 4 = 2.68\text{ TB} of state. Even sharded across a DP-degree of 64 it still costs ~42 GB per GPU — over half the H800 budget by itself. That is far too much for a buffer that gets touched once per step.

DeepSeek pushes the EMA off the GPU entirely. The shadow copy lives in pinned host RAM — page-locked memory that the GPU can DMA to and from without involving the kernel scheduler. After each optimizer step, a side CUDA stream copies the new BF16 weights to the host, upcasts to FP32, and blends them into the shadow copy with β=0.999\beta = 0.999. The whole transfer is async and overlaps with the next forward pass — the model never waits for it.

Why pinned memory matters. .to('cpu', non_blocking=True) only behaves asynchronously when the destination is pinned. With pageable host memory the runtime falls back to a synchronous staging copy, which would block the GPU until the EMA finished — killing the overlap and burning the whole optimization.

Trick 3: Co-locate and Share the Embedding and Output Head

The input embedding maps a token id in {1,,V}\{1, \dots, V\} to a width- dd vector. The output head — the "LM head" — maps a width-dd vector back to a distribution over the same VV tokens. Both are V×dV \times d matrices. For V3 with V128KV \approx 128\text{K} and d=7168d = 7168, that is 12800071682=1.83 GB128000 \cdot 7168 \cdot 2 = 1.83\text{ GB} of BF16 weights — per matrix.

Under pipeline parallelism, the embedding lives on the first pipeline stage and the LM head lives on the last. If you treat them as separate parameters, you pay twice. DeepSeek-V3 makes two coordinated choices:

  • Tie the weights. The LM head reuses the embedding matrix transposed. One parameter, two roles — the standard weight-sharing trick that goes back to Press & Wolf (2016).
  • Co-locate them on the same PP stage. The DualPipe schedule places the embedding and the head on the same rank, so the tied weight lives on exactly one GPU and no extra all-reduce is needed to keep two copies in sync.

The first choice saves 1.83 GB on the LM-head rank. The second choice saves the gradient sync that an untied head would otherwise require. Together they remove almost 4 GB from the critical-path rank's budget — enough on its own to push DeepSeek-V3 across the 80 GB line on the rank that holds the deepest activations.

Manual Numerical Walkthrough: 671B Without TP

Plug the DeepSeek-V3 numbers in — one bucket at a time

Setup. 671 B total parameters, 16 PP stages, 64-way data parallelism, sequence length 4096, micro-batch 1, hidden width 7168, 61 transformer layers, 128K vocabulary.

Weights (BF16). 2671B/16=83.9 GB2 \cdot 671 \text{B} / 16 = 83.9\text{ GB} spread across the stage — but the stage holds 61/163.861 / 16 \approx 3.8 layers, and onlyactive parameters of an MoE layer count for working memory because the routed experts that did not fire this token do not need to be in fast cache. The effective working set per GPU is closer to 237B/164.6 GB2 \cdot 37 \text{B} / 16 \approx 4.6\text{ GB} for active params, plus a thin slice of routed experts.

Gradients (BF16). Same shape, same number: 4.6 GB\approx 4.6\text{ GB} per GPU for active weights.

Optimizer state (ZeRO-1). 12671B/(1664)7.9 GB12 \cdot 671 \text{B} / (16 \cdot 64) \approx 7.9\text{ GB} per GPU. ZeRO-1 shards it across the 64-way DP group, so each rank keeps only 1/64th of the full 12-bytes-per-param state.

EMA (Trick 2). Without Trick 2: 4671B/(1664)2.6 GB4 \cdot 671 \text{B} / (16 \cdot 64) \approx 2.6\text{ GB} per GPU on HBM. With Trick 2: 0 GB (lives on host).

Activations. Approximately 1241409671682/2302.6 GB12 \cdot 4 \cdot 1 \cdot 4096 \cdot 7168 \cdot 2 / 2^{30} \approx 2.6\text{ GB} per layer of one micro-batch, where 4 is the average layer count per stage. Without recomputation (Trick 1): 2.6 GB4=10.5 GB\approx 2.6 \text{ GB} \cdot 4 = 10.5\text{ GB} per stage. With recomputation: 7.3 GB\approx 7.3\text{ GB} (about 30% savings from dropping the wide intermediates).

Embed + head (Trick 3). Without sharing: 212800071682/2303.4 GB2 \cdot 128000 \cdot 7168 \cdot 2 / 2^{30} \approx 3.4\text{ GB} on the embed rank and the head rank. With sharing: 1.7 GB\approx 1.7\text{ GB}, only on the co-located rank.

Sum on the heaviest rank (the embed+head one). Without tricks: 4.6+4.6+7.9+2.6+10.5+3.4=33.6 GB4.6 + 4.6 + 7.9 + 2.6 + 10.5 + 3.4 = 33.6\text{ GB}. With all three tricks: 4.6+4.6+7.9+0+7.3+1.7=26.1 GB4.6 + 4.6 + 7.9 + 0 + 7.3 + 1.7 = 26.1\text{ GB}.

That is the answer. 26 GB26\text{ GB} per H800 leaves 54 GB\approx 54\text{ GB} for NCCL buffers, micro-batch staging, attention KV during warm-up, and the inevitable allocator fragmentation — with zero tensor parallelism in the picture.

Visualizing the Budget

The walkthrough above is one configuration. The simulator below lets you sweep the knobs — PP stages, sequence length, micro-batch — and switch each trick on or off independently. The dashed line is the 80 GB H800 limit; the goal of the whole exercise is to walk the second bar comfortably under it.

Loading memory budget simulator…

Two things to notice as you experiment. First, the embed/head saving (Trick 3) is the biggest individual jump on the most heavily-loaded rank, but it does not change the average budget much — it is a localized, rank-specific win. Second, recomputation (Trick 1) is the saving that scales with sequence length: pull the sequence slider to 32K and the activation bar grows linearly; toggle recomputation on and the bar drops by ~30% no matter where the slider is.

Plain Python: Manual Save-and-Redo for One Block

Stripped of autograd, what does selective recomputation actually look like? The code below is the smallest honest implementation: a forward function that explicitly chooses what to put on the tape, and a backward function that explicitly recomputes whatever is missing. Two modes, identical math, different memory peaks — and a final assertion that the gradients match.

🐍ckpt_manual_numpy.py
1📚 Import NumPy

We use NumPy instead of PyTorch so every save and every recompute is visible — no autograd hiding the tape. The same logic runs in PyTorch behind torch.utils.checkpoint.

EXAMPLE
import numpy as np  # → np.ndarray, the manual tensor type
10Set the random seed

Deterministic inputs so the printed peak-tape and grad-diff numbers are reproducible exactly. Replace 0 with any int; the comparison still holds.

EXAMPLE
rng = np.random.default_rng(0) → rng.standard_normal((2,)) = [0.126, -0.132]
11Toy shapes

B=2 sequences, S=4 tokens each, hidden width d=8. In the real DeepSeek-V3 block these become B=1 per GPU, S=4096, d=7168 — same arithmetic, three orders of magnitude bigger.

EXAMPLE
B·S·d = 2·4·8 = 64 floats per same-width tensor; the 4d tensor is 256.
12Input activation x

Shape (B, S, d) = (2, 4, 8). This is the boundary tensor — whichever strategy we choose, x is the one input we never throw away. Any recompute path needs it to start from.

EXAMPLE
x.shape = (2, 4, 8); x[0, 0, :3] = [0.126, -0.132,  0.640]
13Up-projection weight W_up

Shape (4d, d) = (32, 8). Multiplies the normed activation y up to the wide channel dim. Scaled by 0.1 to keep tanh below saturation. In DeepSeek-V3 this is the MLA up-projection; the corresponding (4d, d) GEMM is the single fattest matmul in the block.

EXAMPLE
W_up @ y_row → vector of length 4d = 32
14Output projection W_out

Shape (d, 4d) = (8, 32). Brings the wide tensor back to d so the block output matches the block input. Same role as the down-projection inside an MLP.

EXAMPLE
W_out @ u_row → vector of length d = 8
16RMSNorm definition

RMSNorm divides every element by the root-mean-square of its row. No bias, no mean subtraction. This is the cheap-to-recompute operation we will drop from the tape and redo on demand.

EXAMPLE
rms_norm([3, 4]) ≈ [3, 4] / √(12.5) ≈ [0.85, 1.13]
17📚 axis=-1, keepdims=True

axis=-1 reduces along the last (channel) axis so we get one RMS per (batch, token) row. keepdims=True keeps the output shape (B, S, 1) so the division broadcasts cleanly against (B, S, d).

EXAMPLE
(t**2).mean(-1, keepdims=True).shape = (2, 4, 1)
18Divide by the RMS

t / rms broadcasts the (B, S, 1) divisor across all d channels — the standard normalize-by-row pattern. The cost is one elementwise pass; perfect for recomputation.

EXAMPLE
t.shape (2,4,8) ÷ rms.shape (2,4,1) → out.shape (2,4,8)
20Forward function (def block_forward)

Runs the three sub-ops and, depending on save_intermediates, either keeps every activation on the tape (NAIVE) or keeps only x and a (CHECKPOINT). The choice is a *bookkeeping* choice — the math is identical in both modes.

EXAMPLE
block_forward(x, True)  → tape has 4 tensors
block_forward(x, False) → tape has 2 tensors
21y = RMSNorm(x)

First sub-op. Same shape as x — width d. In the naive tape this gets saved; in the checkpoint tape it gets dropped and recomputed during backward.

EXAMPLE
y.shape = (2, 4, 8); y[0,0,:3] ≈ [ 0.30, -0.31,  1.51]
22u = W_up @ y (the 4d-wide tensor)

This is the expensive line — u is 4× wider than every other activation in the block. In a real model that means 4 GB of activation memory becomes 16 GB. The whole point of recompute is to throw u away.

EXAMPLE
u.shape = (2, 4, 32) — note the last-axis 32 = 4·d
23📚 attn(u) (stub) — tanh(u @ W_out.T)

Stand-in for the attention block: down-projects u back to d and squashes through tanh. The exact non-linearity does not matter for the memory story — we just need an output a of shape (B, S, d) that the next layer will consume.

EXAMPLE
a.shape = (2, 4, 8); a[0,0,:3] ≈ [ 0.04, -0.10,  0.27]
24if save_intermediates: build the NAIVE tape

When True, we record all four tensors: x, y, u, a. Total floats stored = (1 + 1 + 4 + 1) · (B·S·d) = 7·(B·S·d). This is what stock autograd does by default.

EXAMPLE
tape = {x, y, u, a} → 7 · B · S · d = 7 · 64 = 448 floats
25NAIVE tape contents

All four tensors live on the tape. Cheap on toy shapes; ruinous when d = 7168, S = 4096, B = 1 — u alone is 4·1·4096·7168·2 ≈ 0.24 GB per layer.

EXAMPLE
{'x': (2,4,8), 'y': (2,4,8), 'u': (2,4,32), 'a': (2,4,8)}
26else: build the CHECKPOINT tape

When False, the function deliberately drops y and u — the two intermediate tensors — and only records the boundary tensors x and a. y can be recomputed from x via RMSNorm; u can be recomputed from y via one matmul. The backward will redo them on demand.

EXAMPLE
tape = {x, a} → 2 · B · S · d = 2 · 64 = 128 floats
27CHECKPOINT tape contents

Only x (input) and a (output) survive. We saved the 1× + 4× = 5× contribution of y and u in exchange for one extra RMSNorm + matmul during backward.

EXAMPLE
{'x': (2,4,8), 'a': (2,4,8)} — same B·S·d each
28return a, tape

The forward pass yields the next-layer input a and the chosen tape. In autograd terms, the tape is what the engine consults when it walks backward.

EXAMPLE
(a, tape) → next-layer input + the recovery context
30Backward function (def block_backward)

Takes grad_a (the gradient arriving from the layer above) and the tape, and must return grad_x. If y or u are missing from the tape it must reconstruct them — this is where recompute earns its keep.

EXAMPLE
block_backward(grad_a, ckpt_tape) → grad_x of shape (2,4,8)
32Read x off the tape

x is the one tensor that is always present (we never drop the boundary). It is the seed for any recomputation we need to do.

EXAMPLE
x = tape['x'] → ndarray (2,4,8)
33if 'y' not in tape: REDO RMSNorm(x)

Checkpoint branch. Re-fires the RMSNorm — one fused kernel that is fast on the GPU. The result is bit-equal to the forward y because RMSNorm is deterministic given the same x and eps.

EXAMPLE
y_redone equals y_saved up to numerical noise (max diff < 1e-7)
34Recompute y

Single line, single op. In PyTorch this is wrapped by torch.utils.checkpoint and re-runs under no_grad → with_grad to rebuild the autograd graph for the redo region.

EXAMPLE
y.shape = (2, 4, 8)  — same as forward
35else: read y off the naive tape

Naive branch. y was saved during forward, so we just look it up — no recompute. The cost was paid in memory instead.

EXAMPLE
y = tape['y']  — zero FLOPs
36Read saved y

Naive path. Hands back the same tensor we stored. The branch above and the branch here produce mathematically identical y values; only the bookkeeping differs.

EXAMPLE
y from tape == y recomputed (within float noise)
37if 'u' not in tape: REDO the up-projection

Checkpoint branch. Re-runs the (4d, d) matmul. This is the *expensive* recompute — but it is one GEMM, dwarfed by the cost of attention itself. In DeepSeek-V3 measurements this redo costs ≈ 1% extra FLOPs per training step.

EXAMPLE
u_redone = y @ W_up.T → shape (2, 4, 32)
38Recompute u

The single (B·S·d) → (B·S·4d) GEMM whose output is the widest activation in the block. Doing it twice — once forward, once backward — is the price of admission for not saving 16 GB per layer at full scale.

EXAMPLE
u.shape = (2, 4, 32) = (B, S, 4d)
39else: read u off the naive tape

Naive branch. u was saved; just read it. Zero recompute FLOPs, but you paid the 4× memory penalty during forward.

EXAMPLE
u = tape['u'] — zero FLOPs
40Read saved u

Naive path. u is recovered in one tape lookup. The 4d-wide tensor sits in HBM until backward fires.

EXAMPLE
u.shape = (2, 4, 32) — recovered, not recomputed
41Read saved a

Both modes save a — it is the block output and the next layer&apos;s input. There is no shortcut for recomputing attention cheaply, so we never drop a.

EXAMPLE
a = tape['a'] → ndarray (2, 4, 8)
42Back through tanh and W_out

Tanh derivative is 1 − a², applied elementwise. Then we project the d-wide gradient back to 4d via grad_a @ W_out, producing grad_u. This line uses a, not u — that is why a must stay on the tape.

EXAMPLE
(grad_a * (1 - a**2)).shape (2,4,8) · W_out (8,32) → grad_u (2,4,32)
43Back through W_up

Linear-layer backward: grad_y = grad_u @ W_up. Same matmul cost in both modes — only the question of where u came from differs.

EXAMPLE
grad_u (2,4,32) · W_up (32,8) → grad_y (2,4,8)
44Comment: back through RMSNorm

RMSNorm backward is a single fused kernel in practice. Here we use a simplified Jacobian (just dividing by the saved rms) for clarity — the real PyTorch kernel handles the full chain rule of the rms term itself.

EXAMPLE
Conceptually: grad_x ≈ grad_y / rms (broadcast)
45Recompute the rms divisor

The rms tensor itself was never saved — it is cheap and we just rebuild it from x. axis=-1, keepdims=True so the shape stays (B, S, 1) for broadcasting.

EXAMPLE
rms.shape = (2, 4, 1)
46Compute grad_x by division

Final step: broadcast-divide grad_y by rms. In a real implementation there is one more term to handle the gradient of rms with respect to x; we omit it for clarity but the structure (and shape) is the same.

EXAMPLE
grad_y (2,4,8) ÷ rms (2,4,1) → grad_x (2,4,8)
47return grad_x

Backward complete. grad_x is what the previous layer wants. The naive and the checkpoint backward produce the same grad_x up to numerical noise — this is the invariant the print on line 57 verifies.

EXAMPLE
grad_x.shape = (2, 4, 8)
50Helper: count floats on the tape

Sums .size across every tape entry to get a single peak-memory number. .size is the total number of elements (not bytes) — multiply by 4 for FP32 or 2 for BF16 in real life.

EXAMPLE
tape_floats({'x':(2,4,8),'a':(2,4,8)}) = 64 + 64 = 128
51Sum over tape values

Generator expression — iterates the dict&apos;s values (not the keys), reads .size on each ndarray. Same shape regardless of tape strategy; the *count* of entries is what changes.

EXAMPLE
naive: 64+64+256+64 = 448  |  ckpt: 64+64 = 128
53Run forward with NAIVE tape

save_intermediates=True. a1 is the block output; naive_tape contains x, y, u, a. We will compare its size against the checkpoint tape.

EXAMPLE
naive_tape keys = ['x', 'y', 'u', 'a']
54Run forward with CHECKPOINT tape

save_intermediates=False. a2 is the *same* block output (forward math is identical), but ckpt_tape contains only x and a. The 5× savings is already visible.

EXAMPLE
ckpt_tape keys = ['x', 'a']
56Print naive peak

Naive peak = 7·(B·S·d) = 7·64 = 448 floats. Scale this to DeepSeek-V3 shapes (S=4096, d=7168) and one block of one micro-batch is already ~0.6 GB of activations.

EXAMPLE
naive tape floats: 448
57Print checkpoint peak

Checkpoint peak = 2·(B·S·d) = 128 floats. The 5× savings holds at every scale because the wide tensor u is exactly 4× the boundary tensor — so dropping it dominates the win.

EXAMPLE
ckpt  tape floats: 128
59Random grad_a from above

Simulates the gradient handed to this block by the next one. Shape matches a. Any random vector works — we are only checking that the two backwards produce the same grad_x.

EXAMPLE
grad_a.shape = (2, 4, 8)
60Backward with naive tape

Hits the `else` branches everywhere: y and u come straight off the tape. Zero recompute FLOPs.

EXAMPLE
gx_naive.shape = (2, 4, 8)
61Backward with checkpoint tape

Hits the `if` branches everywhere: y is recomputed from x, u from the recomputed y. One extra RMSNorm + one extra matmul of forward FLOPs.

EXAMPLE
gx_ckpt.shape = (2, 4, 8)
62Verify identical gradients

The whole point: ckpt is not an approximation, it is the same math computed twice. The max absolute difference is at the level of float-rounding noise (~1e-15 in FP64, ~1e-7 in FP32). If you see a real gap, the redo did not match the forward — that is a bug, not a feature.

EXAMPLE
max |Δ grad_x|: 2.22e-16   # full FP64 equivalence
18 lines without explanation
1import numpy as np
2
3# Tiny stand-in for an MLA-style block:
4#   1) y = RMSNorm(x)
5#   2) u = W_up @ y         (the 4d-wide tensor)
6#   3) a = attn(u) (here a tanh head, for compactness)
7# We will run it once with the NAIVE tape (save everything) and once
8# with SELECTIVE recompute (drop y and u, redo them on the backward).
9
10rng = np.random.default_rng(0)
11B, S, d = 2, 4, 8                  # batch, sequence, hidden
12x = rng.standard_normal((B, S, d)) # input activation
13W_up = rng.standard_normal((4 * d, d)) * 0.1
14W_out = rng.standard_normal((d, 4 * d)) * 0.1
15
16def rms_norm(t, eps=1e-6):
17    rms = np.sqrt((t ** 2).mean(axis=-1, keepdims=True) + eps)
18    return t / rms
19
20def block_forward(x, save_intermediates):
21    y = rms_norm(x)                # (B, S, d)   — fat-by-d
22    u = y @ W_up.T                 # (B, S, 4d)  — fat-by-4d  ← the expensive one
23    a = np.tanh(u @ W_out.T)       # (B, S, d)
24    if save_intermediates:
25        tape = {"x": x, "y": y, "u": u, "a": a}        # NAIVE — keeps all 4
26    else:
27        tape = {"x": x, "a": a}                         # CKPT  — keeps only x and a
28    return a, tape
29
30def block_backward(grad_a, tape):
31    # We need: ∂L/∂x. If y or u are missing from the tape, redo them.
32    x = tape["x"]
33    if "y" not in tape:
34        y = rms_norm(x)            # REDO: one fused norm pass
35    else:
36        y = tape["y"]
37    if "u" not in tape:
38        u = y @ W_up.T             # REDO: one matmul, the same one as forward
39    else:
40        u = tape["u"]
41    a = tape["a"]
42    grad_u = (grad_a * (1 - a ** 2)) @ W_out  # back through tanh + W_out
43    grad_y = grad_u @ W_up                    # back through W_up
44    # back through RMSNorm: ∂y/∂x for the same eps; we approximate by the
45    # local Jacobian of the norm — a single B*S*d-sized kernel in practice.
46    rms = np.sqrt((x ** 2).mean(axis=-1, keepdims=True) + 1e-6)
47    grad_x = grad_y / rms
48    return grad_x
49
50# --- Run both modes; report peak tape size in floats ---
51def tape_floats(tape):
52    return sum(v.size for v in tape.values())
53
54a1, naive_tape = block_forward(x, save_intermediates=True)
55a2, ckpt_tape  = block_forward(x, save_intermediates=False)
56
57print("naive tape floats:", tape_floats(naive_tape))   # x+y+u+a = 1+1+4+1 = 7·(B·S·d)
58print("ckpt  tape floats:", tape_floats(ckpt_tape))    # x+a     =         2·(B·S·d)
59
60grad_a = rng.standard_normal(a1.shape)
61gx_naive = block_backward(grad_a, naive_tape)
62gx_ckpt  = block_backward(grad_a, ckpt_tape)
63print("max |Δ grad_x|:", np.abs(gx_naive - gx_ckpt).max())  # ~0, same math

The crucial line is the dict-build inside block_forward: that single conditional is the entire policy. Naive autograd is the "always include everything" case; selective recompute is the case where the framework picks which boundary tensors to keep and trusts the backward to rebuild the rest. PyTorch's torch.utils.checkpoint wraps exactly this logic into a decorator — same idea, with the autograd graph rebuilt automatically.

PyTorch: torch.utils.checkpoint and a Pinned-Host EMA

In production the two tricks become a one-line wrap and a ~20-line helper class. The wrap is the "dropping intermediates" story; the helper class is the "EMA on CPU" story. Trick 3 (shared embed/head) lives in the pipeline builder, which we touched on in the DualPipe section and will not duplicate here.

🐍ckpt_and_ema_pytorch.py
1📚 import torch

Brings in the PyTorch tensor library. Same role as numpy in the previous block — but tensors carry an autograd graph by default, and that is what torch.utils.checkpoint manipulates.

EXAMPLE
torch.tensor([1.0]).requires_grad_(True) → leaf in the autograd DAG
2📚 import torch.nn

Module subsystem. nn.Module subclasses register parameters and submodules so torch.utils.checkpoint can find them when it rebuilds the forward graph during backward.

EXAMPLE
nn.Linear(8, 32) → has .weight and .bias as Parameters
3📚 from torch.utils.checkpoint import checkpoint

The selective-recompute primitive. Wrapping a forward in checkpoint(fn, *args) means: run fn under no_grad on forward, save only the inputs, then re-run fn with grad enabled on backward to rebuild the local autograd graph.

EXAMPLE
out = checkpoint(self._inner, x)  →  one extra forward of self._inner during backward
5class MLABlock(nn.Module)

Skeleton of one transformer block. Three sub-modules: RMSNorm, an up-projection, and an output-projection. In DeepSeek-V3 the attention math sits between up and out; we simplify with tanh for clarity.

EXAMPLE
block = MLABlock(d=7168)  → 1 norm + 2 Linears, ≈ 100M params
6def __init__(self, d=7168)

d is the hidden width. The default 7168 is the real DeepSeek-V3 value. The up-projection inflates 7168 → 28672 (4·d) — that is the tensor we will drop from the activation tape.

EXAMPLE
MLABlock(d=7168) → up.weight.shape = (28672, 7168)
7super().__init__()

Registers this object as an nn.Module. Required before assigning any submodules — otherwise the framework cannot track Parameters or children.

EXAMPLE
Initializes the _parameters, _buffers, _modules OrderedDicts.
8self.norm = nn.RMSNorm(d, eps=1e-6)

Root-mean-square normalisation, no bias. Cheap to forward (one fused CUDA kernel), cheaper to recompute. Output shape matches input: (B, S, d).

EXAMPLE
self.norm(torch.zeros(1,4,7168)).shape == (1,4,7168)
9self.up = nn.Linear(d, 4d, bias=False)

The MLA up-projection. bias=False both matches DeepSeek-V3 and saves a few MB. The .weight matrix has shape (4d, d) — exactly the GEMM we want to recompute on backward rather than store the output of.

EXAMPLE
self.up.weight.shape = (28672, 7168);  output activation shape = (B, S, 28672)
10self.out = nn.Linear(4d, d, bias=False)

Down-projection back to d. Together with `up`, this pair forms the wide-then-narrow funnel that gives transformer blocks their FLOP / activation-memory imbalance.

EXAMPLE
self.out.weight.shape = (7168, 28672)
12def _inner(self, x)

The function we are going to wrap in checkpoint(). Anything that lives inside _inner gets recomputed on backward — that is the whole control surface we expose to selective recompute.

EXAMPLE
checkpoint(self._inner, x) → on backward, _inner runs a second time
13y = self.norm(x)

Forward RMSNorm. Shape stays (B, S, d). This tensor is *not* saved by autograd inside the checkpoint region — it is re-derived from x when backward fires.

EXAMPLE
y.shape = x.shape = (B, S, 7168)
14u = self.up(y) — the dropped tensor

The 4·d-wide intermediate. Without checkpoint, autograd would pin u in HBM for the entire backward window — that is the activation memory we are reclaiming. With checkpoint, u is computed, used, freed, and then recomputed during the backward pass.

EXAMPLE
u.shape = (B, S, 28672); BF16 bytes per layer at S=4096, B=1 ≈ 224 MB
15return self.out(torch.tanh(u))

Block output. Shape returns to (B, S, d) — the boundary tensor that the next layer (or the next checkpoint region) will consume. This is what autograd saves on the tape; everything between this return and the input x is recomputable.

EXAMPLE
out.shape = (B, S, 7168)
17def forward(self, x)

Public entry point. nn.Module&apos;s __call__ routes through forward. We don&apos;t put the recomputable region directly here — we route through _inner so checkpoint can wrap it cleanly.

EXAMPLE
y = block(x)  → calls block.forward(x), which calls checkpoint(self._inner, x)
18Comment: use_reentrant=False

The new (PyTorch 2.x) checkpoint API. The legacy reentrant mode used a hack that broke under AMP, grad scaling, and non-tensor inputs. use_reentrant=False fixes all of that — this is the only setting you should ship in 2025.

EXAMPLE
checkpoint(fn, x, use_reentrant=False)  → AMP-safe, grad-scaler-safe
19return checkpoint(self._inner, x, use_reentrant=False)

The full wrap. During forward, _inner runs under no_grad — autograd records only x as saved. During backward, autograd calls _inner again, this time with grad enabled, to rebuild the local DAG and propagate ∂L/∂x.

EXAMPLE
Memory: only x stays alive across the block. FLOPs: ~1.33× of naive (one extra fwd).
21class CPUEma — pinned-host EMA

Trick 2 in code. Keeps a second FP32 copy of every parameter, but lives on host RAM (pinned, page-locked) rather than HBM. Removes ~4 bytes per parameter from the GPU budget at the cost of one async copy per step.

EXAMPLE
ema = CPUEma(model)  → 671B × 4 bytes ≈ 2.5 TB on CPU, 0 bytes on GPU
22Docstring

Reminds readers of two facts that constrain the design: FP32 (we don&apos;t want EMA noise from BF16 rounding), and async update (we are about to overlap with the next-step forward, not block it).

EXAMPLE
help(CPUEma) prints this docstring at runtime.
24def __init__(self, model, decay=0.999)

decay controls how slowly the EMA tracks. 0.999 means after N steps the effective half-life is ≈ N/ln(2)/0.001 ≈ 693 steps — slow enough to smooth out noisy gradients, fast enough to track the real loss trajectory.

EXAMPLE
CPUEma(model, decay=0.9999) → smoother but takes longer to lock on
25self.decay = decay

Stash the constant so .update() can use it. Float, never changes during training (modulo warmup schemes some recipes use).

EXAMPLE
self.decay = 0.999
26Comment: pin_memory=True

Page-locked host memory. Required for the .to('cpu', non_blocking=True) copy below to actually be async — pageable memory would force a synchronous staging copy and we&apos;d lose the overlap.

EXAMPLE
torch.zeros(1024, pin_memory=True)  → page-locked, DMA-eligible
27self.shadow = {n: p.detach()...} — dict comprehension

Snapshots every named parameter into a CPU FP32 pinned tensor. .detach() severs the autograd link (we don&apos;t want EMA gradients). Result is one shadow tensor per parameter, stored under the parameter&apos;s qualified name.

EXAMPLE
shadow['layers.0.up.weight'].shape == model.layers[0].up.weight.shape
28p.detach().to('cpu', torch.float32).pin_memory()

Three chained ops. detach() → fresh leaf, no grad. .to('cpu', torch.float32) → moves off GPU and upcasts BF16 → FP32 in one fused call. .pin_memory() → marks the resulting CPU tensor as page-locked.

EXAMPLE
BF16 GPU tensor → FP32 CPU pinned tensor in one statement
29for n, p in model.named_parameters()

Iterates every Parameter in the model. Includes every sub-Module recursively. For DeepSeek-V3 that is ~10⁴ tensors covering 671B floats.

EXAMPLE
list(model.named_parameters())[0] → ('embed.weight', Parameter(...))
32@torch.no_grad()

Decorator on .update(). Disables autograd inside the function. EMA is a pure parameter update, never differentiated through — so we want zero overhead from the autograd machinery.

EXAMPLE
Equivalent to wrapping the whole body in `with torch.no_grad(): ...`
33def update(self, model)

Called once per optimizer step, after optimizer.step(). Streams each new parameter to host RAM, blends it with the shadow copy in place, and returns. Should overlap with the next-step forward almost entirely.

EXAMPLE
for step in train_loop: opt.step(); ema.update(model)
34Comment: side stream for overlap

The default CUDA stream is what the model runs on. By doing EMA on a *different* stream, the runtime can issue next-step kernels while the EMA copies are still in flight. This is the overlap that makes the trick free.

EXAMPLE
default-stream FWD ┐
 side-stream    EMA ┘  run concurrently
35stream = torch.cuda.Stream()

Allocates a fresh CUDA stream object. Streams are FIFO queues of GPU work; the GPU can run kernels from different streams in parallel as long as they don&apos;t touch the same memory.

EXAMPLE
torch.cuda.Stream() → Stream(device=0, priority=0)
36with torch.cuda.stream(stream):

Context manager that sets `stream` as the *current* stream for kernels launched inside it. Any GPU op issued in the block is enqueued on `stream`, not on the default one.

EXAMPLE
Anything inside the `with` runs on the side stream.
37for n, p in model.named_parameters()

Same iteration as in __init__, but at update time we already have the shadow dict — we just need to walk parameters in the same order to pair them with their CPU partners.

EXAMPLE
n='embed.weight', p=Parameter on GPU, cpu_p=self.shadow[n]
38cpu_p = self.shadow[n]

Look up the CPU FP32 pinned partner for this parameter. Same shape, same name. This is the tensor we are about to mutate in place.

EXAMPLE
cpu_p.shape == p.shape; cpu_p.dtype == torch.float32; cpu_p.is_pinned() == True
39gpu_p_fp32 = p.detach().float() — on-GPU upcast

Convert the BF16 working-copy parameter to FP32 *on the GPU*, before the host copy. Doing the upcast on GPU is much faster than streaming BF16 to CPU and upcasting there, and it produces a smaller GPU resident only for the millisecond of the copy.

EXAMPLE
BF16 (2 B/param) → FP32 (4 B/param) on GPU, immediately copied off
40Comment: in-place CPU EMA blend

Spells out the formula: cpu_p ← decay·cpu_p + (1 − decay)·gpu_p_fp32. In-place avoids any allocator activity on the CPU, which keeps the host fast-path lean.

EXAMPLE
Mathematically:  m_t = β·m_{t-1} + (1−β)·θ_t   with β = 0.999
41cpu_p.mul_(self.decay)

In-place multiply. After this line cpu_p has been scaled by decay; we are about to add the (1−decay)·new term. Doing the scale first avoids ever materialising a temporary tensor of full parameter size.

EXAMPLE
cpu_p ← 0.999 · cpu_p
42cpu_p.add_(gpu_p_fp32.to('cpu', non_blocking=True), alpha=...)

The async host copy + in-place add. non_blocking=True is honoured because cpu_p is pinned — the copy returns immediately and the GPU can start the next kernel. The actual add happens on the CPU side once the copy completes.

EXAMPLE
cpu_p ← cpu_p + (1 − 0.999) · gpu_p_fp32
43gpu_p_fp32.to('cpu', non_blocking=True)

Issues the DMA. Because the destination is pinned, the runtime can fire-and-forget — the CPU does not need to babysit the copy and the GPU does not stall waiting for it.

EXAMPLE
Throughput-bound by PCIe / NVLink-C2C, not by Python overhead.
44alpha=1.0 - self.decay

Scalar multiplier for the right-hand operand in add_(). alpha=0.001 here. PyTorch fuses the scale + add into one CPU kernel call — no temporary tensor for (1−decay)·gpu_p_fp32.

EXAMPLE
alpha = 1.0 - 0.999 = 0.001
7 lines without explanation
1import torch
2import torch.nn as nn
3from torch.utils.checkpoint import checkpoint
4
5class MLABlock(nn.Module):
6    def __init__(self, d=7168):
7        super().__init__()
8        self.norm = nn.RMSNorm(d, eps=1e-6)
9        self.up = nn.Linear(d, 4 * d, bias=False)        # the FAT projection
10        self.out = nn.Linear(4 * d, d, bias=False)
11
12    def _inner(self, x):
13        y = self.norm(x)                                 # (B, S, d)
14        u = self.up(y)                                   # (B, S, 4d) ← dropped
15        return self.out(torch.tanh(u))                   # (B, S, d)
16
17    def forward(self, x):
18        # use_reentrant=False is the modern API; preserves AMP + grad scaling.
19        return checkpoint(self._inner, x, use_reentrant=False)
20
21class CPUEma:
22    """FP32 exponential moving average kept in pinned host RAM.
23    Updated asynchronously after each optimizer step."""
24    def __init__(self, model, decay=0.999):
25        self.decay = decay
26        # pin_memory=True → page-locked, faster GPU↔CPU copies
27        self.shadow = {
28            n: p.detach().to("cpu", torch.float32).pin_memory()
29            for n, p in model.named_parameters()
30        }
31
32    @torch.no_grad()
33    def update(self, model):
34        # Stream each parameter to a side stream so it overlaps next-step fwd.
35        stream = torch.cuda.Stream()
36        with torch.cuda.stream(stream):
37            for n, p in model.named_parameters():
38                cpu_p = self.shadow[n]
39                gpu_p_fp32 = p.detach().float()         # on-GPU upcast (BF16→FP32)
40                # cpu_p ← decay·cpu_p + (1-decay)·gpu_p_fp32, in-place on CPU
41                cpu_p.mul_(self.decay).add_(
42                    gpu_p_fp32.to("cpu", non_blocking=True),
43                    alpha=1.0 - self.decay,
44                )
Production note. use_reentrant=False is non-negotiable in 2025: the legacy reentrant mode silently breaks under AMP, mixed-grad-scaling, and any non-tensor input, and PyTorch will drop it entirely in 2.6. If you inherit a code base that still uses the reentrant API, switching is a one-flag change and it's worth doing before you debug anything else.

What Changes at 671B Parameters and 14.8T Tokens

Everything in this section scales, but the way each trick scales is different and worth tracking separately.

TrickScaling axisHow it scalesLimit
Recompute RMSNorm + MLA-upSequence length × batch × layers per stageLinear: doubling sequence length doubles the saving in absolute GB.Bound below by the size of the boundary tensors (x and a); you cannot drop them.
EMA on CPUTotal parametersLinear in N: every extra billion params adds 4 GB to the savings.Bound by host RAM and PCIe/NVLink-C2C bandwidth; at extreme N you run out of CPU memory before HBM.
Shared embed/headVocabulary × hiddenConstant in N (once you have one of each) but scales with V·d. Bigger tokenizers, bigger savings.Bound by the architectural choice — some recipes deliberately untie the head for quality reasons.

The numbers compound. On a 64-node, 512-GPU cluster, the recomputation saving alone is 5123.2 GB1.6 TB\approx 512 \cdot 3.2 \text{ GB} \approx 1.6\text{ TB} of HBM reclaimed cluster-wide. The EMA-to-CPU saving moves about 2.7 TB2.7\text{ TB} of FP32 state off the accelerators completely. The shared head reclaims another 1.8 GB1.8\text{ GB} on the two stages that used to duplicate it.

The most important number is the one that doesn't appear in the table: zero new collectives. None of these tricks introduces an all_reduce, an all_gather, or a barrier. The activation recomputation is a per-GPU forward pass. The EMA copy is a per-GPU host transfer. The head sharing is a static rebinding. The communication graph that DualPipe scheduled in the previous section stays untouched.

Engineering Reality: When the Tricks Bite Back

Three places where the trio is not free:

Recomputation breaks the "just look at the activation" debug pattern

Once an intermediate is recomputed instead of saved, any tooling that wanted to peek at it — gradient inspectors, activation loggers, debugger hooks — sees something subtler than it expects. Hooks fire twice for tensors inside the checkpoint region (once on the discarded forward, once on the recomputed forward). The fix is either to register hooks outside the checkpoint, or to use the explicit save-on-CPU pattern (checkpoint_sequential with a known boundary list).

The async EMA can hide bugs until eval time

Because the EMA update is fire-and-forget on a side stream, nothing checks the value of the shadow tensor until you load it for evaluation — potentially hundreds of steps later. A common failure mode is forgetting to torch.cuda.synchronize() before reading the EMA back for a mid-training eval, which can race with the last in-flight update. Always synchronize the side stream explicitly before serializing the EMA.

Tied weights mean tied learning rates

If the embedding and the output head share parameters, they share gradients, and they share whatever optimizer state goes with those gradients. Recipes that want a different learning rate on the LM head than on the embedding cannot tie them — the saving disappears, and you pay the 1.8 GB. DeepSeek-V3 deliberately uses one optimizer group for the shared matrix, which is why the saving holds.

The takeaway. Memory savings that do not add communication are the highest-leverage decision a large-model team can make. Tensor parallelism is the obvious answer to a tight HBM budget, and it is also a tax on every other layer of the parallelism stack. The DeepSeek-V3 trio — selective recomputation, CPU EMA, shared head — reclaims enough HBM to dodge that tax entirely. The result is a cluster where the only inter-GPU traffic is the one collective per layer that DualPipe was always going to schedule anyway. The next section turns to the last piece of the stack that has to be right under those constraints: how to checkpoint and recover the whole training run when something inevitably crashes.
Loading comments...