Overview
A modern transformer's training run can cost seven figures, and a deployed model serves billions of tokens per day. Yet most of that cost is decided not by the algorithm but by how the algorithm uses hardware. The same matmul can run at 3% or 90% of peak FLOPs depending on cache locality. The same attention can fit in 4 GB or fail with out-of-memory depending on whether the score matrix is materialised.
This section gives you a single mental model — the roofline — and five concrete levers you can pull to move a kernel along it. Then we trace each lever back to the modern systems you have already met (or are about to meet) in this book: Flash Attention, multi-head attention with its MQA / GQA variants, positional encodings (sinusoidal, RoPE, ALiBi), KV-cache optimisations like paged and quantised cache, and the scaling laws that govern transformer training.
The examples lean heavily on transformers because they are the richest current case study, not because the principles are transformer-only. The same performance logic applies to CNNs, RNNs, MLPs, and diffusion models: increase useful work per byte moved, keep the working set in the fastest memory tier you can, and only recompute values when that is cheaper than storing or reloading them.
Why this matters. A correct neural network that runs 100× too slowly is a research artifact. A correct neural network that runs at 80% of peak hardware throughput is a product. The gap is almost always memory traffic — not arithmetic.
The Real Bottleneck Is Not Compute
New practitioners assume neural networks are slow because they do too many multiplications. They do not. A single A100 GPU can perform FP16 multiply-adds per second. That is enough raw arithmetic to multiply two matrices roughly 150 times per second. The actual measured throughput on a typical PyTorch model is often 5–15% of that.
Where does the missing 85% go? It is spent waiting for data to arrive at the arithmetic units. The GPU's compute units sit idle while data crawls from off-chip HBM (~2 TB/s) into on-chip SRAM (~20 TB/s) into registers (~30 TB/s). Every order of magnitude closer to the ALU costs another 10× in price-per-byte and 10× in capacity. That gradient is the single most important fact about modern hardware.
Two simple equations capture the regime. The arithmetic intensity of a kernel is . The hardware exposes a peak compute rate (FLOPs/s) and a peak memory bandwidth (Bytes/s). The achievable throughput of the kernel is . The two regimes meet at the ridge point . Below it you are memory-bound and the only way to go faster is to do more compute per byte loaded. Above it you are compute-bound and the only way is to do fewer FLOPs.
The Roofline Mental Model
The roofline plot puts these two equations on a log-log graph: arithmetic intensity on the x-axis, throughput on the y-axis. A diagonal line — slope — represents the memory ceiling. A horizontal line at represents the compute ceiling. Every kernel lives at a point under both ceilings. The roof you sit closest to is the bottleneck you must attack.
Drag the slider below — or click an operator chip — to place a kernel on the plot. Notice how naive softmax (1 FLOP/B) lives deep in memory-bound territory: it does not matter how fast your GPU is; you are limited by the HBM bandwidth feeding it. Flash Attention pushes the operator to FLOP/B by tiling QKV into SRAM — close to the compute ceiling, where the GPU's tensor cores can finally stretch.
The GPU Memory Hierarchy
The roofline assumes a single bandwidth number, but real GPUs have a ladder of memories spanning four orders of magnitude in bandwidth. Each rung is a different physical chip, and crossing between them is what makes a kernel slow.
Click any tier in the visualisation below to see its capacity and bandwidth. Then use the speedup calculator to compare any two tiers — the difference between holding your data in HBM versus SRAM is roughly , which is exactly the speedup Flash Attention reports over standard attention on long sequences.
Three rules that follow from the hierarchy
- Reuse > recompute > re-fetch. If a value is in registers, use it as many times as possible before letting it spill. If it is in SRAM, the same applies one rung down. Crossing to HBM is a 100× cost.
- Fuse adjacent ops. Two element-wise ops back-to-back read from HBM, write to HBM, read again, write again. Fused into a single kernel they read once, write once — halving traffic.
- Tile to fit the smallest fast memory. Flash Attention chooses block sizes precisely so that all fit inside one SM's 100–200 KB of SRAM simultaneously. Pick the tile size to fill, but not exceed, the fast tier.
Lever 1 — Vectorization
Before any GPU trick, eliminate the worst possible source of slowness: Python loops over scalar math. The Python interpreter adds roughly 700 ns of overhead per primitive operation; a CPU's ALU can do the underlying multiply in about 1 ns. Doing matrix multiplication element-by-element in pure Python therefore wastes 99.85% of every cycle. Vectorisation hands the entire problem to a compiled BLAS routine that runs at near-peak SIMD throughput on every core.
This is the same idea that distinguishes CUDA kernels from a single call: the unit of work the system sees should be as large as possible so the runtime can schedule, batch, and pipeline it. Click any line to see the per-step cost on the left.
torch.matmul(A, B) on the GPU, PyTorch dispatches to . On Tensor Cores it reaches ~95% of theoretical peak FP16 FLOPs. The PyTorch eager loop overhead per op (~50 µs) is also why torch.compile exists: it fuses many small ops into one CUDA graph, shrinking dispatch cost by 100×.Lever 2 — Operator Fusion (Online Softmax)
Vectorisation removes Python overhead, but a kernel can still be slow if it crosses HBM more times than necessary. The textbook softmax is a case in point: it makes three sequential passes over its input — once for the max, once for , once for the sum. Each pass crosses HBM. For a long sequence those three passes dominate the runtime; the actual exp() arithmetic is almost free.
The online softmax trick collapses all three into a single pass by maintaining running statistics. When a new max appears, the running denominator is rescaled by to undo the now-stale normalisation, then the new term is added on with the new normalisation. This is mathematically exact (no approximation) and the key trick that makes Flash Attention possible — it lets attention scores be computed in tiles without ever materialising the full matrix.
PyTorch is already doing this for you
When you call F.softmax(x, dim=-1) in PyTorch, the framework dispatches to a fused CUDA kernel that performs max, subtract, exp, sum, and divide in a single HBM round-trip. With torch.compile or F.scaled_dot_product_attention(...), the entire attention block — Q@Kᵀ, softmax, attention @ V — is fused into one Flash-Attention kernel. The naive 3-pass version we wrote in Python exists only as a specification; in production it is never executed.
Lever 3 — Memory Layout & Contiguity
Two tensors of identical shape can run at radically different speeds depending on how their elements are laid out in memory. A tensor is contiguous when its elements are stored in the order you iterate them; non-contiguous tensors force the cache to fetch one element at a time, defeating the SIMD vector loads.
In PyTorch, operations like .transpose(), .permute(), and slicing with stride do not actually move data — they return a view with rearranged strides. The next op that needs contiguous memory will then either dispatch to a slow non-contiguous kernel or trigger an invisible .contiguous() copy. The fix is to call .contiguous() explicitly after a transpose if the downstream op is bandwidth-sensitive.
| Operation | Returns view (free)? | Triggers copy? |
|---|---|---|
| x.view(B, T, D) | yes | no — must be contiguous-compatible |
| x.reshape(B, T, D) | if possible | yes if not contiguous |
| x.transpose(-2, -1) | yes — strides flipped | no, but breaks contiguity |
| x.permute(0, 2, 1, 3) | yes | no, but next op may pay |
| x.contiguous() | no | yes — explicit copy |
| x[:, :, :32] | yes if leading dims contiguous | no |
The deeper lesson: memory layout is part of the algorithm. Flash Attention chooses the (block, head, dim) memory order deliberately so that loads from HBM into SRAM are 128-byte coalesced bursts. Multi-head attention typically stores Q, K, V interleaved (the "packed QKV" layout) so a single cache line load brings all three projections.
Lever 4 — Mixed Precision (FP16/BF16/FP8)
FP32 stores 32 bits per number. FP16/BF16 store 16. FP8 stores 8. Halving precision halves memory traffic and doubles the FLOPs the Tensor Cores deliver — for free, if the loss in numerical range or precision does not break training.
| Format | Bits | Range | Mantissa | When to use |
|---|---|---|---|---|
| FP32 | 32 | ±10³⁸ | 23 bits | Master weights, optimizer state |
| TF32 | 19 (in 32-bit slot) | ±10³⁸ | 10 bits | Default on Ampere+, drop-in matmul speedup |
| FP16 | 16 | ±10⁴ | 10 bits | Activations; needs loss scaling for grads |
| BF16 | 16 | ±10³⁸ | 7 bits | Activations + grads; no loss scaling needed |
| FP8 (E4M3) | 8 | ±10² | 3 bits | Inference activations, H100+ training |
| INT8 | 8 | ±127 | — | Post-training inference quantisation |
| INT4 | 4 | ±7 | — | GPTQ / AWQ weight-only quant for serving |
BF16 has been the workhorse for transformer training since Google's TPUs popularised it: it keeps FP32's exponent range (so gradients rarely underflow) at the cost of mantissa precision, which weights absorb gracefully. FP8 is the new frontier — H100 and TPUv5p ship with Tensor Cores that double FP16 throughput again. nn.Linear weights of GPT-class models are now routinely served as INT4 with per-group scales (GPTQ, AWQ); the model fits in a third the memory with under 1% perplexity loss.
torch.autocast("cuda", dtype=torch.bfloat16). Update FP32 master weights from BF16 gradients. Under torch.compile this fuses with the Flash Attention kernel for end-to-end BF16 speed.Lever 5 — Gradient Checkpointing
Backprop needs the activations from the forward pass. For a deep network this can dominate memory: a 70B-parameter model at sequence length 4096 stores ~80 GB of activations alone. The classical fix is gradient checkpointing: keep only a sparse subset of activations in memory, and recompute the missing ones during backward.
For a network with layers, naive backprop keeps all activations → memory. Checkpointing every -th layer reduces memory to while increasing compute by roughly 33% (one extra forward pass distributed over all checkpoint segments). For transformers, torch.utils.checkpoint applied to each transformer block is the standard pattern.
Flash Attention takes this even further: it discards the entire attention matrix after the forward pass and recomputes it from QKV during backward, since Flash Attention is fast enough that recomputation is cheaper than the HBM cost of storing an matrix.
Inference: The KV Cache
Training and inference have opposite bottlenecks. Training is compute-bound (one big batched forward+backward over a fixed sequence). Inference for a chat model is memory-bound: the model generates one token at a time, doing tiny matmuls but reading the full prefix of K and V from memory at every step.
Without caching, a naive autoregressive decoder reprojects the entire prefix at each step, so past tokens are repeatedly projected into K and V. In the toy code below, that redundant projection work grows quadratically with total decode length because step recomputes a prefix of length . With a KV cache we project past K and V exactly once during prefill, then in each generation step we only project the single new token and append its to the cache. The attention lookup is still linear in current context length, but the repeated prefix projection disappears, which is the difference between a pedagogical decoder and a usable one.
The cache itself, however, becomes the new memory bottleneck. For a Llama-2 7B-class model (32 layers, 32 heads, head dim 128, BF16) the cache costs bytes per token — about 0.5 MB. At sequence length 4096 this is GB per request. Serving many concurrent users means juggling many of these caches, and that memory pressure is the headline cost of LLM inference today.
Connections to Modern Systems
Every famous performance technique in modern transformer infrastructure is one of the five levers above, applied to the right rung of the memory hierarchy. Let's walk the connections.
Flash Attention
Standard attention computes , materialises the full matrix S in HBM, then reads it back to compute softmax and . Each entry of S is touched twice — once written, once read — for a total HBM traffic of . At this is the dominant cost. Arithmetic intensity is roughly 1 FLOP/B — deep memory-bound.
Flash Attention (Dao et al., 2022) tiles Q, K, V into SRAM-sized blocks and uses the online softmax above to compute the output one tile at a time, accumulating both the running normaliser and the running output incrementally. The full S matrix is never written to HBM. HBM traffic drops from to where M is the SRAM block size — typically a 5–10× wall-clock speedup on long sequences and the reason 100K-token contexts are now feasible. Lever 2 (fusion) plus Lever 3 (tiled layout) at the SRAM rung.
Multi-Head, MQA, and GQA
Multi-head attention splits the model dimension across H heads, each with its own Q, K, V projections. Total FLOPs are unchanged versus a single-head attention of the same dimension — the win is representational, not computational. But for inference the choice of how many K/V heads to use is a memory lever.
| Variant | Q heads | K/V heads | KV cache size | Used by |
|---|---|---|---|---|
| MHA | H | H | 1× baseline | Original Transformer, GPT-2/3 |
| MQA | H | 1 | 1/H × baseline | PaLM, Falcon |
| GQA-g | H | g | g/H × baseline | Llama-2 70B (g=8), Mistral, Llama-3 |
Multi-Query Attention (MQA) shares one K/V across all query heads, shrinking the KV cache by a factor of H. Grouped-Query Attention (GQA) is the practical compromise: shrinks cache by H/g while losing almost no quality. Lever 4 (lower precision of representation) at the algorithmic level — fewer K/V heads is a kind of structured pruning.
Positional Encodings
Position information must enter the model somehow, and the choice has direct performance consequences:
- Sinusoidal (Vaswani 2017): precomputed table, added to embeddings. Zero memory at runtime, zero extra FLOPs. But fixed — extending beyond training length degrades fast.
- Learned absolute: a learnable matrix of shape . Costs memory proportional to max sequence length and breaks at unseen positions.
- RoPE (Rotary) (Su et al., 2021): applies a 2D rotation to Q and K based on their absolute position; the dot product QKᵀ then naturally encodes relative position. Implemented as two element-wise multiplies fused into the QK matmul kernel — adds zero FLOPs in practice, no extra memory, and extrapolates well. Used in Llama, Mistral, GPT-NeoX.
- ALiBi (Press 2022): adds a static linear bias to the attention scores before softmax. Even cheaper — a single element-wise add — and extrapolates to sequences far longer than training.
Notice the trend: each newer scheme moves from a separate component to something fused into an existing kernel. That is Lever 2 — operator fusion — applied to position information.
KV-Cache Optimizations
Once you accept that inference is gated by KV-cache memory, every gigabyte recovered is more concurrent users. The major techniques:
- Paged Attention (vLLM): treat the KV cache like virtual memory. Allocate cache in fixed-size pages instead of one contiguous slab per request. Eliminates fragmentation, lets pages be shared across requests with identical prompts, and enables batching requests of very different lengths together. The single biggest inference systems win of 2023.
- KV-cache quantisation: store K, V in INT8 or INT4 instead of FP16. Halves or quarters cache memory at modest accuracy cost. Lever 4 applied to the cache, not the weights.
- Sliding window attention (Mistral, Longformer): each token only attends to the last W tokens. Cache size becomes per layer instead of . Combine with global attention on a few special tokens to keep long-range information.
- Eviction (StreamingLLM, H2O): heuristically drop cache entries that future tokens are unlikely to attend to. Constant memory regardless of sequence length, with carefully chosen retention rules.
Transformer Scaling Laws
Performance optimisation also dictates what model is worth training in the first place. The Chinchilla scaling laws (Hoffmann et al., 2022) showed that for a fixed compute budget , the optimal allocation between model parameters and training tokens satisfies — roughly, train a smaller model on more data than GPT-3 did.
Inference cost scales differently. For deployed models the per-token cost is roughly FLOPs (forward pass) plus the KV-cache memory traffic. Once a model is going to serve trillions of tokens, every percent shaved off N or off the KV cache pays for itself many times over — which is why the recent generation of production LLMs (Llama-3, Mistral, Claude Haiku) all pair Chinchilla- optimal training with aggressive inference engineering: GQA, FP8 / INT4 weights, paged KV cache, speculative decoding, and heavy use of Flash Attention.
Mixture-of-Experts (MoE) models (Mixtral, Switch Transformer, DeepSeek-V3) push this further: they have huge parameter counts but activate only a small subset per token, decoupling capacity from per-token compute. Performance optimisation has reshaped what we even mean by "model size".
Distributed Training: DDP, FSDP, ZeRO
Once a single GPU is saturated, the only path forward is more GPUs — but that path branches into four very different strategies, each suited to a different bottleneck. The choice is dictated by which resource ran out first: compute, activation memory, optimizer-state memory, or per-layer parameter memory.
| Strategy | What is split | Best when | Communication cost |
|---|---|---|---|
| DDP (Data Parallel) | Batch across N GPUs; each holds a full model copy | Model fits on one GPU; you want larger effective batch | All-reduce of gradients each step (~2× model size) |
| FSDP / ZeRO-3 | Parameters, gradients, optimizer state — sharded across N | Model does NOT fit on one GPU; up to ~10× larger models | All-gather params + reduce-scatter grads each layer |
| Tensor Parallel | Each weight matrix sliced across GPUs | A single weight matrix > one GPU; intra-node only | All-reduce within each transformer block (high) |
| Pipeline Parallel | Layers grouped into stages, GPUs form a pipeline | Very deep models; cross-node fine | Activation hand-off between stages; bubble overhead |
Data Parallel (DDP)
The simplest scale-out. Replicate the model on N GPUs, give each a different slice of the batch, and synchronise gradients with all-reduce after the backward pass. Effective batch size becomes ; per-step time stays roughly the same as single-GPU because the all-reduce overlaps with backward.
1import torch
2import torch.distributed as dist
3from torch.nn.parallel import DistributedDataParallel as DDP
4
5dist.init_process_group(backend="nccl") # one process per GPU
6torch.cuda.set_device(local_rank)
7
8model = MyModel().to(local_rank)
9model = DDP(model, device_ids=[local_rank]) # wraps + hooks all-reduce
10
11# Sampler shards the dataset so each rank sees disjoint examples
12sampler = torch.utils.data.distributed.DistributedSampler(train_set)
13loader = torch.utils.data.DataLoader(train_set, batch_size=B, sampler=sampler)
14
15for epoch in range(num_epochs):
16 sampler.set_epoch(epoch) # reshuffle deterministically
17 for x, y in loader:
18 loss = loss_fn(model(x), y)
19 opt.zero_grad(); loss.backward() # all-reduce happens here
20 opt.step()find_unused_parameters=False when possible — leaving it on hides bugs by tolerating divergent graphs.FSDP and ZeRO Sharding
DDP keeps a full copy of model weights, gradients, and optimizer state on every rank. For a 7B-parameter model in BF16 with Adam, that is roughly per rank — too big for any single GPU. ZeRO (Rajbhandari et al., 2020) and PyTorch's FSDP partition this state across ranks: each GPU owns -th of the parameters, gradients, and optimizer state, and gathers full layers on demand.
| Stage | Sharded | Memory / rank | Comm overhead |
|---|---|---|---|
| ZeRO-1 | Optimizer state | ~50% of DDP | Low |
| ZeRO-2 | + Gradients | ~25% of DDP | Medium |
| ZeRO-3 / FSDP | + Parameters | ~1/N of DDP | High (per-layer all-gather) |
1from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
2from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
3
4mp = MixedPrecision(
5 param_dtype=torch.bfloat16,
6 reduce_dtype=torch.bfloat16,
7 buffer_dtype=torch.bfloat16,
8)
9
10model = FSDP(
11 MyTransformer(),
12 sharding_strategy=ShardingStrategy.FULL_SHARD, # ZeRO-3 equivalent
13 mixed_precision=mp,
14 device_id=torch.cuda.current_device(),
15)
16# Use exactly like DDP: model(x), loss.backward(), opt.step().transformer_auto_wrap_policy for transformers, or a custom auto_wrap_policy for other architectures.Tensor and Pipeline Parallelism
For models where even a single weight matrix exceeds one GPU's memory (300B+ params), or where the all-gather traffic of FSDP becomes the bottleneck, partition the model itself.
- Tensor parallel (TP) (Megatron-LM): split each weight matrix column-wise or row-wise across GPUs in the same node. Each transformer block needs an all-reduce — high bandwidth required, so almost always intra-node over NVLink.
- Pipeline parallel (PP) (GPipe, PipeDream): stack layers into stages, one stage per GPU group. Activations flow forward, gradients flow backward — like a CPU pipeline. Solves deep models on slow links but suffers from bubble overhead when the pipeline drains.
- 3-D parallelism: combine DP × TP × PP. Used by GPT-3 (DP=64, TP=8, PP=12) and most modern frontier-scale training.
FSDP × TP × PP — but you almost certainly do not need this; over-engineering distributed setups is a leading cause of wasted compute.Profiling: Measuring Before Optimising
Every optimisation in this section is conditional on which bottleneck you have. Guessing wastes weeks. The right reflex is: profile first, optimise second. PyTorch ships with a profiler that exports Chrome-trace-format timelines and links each kernel back to the Python line that launched it.
1from torch.profiler import profile, record_function, ProfilerActivity, schedule
2
3# warmup → active → repeat — keeps the trace small
4sched = schedule(wait=1, warmup=1, active=3, repeat=1)
5
6with profile(
7 activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
8 schedule=sched,
9 record_shapes=True,
10 profile_memory=True,
11 with_stack=True,
12) as prof:
13 for step, (x, y) in enumerate(loader):
14 with record_function("forward"):
15 logits = model(x.cuda(non_blocking=True))
16 with record_function("loss"):
17 loss = loss_fn(logits, y.cuda(non_blocking=True))
18 with record_function("backward"):
19 loss.backward()
20 opt.step(); opt.zero_grad()
21 prof.step() # advances the schedule
22
23# Top kernels by self CUDA time
24print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=15))
25
26# Save Chrome trace; open with chrome://tracing or perfetto.dev
27prof.export_chrome_trace("trace.json")| Profiler signature | Likely cause | First fix |
|---|---|---|
| Long gaps with GPU idle, big CPU bars | Dataloader is blocking; no async H2D copies | num_workers, pin_memory, non_blocking=True |
| Many tiny kernels back-to-back | Per-op dispatch overhead | torch.compile or CUDA graphs |
| Single huge kernel dominates | Compute-bound — already near roof | Look for redundant FLOPs (sparsity, low-rank) |
| memcpyHtoD or DtoH bars | Data crossing PCIe each step | Move tensors to GPU once and reuse; non_blocking copies |
| .item() / .cpu() inside hot loop | Forces GPU sync; serialises everything | Defer scalar reads; log async |
py-spy (sampling Python profiler — finds hot CPU code), nvtop (live GPU utilisation), and the W&B System tab (per-step GPU/CPU/network utilisation cheaply). Pick one tool, learn it well — the diagnostic value of any profiler comes from familiarity, not feature count.torch.compile and CUDA Graphs
Eager-mode PyTorch dispatches each op separately — the per-op overhead is roughly 50 µs. For a transformer block with hundreds of small element-wise ops this overhead dominates on small batches. torch.compile (introduced in PyTorch 2.0) traces the Python forward function, fuses adjacent ops, and emits a single optimised kernel.
1import torch
2
3model = MyModel().cuda()
4
5# One line. Wraps model.forward in TorchDynamo + Inductor.
6model = torch.compile(model, mode="reduce-overhead") # or "max-autotune"
7
8# Use as normal. First few calls are slow (compilation); subsequent calls fly.
9for x, y in loader:
10 loss = loss_fn(model(x), y)
11 loss.backward(); opt.step(); opt.zero_grad()| Mode | What it does | When to use |
|---|---|---|
| default | Trace + fuse element-wise ops | Most cases |
| reduce-overhead | + CUDA graphs (replay captured launch sequence) | Small batches, dispatch-bound |
| max-autotune | + Triton autotuning of fused matmuls | Inference / very repetitive shapes |
CUDA graphs are the magic behind reduce-overhead: the launch sequence of a forward pass is captured once, then replayed as a single GPU command instead of N Python-level launches. The first launch costs ~10 ms of capture; every subsequent launch is essentially free. Limits: the graph is fixed-shape, so dynamic-shape inputs (variable batch or sequence length) trigger expensive recapture — which is why many serving stacks pad to fixed shapes.
torch.where) when possible.Quantization: PTQ vs QAT
The precision table earlier listed INT8 and INT4 as serving formats. There are two families of techniques to get there, and the distinction matters.
| Technique | Training cost | Accuracy | When to use |
|---|---|---|---|
| PTQ (Post-Training Quantization) | Zero retraining; small calibration set | Often 0.5–2% loss | Fast deployment of an existing FP model |
| QAT (Quantization-Aware Training) | Fine-tune with simulated quant ops | Often within 0.1% of FP | When PTQ hurts too much |
| GPTQ / AWQ (weight-only) | Hours of calibration; no full retraining | Excellent for LLM weights at INT4 | LLM inference at minimal quality cost |
| SmoothQuant | Calibration only | Activation outliers handled | INT8 LLM activations + weights |
For LLM serving the practical recipe in 2026 is: weights in INT4 (GPTQ or AWQ), activations in BF16 or INT8 (SmoothQuant), KV cache in INT8. This typically triples throughput on a single GPU at < 1% perplexity cost versus the BF16 baseline.
Speculative Decoding
Inference's deepest secret is that autoregressive decoding is fundamentally sequential — token needs token . Speculative decoding (Leviathan et al., 2023; Chen et al., 2023) breaks the sequential chain by using a small "draft" model to propose tokens at once, then having the large target model verify them all in one forward pass. Verified prefix is accepted; the rest is regenerated.
The win comes from amortising the target model's per-forward-pass overhead across tokens. Acceptance rates of 60–80% on natural text are common, giving 2–3× wall-clock speedups for free (no quality loss — the target distribution is exactly preserved).
Chapter Capstone: An End-to-End Debug
To close out the chapter, walk through one realistic scenario that exercises every section.
The scenario. A 1.3B-parameter transformer trains cleanly for the first 8,000 steps, then loss spikes to NaN at step 8,212 every restart. Validation loss right before the spike was healthy. GPU utilisation has been 70%.
- (§1) Reproduce and shrink. Pin the seed; save the optimizer state and one offending minibatch. Confirm the NaN reproduces. Then test: does the same minibatch on the step-8,000 checkpoint with a 10× lower LR still NaN? If no, the spike is LR-conditional.
- (§1) Detect-anomaly + first NaN. Re-run the offending step under
torch.autograd.set_detect_anomaly(True). The traceback names the first op that produced a NaN — say alog_softmaxin the loss head. The cause is upstream: a hidden state with a 0-variance LayerNorm input. - (§2) Activation histograms. Hook the layer before the offending LayerNorm. Plot pre-LN activations across training. The variance of one channel collapsed to zero around step 7,500 — a single neuron died, taking down LayerNorm with it. This is exactly the dying-ReLU pattern from §1, manifesting two layers downstream.
- (§2) Per-layer gradient norms. Confirm that the offending layer's gradients had been shrinking steadily for ~500 steps before the death. The visualisation tells you when, the histogram tells you why.
- (§1 + §3) Apply the fix. Switch ReLU → GELU in that block (cures the dying-neuron mode), add gradient clipping at (defends against future spikes), and enable BF16 mixed precision to widen the loss-scale margin (§3 lever 4).
- (§3) Profile after the fix. The NaN is gone but throughput is still 70% — run
torch.profiler. The trace shows a long gap each step where the dataloader is blocking: switch tonum_workers=8,pin_memory=True, and async H2D copies. GPU utilisation now 92%. - (§3) Compile.
torch.compile(model, mode="max-autotune")fuses the attention block; another 1.4× wall-clock win.
The pattern. A NaN in §1 was a numerical instability in §3 caused by a dead neuron diagnosed by §2's activation histogram. None of the three sections solved this alone. The chapter's point is that the techniques compose — debugging is most powerful when symptoms (§1), looking inside (§2), and measurement (§3) work together.
Performance Cheat Sheet
| Symptom | Likely cause | Lever | Concrete fix |
|---|---|---|---|
| GPU utilisation < 30% | Memory-bound kernel | Fusion / tiling | torch.compile, Flash Attention, fused LayerNorm |
| Out-of-memory at long context | KV cache or attention matrix | KV cache + Flash Attention | GQA, paged attention, INT8 KV, sliding window |
| Training stalls on small ops | Python / dispatch overhead | Vectorisation | Larger batch, torch.compile, CUDA graphs |
| Slow inference latency | Autoregressive recompute | KV cache | Enable use_cache=True, speculative decoding |
| Numerical instability in BF16 | Mantissa loss in reductions | Mixed precision recipe | Keep accumulator in FP32, master weights in FP32 |
| Activation memory dominates | Storing all forward activations | Gradient checkpointing | torch.utils.checkpoint per transformer block |
| transpose followed by matmul is slow | Non-contiguous strides | Memory layout | Add .contiguous() or use packed QKV layout |
Quick Check
Q1. A kernel does 2 GFLOPs and reads/writes 1 GB of data. Your GPU has TFLOP/s and TB/s. Where is the kernel on the roofline? What should you do first?
Answer: Arithmetic intensity FLOP/B. Ridge point is FLOP/B. The kernel sits 5 orders of magnitude below the ridge — wildly memory-bound. First fix: fuse it with a neighbour to reduce HBM traffic.
Q2. You enable a KV cache and your inference latency improves 100×. A week later, with longer prompts, latency creeps back up. What changed?
Answer: The cache eliminated the projection waste, but attention itself is still per step due to reading K and V from HBM. At long N the bandwidth of those reads dominates again — the next levers are GQA (smaller cache), Flash Attention (fewer reads per FLOP), or sliding window (constant cache).
Q3. Why does online softmax have to rescale the running denominator when a new max arrives — why not just keep the old denominator and add a new term?
Answer: The denominator is ; it depends on . If m grows, every old term needs to be re-expressed relative to the new m. Multiplying by does exactly that — it converts each into in one multiplication. Skipping the rescale would underweight the new term relative to the old.
Summary
Performance optimisation in deep learning is dominated by memory traffic, not arithmetic. The roofline model tells you which roof you sit under; the GPU memory hierarchy tells you why. Five levers move you toward the roof:
- Vectorisation — let BLAS / cuBLAS do scalar work in tight C loops, never Python.
- Operator fusion — collapse adjacent ops into one kernel pass; online softmax is the canonical example.
- Memory layout — keep tensors contiguous in the access order the next op needs.
- Mixed precision — BF16 / FP8 / INT4 to halve memory and double throughput where range and precision allow.
- Recomputation (gradient checkpointing) — trade compute for memory when activations dominate.
Modern transformer infrastructure is built on these levers. Flash Attention is fusion + tiling at the SRAM rung. The KV cache is recompute-vs-store applied to autoregressive inference. MQA / GQA shrink the cache at the cost of representational capacity. RoPE / ALiBi are fused position encodings. Paged attention and KV quantisation chase the cache memory pressure further. And the scaling laws tell you which model even deserves these optimisations.
When your network feels slow, do not start by guessing. Measure the arithmetic intensity. Place the kernel on the roofline. Find the nearest rung of the memory hierarchy that can hold the working set. Then — and only then — choose which lever to pull.
References
- Williams, S., Waterman, A., & Patterson, D. (2009). "Roofline: an insightful visual performance model for multicore architectures". Communications of the ACM.
- Dao, T., Fu, D., Ermon, S., Rudra, A., & Re, C. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness". NeurIPS.
- Dao, T. (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning". arXiv:2307.08691.
- Shazeer, N. (2019). "Fast Transformer Decoding: One Write-Head is All You Need" (MQA). arXiv:1911.02150.
- Ainslie, J. et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints". EMNLP.
- Su, J. et al. (2021). "RoFormer: Enhanced Transformer with Rotary Position Embedding". arXiv:2104.09864.
- Press, O., Smith, N., & Lewis, M. (2022). "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation" (ALiBi). ICLR.
- Hoffmann, J. et al. (2022). "Training Compute-Optimal Large Language Models" (Chinchilla). NeurIPS.
- Kwon, W. et al. (2023). "Efficient Memory Management for Large Language Model Serving with PagedAttention" (vLLM). SOSP.
- Rajbhandari, S., Rasley, J., Ruwase, O., & He, Y. (2020). "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models". SC.
- Shoeybi, M. et al. (2019). "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism". arXiv:1909.08053.
- Huang, Y. et al. (2019). "GPipe: Efficient Training of Giant Neural Networks Using Pipeline Parallelism". NeurIPS.
- Frantar, E. et al. (2023). "GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers". ICLR.
- Lin, J. et al. (2024). "AWQ: Activation-aware Weight Quantization for LLM Compression". MLSys.
- Xiao, G. et al. (2023). "SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models". ICML.
- Leviathan, Y., Kalman, M., & Matias, Y. (2023). "Fast Inference from Transformers via Speculative Decoding". ICML.
- Chen, C. et al. (2023). "Accelerating Large Language Model Decoding with Speculative Sampling". arXiv:2302.01318.
- PyTorch documentation. "FullyShardedDataParallel". pytorch.org/docs/stable/fsdp.html.
- PyTorch documentation. "torch.profiler". pytorch.org/docs/stable/profiler.html.
- PyTorch documentation. "torch.compile Tutorial". pytorch.org/tutorials/intermediate/torch_compile_tutorial.html.