Chapter 20
32 min read
Section 65 of 65

Performance Optimization Tips

Debugging and Improving Networks

Overview

A modern transformer's training run can cost seven figures, and a deployed model serves billions of tokens per day. Yet most of that cost is decided not by the algorithm but by how the algorithm uses hardware. The same matmul can run at 3% or 90% of peak FLOPs depending on cache locality. The same attention can fit in 4 GB or fail with out-of-memory depending on whether the score matrix is materialised.

This section gives you a single mental model — the roofline — and five concrete levers you can pull to move a kernel along it. Then we trace each lever back to the modern systems you have already met (or are about to meet) in this book: Flash Attention, multi-head attention with its MQA / GQA variants, positional encodings (sinusoidal, RoPE, ALiBi), KV-cache optimisations like paged and quantised cache, and the scaling laws that govern transformer training.

The examples lean heavily on transformers because they are the richest current case study, not because the principles are transformer-only. The same performance logic applies to CNNs, RNNs, MLPs, and diffusion models: increase useful work per byte moved, keep the working set in the fastest memory tier you can, and only recompute values when that is cheaper than storing or reloading them.

Why this matters. A correct neural network that runs 100× too slowly is a research artifact. A correct neural network that runs at 80% of peak hardware throughput is a product. The gap is almost always memory traffic — not arithmetic.

The Real Bottleneck Is Not Compute

New practitioners assume neural networks are slow because they do too many multiplications. They do not. A single A100 GPU can perform 3.12×10143.12 \times 10^{14} FP16 multiply-adds per second. That is enough raw arithmetic to multiply two 10000×1000010000 \times 10000 matrices roughly 150 times per second. The actual measured throughput on a typical PyTorch model is often 5–15% of that.

Where does the missing 85% go? It is spent waiting for data to arrive at the arithmetic units. The GPU's compute units sit idle while data crawls from off-chip HBM (~2 TB/s) into on-chip SRAM (~20 TB/s) into registers (~30 TB/s). Every order of magnitude closer to the ALU costs another 10× in price-per-byte and 10× in capacity. That gradient is the single most important fact about modern hardware.

Two simple equations capture the regime. The arithmetic intensity of a kernel is I=FLOPsBytes movedI = \frac{\text{FLOPs}}{\text{Bytes moved}}. The hardware exposes a peak compute rate PmaxP_{\max} (FLOPs/s) and a peak memory bandwidth BmaxB_{\max} (Bytes/s). The achievable throughput of the kernel is P=min(Pmax,  IBmax)P = \min(P_{\max},\; I \cdot B_{\max}). The two regimes meet at the ridge point I=Pmax/BmaxI^{*} = P_{\max} / B_{\max}. Below it you are memory-bound and the only way to go faster is to do more compute per byte loaded. Above it you are compute-bound and the only way is to do fewer FLOPs.


The Roofline Mental Model

The roofline plot puts these two equations on a log-log graph: arithmetic intensity on the x-axis, throughput on the y-axis. A diagonal line — slope BmaxB_{\max} — represents the memory ceiling. A horizontal line at PmaxP_{\max} represents the compute ceiling. Every kernel lives at a point under both ceilings. The roof you sit closest to is the bottleneck you must attack.

Drag the slider below — or click an operator chip — to place a kernel on the plot. Notice how naive softmax (1 FLOP/B) lives deep in memory-bound territory: it does not matter how fast your GPU is; you are limited by the HBM bandwidth feeding it. Flash Attention pushes the operator to I60I \approx 60 FLOP/B by tiling QKV into SRAM — close to the compute ceiling, where the GPU's tensor cores can finally stretch.

Interactive Roofline Plot
Drag the kernel marker (or operator chip) — see whether you are memory-bound or compute-bound.
0.111010010001101001000Arithmetic Intensity I = FLOPs / Byte (log)Throughput (TFLOP/s, log)ridge: I* = 156.0 FLOP/B← memory-bound (raise intensity!)compute-bound (peak FLOPs) →you: 8.0 TFLOP/s
Status
memory-bound
efficiency = 2.6% of peak
Arithmetic intensity
I = 4.00 FLOP/B
Hardware roofline
peak 312 TFLOP/s · 2 TB/s
FLOPsBW
The optimisation algorithm. Step 1: measure arithmetic intensity. Step 2: place the kernel on the roofline. Step 3: if it sits on the diagonal, find a way to load fewer bytes (fusion, tiling, quantisation). If it sits on the horizontal, find a way to compute fewer FLOPs (sparsity, low-rank, distillation). Never optimise blindly — the roofline tells you which knob to turn.

The GPU Memory Hierarchy

The roofline assumes a single bandwidth number, but real GPUs have a ladder of memories spanning four orders of magnitude in bandwidth. Each rung is a different physical chip, and crossing between them is what makes a kernel slow.

Click any tier in the visualisation below to see its capacity and bandwidth. Then use the speedup calculator to compare any two tiers — the difference between holding your data in HBM versus SRAM is roughly 10×10 \times, which is exactly the speedup Flash Attention reports over standard attention on long sequences.

Loading 3D memory hierarchy visualization…

Three rules that follow from the hierarchy

  1. Reuse > recompute > re-fetch. If a value is in registers, use it as many times as possible before letting it spill. If it is in SRAM, the same applies one rung down. Crossing to HBM is a 100× cost.
  2. Fuse adjacent ops. Two element-wise ops back-to-back read from HBM, write to HBM, read again, write again. Fused into a single kernel they read once, write once — halving traffic.
  3. Tile to fit the smallest fast memory. Flash Attention chooses block sizes precisely so that Qblock,Kblock,VblockQ_{block}, K_{block}, V_{block} all fit inside one SM's 100–200 KB of SRAM simultaneously. Pick the tile size to fill, but not exceed, the fast tier.

Lever 1 — Vectorization

Before any GPU trick, eliminate the worst possible source of slowness: Python loops over scalar math. The Python interpreter adds roughly 700 ns of overhead per primitive operation; a CPU's ALU can do the underlying multiply in about 1 ns. Doing matrix multiplication element-by-element in pure Python therefore wastes 99.85% of every cycle. Vectorisation hands the entire problem to a compiled BLAS routine that runs at near-peak SIMD throughput on every core.

This is the same idea that distinguishes for i,j,k\text{for i,j,k} CUDA kernels from a single cublasSgemm\texttt{cublasSgemm} call: the unit of work the system sees should be as large as possible so the runtime can schedule, batch, and pipeline it. Click any line to see the per-step cost on the left.

Naive triple-loop matmul vs A @ B
🐍vectorization.py
1import numpy as np

NumPy gives us ndarray and the @ matmul operator. Crucially, A @ B does not loop in Python — it dispatches to a precompiled BLAS routine (OpenBLAS / MKL / Accelerate) written in C and assembly, which uses SIMD vector registers and multiple CPU cores. This is the difference between 'Python doing math' and 'Python asking the CPU to do math'.

EXECUTION STATE
📚 numpy = Numerical library. ndarray stores values in a contiguous, typed C buffer (no Python object per element).
@ operator = Bound to ndarray.__matmul__ → calls np.matmul → calls cblas_sgemm (FP32) or cblas_dgemm (FP64). Hand-tuned for L1/L2 cache blocking.
2import time

Standard library timer. We use time.perf_counter() — the highest-resolution monotonic clock — to compare wall-clock seconds between the naive and vectorized versions.

EXECUTION STATE
📚 time.perf_counter() = Returns a float in seconds with sub-microsecond resolution. Monotonic — won't go backwards if the system clock is adjusted.
4np.random.seed(0)

Make the random matrices reproducible so the timing is comparable across runs. seed(0) initialises the global PRNG state to a deterministic starting point.

EXECUTION STATE
📚 np.random.seed = Sets the seed of NumPy's legacy global RandomState.
⬇ arg: 0 = Any integer works — 0 is conventional. Same seed → same matrices.
5N = 256

Matrix size. We deliberately pick a moderate N: 256×256 means 16.7 million scalar multiplications (256³). Big enough to expose Python's overhead per operation, small enough to actually finish the naive version in seconds.

EXECUTION STATE
N = 256 — matrix dimension
→ FLOPs = 2 × N³ = 33.5 M FLOPs (a multiply-add counts as 2)
6A = np.random.randn(N, N).astype(np.float32)

Allocate a 256×256 matrix sampled from N(0, 1). astype(np.float32) downcasts from the default float64 — halves memory and matches what GPUs use most.

EXECUTION STATE
📚 np.random.randn(*shape) = Samples from standard normal N(0, 1). Returns ndarray of given shape.
⬇ arg 1: N = 256 = rows
⬇ arg 2: N = 256 = columns
.astype(np.float32) = Cast each element from FP64 (8 bytes) to FP32 (4 bytes). Memory drops from 524 KB to 262 KB per matrix. Modern GPUs are 2–8× faster at FP32/BF16 than FP64.
⬆ A = ndarray shape (256, 256), dtype float32, ~262 KB
7B = np.random.randn(N, N).astype(np.float32)

Second random matrix, same shape and dtype. We will multiply A @ B → C of shape (256, 256).

EXECUTION STATE
⬆ B = ndarray shape (256, 256), dtype float32
9def matmul_naive(A, B)

The textbook O(N³) matmul, written in pure Python. Every multiply-and-add goes through the Python interpreter: bytecode dispatch, type checking, memory allocation for intermediate floats. This is what every algorithms class shows, and it is catastrophically slow.

EXECUTION STATE
⬇ input: A (256×256) = First operand. Shape (n, m) where n=256, m=256.
⬇ input: B (256×256) = Second operand. Shape (m, p) where m=256, p=256.
⬆ returns = C of shape (n, p) = (256, 256), where C[i,j] = Σ_k A[i,k] · B[k,j]
10n, m = A.shape

Unpack A's dimensions. .shape is a tuple of int. For A=(256, 256): n=256, m=256.

EXECUTION STATE
.shape = (256, 256) — (rows, cols) of A
n = 256 — rows of A and rows of output C
m = 256 — cols of A, must equal rows of B
11_, p = B.shape

Unpack B's columns. The underscore _ is Python convention for 'I do not need this value' (we already know B's row dim equals m).

EXECUTION STATE
_ = Discarded — B.shape[0] which equals m=256
p = 256 — cols of B and cols of output C
12C = np.zeros((n, p), dtype=np.float32)

Allocate the output buffer pre-filled with zeros. Doing this once outside the loop avoids 65,536 fresh allocations.

EXECUTION STATE
📚 np.zeros(shape, dtype) = Returns an ndarray filled with 0.0. Memory-efficient — uses calloc.
⬇ arg 1: (n, p) = (256, 256) — shape of the output matrix
⬇ arg 2: dtype=np.float32 = Match the inputs to avoid implicit FP64 promotion later. FP32 = 4 bytes per element.
⬆ C = ndarray (256, 256) of zeros, dtype float32
13for i in range(n):

Outer row loop. Iterates 256 times in the Python interpreter — each iteration costs ~50 ns of pure interpreter overhead BEFORE any math happens.

EXECUTION STATE
i = 0, 1, 2, …, 255 — current row of C being computed
→ Python overhead = 256 outer iterations × ~50 ns = 12.8 µs just for iteration. Negligible alone — adds up at deeper levels.
14for j in range(p):

Middle column loop. 256 × 256 = 65,536 (i, j) pairs. Each (i, j) computes one output element C[i, j].

EXECUTION STATE
j = 0, 1, …, 255 — current column of C
→ pairs = 65,536 (i, j) cells to fill
15s = 0.0

Per-cell accumulator. Holds the running dot product for cell (i, j). Reset 65,536 times — each time allocating a fresh Python float object on the heap.

EXECUTION STATE
s = 0.0 — running sum for C[i, j]
→ hidden cost = Each float operation creates a NEW Python object (immutable). The garbage collector must reclaim 16.7M intermediate floats.
16for k in range(m):

Inner reduction loop — runs 256 times PER cell. Total inner-loop iterations: 256 × 256 × 256 = 16,777,216. Each iteration does Python work: bytecode for the add, two int→array index lookups, and a multiplication.

EXECUTION STATE
k = 0, 1, …, 255 — index along the shared dimension
→ total iterations = 16.78 million — and each is ~700 ns of Python interpreter work → ~12 seconds.
17s += A[i, k] * B[k, j]

The actual math. But it is wrapped in 5 layers of Python overhead per iteration: 2 indexed array reads (each goes through ndarray.__getitem__ which boxes a scalar into a Python float), 1 multiply, 1 add, 1 assignment back to s.

EXECUTION STATE
A[i, k] = Indexed access — calls ndarray.__getitem__, allocates a Python float wrapping the underlying FP32. ~200 ns.
B[k, j] = Same — another ~200 ns indexed read.
* = Python float multiply — calls float.__mul__, allocates result float.
+= = In-place add. Rebinds local s to a new float.
→ Why so slow? = The ALU does the multiply in 1 ns. The other 699 ns is the interpreter shuffling Python objects around.
18C[i, j] = s

Store the final accumulator into the output buffer. Writes the FP32 value back into C's contiguous memory.

EXECUTION STATE
C[i, j] = s = calls ndarray.__setitem__ with the float — converts back to FP32 and writes 4 bytes.
19return C

Return the filled matrix. After ~12 seconds of work.

EXECUTION STATE
⬆ return: C (256×256) = A @ B computed element-by-element. Numerically equivalent to matmul_vectorized(A, B) up to FP32 reordering rounding.
21def matmul_vectorized(A, B)

The same algorithm — but written as a single line that hands the entire problem to BLAS. Zero Python loops. The C library tiles the matrices into L1-cache-sized blocks, uses AVX-512 / NEON SIMD instructions (8 or 16 multiplies in parallel), and runs on every available core.

EXECUTION STATE
⬇ input: A, B = Same matrices as before. Both contiguous FP32 buffers — important so BLAS can use vector loads.
⬆ returns = ndarray (256, 256) — same numerical answer, computed in ~1 ms.
22return A @ B

The matmul operator. The Python interpreter sees @ exactly once → calls A.__matmul__(B) → np.matmul(A, B) → cblas_sgemm in libopenblas.so. Inside C, this becomes ~33 million MAC operations executed at near-peak SIMD throughput.

EXECUTION STATE
📚 @ → cblas_sgemm = Single-precision general matrix multiply. The reference workhorse of dense linear algebra.
→ Tiling = BLAS recursively chops A and B into ~64×64 tiles that fit in L1 cache (32 KB). Each tile is multiplied without leaving the cache.
→ SIMD = AVX-512 does 16 FP32 multiplies per cycle. NEON on Apple Silicon does 4. GPU tensor cores do 256+.
→ Threads = OpenMP parallelism across CPU cores. 8 cores → 8× more FLOPs/s.
⬆ return: C (256×256) = Bit-comparable to the naive version up to floating-point reordering (~1e-4 ulps).
24t0 = time.perf_counter(); C1 = matmul_naive(A, B); t1 = time.perf_counter()

Time the naive version. perf_counter is called immediately before and after — t1 − t0 gives wall-clock seconds.

EXECUTION STATE
t0 = Start time, e.g. 1234.56789012
matmul_naive(A, B) = Runs the triple loop. ~12 seconds.
t1 = End time, e.g. 1246.57123456
→ elapsed = t1 − t0 ≈ 12.0 s
25t2 = time.perf_counter(); C2 = matmul_vectorized(A, B); t3 = time.perf_counter()

Time the vectorized version. Same input, same answer.

EXECUTION STATE
matmul_vectorized(A, B) = BLAS sgemm. ~1 ms on a modern laptop CPU.
→ elapsed = t3 − t2 ≈ 0.001 s
→ speedup = 12.0 / 0.001 ≈ 12,000× faster — for the same algorithm, same hardware.
27print(f"naive : {t1 - t0:8.3f} s")

f-string with format spec :8.3f means 'right-aligned, width 8, 3 decimal places'. Prints e.g. ' 12.034 s'.

EXECUTION STATE
f"…{expr:fmt}…" = Python f-string. {t1 - t0:8.3f} substitutes the elapsed time formatted as fixed-point.
→ output = naive : 12.034 s
28print(f"vectorized : {t3 - t2:8.4f} s")

Same idea, 4 decimals to actually show the millisecond.

EXECUTION STATE
→ output = vectorized : 0.0010 s
29print("max abs diff:", np.max(np.abs(C1 - C2)))

Sanity check: the two answers should agree to floating-point precision. They differ by ~1e-4 because BLAS sums in a different order, and floating-point addition is not associative.

EXECUTION STATE
📚 np.abs(x) = Element-wise absolute value.
📚 np.max(x) = Single largest element across the entire array.
→ output = max abs diff: 0.000122 — well within FP32 expected error (relative ε ≈ 6e-8 per op, accumulated over 256 mults).
5 lines without explanation
1import numpy as np
2import time
3
4np.random.seed(0)
5N = 256
6A = np.random.randn(N, N).astype(np.float32)
7B = np.random.randn(N, N).astype(np.float32)
8
9def matmul_naive(A, B):
10    n, m = A.shape
11    _, p = B.shape
12    C = np.zeros((n, p), dtype=np.float32)
13    for i in range(n):
14        for j in range(p):
15            s = 0.0
16            for k in range(m):
17                s += A[i, k] * B[k, j]
18            C[i, j] = s
19    return C
20
21def matmul_vectorized(A, B):
22    return A @ B
23
24t0 = time.perf_counter(); C1 = matmul_naive(A, B);     t1 = time.perf_counter()
25t2 = time.perf_counter(); C2 = matmul_vectorized(A, B); t3 = time.perf_counter()
26
27print(f"naive       : {t1 - t0:8.3f} s")  # ~12.0 s on a laptop CPU
28print(f"vectorized  : {t3 - t2:8.4f} s")  # ~0.001 s — about 10,000× faster
29print("max abs diff:", np.max(np.abs(C1 - C2)))  # ~1e-4 (FP32 reorder)
Connection to PyTorch. When you write torch.matmul(A, B) on the GPU, PyTorch dispatches to cublasGemmEx\texttt{cublasGemmEx}. On Tensor Cores it reaches ~95% of theoretical peak FP16 FLOPs. The PyTorch eager loop overhead per op (~50 µs) is also why torch.compile exists: it fuses many small ops into one CUDA graph, shrinking dispatch cost by 100×.

Lever 2 — Operator Fusion (Online Softmax)

Vectorisation removes Python overhead, but a kernel can still be slow if it crosses HBM more times than necessary. The textbook softmax is a case in point: it makes three sequential passes over its input — once for the max, once for exme^{x - m}, once for the sum. Each pass crosses HBM. For a long sequence those three passes dominate the runtime; the actual exp() arithmetic is almost free.

The online softmax trick collapses all three into a single pass by maintaining running statistics. When a new max appears, the running denominator is rescaled by emoldmnewe^{m_{\text{old}} - m_{\text{new}}} to undo the now-stale normalisation, then the new term is added on with the new normalisation. This is mathematically exact (no approximation) and the key trick that makes Flash Attention possible — it lets attention scores be computed in tiles without ever materialising the full N×NN \times N matrix.

3-pass softmax vs single-pass online softmax
🐍online_softmax.py
1import numpy as np

Load NumPy. We will use np.array, np.max, np.exp, np.sum, np.round, and np.inf.

EXECUTION STATE
📚 numpy = Provides ndarray plus the math functions used below.
3x = np.array([2.0, 1.0, 0.5, 3.0], dtype=np.float64)

Our toy input vector. Picked so the largest value (3.0) is NOT first — this forces the online algorithm to actually 'rescale' once, exposing the trick.

EXECUTION STATE
📚 np.array(list, dtype) = Builds an ndarray from a Python list.
⬇ arg 1: [2.0, 1.0, 0.5, 3.0] = Four scores. In real attention these are Q·Kᵀ row entries.
⬇ arg 2: dtype=np.float64 = Use FP64 here so rounding doesn't muddy the comparison. In production attention you'd use FP16 or BF16.
⬆ x = [2.0, 1.0, 0.5, 3.0] — shape (4,), float64
5def softmax_naive(x) — the 3-pass version

The textbook stable softmax. It walks over x three times: once to find the max, once to compute exp(x − max), once to sum. Each pass means another full read of x from memory. For long sequences, those memory passes dominate over the actual arithmetic.

EXECUTION STATE
⬇ input: x = [2.0, 1.0, 0.5, 3.0] — the score vector
⬆ returns = Probability distribution summing to 1.0 — same shape as x
→ memory cost = Reads x three times = 3·N elements. Each read crosses HBM if x is large.
6Docstring: """Standard 3-pass softmax — reads x three times."""

The docstring is the single most important comment: it names the inefficiency we are about to fix.

7m = np.max(x) # pass 1

First pass over x. We need the max because exp(x) overflows for large x — exp(1000) is +∞. Subtracting m before exp keeps the largest value at exp(0) = 1, which is always representable. The output of softmax is mathematically identical: softmax(x) = softmax(x − c) for any constant c.

EXECUTION STATE
📚 np.max(arr) = Returns the single largest element. Implemented as a tight C loop with SIMD reductions.
⬇ arg: x = [2.0, 1.0, 0.5, 3.0] — the array to scan
⬆ m = 3.0 — the max element
→ why subtract? = exp(x − m) ≤ exp(0) = 1 → no overflow. Largest entry exp(0)=1, smaller entries shrink toward 0.
8e = np.exp(x - m) # pass 2

Second pass: subtract m from every element (fused with exp), then take exp element-wise.

EXECUTION STATE
📚 np.exp(arr) = Element-wise e^x.
⬇ arg: x - m = [2-3, 1-3, 0.5-3, 3-3] = [-1.0, -2.0, -2.5, 0.0]
⬆ e = [exp(-1), exp(-2), exp(-2.5), exp(0)] = [0.3679, 0.1353, 0.0821, 1.0]
9s = np.sum(e) # pass 3

Third pass: total of all exponentials. This is the normaliser that makes the output sum to 1.

EXECUTION STATE
📚 np.sum(arr) = Element-wise sum reduction. C loop with parallel reductions.
⬇ arg: e = [0.3679, 0.1353, 0.0821, 1.0]
⬆ s = 0.3679 + 0.1353 + 0.0821 + 1.0 = 1.5853
10return e / s

Final element-wise divide — convert each exp into a probability. NumPy broadcasts the scalar s across e.

EXECUTION STATE
/ (broadcasting) = Each element e[i] divided by the same scalar s.
⬆ return = [0.3679/1.5853, 0.1353/1.5853, 0.0821/1.5853, 1.0/1.5853] = [0.2321, 0.0854, 0.0518, 0.6308]
→ verify = 0.2321 + 0.0854 + 0.0518 + 0.6308 = 1.0001 ≈ 1.0 ✓
12def softmax_online(x) — single pass

The online variant — known as 'streaming softmax' or the 'log-sum-exp trick'. It maintains a RUNNING max m and a RUNNING denominator l, updating both as new elements arrive. When a new max is discovered, l is rescaled by exp(m_old − m_new) to undo the old normalisation and rebase to the new max. This is the central trick that makes Flash Attention work — it lets attention be computed in tiles without ever materialising the full N×N score matrix.

EXECUTION STATE
⬇ input: x = [2.0, 1.0, 0.5, 3.0]
⬆ returns = Same probability distribution as softmax_naive — but computed in ONE pass.
→ key invariant = After processing the first k elements, l[k] = Σ exp(x_i − m[k]) for i<k. When m updates, we rescale l to maintain the invariant.
13Docstring: """Single-pass softmax — Flash Attention's core trick."""

Names the trick. This is exactly Algorithm 1 of Dao et al. 2022, applied to a 1D vector instead of an attention row.

14m = -np.inf

Initialise running max to negative infinity so the first element always becomes the max. np.inf is IEEE 754 +∞; -np.inf is −∞.

EXECUTION STATE
np.inf = +∞ — the IEEE 754 infinity sentinel
m = −∞ — running max, replaced by the first finite x_i
15l = 0.0

Initialise running normaliser to 0. After the first element it becomes exp(0) = 1.

EXECUTION STATE
l = 0.0 — running sum of exp(x_i − m)
16for xi in x: # one pass over x

The single pass. Each iteration consumes one element xi and updates (m, l). For attention, this loop iterates over key blocks, not individual scalars — but the math is identical.

LOOP TRACE · 4 iterations
iter 1: xi = 2.0
m_new = max(-∞, 2.0) = 2.0 (new max)
l update = 0 · exp(-∞ - 2) + exp(2 - 2) = 0 + 1 = 1.0
after = m = 2.0, l = 1.0
iter 2: xi = 1.0
m_new = max(2.0, 1.0) = 2.0 (no change)
l update = 1.0 · exp(2 - 2) + exp(1 - 2) = 1.0 + 0.3679 = 1.3679
after = m = 2.0, l = 1.3679
iter 3: xi = 0.5
m_new = max(2.0, 0.5) = 2.0 (no change)
l update = 1.3679 · exp(0) + exp(-1.5) = 1.3679 + 0.2231 = 1.5910
after = m = 2.0, l = 1.5910
iter 4: xi = 3.0 ← max changes!
m_new = max(2.0, 3.0) = 3.0 ← NEW MAX, must rescale
rescale factor = exp(m_old - m_new) = exp(2 - 3) = exp(-1) = 0.3679
l update = 1.5910 · 0.3679 + exp(3 - 3) = 0.5853 + 1.0 = 1.5853
after = m = 3.0, l = 1.5853 — IDENTICAL to naive sum 1.5853 ✓
17m_new = max(m, xi)

Tentative new running max. If xi exceeds the old max, m_new > m and we will need to rescale l. Otherwise m_new = m.

EXECUTION STATE
📚 max(a, b) = Python builtin — returns the larger of two values.
→ branchless = GPU implementations use a fmax intrinsic to avoid divergent branches.
18l = l * np.exp(m - m_new) + np.exp(xi - m_new)

The heart of the trick. Rebase the running denominator from m to m_new, then add the new term. When m_new = m, exp(m − m_new) = exp(0) = 1.0 → no rescale. When m_new > m (a new max appeared), exp(m − m_new) < 1.0 → l shrinks, undoing the now-stale normalisation.

EXECUTION STATE
l * np.exp(m - m_new) = Old sum, rescaled. Recall l was Σ exp(x_i − m) for old m. Multiplying by exp(m − m_new) converts each term to exp(x_i − m_new), the new normalisation.
+ np.exp(xi - m_new) = Add the new term, normalised against m_new.
→ invariant maintained = After this line, l = Σ exp(x_i − m_new) for all x_i seen so far.
19m = m_new

Commit the new max into the running state.

EXECUTION STATE
m = Updated to m_new
20return np.exp(x - m) / l

Now that we have the final m and l, produce probabilities. Note this is a SECOND pass over x — but in Flash Attention, the values V have been incrementally accumulated alongside l, so this final pass becomes a single divide on the already-computed weighted sum, with NO need to revisit the original scores.

EXECUTION STATE
x - m = [2-3, 1-3, 0.5-3, 3-3] = [-1.0, -2.0, -2.5, 0.0]
np.exp(x - m) = [0.3679, 0.1353, 0.0821, 1.0]
/ l = Divide each by 1.5853
⬆ return = [0.2321, 0.0854, 0.0518, 0.6308] — matches naive bit-for-bit (modulo FP error)
22p_naive = softmax_naive(x)

Run the 3-pass version and store the result.

EXECUTION STATE
p_naive = [0.2321, 0.0854, 0.0518, 0.6308]
23p_online = softmax_online(x)

Run the single-pass version. Should be identical.

EXECUTION STATE
p_online = [0.2321, 0.0854, 0.0518, 0.6308]
24print("naive :", np.round(p_naive, 4))

np.round rounds element-wise to the requested number of decimals — purely cosmetic.

EXECUTION STATE
📚 np.round(arr, decimals) = Element-wise round to nearest, ties-to-even.
→ output = naive : [0.2321 0.0854 0.0518 0.6308]
25print("online:", np.round(p_online, 4))

Print the online result for visual comparison.

EXECUTION STATE
→ output = online: [0.2321 0.0854 0.0518 0.6308] ← identical to 4 decimals
26print("max diff:", np.max(np.abs(p_naive - p_online)))

Verify equivalence. The maximum element-wise difference should be at floating-point precision (~1e-16 in FP64, ~1e-7 in FP32).

EXECUTION STATE
→ output = max diff: 1.1e-16 — equal up to FP64 epsilon ✓
→ implication = We just halved the memory traffic of softmax with no loss of precision. Now imagine doing this for a 1M×1M attention matrix — that is Flash Attention.
4 lines without explanation
1import numpy as np
2
3x = np.array([2.0, 1.0, 0.5, 3.0], dtype=np.float64)
4
5def softmax_naive(x):
6    """Standard 3-pass softmax — reads x three times."""
7    m = np.max(x)              # pass 1: find max for stability
8    e = np.exp(x - m)          # pass 2: shifted exponentials
9    s = np.sum(e)              # pass 3: normaliser
10    return e / s
11
12def softmax_online(x):
13    """Single-pass softmax — Flash Attention's core trick."""
14    m = -np.inf
15    l = 0.0
16    for xi in x:                              # one pass over x
17        m_new = max(m, xi)
18        l = l * np.exp(m - m_new) + np.exp(xi - m_new)
19        m = m_new
20    return np.exp(x - m) / l                  # final element-wise divide
21
22p_naive  = softmax_naive(x)
23p_online = softmax_online(x)
24print("naive :", np.round(p_naive,  4))   # [0.2321 0.0854 0.0518 0.6308]
25print("online:", np.round(p_online, 4))   # [0.2321 0.0854 0.0518 0.6308]
26print("max diff:", np.max(np.abs(p_naive - p_online)))  # ~0
Why this is exact, not approximate. At every iteration we maintain the invariant t=iteximt\ell_t = \sum_{i \le t} e^{x_i - m_t}. When mm changes, multiplying \ell by emoldmnewe^{m_{\text{old}} - m_{\text{new}}} converts each old summand eximolde^{x_i - m_{\text{old}}} into eximnewe^{x_i - m_{\text{new}}}. The invariant survives, so the final pi=exim/p_i = e^{x_i - m} / \ell equals the textbook result up to floating-point noise.

PyTorch is already doing this for you

When you call F.softmax(x, dim=-1) in PyTorch, the framework dispatches to a fused CUDA kernel that performs max, subtract, exp, sum, and divide in a single HBM round-trip. With torch.compile or F.scaled_dot_product_attention(...), the entire attention block — Q@Kᵀ, softmax, attention @ V — is fused into one Flash-Attention kernel. The naive 3-pass version we wrote in Python exists only as a specification; in production it is never executed.


Lever 3 — Memory Layout & Contiguity

Two tensors of identical shape can run at radically different speeds depending on how their elements are laid out in memory. A tensor is contiguous when its elements are stored in the order you iterate them; non-contiguous tensors force the cache to fetch one element at a time, defeating the SIMD vector loads.

In PyTorch, operations like .transpose(), .permute(), and slicing with stride do not actually move data — they return a view with rearranged strides. The next op that needs contiguous memory will then either dispatch to a slow non-contiguous kernel or trigger an invisible .contiguous() copy. The fix is to call .contiguous() explicitly after a transpose if the downstream op is bandwidth-sensitive.

OperationReturns view (free)?Triggers copy?
x.view(B, T, D)yesno — must be contiguous-compatible
x.reshape(B, T, D)if possibleyes if not contiguous
x.transpose(-2, -1)yes — strides flippedno, but breaks contiguity
x.permute(0, 2, 1, 3)yesno, but next op may pay
x.contiguous()noyes — explicit copy
x[:, :, :32]yes if leading dims contiguousno

The deeper lesson: memory layout is part of the algorithm. Flash Attention chooses the (block, head, dim) memory order deliberately so that loads from HBM into SRAM are 128-byte coalesced bursts. Multi-head attention typically stores Q, K, V interleaved (the "packed QKV" layout) so a single cache line load brings all three projections.


Lever 4 — Mixed Precision (FP16/BF16/FP8)

FP32 stores 32 bits per number. FP16/BF16 store 16. FP8 stores 8. Halving precision halves memory traffic and doubles the FLOPs the Tensor Cores deliver — for free, if the loss in numerical range or precision does not break training.

FormatBitsRangeMantissaWhen to use
FP3232±10³⁸23 bitsMaster weights, optimizer state
TF3219 (in 32-bit slot)±10³⁸10 bitsDefault on Ampere+, drop-in matmul speedup
FP1616±10⁴10 bitsActivations; needs loss scaling for grads
BF1616±10³⁸7 bitsActivations + grads; no loss scaling needed
FP8 (E4M3)8±10²3 bitsInference activations, H100+ training
INT88±127Post-training inference quantisation
INT44±7GPTQ / AWQ weight-only quant for serving

BF16 has been the workhorse for transformer training since Google's TPUs popularised it: it keeps FP32's exponent range (so gradients rarely underflow) at the cost of mantissa precision, which weights absorb gracefully. FP8 is the new frontier — H100 and TPUv5p ship with Tensor Cores that double FP16 throughput again. nn.Linear weights of GPT-class models are now routinely served as INT4 with per-group scales (GPTQ, AWQ); the model fits in a third the memory with under 1% perplexity loss.

The standard mixed-precision recipe (PyTorch AMP).Keep a master copy of weights in FP32. Cast activations and matmuls to BF16 inside torch.autocast("cuda", dtype=torch.bfloat16). Update FP32 master weights from BF16 gradients. Under torch.compile this fuses with the Flash Attention kernel for end-to-end BF16 speed.

Lever 5 — Gradient Checkpointing

Backprop needs the activations from the forward pass. For a deep network this can dominate memory: a 70B-parameter model at sequence length 4096 stores ~80 GB of activations alone. The classical fix is gradient checkpointing: keep only a sparse subset of activations in memory, and recompute the missing ones during backward.

For a network with LL layers, naive backprop keeps all LL activations → O(L)O(L) memory. Checkpointing every L\sqrt{L}-th layer reduces memory to O(L)O(\sqrt{L}) while increasing compute by roughly 33% (one extra forward pass distributed over all checkpoint segments). For transformers, torch.utils.checkpoint applied to each transformer block is the standard pattern.

Flash Attention takes this even further: it discards the entire attention matrix after the forward pass and recomputes it from QKV during backward, since Flash Attention is fast enough that recomputation is cheaper than the HBM cost of storing an N×NN \times N matrix.


Inference: The KV Cache

Training and inference have opposite bottlenecks. Training is compute-bound (one big batched forward+backward over a fixed sequence). Inference for a chat model is memory-bound: the model generates one token at a time, doing tiny matmuls but reading the full prefix of K and V from memory at every step.

Without caching, a naive autoregressive decoder reprojects the entire prefix at each step, so past tokens are repeatedly projected into K and V. In the toy code below, that redundant projection work grows quadratically with total decode length because step tt recomputes a prefix of length tt. With a KV cache we project past K and V exactly once during prefill, then in each generation step we only project the single new token and append its k,vk, v to the cache. The attention lookup is still linear in current context length, but the repeated prefix projection disappears, which is the difference between a pedagogical decoder and a usable one.

Naive recompute decode vs KV-cache decode
🐍kv_cache.py
1import torch

PyTorch — gives us tensors with autograd and GPU support. We use torch.Tensor (the analogue of np.ndarray), torch.randn for random initialisation, @ for matmul, and .transpose for axis swaps.

EXECUTION STATE
📚 torch = Deep learning framework. torch.Tensor is the core dtype; supports CPU, CUDA, MPS backends.
2import torch.nn.functional as F

torch.nn.functional contains stateless math operations (softmax, gelu, dropout, etc.). Aliased to F by convention.

EXECUTION STATE
📚 F.softmax = Functional softmax — no learnable parameters. Numerically stable; accepts dim arg.
4torch.manual_seed(0)

Make randn calls deterministic — same seed gives same prompt and same weights across runs.

EXECUTION STATE
📚 torch.manual_seed(int) = Sets the seed of the global CPU PRNG (and CUDA, if available).
⬇ arg: 0 = Conventional fixed seed for reproducibility.
5B, H, D = 1, 1, 4 # batch, heads, head dim

Tiny dimensions to keep things readable. B=batch=1 (single sequence). H=1 head. D=4-dim per head. Real models: B=8–256, H=32, D=128.

EXECUTION STATE
B = 1 — batch size
H = 1 — number of attention heads (we keep it 1 here for clarity)
D = 4 — per-head feature dimension
6W_q = torch.randn(D, D)

Random query projection. In a real model this would be nn.Linear with learned weights and shape (d_model, d_model). We use a 4×4 matrix for the demo.

EXECUTION STATE
📚 torch.randn(*shape) = Samples from N(0, 1). Returns Tensor.
⬇ args: D, D = (4, 4) — square projection matrix
⬆ W_q = Tensor shape (4, 4) — projects emb → query
7W_k = torch.randn(D, D)

Key projection. Crucially, K depends ONLY on the input embedding — once a token is in the prompt, its K never changes. That immutability is what makes caching K safe.

EXECUTION STATE
⬆ W_k = Tensor (4, 4) — projects emb → key
→ key insight = K[t] = emb[t] @ W_k is a function of emb[t] alone. Future tokens cannot change past keys (causal).
8W_v = torch.randn(D, D)

Value projection. Same property as K — V[t] depends only on emb[t]. We can cache V[t] for the entire prefix and never recompute.

EXECUTION STATE
⬆ W_v = Tensor (4, 4) — projects emb → value
10def attn(Q, K, V)

Scaled dot-product attention. Computes softmax(QKᵀ/√D) V. We will call this from both decode functions.

EXECUTION STATE
⬇ Q = (B, T_q, D) — query rows we want to attend FROM
⬇ K = (B, T_k, D) — key rows we attend TO
⬇ V = (B, T_k, D) — value rows we read
⬆ returns = (B, T_q, D) — context vector per query
11scores = Q @ K.transpose(-2, -1) / (D ** 0.5)

Compute pairwise scores then scale by √D for variance control.

EXECUTION STATE
📚 .transpose(dim0, dim1) = Swap two tensor dims. -2 = second-to-last, -1 = last. For K shape (B, T_k, D), this becomes (B, D, T_k).
⬇ args: -2, -1 = Swap the last two axes — converts K from (B, T_k, D) to (B, D, T_k) so Q@Kᵀ has matching inner dim D.
@ = Batched matmul. (B, T_q, D) @ (B, D, T_k) = (B, T_q, T_k).
/ (D ** 0.5) = Divide by √D = √4 = 2. Keeps softmax variance ≈ 1 regardless of D.
⬆ scores = Tensor (B, T_q, T_k) — pairwise compatibilities
12return F.softmax(scores, dim=-1) @ V

Row-wise softmax (each query distributes 1.0 of attention across keys), then multiply by V.

EXECUTION STATE
📚 F.softmax(x, dim) = Numerically stable softmax — internally subtracts the max along dim before exp.
⬇ arg: dim=-1 = Apply softmax along the LAST axis (T_k). Each query row sums to 1.0.
@ V = Weighted sum: (B, T_q, T_k) @ (B, T_k, D) = (B, T_q, D)
⬆ return = Tensor (B, T_q, D) — context vector for each query
15def decode_naive(prompt_emb, n_new) — RECOMPUTE everything

The wasteful baseline: at every new token, project the ENTIRE prefix (prompt + everything generated so far) to obtain Q, K, V — even though the K and V of past tokens never change. Cost grows quadratically with sequence length.

EXECUTION STATE
⬇ prompt_emb = (B, T_prompt, D) — embeddings of the prompt tokens
⬇ n_new = How many new tokens to generate
⬆ returns = (B, n_new, D) — generated context vectors
→ complexity = O(T² · D) per generated token. Total: O(N³ · D) for N total tokens. Catastrophic.
16out_tokens = []

Collect the per-step outputs to concatenate at the end.

EXECUTION STATE
out_tokens = [] — list of (B, 1, D) tensors
17seq = prompt_emb.clone()

.clone() makes a deep copy of the tensor so appending doesn't mutate the caller's prompt. seq grows by one token per loop iteration.

EXECUTION STATE
📚 .clone() = Returns an independent tensor with the same data. Required because we will torch.cat onto seq.
seq = Initially shape (1, 3, 4) — a copy of the prompt
18for _ in range(n_new):

Generate n_new tokens. The underscore _ means 'we don't care about the iteration index'.

EXECUTION STATE
→ iterations = n_new = 5 → loop body runs 5 times
19Q_all = seq @ W_q

Project the entire current sequence to queries — but we will only USE the last row. Wasteful.

EXECUTION STATE
seq @ W_q = (1, T, 4) @ (4, 4) = (1, T, 4)
Q_all = Shape grows: T=3 → 4 → 5 → 6 → 7 across iterations
→ waste = We compute T queries but only need 1. (T-1)/T fraction of work is thrown away.
20K_all = seq @ W_k # ← recomputed every step (waste!)

RECOMPUTE the keys for the entire prefix. Past tokens' keys are mathematically identical to last step's, but we redo the matmul. This is the redundant work the KV cache eliminates.

EXECUTION STATE
K_all = Shape (1, T, 4) — full key matrix
→ FLOPs = T × D² FLOPs every step → O(T²) over the whole decode. With T=2048 and D=128 this is ~5.4 GFLOPs of pure waste per step.
21V_all = seq @ W_v

Same waste for V.

EXECUTION STATE
V_all = Shape (1, T, 4)
22h = attn(Q_all[:, -1:], K_all, V_all)

Slice Q_all to the LAST row (the only query that matters for the new token), then call attention.

EXECUTION STATE
Q_all[:, -1:] = Slice notation — keep all batches, take only the last token along dim 1. Shape (1, 1, 4). The colon-after-index keeps the dim instead of squeezing it.
→ -1: vs -1 = [:, -1] would squeeze to (1, 4); [:, -1:] keeps it (1, 1, 4) which attn() expects.
h = Shape (1, 1, 4) — context vector for the new token
23seq = torch.cat([seq, h], dim=1)

Append the newly generated h to the running sequence. Next iteration will (wastefully) reproject ALL of seq.

EXECUTION STATE
📚 torch.cat(tensors, dim) = Concatenate tensors along the given axis. All other axes must match in size.
⬇ arg 1: [seq, h] = List of two tensors — current sequence and new token
⬇ arg 2: dim=1 = Concatenate along the time axis. (1, T, 4) cat (1, 1, 4) → (1, T+1, 4).
24out_tokens.append(h)

Save this step's output.

25return torch.cat(out_tokens, dim=1)

Concatenate the 5 single-token outputs into one (1, 5, 4) tensor.

EXECUTION STATE
⬆ return = Tensor (1, 5, 4) — five generated context vectors stacked along time
28def decode_cached(prompt_emb, n_new) — STORE K, V; project only new

The KV-cache decode. We project the prompt's K and V ONCE (the 'prefill' phase), then in each generation step we only project the single new token's q, k, v and APPEND k, v to the cache. The expensive O(T²) prefix work happens once instead of T times.

EXECUTION STATE
⬇ prompt_emb = (B, T_prompt, D) — embeddings of the prompt tokens
⬇ n_new = Number of tokens to generate
⬆ returns = (B, n_new, D) — same answer as decode_naive
→ complexity = O(T · D) per generated token. Total: O(N² · D) — one factor of N better than naive.
29K_cache = prompt_emb @ W_k # one-time prefill

PREFILL phase: project the entire prompt to keys, ONCE. From now on, no past key is ever recomputed.

EXECUTION STATE
K_cache = Shape (1, 3, 4) initially — keys for the 3 prompt tokens
→ memory cost = K_cache grows by D bytes per token per layer per head. For Llama-2 7B at 2048 ctx: 1 GB just for KV cache.
30V_cache = prompt_emb @ W_v

Same for values. Together K_cache and V_cache are the 'KV cache' you hear about in inference papers.

EXECUTION STATE
V_cache = Shape (1, 3, 4) — values for the 3 prompt tokens
31last_emb = prompt_emb[:, -1:] # most recent token

Only the LAST embedding is needed to start generating — its query vector will attend to everything in K_cache.

EXECUTION STATE
last_emb = Shape (1, 1, 4) — embedding of the last prompt token
32out_tokens = []

Collect outputs.

33for _ in range(n_new):

Same outer loop, but each iteration is much cheaper.

34q = last_emb @ W_q # project only the new query

Project ONE token. Cost: 1 × D × D = 16 FLOPs vs naive's T × D × D.

EXECUTION STATE
q = Shape (1, 1, 4) — query for the new token
→ saving = 1 row matmul instead of T rows. At T=2048, this is 2048× less work for the projection.
35k = last_emb @ W_k

New key for the new token. Will be appended to K_cache.

EXECUTION STATE
k = Shape (1, 1, 4) — new key
36v = last_emb @ W_v

New value for the new token.

EXECUTION STATE
v = Shape (1, 1, 4) — new value
37K_cache = torch.cat([K_cache, k], dim=1) # append new K

Append the new key into the cache. dim=1 = the time axis. After this, K_cache has one more row than before.

EXECUTION STATE
torch.cat([K_cache, k], dim=1) = (1, T, 4) cat (1, 1, 4) → (1, T+1, 4) along the time axis
→ in production = Real frameworks pre-allocate K_cache with shape (B, max_T, H, D) and write into a slice — no torch.cat allocations.
38V_cache = torch.cat([V_cache, v], dim=1) # append new V

Same for values.

39h = attn(q, K_cache, V_cache)

Attention with a single query against the FULL cached prefix. The attention itself is still O(T) per step, but the wasted projection of past tokens is gone.

EXECUTION STATE
q (1, 1, 4) = Single query
K_cache, V_cache (1, T+1, 4) = Full prefix keys and values
h = Shape (1, 1, 4) — output for the new token
40out_tokens.append(h)

Save the new token's output.

41last_emb = h # next query = current output

In autoregressive generation, the model's output at step t becomes the input at step t+1 (after passing through the LM head, sampling, and re-embedding — abbreviated here).

EXECUTION STATE
→ true loop in real LMs = h → LM head → logits → sample token id → embed → next last_emb. We skip those for brevity.
42return torch.cat(out_tokens, dim=1)

Stack the per-token outputs into a (B, n_new, D) tensor.

EXECUTION STATE
⬆ return = (1, 5, 4) — same shape and (modulo numerics) same values as decode_naive
44prompt = torch.randn(B, 3, D)

A random 3-token prompt. In a real model these would come from the embedding table after tokenisation.

EXECUTION STATE
📚 torch.randn(*shape) = Standard normal samples
⬇ args: B, 3, D = Shape (1, 3, 4)
prompt = Tensor (1, 3, 4)
45print(decode_naive(prompt, 5).shape)

Generate 5 tokens with the wasteful version. Should print torch.Size([1, 5, 4]).

EXECUTION STATE
→ output = torch.Size([1, 5, 4])
→ FLOPs (cumulative for 5 steps with T_prompt=3) = 3·16 + 4·16 + 5·16 + 6·16 + 7·16 = 400 — extra projections we throw away.
46print(decode_cached(prompt, 5).shape)

Same shape, but the projection cost is constant per step (16 FLOPs) instead of growing linearly. At T=2048 this is the difference between 'instant' and 'unusable'.

EXECUTION STATE
→ output = torch.Size([1, 5, 4])
→ FLOPs for 5 steps = 5 · 3 · 16 = 240 (only new projections + attention against cache)
→ in production = vLLM, TGI, llama.cpp all use KV cache. Without it, generating a single 1000-token reply would take minutes.
7 lines without explanation
1import torch
2import torch.nn.functional as F
3
4torch.manual_seed(0)
5B, H, D = 1, 1, 4         # batch, heads, head dim
6W_q = torch.randn(D, D)
7W_k = torch.randn(D, D)
8W_v = torch.randn(D, D)
9
10def attn(Q, K, V):
11    scores = Q @ K.transpose(-2, -1) / (D ** 0.5)   # [B, T_q, T_k]
12    return F.softmax(scores, dim=-1) @ V
13
14# --- NAIVE decode: recompute K, V for the full prefix at each step ---
15def decode_naive(prompt_emb, n_new):
16    out_tokens = []
17    seq = prompt_emb.clone()                          # grows token-by-token
18    for _ in range(n_new):
19        Q_all = seq @ W_q
20        K_all = seq @ W_k         # ← recomputed every step (waste!)
21        V_all = seq @ W_v
22        h = attn(Q_all[:, -1:], K_all, V_all)         # only last query needed
23        seq = torch.cat([seq, h], dim=1)
24        out_tokens.append(h)
25    return torch.cat(out_tokens, dim=1)
26
27# --- CACHED decode: store K, V; only project the new token ---
28def decode_cached(prompt_emb, n_new):
29    K_cache = prompt_emb @ W_k                        # one-time prefill
30    V_cache = prompt_emb @ W_v
31    last_emb = prompt_emb[:, -1:]                     # most recent token
32    out_tokens = []
33    for _ in range(n_new):
34        q = last_emb @ W_q                            # project only the new query
35        k = last_emb @ W_k
36        v = last_emb @ W_v
37        K_cache = torch.cat([K_cache, k], dim=1)      # append new K
38        V_cache = torch.cat([V_cache, v], dim=1)      # append new V
39        h = attn(q, K_cache, V_cache)
40        out_tokens.append(h)
41        last_emb = h                                  # next query = current output
42    return torch.cat(out_tokens, dim=1)
43
44prompt = torch.randn(B, 3, D)
45print(decode_naive(prompt,  5).shape)   # torch.Size([1, 5, 4])
46print(decode_cached(prompt, 5).shape)   # torch.Size([1, 5, 4])

The cache itself, however, becomes the new memory bottleneck. For a Llama-2 7B-class model (32 layers, 32 heads, head dim 128, BF16) the cache costs 232321282524,2882 \cdot 32 \cdot 32 \cdot 128 \cdot 2 \approx 524{,}288 bytes per token — about 0.5 MB. At sequence length 4096 this is 2\approx 2 GB per request. Serving many concurrent users means juggling many of these caches, and that memory pressure is the headline cost of LLM inference today.

KV Cache Memory Growth — Llama-2 7B-class config
d_model = 4096, layers = 32, heads = 32, dtype = BF16. Per token: 524 KB.
MHA cache33.6 MBGQA-8 cache8.4 MBMQA cache1.0 MBdecode FLOPs / tokencached 33.6 MFLOP vs naive 2.15 GFLOP → 64.0× fewer FLOPs
Sequence length: 64 tokens
Why GQA/MQA?
MQA shares one K,V across all heads — 32× smaller cache. GQA-8 is the practical middle ground (Llama-2 70B, Mistral).

Connections to Modern Systems

Every famous performance technique in modern transformer infrastructure is one of the five levers above, applied to the right rung of the memory hierarchy. Let's walk the connections.

Flash Attention

Standard attention computes S=QK/dS = QK^{\top}/\sqrt{d}, materialises the full N×NN \times N matrix S in HBM, then reads it back to compute softmax and O=softmax(S)VO = \text{softmax}(S) V. Each entry of S is touched twice — once written, once read — for a total HBM traffic of O(N2)O(N^{2}). At N=4096N = 4096 this is the dominant cost. Arithmetic intensity is roughly 1 FLOP/B — deep memory-bound.

Flash Attention (Dao et al., 2022) tiles Q, K, V into SRAM-sized blocks and uses the online softmax above to compute the output one tile at a time, accumulating both the running normaliser \ell and the running output OO incrementally. The full S matrix is never written to HBM. HBM traffic drops from O(N2)O(N^{2}) to O(N2d/M)O(N^{2} d / M) where M is the SRAM block size — typically a 5–10× wall-clock speedup on long sequences and the reason 100K-token contexts are now feasible. Lever 2 (fusion) plus Lever 3 (tiled layout) at the SRAM rung.

Multi-Head, MQA, and GQA

Multi-head attention splits the model dimension across H heads, each with its own Q, K, V projections. Total FLOPs are unchanged versus a single-head attention of the same dimension — the win is representational, not computational. But for inference the choice of how many K/V heads to use is a memory lever.

VariantQ headsK/V headsKV cache sizeUsed by
MHAHH1× baselineOriginal Transformer, GPT-2/3
MQAH11/H × baselinePaLM, Falcon
GQA-gHgg/H × baselineLlama-2 70B (g=8), Mistral, Llama-3

Multi-Query Attention (MQA) shares one K/V across all query heads, shrinking the KV cache by a factor of H. Grouped-Query Attention (GQA) is the practical compromise: shrinks cache by H/g while losing almost no quality. Lever 4 (lower precision of representation) at the algorithmic level — fewer K/V heads is a kind of structured pruning.

Positional Encodings

Position information must enter the model somehow, and the choice has direct performance consequences:

  • Sinusoidal (Vaswani 2017): precomputed table, added to embeddings. Zero memory at runtime, zero extra FLOPs. But fixed — extending beyond training length degrades fast.
  • Learned absolute: a learnable matrix of shape (Nmax,dmodel)(N_{\max}, d_{\text{model}}). Costs memory proportional to max sequence length and breaks at unseen positions.
  • RoPE (Rotary) (Su et al., 2021): applies a 2D rotation to Q and K based on their absolute position; the dot product QKᵀ then naturally encodes relative position. Implemented as two element-wise multiplies fused into the QK matmul kernel — adds zero FLOPs in practice, no extra memory, and extrapolates well. Used in Llama, Mistral, GPT-NeoX.
  • ALiBi (Press 2022): adds a static linear bias to the attention scores before softmax. Even cheaper — a single element-wise add — and extrapolates to sequences far longer than training.

Notice the trend: each newer scheme moves from a separate component to something fused into an existing kernel. That is Lever 2 — operator fusion — applied to position information.

KV-Cache Optimizations

Once you accept that inference is gated by KV-cache memory, every gigabyte recovered is more concurrent users. The major techniques:

  1. Paged Attention (vLLM): treat the KV cache like virtual memory. Allocate cache in fixed-size pages instead of one contiguous slab per request. Eliminates fragmentation, lets pages be shared across requests with identical prompts, and enables batching requests of very different lengths together. The single biggest inference systems win of 2023.
  2. KV-cache quantisation: store K, V in INT8 or INT4 instead of FP16. Halves or quarters cache memory at modest accuracy cost. Lever 4 applied to the cache, not the weights.
  3. Sliding window attention (Mistral, Longformer): each token only attends to the last W tokens. Cache size becomesO(W)O(W) per layer instead of O(N)O(N). Combine with global attention on a few special tokens to keep long-range information.
  4. Eviction (StreamingLLM, H2O): heuristically drop cache entries that future tokens are unlikely to attend to. Constant memory regardless of sequence length, with carefully chosen retention rules.

Transformer Scaling Laws

Performance optimisation also dictates what model is worth training in the first place. The Chinchilla scaling laws (Hoffmann et al., 2022) showed that for a fixed compute budget CC, the optimal allocation between model parameters NN and training tokens DD satisfies NC0.5,  DC0.5N \propto C^{0.5}, \; D \propto C^{0.5} — roughly, train a smaller model on more data than GPT-3 did.

Inference cost scales differently. For deployed models the per-token cost is roughly 2N2N FLOPs (forward pass) plus the KV-cache memory traffic. Once a model is going to serve trillions of tokens, every percent shaved off N or off the KV cache pays for itself many times over — which is why the recent generation of production LLMs (Llama-3, Mistral, Claude Haiku) all pair Chinchilla- optimal training with aggressive inference engineering: GQA, FP8 / INT4 weights, paged KV cache, speculative decoding, and heavy use of Flash Attention.

Mixture-of-Experts (MoE) models (Mixtral, Switch Transformer, DeepSeek-V3) push this further: they have huge parameter counts but activate only a small subset per token, decoupling capacity from per-token compute. Performance optimisation has reshaped what we even mean by "model size".


Distributed Training: DDP, FSDP, ZeRO

Once a single GPU is saturated, the only path forward is more GPUs — but that path branches into four very different strategies, each suited to a different bottleneck. The choice is dictated by which resource ran out first: compute, activation memory, optimizer-state memory, or per-layer parameter memory.

StrategyWhat is splitBest whenCommunication cost
DDP (Data Parallel)Batch across N GPUs; each holds a full model copyModel fits on one GPU; you want larger effective batchAll-reduce of gradients each step (~2× model size)
FSDP / ZeRO-3Parameters, gradients, optimizer state — sharded across NModel does NOT fit on one GPU; up to ~10× larger modelsAll-gather params + reduce-scatter grads each layer
Tensor ParallelEach weight matrix sliced across GPUsA single weight matrix > one GPU; intra-node onlyAll-reduce within each transformer block (high)
Pipeline ParallelLayers grouped into stages, GPUs form a pipelineVery deep models; cross-node fineActivation hand-off between stages; bubble overhead

Data Parallel (DDP)

The simplest scale-out. Replicate the model on N GPUs, give each a different slice of the batch, and synchronise gradients with all-reduce after the backward pass. Effective batch size becomes NBN \cdot B; per-step time stays roughly the same as single-GPU because the all-reduce overlaps with backward.

🐍python
1import torch
2import torch.distributed as dist
3from torch.nn.parallel import DistributedDataParallel as DDP
4
5dist.init_process_group(backend="nccl")          # one process per GPU
6torch.cuda.set_device(local_rank)
7
8model = MyModel().to(local_rank)
9model = DDP(model, device_ids=[local_rank])      # wraps + hooks all-reduce
10
11# Sampler shards the dataset so each rank sees disjoint examples
12sampler = torch.utils.data.distributed.DistributedSampler(train_set)
13loader  = torch.utils.data.DataLoader(train_set, batch_size=B, sampler=sampler)
14
15for epoch in range(num_epochs):
16    sampler.set_epoch(epoch)                     # reshuffle deterministically
17    for x, y in loader:
18        loss = loss_fn(model(x), y)
19        opt.zero_grad(); loss.backward()         # all-reduce happens here
20        opt.step()
The DDP gotcha that wastes hours. If different ranks take different code paths (e.g. a conditional branch on rank-local data), DDP's all-reduce will hang silently. Always make every rank execute the same forward/backward graph, and use find_unused_parameters=False when possible — leaving it on hides bugs by tolerating divergent graphs.

FSDP and ZeRO Sharding

DDP keeps a full copy of model weights, gradients, and optimizer state on every rank. For a 7B-parameter model in BF16 with Adam, that is roughly 7B(2+2+12)=112 GB7\text{B} \cdot (2 + 2 + 12) = 112 \text{ GB} per rank — too big for any single GPU. ZeRO (Rajbhandari et al., 2020) and PyTorch's FSDP partition this state across ranks: each GPU owns 1/N1/N-th of the parameters, gradients, and optimizer state, and gathers full layers on demand.

StageShardedMemory / rankComm overhead
ZeRO-1Optimizer state~50% of DDPLow
ZeRO-2+ Gradients~25% of DDPMedium
ZeRO-3 / FSDP+ Parameters~1/N of DDPHigh (per-layer all-gather)
🐍python
1from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
2from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
3
4mp = MixedPrecision(
5    param_dtype=torch.bfloat16,
6    reduce_dtype=torch.bfloat16,
7    buffer_dtype=torch.bfloat16,
8)
9
10model = FSDP(
11    MyTransformer(),
12    sharding_strategy=ShardingStrategy.FULL_SHARD,   # ZeRO-3 equivalent
13    mixed_precision=mp,
14    device_id=torch.cuda.current_device(),
15)
16# Use exactly like DDP: model(x), loss.backward(), opt.step().
FSDP wraps blocks, not whole models. The right wrapping policy is per-transformer-block (not the whole model in one shard) — otherwise you lose the per-layer all-gather window that lets FSDP overlap communication with compute. Use transformer_auto_wrap_policy for transformers, or a custom auto_wrap_policy for other architectures.

Tensor and Pipeline Parallelism

For models where even a single weight matrix exceeds one GPU's memory (300B+ params), or where the all-gather traffic of FSDP becomes the bottleneck, partition the model itself.

  • Tensor parallel (TP) (Megatron-LM): split each weight matrix column-wise or row-wise across GPUs in the same node. Each transformer block needs an all-reduce — high bandwidth required, so almost always intra-node over NVLink.
  • Pipeline parallel (PP) (GPipe, PipeDream): stack layers into stages, one stage per GPU group. Activations flow forward, gradients flow backward — like a CPU pipeline. Solves deep models on slow links but suffers from bubble overhead when the pipeline drains.
  • 3-D parallelism: combine DP × TP × PP. Used by GPT-3 (DP=64, TP=8, PP=12) and most modern frontier-scale training.
The decision tree. Model fits on one GPU? Use DDP. Doesn't fit but you have NVLink? FSDP. Single matrix exceeds one GPU? Add tensor parallelism. More than ~256 GPUs? Add pipeline parallelism. The frontier-scale stack is FSDP × TP × PP — but you almost certainly do not need this; over-engineering distributed setups is a leading cause of wasted compute.

Profiling: Measuring Before Optimising

Every optimisation in this section is conditional on which bottleneck you have. Guessing wastes weeks. The right reflex is: profile first, optimise second. PyTorch ships with a profiler that exports Chrome-trace-format timelines and links each kernel back to the Python line that launched it.

🐍python
1from torch.profiler import profile, record_function, ProfilerActivity, schedule
2
3# warmup → active → repeat — keeps the trace small
4sched = schedule(wait=1, warmup=1, active=3, repeat=1)
5
6with profile(
7    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
8    schedule=sched,
9    record_shapes=True,
10    profile_memory=True,
11    with_stack=True,
12) as prof:
13    for step, (x, y) in enumerate(loader):
14        with record_function("forward"):
15            logits = model(x.cuda(non_blocking=True))
16        with record_function("loss"):
17            loss = loss_fn(logits, y.cuda(non_blocking=True))
18        with record_function("backward"):
19            loss.backward()
20        opt.step(); opt.zero_grad()
21        prof.step()                       # advances the schedule
22
23# Top kernels by self CUDA time
24print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=15))
25
26# Save Chrome trace; open with chrome://tracing or perfetto.dev
27prof.export_chrome_trace("trace.json")
Profiler signatureLikely causeFirst fix
Long gaps with GPU idle, big CPU barsDataloader is blocking; no async H2D copiesnum_workers, pin_memory, non_blocking=True
Many tiny kernels back-to-backPer-op dispatch overheadtorch.compile or CUDA graphs
Single huge kernel dominatesCompute-bound — already near roofLook for redundant FLOPs (sparsity, low-rank)
memcpyHtoD or DtoH barsData crossing PCIe each stepMove tensors to GPU once and reuse; non_blocking copies
.item() / .cpu() inside hot loopForces GPU sync; serialises everythingDefer scalar reads; log async
Other profilers worth knowing. Nsight Systems (NVIDIA, deeper kernel-level view), py-spy (sampling Python profiler — finds hot CPU code), nvtop (live GPU utilisation), and the W&B System tab (per-step GPU/CPU/network utilisation cheaply). Pick one tool, learn it well — the diagnostic value of any profiler comes from familiarity, not feature count.

torch.compile and CUDA Graphs

Eager-mode PyTorch dispatches each op separately — the per-op overhead is roughly 50 µs. For a transformer block with hundreds of small element-wise ops this overhead dominates on small batches. torch.compile (introduced in PyTorch 2.0) traces the Python forward function, fuses adjacent ops, and emits a single optimised kernel.

🐍python
1import torch
2
3model = MyModel().cuda()
4
5# One line. Wraps model.forward in TorchDynamo + Inductor.
6model = torch.compile(model, mode="reduce-overhead")  # or "max-autotune"
7
8# Use as normal. First few calls are slow (compilation); subsequent calls fly.
9for x, y in loader:
10    loss = loss_fn(model(x), y)
11    loss.backward(); opt.step(); opt.zero_grad()
ModeWhat it doesWhen to use
defaultTrace + fuse element-wise opsMost cases
reduce-overhead+ CUDA graphs (replay captured launch sequence)Small batches, dispatch-bound
max-autotune+ Triton autotuning of fused matmulsInference / very repetitive shapes

CUDA graphs are the magic behind reduce-overhead: the launch sequence of a forward pass is captured once, then replayed as a single GPU command instead of N Python-level launches. The first launch costs ~10 ms of capture; every subsequent launch is essentially free. Limits: the graph is fixed-shape, so dynamic-shape inputs (variable batch or sequence length) trigger expensive recapture — which is why many serving stacks pad to fixed shapes.

The torch.compile failure mode. Code that mixes Python-level branching on tensor values (a graph break for every if-statement) compiles into many small graphs and barely speeds up. Restructure to keep tensor-dependent control flow inside differentiable masks (torch.where) when possible.

Quantization: PTQ vs QAT

The precision table earlier listed INT8 and INT4 as serving formats. There are two families of techniques to get there, and the distinction matters.

TechniqueTraining costAccuracyWhen to use
PTQ (Post-Training Quantization)Zero retraining; small calibration setOften 0.5–2% lossFast deployment of an existing FP model
QAT (Quantization-Aware Training)Fine-tune with simulated quant opsOften within 0.1% of FPWhen PTQ hurts too much
GPTQ / AWQ (weight-only)Hours of calibration; no full retrainingExcellent for LLM weights at INT4LLM inference at minimal quality cost
SmoothQuantCalibration onlyActivation outliers handledINT8 LLM activations + weights

For LLM serving the practical recipe in 2026 is: weights in INT4 (GPTQ or AWQ), activations in BF16 or INT8 (SmoothQuant), KV cache in INT8. This typically triples throughput on a single GPU at < 1% perplexity cost versus the BF16 baseline.


Speculative Decoding

Inference's deepest secret is that autoregressive decoding is fundamentally sequential — token t+1t+1 needs token tt. Speculative decoding (Leviathan et al., 2023; Chen et al., 2023) breaks the sequential chain by using a small "draft" model to propose kk tokens at once, then having the large target model verify them all in one forward pass. Verified prefix is accepted; the rest is regenerated.

The win comes from amortising the target model's per-forward-pass overhead across kk tokens. Acceptance rates of 60–80% on natural text are common, giving 2–3× wall-clock speedups for free (no quality loss — the target distribution is exactly preserved).

Picking the draft model. A draft model that is too weak gets too many tokens rejected; one that is too strong spends too much per draft. The empirical sweet spot is roughly 10–20× smaller than the target. For Llama-3 70B, a 7B draft works well; for Claude-class models, a small same-family model trained on the same data is usually best.

Chapter Capstone: An End-to-End Debug

To close out the chapter, walk through one realistic scenario that exercises every section.

The scenario. A 1.3B-parameter transformer trains cleanly for the first 8,000 steps, then loss spikes to NaN at step 8,212 every restart. Validation loss right before the spike was healthy. GPU utilisation has been 70%.

  1. (§1) Reproduce and shrink. Pin the seed; save the optimizer state and one offending minibatch. Confirm the NaN reproduces. Then test: does the same minibatch on the step-8,000 checkpoint with a 10× lower LR still NaN? If no, the spike is LR-conditional.
  2. (§1) Detect-anomaly + first NaN. Re-run the offending step under torch.autograd.set_detect_anomaly(True). The traceback names the first op that produced a NaN — say a log_softmax in the loss head. The cause is upstream: a hidden state with a 0-variance LayerNorm input.
  3. (§2) Activation histograms. Hook the layer before the offending LayerNorm. Plot pre-LN activations across training. The variance of one channel collapsed to zero around step 7,500 — a single neuron died, taking down LayerNorm with it. This is exactly the dying-ReLU pattern from §1, manifesting two layers downstream.
  4. (§2) Per-layer gradient norms. Confirm that the offending layer's gradients had been shrinking steadily for ~500 steps before the death. The visualisation tells you when, the histogram tells you why.
  5. (§1 + §3) Apply the fix. Switch ReLU → GELU in that block (cures the dying-neuron mode), add gradient clipping at τ=1.0\tau = 1.0 (defends against future spikes), and enable BF16 mixed precision to widen the loss-scale margin (§3 lever 4).
  6. (§3) Profile after the fix. The NaN is gone but throughput is still 70% — run torch.profiler. The trace shows a long gap each step where the dataloader is blocking: switch to num_workers=8, pin_memory=True, and async H2D copies. GPU utilisation now 92%.
  7. (§3) Compile. torch.compile(model, mode="max-autotune") fuses the attention block; another 1.4× wall-clock win.
The pattern. A NaN in §1 was a numerical instability in §3 caused by a dead neuron diagnosed by §2's activation histogram. None of the three sections solved this alone. The chapter's point is that the techniques compose — debugging is most powerful when symptoms (§1), looking inside (§2), and measurement (§3) work together.

Performance Cheat Sheet

SymptomLikely causeLeverConcrete fix
GPU utilisation < 30%Memory-bound kernelFusion / tilingtorch.compile, Flash Attention, fused LayerNorm
Out-of-memory at long contextKV cache or attention matrixKV cache + Flash AttentionGQA, paged attention, INT8 KV, sliding window
Training stalls on small opsPython / dispatch overheadVectorisationLarger batch, torch.compile, CUDA graphs
Slow inference latencyAutoregressive recomputeKV cacheEnable use_cache=True, speculative decoding
Numerical instability in BF16Mantissa loss in reductionsMixed precision recipeKeep accumulator in FP32, master weights in FP32
Activation memory dominatesStoring all forward activationsGradient checkpointingtorch.utils.checkpoint per transformer block
transpose followed by matmul is slowNon-contiguous stridesMemory layoutAdd .contiguous() or use packed QKV layout

Quick Check

Q1. A kernel does 2 GFLOPs and reads/writes 1 GB of data. Your GPU has Pmax=300P_{\max} = 300 TFLOP/s and Bmax=2B_{\max} = 2 TB/s. Where is the kernel on the roofline? What should you do first?
Answer: Arithmetic intensity I=2/10242×103I = 2 / 1024 \approx 2 \times 10^{-3} FLOP/B. Ridge point is I=150I^{*} = 150 FLOP/B. The kernel sits 5 orders of magnitude below the ridge — wildly memory-bound. First fix: fuse it with a neighbour to reduce HBM traffic.
Q2. You enable a KV cache and your inference latency improves 100×. A week later, with longer prompts, latency creeps back up. What changed?
Answer: The cache eliminated the projection waste, but attention itself is still O(N)O(N) per step due to reading K and V from HBM. At long N the bandwidth of those reads dominates again — the next levers are GQA (smaller cache), Flash Attention (fewer reads per FLOP), or sliding window (constant cache).
Q3. Why does online softmax have to rescale the running denominator when a new max arrives — why not just keep the old denominator and add a new term?
Answer: The denominator is exim\sum e^{x_i - m}; it depends on mm. If m grows, every old term needs to be re-expressed relative to the new m. Multiplying by emoldmnewe^{m_{\text{old}} - m_{\text{new}}} does exactly that — it converts each eximolde^{x_i - m_{\text{old}}} into eximnewe^{x_i - m_{\text{new}}} in one multiplication. Skipping the rescale would underweight the new term relative to the old.

Summary

Performance optimisation in deep learning is dominated by memory traffic, not arithmetic. The roofline model tells you which roof you sit under; the GPU memory hierarchy tells you why. Five levers move you toward the roof:

  1. Vectorisation — let BLAS / cuBLAS do scalar work in tight C loops, never Python.
  2. Operator fusion — collapse adjacent ops into one kernel pass; online softmax is the canonical example.
  3. Memory layout — keep tensors contiguous in the access order the next op needs.
  4. Mixed precision — BF16 / FP8 / INT4 to halve memory and double throughput where range and precision allow.
  5. Recomputation (gradient checkpointing) — trade compute for memory when activations dominate.

Modern transformer infrastructure is built on these levers. Flash Attention is fusion + tiling at the SRAM rung. The KV cache is recompute-vs-store applied to autoregressive inference. MQA / GQA shrink the cache at the cost of representational capacity. RoPE / ALiBi are fused position encodings. Paged attention and KV quantisation chase the cache memory pressure further. And the scaling laws tell you which model even deserves these optimisations.

When your network feels slow, do not start by guessing. Measure the arithmetic intensity. Place the kernel on the roofline. Find the nearest rung of the memory hierarchy that can hold the working set. Then — and only then — choose which lever to pull.


References

  • Williams, S., Waterman, A., & Patterson, D. (2009). "Roofline: an insightful visual performance model for multicore architectures". Communications of the ACM.
  • Dao, T., Fu, D., Ermon, S., Rudra, A., & Re, C. (2022). "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness". NeurIPS.
  • Dao, T. (2023). "FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning". arXiv:2307.08691.
  • Shazeer, N. (2019). "Fast Transformer Decoding: One Write-Head is All You Need" (MQA). arXiv:1911.02150.
  • Ainslie, J. et al. (2023). "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints". EMNLP.
  • Su, J. et al. (2021). "RoFormer: Enhanced Transformer with Rotary Position Embedding". arXiv:2104.09864.
  • Press, O., Smith, N., & Lewis, M. (2022). "Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation" (ALiBi). ICLR.
  • Hoffmann, J. et al. (2022). "Training Compute-Optimal Large Language Models" (Chinchilla). NeurIPS.
  • Kwon, W. et al. (2023). "Efficient Memory Management for Large Language Model Serving with PagedAttention" (vLLM). SOSP.
  • Rajbhandari, S., Rasley, J., Ruwase, O., & He, Y. (2020). "ZeRO: Memory Optimizations Toward Training Trillion Parameter Models". SC.
  • Shoeybi, M. et al. (2019). "Megatron-LM: Training Multi-Billion Parameter Language Models Using Model Parallelism". arXiv:1909.08053.
  • Huang, Y. et al. (2019). "GPipe: Efficient Training of Giant Neural Networks Using Pipeline Parallelism". NeurIPS.
  • Frantar, E. et al. (2023). "GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers". ICLR.
  • Lin, J. et al. (2024). "AWQ: Activation-aware Weight Quantization for LLM Compression". MLSys.
  • Xiao, G. et al. (2023). "SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models". ICML.
  • Leviathan, Y., Kalman, M., & Matias, Y. (2023). "Fast Inference from Transformers via Speculative Decoding". ICML.
  • Chen, C. et al. (2023). "Accelerating Large Language Model Decoding with Speculative Sampling". arXiv:2302.01318.
  • PyTorch documentation. "FullyShardedDataParallel". pytorch.org/docs/stable/fsdp.html.
  • PyTorch documentation. "torch.profiler". pytorch.org/docs/stable/profiler.html.
  • PyTorch documentation. "torch.compile Tutorial". pytorch.org/tutorials/intermediate/torch_compile_tutorial.html.
Loading comments...