Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness", NeurIPS 2022
Learning Objectives
After studying this section you will be able to:
Explain why standard attention is bottlenecked by memory bandwidth, not arithmetic, and quantify the gap between HBM and SRAM on modern GPUs.
Derive the online softmax identities that allow exact softmax computation in a single pass without storing the full N×N score matrix.
Walk through a complete tiled computation of our running example "the cat sat on the mat" and verify that the output is identical to Chapter 1.
Explain how Flash Attention achieves O(N) memory and 3–8× wall-clock speedup while computing mathematically identical results.
Describe how Flash Attention integrates with multi-head attention, KV-cache, causal masking, and modern systems like PyTorch's scaled_dot_product_attention.
The Real Problem
By Chapter 12, we have seen attention variants that change the mathematics — masking, sparsity, linear kernels. Flash Attention is fundamentally different: it does NOT change what is computed. The output is mathematically identical to standard scaled dot-product attention from Chapter 1 — identical up to floating-point precision, since the reordered arithmetic may differ by machine epsilon (~10−16). What it changes is how the computation is executed on hardware.
The problem it solves is simple to state: standard attention is slow because it moves too much data, not because it does too much math.
The GPU Memory Hierarchy
Modern GPUs have a two-level memory hierarchy. Understanding it is essential to understanding why Flash Attention works.
Memory Level
Capacity
Bandwidth
Latency
Role
HBM (High-Bandwidth Memory)
40–80 GB (A100)
~2 TB/s
~200 cycles
Main GPU memory — stores Q, K, V, outputs
SRAM (Shared Memory / Registers)
~20 MB (A100)
~19 TB/s
~5 cycles
On-chip cache — where computation happens
The key ratio: SRAM is roughly 10\u00d7 faster than HBM but ~4,000\u00d7 smaller (80 GB vs ~20 MB on A100). Every byte read from or written to HBM costs approximately 40\u00d7 more time than an arithmetic operation. This gap is called the memory wall.
The Memory Wall in Standard Attention
Standard attention (Chapter 1) executes these steps, each requiring a full round-trip to HBM:
ReadQ and K from HBM → compute S=QKT/dk → writeS (N×N) back to HBM
ReadS from HBM → compute P=softmax(S) → writeP (N×N) back to HBM
ReadP and V from HBM → compute O=PV → writeO to HBM
Total HBM access: O(N2+Nd) reads/writes. For a sequence of N=8192 tokens with d=128, the score matrix alone is 81922=67,108,864 elements — 256 MB in float32. This matrix is written, read, softmaxed, written again, read again. The GPU spends most of its time waiting for data, not computing.
The core insight: Standard attention is memory-bound, not compute-bound. The GPU has enough arithmetic throughput to compute attention much faster than it can read and write the intermediate matrices. Flash Attention eliminates these intermediates entirely.
Exact vs Approximate: A Key Distinction
Chapter 12 explored methods like Linformer and Performer that reduce attention's complexity by changing the mathematics — using low-rank projections or kernel approximations to avoid computing the full N×N score matrix. These methods produce different outputs from standard attention, trading accuracy for speed.
Flash Attention takes a fundamentally different approach: it produces mathematically identical output to standard attention (identical up to floating-point rounding, typically differing by less than 10−16). The innovation is entirely in how the computation is scheduled on hardware. This distinction is critical because Flash Attention can be used as a drop-in replacement with no accuracy trade-off — every model that uses standard attention can switch to Flash Attention and get the exact same training dynamics and outputs, just faster.
Approximate Methods (Ch. 12)
Flash Attention (Ch. 13)
Output
Different from standard attention
Identical to standard attention
Approach
Change the math (low-rank, kernels)
Change the execution (tiling, IO-awareness)
Time complexity
O(N) or O(N log N)
O(N²d) (same as standard)
Memory complexity
O(N)
O(N) (same improvement)
Accuracy trade-off
Yes — approximation error
None — exact computation
Drop-in replacement
No — may affect model quality
Yes — identical training dynamics
The Story Behind Flash Attention
Tri Dao's IO-Aware Insight
In 2022, Tri Dao (Stanford, advised by Christopher Ré) published "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness." The paper's key contribution was not a new attention formula but a new way to think about the problem: treat attention as an IO problem, not a compute problem.
Dao observed that the standard attention algorithm was designed for mathematical clarity, not hardware efficiency. It creates a clean pipeline (scores → softmax → output) but this pipeline requires materializing the N×N score matrix in HBM. His insight: if we can fuse all three steps into a single kernel that processes data in tiles small enough to fit in SRAM, we eliminate all intermediate HBM writes.
The challenge: softmax is a non-decomposable operation. Unlike addition or max, the softmax of a row depends on all elements in that row (because of the denominator ∑jesj). You cannot compute the softmax of the first half of a row, then the second half, and combine them — at least not with the naive formulation.
The Online Softmax Trick
The breakthrough came from adapting the online softmax algorithm (Milakov & Gimelshein, 2018). The idea: maintain running statistics — the current maximum m, the current normalizer l, and the current unnormalized output O — and update them incrementally as each tile of keys is processed. When the maximum changes (because a new tile has a larger score), all previously accumulated values are rescaled to be consistent with the new maximum.
This rescaling is the mathematical trick that makes tiling work. It preserves exact numerical equivalence while allowing computation to proceed one small tile at a time, entirely within SRAM.
Mathematical Definition
The Standard Attention Formula
Flash Attention computes exactly the same function as Chapter 1:
Output=softmax(dkQKT)⋅V
where Q,K∈RN×dk and V∈RN×dv. The difference is entirely in the execution strategy.
Online Softmax Equations
For each query row i, we process the keys in tiles of size T. Let the current K-tile cover columns j0 to j1. We maintain three running statistics:
m: the running maximum score seen so far (scalar per query)
l: the running softmax denominator (scalar per query)
O: the running unnormalized output (vector per query)
For each new tile, compute the tile scores st=Qi⋅KtT/dk, then update:
mnew=max(mold,maxj(st,j))
α=emold−mnew(rescaling factor)
lnew=lold⋅α+∑jest,j−mnew
Onew=Oold⋅α+∑jest,j−mnew⋅Vj
After all tiles are processed, the final output is: outputi=Ofinal/lfinal
Why Online Softmax Works
The key identity is that rescaling preserves the ratio. Suppose after tile 1 we have computed l1=∑jesj−m1. When tile 2 arrives with a new maximum m2>m1, every previously computed exponential esj−m1 needs to become esj−m2. The multiplicative correction is:
esj−m2=esj−m1⋅em1−m2
This is a single scalar multiplication that corrects all previously accumulated values at once. The factor α=em1−m2 is always ≤1 (since m2≥m1), so it never causes overflow.
Why this matters: Without online softmax, computing softmax requires two passes over the data — one to find the max, one to compute exp and normalize. Online softmax does it in a single pass by correcting previous values when the max changes. This is what allows the entire attention computation to be fused into one kernel.
Step-by-Step Worked Example
Let's trace Flash Attention for query row 0 ("The") using our running example. With N=5 and T=2, we process 3 K-tiles: K[0:2], K[2:4], K[4:5]. The running statistics start at: m=−∞, l=0, O=[0,0,0,0].
=[0.2254,0.4135,0.2964,0.2964] — identical to standard attention (Chapter 1).
Notice that for "The", the maximum score was 1.000 (from "cat" in K-Tile 1), and it never changed across subsequent tiles. This meant α=1 for K-Tiles 2 and 3 — no rescaling was needed. In practice, the max stabilizes quickly because early tiles often contain the highest-scoring keys.
Interactive: Tiled Computation Explorer
Step through all 9 tiles of Flash Attention. Watch how the running statistics (m, l, O) evolve and how the current tile block moves across the score matrix. The highlighted cells show what is currently in SRAM — only a small 2×2 block at a time.
Loading tiled computation visualizer...
Full Attention Weights and Output
Flash Attention produces the same weights and output as standard attention. The weight matrix is never explicitly computed in Flash Attention — it exists only implicitly through the tiled accumulation — but mathematically the result is identical.
Attention Weight Matrix — Standard = Flash (5×5)
The
cat
sat
on
mat
The
0.1095
0.2976
0.1805
0.1805
0.2318
cat
0.4026
0.0898
0.2442
0.1481
0.1153
sat
0.1519
0.2505
0.2505
0.1519
0.1951
on
0.1903
0.1903
0.1154
0.3137
0.1903
mat
0.1892
0.1892
0.1892
0.1892
0.2430
Output Matrix — Standard = Flash (5×4)
dim-0
dim-1
dim-2
dim-3
The
0.2254
0.4135
0.2964
0.2964
cat
0.4602
0.1475
0.3018
0.2058
sat
0.2495
0.3481
0.3481
0.2495
on
0.2854
0.2854
0.2106
0.4089
mat
0.3108
0.3108
0.3108
0.3108
Maximum difference from standard attention:5.55×10−17 — floating-point machine epsilon. The outputs are mathematically identical.
Backward Pass and Recomputation
The forward pass discussion above is only half the story. For training, we need gradients. Standard attention stores the full N×N attention weight matrix P during the forward pass so it is available for backpropagation. Flash Attention cannot store this matrix — that is the whole point — so how does the backward pass work?
The answer is recomputation (Algorithm 2 in the paper). During the forward pass, Flash Attention stores only three things per query row: the output O, the row-wise maximum m, and the row-wise normalizer l. During the backward pass, when gradients with respect to the attention weights are needed, the algorithm recomputes the attention matrix block-by-block from Q, K, V (which are already stored for backpropagation), using the saved m and l to reconstruct the correct softmax values for each tile.
This recomputation strategy is arguably the paper's most important practical contribution. It is what makes Flash Attention viable for training, not just inference. The extra arithmetic cost of recomputation is modest (one additional pass over the data) and is far outweighed by the massive reduction in HBM access. In practice, the reduced memory traffic more than compensates for the extra FLOPs, resulting in faster training overall.
Why this matters: Without the recomputation trick, Flash Attention would save memory during inference but still need the full N×N matrix for training. The backward pass algorithm is what makes Flash Attention a complete training solution, not just an inference optimization.
Causal Masking Integration
For decoder models (GPT, Llama, etc.), causal masking prevents tokens from attending to future positions (Chapter 3). The paper shows how this integrates naturally with tiling:
Skip entire tiles: When all K-positions in a tile are strictly after all Q-positions (the tile is entirely above the diagonal in the attention matrix), the entire tile can be skipped — no computation, no memory access.
Partial masking: For tiles that straddle the diagonal boundary, masking is applied within the tile by setting future-position scores to −∞ before the softmax computation. The online softmax handles these masked values naturally.
This tile-level skipping provides an additional speedup for causal attention beyond the IO benefits. Roughly half the tiles can be skipped entirely (those in the upper-triangular region of the block grid), reducing effective computation by nearly 2× for causal models compared to bidirectional attention.
Dropout in Tiled Computation
Standard attention applies dropout to the attention weight matrix P, which requires storing the full N×N dropout mask. Flash Attention avoids this by storing only the pseudorandom number generator seed for each tile block. During both forward and backward passes, the dropout mask is regenerated on-the-fly from the seed, keeping memory usage at O(1) per tile rather than O(N2) for the full mask.
Block-Sparse Flash Attention
Section 5 of the paper introduces block-sparse Flash Attention, which combines tiling with predetermined sparsity patterns. Given a block-sparsity mask that specifies which tiles to compute and which to skip, the algorithm simply avoids loading and computing masked-out tiles. The IO complexity improves to:
O(MN2d2⋅s)where s is the fraction of non-zero blocks
With 50% sparsity, this provides roughly 2× additional speedup over dense Flash Attention while maintaining exact attention for the non-masked positions. This bridges the gap between exact and sparse attention methods — you get the IO efficiency of Flash Attention combined with the computational savings of sparsity.
Applications Across Domains
Long-Context NLP
Flash Attention enabled the jump from 2K to 100K+ context windows. Modern large language models — including GPT-4, Claude, Gemini, and Llama — are widely reported to use Flash Attention variants, as the technique has become a standard building block for long-context transformers. Without it, the O(N2) memory of standard attention makes sequences beyond ~8K tokens impractical on 80 GB GPUs.
Vision Transformers
High-resolution images tokenized into patches produce very long sequences (a 1024×1024 image with 16×16 patches = 4096 tokens). Flash Attention makes ViTs practical at these resolutions without the memory overhead of storing 40962 attention matrices per head per layer.
Scientific Sequence Modeling
Protein sequences (up to ~30K amino acids) and genomic sequences (millions of base pairs) require long-range attention. AlphaFold 2 (2021) used its own memory-efficient chunking strategy before Flash Attention existed, but subsequent protein structure models (OpenFold, ESMFold, and AlphaFold 3) benefit directly from Flash Attention-style tiling for their pair representation attention layers.
Code Generation
Codex, StarCoder, and similar models process entire files (10K+ tokens). Flash Attention allows these models to attend over the full file context during both training and inference, enabling cross-function reasoning that would be impossible with shorter context windows.
Connection to Modern Systems
Flash Attention 2
Dao (2023) released Flash Attention 2 with two major improvements: (1) swapping the inner/outer loop order so the outer loop iterates over Q tiles (not K tiles), reducing non-matmul FLOPs by ~50%, and (2) better work partitioning across GPU thread blocks and warps. Flash Attention 2 achieves 50–73% of the theoretical peak FLOP/s on A100 GPUs, versus 25–40% for Flash Attention 1.
Flash Attention 3
Flash Attention 3 (Shah et al., 2024) targets Hopper architecture GPUs (H100). It exploits asynchronous data movement with the Tensor Memory Accelerator (TMA), warp specialization for overlapping computation and data loading, and FP8 low-precision support. It achieves 75% of the H100's theoretical peak throughput.
KV-Cache Integration
During autoregressive generation (Chapters 3, 5, 6), each new token only needs to compute attention scores against all previous keys/values. Flash Attention tiles work naturally with the KV-cache: the K and V tiles are simply loaded from the cache rather than recomputed. The online softmax approach handles the growing sequence length without ever materializing the full attention matrix.
PyTorch scaled_dot_product_attention
Since PyTorch 2.0, torch.nn.functional.scaled_dot_product_attention() automatically dispatches to Flash Attention when available. It selects the optimal backend (Flash Attention, Memory-Efficient Attention, or math fallback) based on input shapes, device, and dtype. In production, you never call Flash Attention directly — PyTorch does it for you.
Complexity Analysis
Metric
Standard Attention
Flash Attention
Time Complexity
O(N²d)
O(N²d) (same)
HBM Memory
O(N² + Nd)
O(Nd) — no N² matrix
HBM Reads/Writes
O(N² + Nd)
O(N²d² / M) where M = SRAM size
Wall-Clock Speed
Baseline
3–8× faster
Output
Baseline
Mathematically identical
The paradox: Flash Attention has the same arithmetic complexity but is much faster. The speedup comes entirely from reducing HBM access. The IO complexity improves from O(N2+Nd) to O(N2d2/M). With typical SRAM sizes, this is a large constant-factor improvement.
Memory savings are even more dramatic. Standard attention stores the N×N score matrix, which at N=8192 is 256 MB per head per layer. Flash Attention stores only the running statistics (O(N) per head) plus the current tile (O(T2) in SRAM). This reduction enables training with much longer sequences on the same hardware.
Flash Attention does NOT reduce FLOPs. It performs exactly the same number of multiply-add operations as standard attention. The speedup comes from doing those operations with fewer memory round-trips. This is a hardware optimization, not an algorithmic shortcut.
Formal IO Model and Lower Bound
The paper defines a formal IO model of computation (Definition 1): given SRAM of size M words and HBM of unbounded size, the cost of an algorithm is measured by the number of HBM accesses (reads and writes). Under this model, Flash Attention achieves IO complexity O(N2d2/M), improving over standard attention's O(N2d+Nd) HBM accesses.
More importantly, Proposition 3 proves this is asymptotically optimal: no exact attention algorithm can achieve fewer than Ω(N2d2/M) HBM accesses in the worst case. This means Flash Attention is not just faster — it is provably the best possible IO-aware exact attention algorithm (up to constant factors).
Benchmark Results from the Paper
The paper reports concrete speedups on standard benchmarks:
BERT-large training: 15% end-to-end speedup over the MLPerf 1.1 speed record.
GPT-2 training: Up to 3\u00d7 speedup compared to HuggingFace and Megatron-LM implementations.
Long Range Arena: Flash Attention enables training on sequence lengths up to 16K on a single GPU, achieving new state-of-the-art on the Path-X task (sequence length 16,384) — a task previously unsolved by standard Transformers due to memory constraints.
These results demonstrate that the theoretical IO improvements translate directly to practical speedups. The gains are largest for longer sequences, where the O(N2) memory overhead of standard attention dominates runtime.
Python Implementation
A complete class-based implementation. The forward() method implements Flash Attention with online softmax tiling. The standard_attention() method provides the baseline for verification. Click any line to see the exact values flowing through the computation.
Flash Attention — NumPy Implementation
🐍flash_attention.py
Explanation(57)
Code(144)
1import numpy as np
NumPy is Python’s numerical computing library. It provides ndarray (N-dimensional array) — a fast, memory-efficient matrix type. All math in this file (matrix multiply via @, element-wise ops, broadcasting) runs as optimized C code under the hood, not slow Python loops.
EXECUTION STATE
📚 numpy = Library for numerical computing — provides ndarray, linear algebra (@ operator), random numbers, and mathematical functions (np.exp, np.max, np.sum, np.zeros, np.full). We use it for ALL matrix operations in this implementation.
as np = Creates alias ‘np’ so we write np.array() instead of numpy.array() — universal Python convention
2import math
Python’s standard library math module. We use math.sqrt() to compute the scaling factor √d_k. Unlike numpy, math.sqrt() returns a plain Python float, which is what we need for a scalar constant.
EXECUTION STATE
📚 math = Standard library module for mathematical functions. We use math.sqrt(4) = 2.0 for the attention scaling factor. Could also use np.sqrt(), but math.sqrt() is clearer for scalar values.
4class FlashAttention
Self-contained class implementing Flash Attention (Dao et al., 2022). Produces IDENTICAL output to standard attention (Chapter 1) but never materializes the full N×N score matrix. The key innovation is online softmax: processing K, V in small tiles while maintaining running statistics (m, l, O).
EXECUTION STATE
FlashAttention = Contains 3 methods: __init__ (setup), standard_attention (baseline for verification), forward (the Flash algorithm), explain (debug trace). The forward() method is the core — it uses tiled online softmax.
13def __init__(self, d_k, tile_size=2)
Constructor. Sets up the three parameters that control Flash Attention’s behavior: the key dimension (for scaling), the tile size (for chunking), and the pre-computed scaling factor.
EXECUTION STATE
⬇ input: self = The class instance being constructed — we store d_k, tile_size, and scale on it for use by forward() and explain()
⬇ input: d_k = 4 = Dimension of each query/key vector. Controls the scaling factor: larger d_k means dot products grow larger, requiring more aggressive scaling. In real transformers: d_k = d_model / n_heads (e.g., 512/8 = 64).
⬇ input: tile_size = 2 = How many rows of K/V to process at once. With T=2 and N=5 tokens: ceil(5/2) = 3 tiles per loop. In production: T is chosen to maximize SRAM utilization (e.g., T=128 on A100 with 20MB SRAM).
Pre-compute √d_k = √4 = 2.0. Dividing dot-product scores by this prevents softmax saturation. Without scaling, dot products grow proportional to d_k, pushing softmax into near-zero gradient regions.
EXECUTION STATE
📚 math.sqrt() = Returns the square root of a number as a Python float. math.sqrt(4) = 2.0, math.sqrt(64) = 8.0. We use this instead of np.sqrt() because we want a scalar float, not an ndarray.
self.scale = √4 = 2.0 = Scaling factor. Example: without scaling, Q[The]·K[cat] = 2.0 (large). After ÷2: 1.0 (moderate). With d_k=64: dots average ~64, after ÷8: ~8 — keeps softmax in a smooth region.
18def standard_attention(self, Q, K, V)
Standard O(N²) attention that materializes the full 5×5 score matrix in memory. This is the BASELINE — Flash Attention’s forward() must produce identical output. The method exists purely for verification.
EXECUTION STATE
⬇ input: self = Class instance — provides self.scale = 2.0 for dividing scores
⬇ input: Q (5×4) — Query matrix =
d0 d1 d2 d3
The 1.0 0.0 1.0 0.0
cat 0.0 2.0 0.0 1.0
sat 1.0 1.0 1.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.0 1.0
→ Q purpose = Each row is a query vector — encodes ‘what am I looking for?’. Q[The]=[1,0,1,0] looks for features in d0 and d2.
⬇ input: K (5×4) — Key matrix =
d0 d1 d2 d3
The 0.0 1.0 0.0 1.0
cat 1.0 0.0 1.0 0.0
sat 1.0 1.0 0.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.5 0.5
→ K purpose = Each row is a key vector — encodes ‘what do I contain?’. K[cat]=[1,0,1,0] matches Q[The] perfectly (dot=2).
⬇ input: V (5×4) — Value matrix =
d0 d1 d2 d3
The 1.0 0.0 0.0 0.0
cat 0.0 1.0 0.0 0.0
sat 0.0 0.0 1.0 0.0
on 0.0 0.0 0.0 1.0
mat 0.5 0.5 0.5 0.5
→ V purpose = Each row is the information that token contributes. Q and K decide WHO to attend to; V decides WHAT to retrieve.
⬆ returns = (weights (5×5), output (5×4)) — both returned for comparison with Flash Attention
19Docstring: Standard O(N²) attention
Marks this as the naive baseline that stores the full N×N score matrix. Flash Attention’s forward() avoids this O(N²) memory cost while producing the same result.
20scores = Q @ K.T / self.scale
Compute the full 5×5 scaled score matrix. THIS is the O(N²) memory allocation that Flash Attention eliminates. Every element scores[i][j] = (Q[i] · K[j]) / √d_k.
EXECUTION STATE
📚 @ (matrix multiply operator) = Python’s matrix multiplication operator (PEP 465). Q(5×4) @ K.T(4×5) = result(5×5). Each element is a dot product: result[i][j] = sum(Q[i][k] * K.T[k][j] for k in range(4)).
.T = NumPy transpose property — swaps rows and columns. K(5×4) becomes K.T(4×5). Needed so inner dimensions match: Q(5×4) @ K.T(4×5) works, Q(5×4) @ K(5×4) would fail.
/ self.scale = Element-wise division by 2.0. Every score is halved. Example: Q[The]·K[cat] = 2 raw → 2/2 = 1.0 scaled.
scores (5×5) =
The cat sat on mat
The 0.000 1.000 0.500 0.500 0.750
cat 1.500 0.000 1.000 0.500 0.250
sat 0.500 1.000 1.000 0.500 0.750
on 0.500 0.500 0.000 1.000 0.500
mat 0.500 0.500 0.500 0.500 0.750
21max_s = np.max(scores, axis=-1, keepdims=True)
Find the maximum value in each row. Subtracting this before exp() prevents overflow (the ‘log-sum-exp trick’). Since softmax(x) = softmax(x - c) for any constant c, the result is unchanged.
EXECUTION STATE
📚 np.max() = NumPy function: finds the maximum value along a specified axis. Without axis: returns the single global max. With axis: returns maxima along that dimension. Example: np.max([3, 1, 4, 1, 5]) → 5
⬇ arg 1: scores (5×5) = The input matrix to search for maxima. Each row represents one token’s attention scores to all other tokens.
⬇ arg 2: axis=-1 = Which axis to find the max along. axis=-1 means the LAST axis (columns). For a 5×5 matrix, finds the max WITHIN each row independently. axis=0 would find max in each column. axis=None would find the single global max.
→ axis example = Row 0 [0.000, 1.000, 0.500, 0.500, 0.750]: max = 1.000
Row 1 [1.500, 0.000, 1.000, 0.500, 0.250]: max = 1.500
⬇ arg 3: keepdims=True = Keep the reduced axis as a size-1 dimension. Without keepdims: result shape is (5,) — a 1D vector. With keepdims: result shape is (5,1) — a 2D column vector. This matters for broadcasting: scores(5×5) - max_s(5,1) works correctly.
Exponentiate shifted scores. Subtracting max_s first ensures the largest value per row becomes exp(0) = 1.0, preventing overflow. All other values are < 1.0.
EXECUTION STATE
📚 np.exp() = NumPy’s element-wise exponential function. Computes e^x for every element. e ≈ 2.71828. Examples: np.exp(0) = 1.0, np.exp(1) = 2.718, np.exp(-1) = 0.368. Always returns positive values (e^x > 0 for all x).
scores - max_s = Broadcasting: scores(5×5) - max_s(5,1) subtracts each row’s max from that row. Row 0: [0.0-1.0, 1.0-1.0, 0.5-1.0, 0.5-1.0, 0.75-1.0] = [-1.0, 0.0, -0.5, -0.5, -0.25]
→ why subtract max? = Without: exp(1.5) = 4.48 (fine here, but with d_k=64: exp(64) = overflow). With: exp(0) = 1.0 max. Since softmax(x) = softmax(x-c), the final result is identical.
exp_s (5×5) =
The cat sat on mat
The 0.3679 1.0000 0.6065 0.6065 0.7788
cat 1.0000 0.2231 0.6065 0.3679 0.2865
sat 0.6065 1.0000 1.0000 0.6065 0.7788
on 0.6065 0.6065 0.3679 1.0000 0.6065
mat 0.7788 0.7788 0.7788 0.7788 1.0000
Normalize each row to sum to 1.0 by dividing by the row sum. This completes the softmax: softmax(x_i) = exp(x_i - max) / sum(exp(x_j - max)).
EXECUTION STATE
📚 np.sum() = NumPy function: sums elements along a specified axis. np.sum([1, 2, 3]) = 6. With axis=-1: sums each row independently. With keepdims=True: result shape is (5,1) for broadcasting.
⬇ arg: axis=-1 = Sum along the last axis (columns within each row). Row 0: 0.3679+1.0+0.6065+0.6065+0.7788 = 3.3597
⬇ arg: keepdims=True = Result shape (5,1) not (5,), so exp_s(5×5) / sums(5,1) broadcasts correctly — each element divided by its row’s sum.
The cat sat on mat
The 0.1095 0.2976 0.1805 0.1805 0.2318
cat 0.4026 0.0898 0.2442 0.1481 0.1153
sat 0.1519 0.2505 0.2505 0.1519 0.1951
on 0.1903 0.1903 0.1154 0.3137 0.1903
mat 0.1892 0.1892 0.1892 0.1892 0.2430
Multiply weights (5×5) by V (5×4). Each output row is a weighted sum of ALL value vectors. output[The] = 0.1095×V[The] + 0.2976×V[cat] + 0.1805×V[sat] + 0.1805×V[on] + 0.2318×V[mat].
EXECUTION STATE
📚 @ (matrix multiply) = weights(5×5) @ V(5×4) = output(5×4). Each element output[i][j] = sum(weights[i][k] * V[k][j] for k in range(5)). This mixes the value vectors according to the attention weights.
output (5×4) =
d0 d1 d2 d3
The 0.2254 0.4135 0.2964 0.2964
cat 0.4602 0.1475 0.3018 0.2058
sat 0.2495 0.3481 0.3481 0.2495
on 0.2854 0.2854 0.2106 0.4089
mat 0.3108 0.3108 0.3108 0.3108
25return weights, output
Return both the attention weight matrix and the output. We return weights so we can print them and compare with Flash Attention’s implicit weights.
EXECUTION STATE
⬆ return: weights (5×5) = Attention probability matrix — each row sums to 1.0
⬆ return: output (5×4) = Weighted sum of value vectors — the final attention output
27def forward(self, Q, K, V)
THE CORE: Flash Attention forward pass. Produces identical output to standard_attention() but NEVER stores the 5×5 score matrix. Instead, processes K and V in tiles of size T=2, maintaining three running statistics per query row: m (max score seen), l (softmax denominator), O (unnormalized output). After all tiles, O/l gives the exact result.
EXECUTION STATE
⬇ input: self = Provides self.scale=2.0 and self.tile_size=2
⬇ input: Q (5×4) — Query matrix =
d0 d1 d2 d3
The 1.0 0.0 1.0 0.0
cat 0.0 2.0 0.0 1.0
sat 1.0 1.0 1.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.0 1.0
→ Q purpose = Same Q as standard_attention(). Processed in tiles of T=2 rows in the inner loop.
⬇ input: K (5×4) — Key matrix =
d0 d1 d2 d3
The 0.0 1.0 0.0 1.0
cat 1.0 0.0 1.0 0.0
sat 1.0 1.0 0.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.5 0.5
→ K purpose = Processed in tiles of T=2 rows in the outer loop. Each K-tile is loaded once from HBM to SRAM.
⬇ input: V (5×4) — Value matrix =
d0 d1 d2 d3
The 1.0 0.0 0.0 0.0
cat 0.0 1.0 0.0 0.0
sat 0.0 0.0 1.0 0.0
on 0.0 0.0 0.0 1.0
mat 0.5 0.5 0.5 0.5
→ V purpose = Tiled alongside K. Each V-tile is multiplied by the tile’s exponential scores to accumulate the output.
⬆ returns = np.ndarray (5×4) — IDENTICAL to standard attention output (diff < 10⁻¹⁶)
32N, D = Q.shape
Unpack Q’s shape: N=5 tokens, D=4 dimensions.
EXECUTION STATE
N = 5
D = 4
33D_v = V.shape[1]
Value dimension. Here D_v=4. In general can differ from D.
EXECUTION STATE
D_v = 4
34T = self.tile_size
Tile size T=2. With N=5 we get ceil(5/2)=3 tiles: [0:2], [2:4], [4:5].
EXECUTION STATE
T = 2
36m_run = np.full(N, -np.inf)
Running maximum per query row. Initialized to -∞ so the first tile’s max always wins (max(-∞, anything) = anything). This tracks the largest score seen so far for each of the 5 query rows.
EXECUTION STATE
📚 np.full(shape, fill_value) = NumPy function: creates an array of the given shape, filled with fill_value. np.full(3, 7) = [7, 7, 7]. np.full((2,3), 0.5) = [[0.5,0.5,0.5],[0.5,0.5,0.5]].
⬇ arg 1: N = 5 = Creates a 1D array with 5 elements (one per query token)
⬇ arg 2: -np.inf = Negative infinity. Any real number is greater than -∞, so the first tile’s scores will always become the initial max.
⬆ result: m_run = [-inf, -inf, -inf, -inf, -inf]
→ why -inf? = On Tile 1: m_new = max(-inf, tile_max) = tile_max. The rescale factor exp(-inf - tile_max) = 0, so old accumulators are zeroed out. This elegantly handles the ‘first tile’ special case.
37l_run = np.zeros(N)
Running softmax denominator per query row. Starts at 0 because no exponentials have been accumulated yet. After all tiles: l_run[i] = sum of all exp(score - m_final) for query i.
EXECUTION STATE
📚 np.zeros(shape) = NumPy function: creates an array filled with 0.0. np.zeros(3) = [0.0, 0.0, 0.0]. np.zeros((2,3)) = [[0,0,0],[0,0,0]].
l_run = [0.0, 0.0, 0.0, 0.0, 0.0]
→ final values = After all 9 tiles: [3.3597, 2.4840, 3.9919, 3.1875, 4.1152]
38O_run = np.zeros((N, D_v))
Running output accumulator (5×4). Each row accumulates the unnormalized weighted sum of value vectors. After all tiles: output = O_run / l_run.
EXECUTION STATE
📚 np.zeros(shape) = Creates a 5×4 matrix of zeros. Shape must be a tuple for 2D: np.zeros((5, 4)), not np.zeros(5, 4).
O_run (5×4) =
d0 d1 d2 d3
The 0.0 0.0 0.0 0.0
cat 0.0 0.0 0.0 0.0
sat 0.0 0.0 0.0 0.0
on 0.0 0.0 0.0 0.0
mat 0.0 0.0 0.0 0.0
40for j0 in range(0, N, T): — outer loop over K tiles
Outer loop iterates over tiles of K and V. With N=5, T=2: j0 = 0, 2, 4. This loads K[j0:j1] and V[j0:j1] from HBM into SRAM once per tile.
LOOP TRACE · 3 iterations
j0=0 → K-tile [The, cat]
K[0:2] = rows 0,1 of K loaded to SRAM
j0=2 → K-tile [sat, on]
K[2:4] = rows 2,3 of K loaded to SRAM
j0=4 → K-tile [mat]
K[4:5] = row 4 of K loaded to SRAM
41j1 = min(j0 + T, N)
End index of K-tile. Handles the last tile which may be smaller: min(4+2, 5) = 5.
EXECUTION STATE
j1 = j0=0→2, j0=2→4, j0=4→5
42K_t = K[j0:j1]
Load K-tile from HBM into SRAM. This is a 2×4 (or 1×4) matrix.
Compute the TINY tile score matrix (2×2 or smaller). This is the key insight: instead of allocating the full 5×5 score matrix, we only compute one small block at a time. Memory: O(T²) instead of O(N²).
EXECUTION STATE
📚 @ (matrix multiply) = Q_t(2×4) @ K_t.T(4×2) = s_t(2×2). Only 4 dot products, not 25.
.T = Transpose K_t(2×4) → K_t.T(4×2) so inner dimensions match
/ self.scale = Divide by √d_k = 2.0, same scaling as standard attention
Update running max: take element-wise maximum of old running max and this tile’s row-wise max. If a new tile has higher scores, the max increases and all previous values must be rescaled.
EXECUTION STATE
📚 np.maximum(a, b) = Element-wise maximum of two arrays. NOT np.max()! np.maximum([1,3], [2,1]) = [2,3]. Compares corresponding elements and keeps the larger one.
s_t.max(axis=1) = Row-wise max of the tile scores. axis=1 means max along columns within each row. Tile 1: [max(0.0, 1.0), max(1.5, 0.0)] = [1.0, 1.5]
⬇ arg: axis=1 = Find max along axis 1 (columns). For a 2×2 tile: returns the max of each row as a 1D array of length 2.
Exponentiate scores shifted by the NEW running max. Uses the updated m_new (not the old m_run) so these exponentials are on the correct scale from the start.
EXECUTION STATE
📚 np.exp() = Element-wise e^x. All results are positive. The max score becomes exp(0) = 1.0.
[:, None] = NumPy indexing trick: reshapes m_new from (2,) to (2,1). This enables broadcasting: s_t(2×2) - m_new(2,1) subtracts each row’s max from that row. Equivalent to m_new.reshape(-1, 1).
THE ONLINE SOFTMAX TRICK: when the max changes, all previously accumulated exponentials need correction. The multiplicative factor exp(m_old - m_new) adjusts them from the old scale to the new scale. Since m_new ≥ m_old, this is always ≤ 1.0 (no overflow risk).
EXECUTION STATE
📚 np.exp() = e^(m_old - m_new). If m unchanged: exp(0) = 1.0 (no rescaling). If m increased: exp(negative) < 1.0 (shrink old values).
→ why this works = Old values were computed as exp(s - m_old). We need exp(s - m_new). Multiply by exp(m_old - m_new): exp(s - m_old) × exp(m_old - m_new) = exp(s - m_new) ✔️
rescale (Tile 1) = [exp(-∞ - 1.0), exp(-∞ - 1.5)] = [0.0, 0.0] — first tile, zeroes out empty accumulators
rescale (Tile 4) = [exp(1.0 - 1.0), exp(1.5 - 1.5)] = [1.0, 1.0] — m unchanged, no rescaling needed
rescale (Tile 5, rows 2:4) = [exp(1.0 - 1.0), exp(0.5 - 1.0)] = [1.0, 0.6065] — ‘on’ row gets rescaled because its max increased from 0.5 to 1.0
Update running normalizer (softmax denominator). Two steps: (1) rescale old accumulated value to the new max scale, (2) add the new tile’s exponential sums. After all tiles, l_run[i] equals the full softmax denominator for query i.
EXECUTION STATE
📚 .sum(axis=1) = NumPy method: sums elements along axis 1 (columns within each row). For e_t(2×2): returns [row0_sum, row1_sum]. Example: [[0.37,1.0],[1.0,0.22]].sum(axis=1) = [1.37, 1.22].
⬇ arg: axis=1 = Sum along columns. For a 2×2 tile, this sums each row’s 2 exponentials into a single value per query.
Update running output: (1) rescale old output to the new max scale, (2) add the new tile’s weighted value contribution. e_t @ V_t computes the unnormalized attention output for this tile only.
EXECUTION STATE
rescale[:, None] = Reshape from (2,) to (2,1) for broadcasting against O_run (2×4). Each row of O_run gets multiplied by its corresponding rescale factor.
📚 e_t @ V_t = Matrix multiply: e_t(2×2) @ V_t(2×4) = contribution(2×4). Each row is a weighted sum of the tile’s value vectors, weighted by the exponential scores.
m_run after all tiles = [1.000, 1.500, 1.000, 1.000, 0.750]
58return O_run / l_run[:, None]
Final normalization: divide each row of the accumulated output by its normalizer. This single division produces the EXACT softmax-weighted output, identical to standard attention.
EXECUTION STATE
[:, None] = Reshape l_run from (5,) to (5,1) for broadcasting. l_run(5,1) broadcasts against O_run(5×4) — each row divided by its own normalizer.
l_run final = [3.3597, 2.4840, 3.9919, 3.1875, 4.1152]
d0 d1 d2 d3
The 0.2254 0.4135 0.2964 0.2964
cat 0.4602 0.1475 0.3018 0.2058
sat 0.2495 0.3481 0.3481 0.2495
on 0.2854 0.2854 0.2106 0.4089
mat 0.3108 0.3108 0.3108 0.3108
60def explain(self, Q, K, V, tokens, query_idx=0)
Debug/visualization method that prints a tile-by-tile trace for a SINGLE query token. Shows how m, l, and O evolve across tiles. This method traces the algorithm’s internal state — useful for understanding, not for production.
EXECUTION STATE
⬇ input: self = Class instance — provides self.scale=2.0 and self.tile_size=2
⬇ input: Q (5×4) = Query matrix — we only use row query_idx from this
⬇ input: K (5×4) = Key matrix — processed in tiles of size T=2
⬇ input: V (5×4) = Value matrix — tiled alongside K
Exponentiate the shifted scores. For a single query row, s is a 1D array (one score per key in this tile). Subtracting m_new ensures the max becomes exp(0) = 1.0.
EXECUTION STATE
📚 np.exp() = Element-wise e^x. Converts raw scores into positive values for softmax.
e (Tile 1: [The,cat]) = [exp(0.0-1.0), exp(1.0-1.0)] = [0.3679, 1.0000]
e (Tile 2: [sat,on]) = [exp(0.5-1.0), exp(0.5-1.0)] = [0.6065, 0.6065]
e (Tile 3: [mat]) = [exp(0.75-1.0)] = [0.7788]
90rescale = np.exp(m_val - m_new) if np.isfinite(m_val) else 0.0
Compute rescaling factor with a guard for the first tile. np.exp(-inf - m_new) would produce NaN, so we explicitly check and return 0.0 for the first tile. The forward() method avoids this issue because np.exp(-inf) = 0.0 in NumPy (no NaN).
EXECUTION STATE
📚 np.isfinite() = Returns True if the value is finite (not inf, -inf, or NaN). np.isfinite(1.0) = True, np.isfinite(-np.inf) = False, np.isfinite(np.nan) = False.
→ why the guard? = exp(-inf - 1.0) = exp(-inf) = 0.0 in NumPy arrays, but Python’s math with scalar -inf can produce NaN in some contexts. The explicit check is safer for scalar code.
rescale (Tile 1) = 0.0 — m_val=-inf is not finite, so use 0.0. Zeroes out empty accumulators.
rescale (Tile 2) = exp(1.0 - 1.0) = exp(0) = 1.0 — m unchanged, no rescaling
rescale (Tile 3) = exp(1.0 - 1.0) = 1.0 — m still unchanged
92l_val = l_val * rescale + float(e.sum())
Update running normalizer: rescale + add new exponentials.
The shared example sentence used throughout all 15 chapters. 5 tokens with d_k=4.
EXECUTION STATE
tokens = ["The", "cat", "sat", "on", "mat"]
N = 5 tokens
106Q = np.array([...])
Query matrix (5×4). Each row is what that token is ‘asking for’ — the features it wants to find.
EXECUTION STATE
Q (5×4) =
d0 d1 d2 d3
The 1.0 0.0 1.0 0.0
cat 0.0 2.0 0.0 1.0
sat 1.0 1.0 1.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.0 1.0
114K = np.array([...])
Key matrix (5×4). Each row is what that token ‘advertises’ about itself.
EXECUTION STATE
K (5×4) =
d0 d1 d2 d3
The 0.0 1.0 0.0 1.0
cat 1.0 0.0 1.0 0.0
sat 1.0 1.0 0.0 0.0
on 0.0 0.0 1.0 1.0
mat 1.0 0.0 0.5 0.5
122V = np.array([...])
Value matrix (5×4). Q and K decide WHO to attend to; V decides WHAT to retrieve.
EXECUTION STATE
V (5×4) =
d0 d1 d2 d3
The 1.0 0.0 0.0 0.0
cat 0.0 1.0 0.0 0.0
sat 0.0 0.0 1.0 0.0
on 0.0 0.0 0.0 1.0
mat 0.5 0.5 0.5 0.5
131fa = FlashAttention(d_k=4, tile_size=2)
Create a FlashAttention instance. d_k=4 sets the scaling factor (√4=2.0), tile_size=2 means we process 2 key rows at a time. With N=5 tokens, this gives 3×3=9 tile computations total.
Run standard O(N²) attention for comparison. This materializes the full 5×5 score matrix in memory — the exact thing Flash Attention avoids.
EXECUTION STATE
std_weights (5×5) =
The cat sat on mat
The 0.1095 0.2976 0.1805 0.1805 0.2318
cat 0.4026 0.0898 0.2442 0.1481 0.1153
sat 0.1519 0.2505 0.2505 0.1519 0.1951
on 0.1903 0.1903 0.1154 0.3137 0.1903
mat 0.1892 0.1892 0.1892 0.1892 0.2430
std_output (5×4) =
d0 d1 d2 d3
The 0.2254 0.4135 0.2964 0.2964
cat 0.4602 0.1475 0.3018 0.2058
sat 0.2495 0.3481 0.3481 0.2495
on 0.2854 0.2854 0.2106 0.4089
mat 0.3108 0.3108 0.3108 0.3108
137flash_output = fa.forward(Q, K, V)
Run Flash Attention. Processes 9 tiles total (3 K-tiles × 3 Q-tiles per K-tile). The largest matrix ever stored is 2×2 (one tile score block). Never stores the 5×5 matrix.
d0 d1 d2 d3
The 0.2254 0.4135 0.2964 0.2964
cat 0.4602 0.1475 0.3018 0.2058
sat 0.2495 0.3481 0.3481 0.2495
on 0.2854 0.2854 0.2106 0.4089
mat 0.3108 0.3108 0.3108 0.3108
141diff = np.abs(flash_output - std_output).max()
Verify Flash and standard attention produce identical output. np.abs() takes element-wise absolute value of the difference matrix, .max() finds the largest element.
EXECUTION STATE
📚 np.abs() = Element-wise absolute value. Converts negative differences to positive for comparison.
📚 .max() = Returns the single largest element in the array. No axis specified = global max.
diff = 5.55e-17 — this is machine epsilon (~2⁻⁵⁴). The outputs are mathematically identical; the tiny difference is from floating-point rounding in different computation orders.
144fa.explain(Q, K, V, tokens, query_idx=0)
Print a detailed tile-by-tile trace for ‘The’ (row 0). Shows how m, l, O evolve across 3 tiles. Confirms the tiled computation matches the standard result.
EXECUTION STATE
query_idx = 0 = Trace query row 0 = ‘The’. Change to 1 for ‘cat’, 2 for ‘sat’, etc.
→ output = Prints 3 tiles showing scores, m, l, O_run for each, then the final normalized output [0.2254, 0.4135, 0.2964, 0.2964]
87 lines without explanation
1import numpy as np
2import math
34classFlashAttention:5"""
6 Flash Attention (Dao et al., 2022)
7 IO-aware exact attention that produces identical output to
8 standard scaled dot-product attention but minimizes HBM
9 reads/writes by processing Q, K, V in tiles, using online
10 softmax to avoid materializing the N*N score matrix.
11 """1213def__init__(self, d_k:int, tile_size:int=2):14 self.d_k = d_k
15 self.tile_size = tile_size
16 self.scale = math.sqrt(d_k)1718defstandard_attention(self, Q, K, V):19"""Standard O(N^2) attention for verification."""20 scores = Q @ K.T / self.scale
21 max_s = np.max(scores, axis=-1, keepdims=True)22 exp_s = np.exp(scores - max_s)23 weights = exp_s / np.sum(exp_s, axis=-1, keepdims=True)24 output = weights @ V
25return weights, output
2627defforward(self, Q, K, V):28"""
29 Flash Attention forward pass -- tiled online softmax.
30 Output is IDENTICAL to standard_attention().
31 """32 N, D = Q.shape
33 D_v = V.shape[1]34 T = self.tile_size
3536 m_run = np.full(N,-np.inf)37 l_run = np.zeros(N)38 O_run = np.zeros((N, D_v))3940for j0 inrange(0, N, T):41 j1 =min(j0 + T, N)42 K_t = K[j0:j1]43 V_t = V[j0:j1]4445for i0 inrange(0, N, T):46 i1 =min(i0 + T, N)47 Q_t = Q[i0:i1]4849 s_t = Q_t @ K_t.T / self.scale
50 m_new = np.maximum(m_run[i0:i1], s_t.max(axis=1))51 e_t = np.exp(s_t - m_new[:,None])52 rescale = np.exp(m_run[i0:i1]- m_new)5354 l_run[i0:i1]= l_run[i0:i1]* rescale + e_t.sum(axis=1)55 O_run[i0:i1]= O_run[i0:i1]* rescale[:,None]+ e_t @ V_t
56 m_run[i0:i1]= m_new
5758return O_run / l_run[:,None]5960defexplain(self, Q, K, V, tokens, query_idx=0):61"""Print a tile-by-tile trace for one query token."""62 N = Q.shape[0]63 T = self.tile_size
64 t = tokens[query_idx]65 i = query_idx
6667 m_val =-np.inf
68 l_val =0.069 o_val = np.zeros(V.shape[1])7071print(f"\n=== Flash trace for '{t}' (row {i}, T={T}) ===")72print(f"Q[{i}] = {Q[i]}")7374 tile_num =075for j0 inrange(0, N, T):76 j1 =min(j0 + T, N)77 tile_num +=178 K_t = K[j0:j1]79 V_t = V[j0:j1]8081 scores = Q[i:i+1] @ K_t.T / self.scale
82 s = scores[0]83 m_new =max(m_val,float(s.max()))8485print(f"\n--- Tile {tile_num}: K[{j0}:{j1}] ---")86for idx inrange(j0, j1):87print(f" score({t},{tokens[idx]}) = {s[idx-j0]:.4f}")8889 e = np.exp(s - m_new)90 rescale = np.exp(m_val - m_new)if np.isfinite(m_val)else0.09192 l_val = l_val * rescale +float(e.sum())93 o_val = o_val * rescale + e @ V_t
94 m_val = m_new
9596print(f" m={m_new:.4f}, l={l_val:.4f}")97print(f" O_run = {np.round(o_val,4)}")9899 final = o_val / l_val
100print(f"\nFinal: O[{t}] = {np.round(final,4)}")101102103# -- Shared Example (same Q, K, V as every chapter) --104tokens =["The","cat","sat","on","mat"]105106Q = np.array([107[1.0,0.0,1.0,0.0],# The108[0.0,2.0,0.0,1.0],# cat109[1.0,1.0,1.0,0.0],# sat110[0.0,0.0,1.0,1.0],# on111[1.0,0.0,0.0,1.0],# mat112])113114K = np.array([115[0.0,1.0,0.0,1.0],# The116[1.0,0.0,1.0,0.0],# cat117[1.0,1.0,0.0,0.0],# sat118[0.0,0.0,1.0,1.0],# on119[1.0,0.0,0.5,0.5],# mat120])121122V = np.array([123[1.0,0.0,0.0,0.0],# The124[0.0,1.0,0.0,0.0],# cat125[0.0,0.0,1.0,0.0],# sat126[0.0,0.0,0.0,1.0],# on127[0.5,0.5,0.5,0.5],# mat128])129130# -- Run --131fa = FlashAttention(d_k=4, tile_size=2)132133std_weights, std_output = fa.standard_attention(Q, K, V)134print("Standard Attention Weights:")135print(np.round(std_weights,4))136137flash_output = fa.forward(Q, K, V)138print("\nFlash Attention Output:")139print(np.round(flash_output,4))140141diff = np.abs(flash_output - std_output).max()142print(f"\nMax diff: {diff:.2e}")143144fa.explain(Q, K, V, tokens, query_idx=0)
PyTorch Implementation
The same algorithm in PyTorch. Note the differences from NumPy: .size() instead of .shape, .unsqueeze() instead of [:, None], dim= instead of axis=, and s_t.max(dim=1).values instead of s_t.max(axis=1). In production, use torch.nn.functional.scaled_dot_product_attention() which calls the optimized CUDA kernel.
Flash Attention — PyTorch Implementation
🐍flash_attention_torch.py
Explanation(34)
Code(96)
1import torch
PyTorch is a deep learning framework providing GPU-accelerated tensor operations and automatic differentiation. All matrix operations in this file (@ multiply, torch.exp, torch.maximum) can run on GPU transparently.
EXECUTION STATE
📚 torch = Core PyTorch module. Provides Tensor (GPU-capable ndarray), autograd (automatic differentiation), and mathematical functions. Key difference from NumPy: tensors track gradients for backpropagation.
2import torch.nn as nn
PyTorch’s neural network module. FlashAttention inherits from nn.Module, which provides parameter registration, device management, and model serialization.
EXECUTION STATE
📚 torch.nn (as nn) = Neural network building blocks. nn.Module is the base class for all layers/models. Provides: .parameters(), .to(device), .train()/.eval(), state_dict for saving.
3import math
Standard Python math library. Used for math.sqrt() to compute the scaling factor as a plain float.
EXECUTION STATE
📚 math = Same as NumPy version — math.sqrt(4) = 2.0. Could use torch.sqrt(tensor) but math.sqrt is simpler for scalar constants.
5class FlashAttention(nn.Module)
PyTorch module implementing Flash Attention. Inherits from nn.Module for integration with PyTorch’s training infrastructure (autograd, device management, model saving). The algorithm is identical to the NumPy version — only the API differs.
EXECUTION STATE
📚 nn.Module = Base class for all PyTorch neural network modules. Provides: forward() auto-called via __call__(), .to(device) for GPU transfer, .parameters() for optimizer, .train()/.eval() mode switching.
14def __init__(self, d_model, tile_size=2)
Constructor. In production, d_model would be the per-head dimension (d_model_total / n_heads, e.g. 512/8=64), and tile_size is chosen based on GPU SRAM capacity (~128 on A100).
EXECUTION STATE
⬇ input: d_model = 4 = Dimension of Q/K vectors. Named d_model (not d_k) following PyTorch convention where each head has its own dimension. Sets self.scale = √4 = 2.0.
⬇ input: tile_size = 2 = How many rows of K/V to process per tile. With N=5: 3 tiles per loop pass.
15super().__init__()
Initialize nn.Module base class. Required for PyTorch’s parameter tracking, hook system, and device management to work correctly.
EXECUTION STATE
📚 super().__init__() = Calls nn.Module.__init__(). Without this, self.register_parameter(), .to(device), and model saving would break. Every nn.Module subclass MUST call this.
16self.d_model = d_model
Store model dimension for reference.
EXECUTION STATE
self.d_model = 4 — dimension of each Q/K vector
17self.tile_size = tile_size
Store tile size for the tiled computation loop.
EXECUTION STATE
self.tile_size = 2 — process 2 rows of K/V per tile
18self.scale = math.sqrt(d_model)
Pre-compute √d_model = √4 = 2.0. Same scaling as the NumPy version.
EXECUTION STATE
self.scale = 2.0 — divides all dot-product scores
20def forward(self, Q, K, V) → torch.Tensor
Forward pass — identical algorithm to the NumPy version but uses PyTorch operations. Key API differences: .size() vs .shape, .unsqueeze() vs [:, None], dim= vs axis=, .max().values vs .max().
Running max initialized to -∞, created on the same device as Q. In production this would be on GPU; here it’s CPU.
EXECUTION STATE
📚 torch.full(size, fill_value) = Creates a tensor of the given size, filled with fill_value. Like np.full() but returns a PyTorch tensor. torch.full((3,), 7.0) = tensor([7., 7., 7.]).
⬇ arg: (N,) = (5,) = Shape tuple — creates a 1D tensor with 5 elements
⬇ arg: float("-inf") = Python float negative infinity. Any score will be greater, so the first tile’s max always wins.
⬇ arg: device=Q.device = Creates tensor on the same device (CPU/GPU) as Q. Critical for GPU: tensors on different devices can’t interact. Q.device = ‘cpu’ here.
m_run = tensor([-inf, -inf, -inf, -inf, -inf])
39l_run = torch.zeros(N, device=Q.device)
Running normalizer (softmax denominator) initialized to zero, on same device.
EXECUTION STATE
📚 torch.zeros(size, device=...) = Creates a zero-filled tensor. Unlike np.zeros, must specify device for GPU compatibility.
l_run = tensor([0., 0., 0., 0., 0.])
40O_run = torch.zeros(N, D_v, device=Q.device)
Running output accumulator (5×4), on same device.
EXECUTION STATE
O_run (5×4) = tensor of zeros — will accumulate unnormalized weighted value sums
42for j0 in range(0, N, T): — outer loop over K tiles
Outer loop over K/V tiles. Identical structure to NumPy version. j0 = 0, 2, 4 → 3 K-tiles.
LOOP TRACE · 3 iterations
j0=0 → K[0:2] = [The, cat]
K_t shape = (2, 4)
j0=2 → K[2:4] = [sat, on]
K_t shape = (2, 4)
j0=4 → K[4:5] = [mat]
K_t shape = (1, 4)
47for i0 in range(0, N, T): — inner loop over Q tiles
Inner loop over Q tiles. For each Q-tile, compute scores against the current K-tile and update running statistics.
LOOP TRACE · 3 iterations
i0=0 → Q[0:2] = [The, cat]
Q_t shape = (2, 4)
i0=2 → Q[2:4] = [sat, on]
Q_t shape = (2, 4)
i0=4 → Q[4:5] = [mat]
Q_t shape = (1, 4)
51s_t = Q_t @ K_t.T / self.scale
Compute tile scores. In PyTorch, .T is the transpose property (same as .t() for 2D tensors, or .transpose(-2,-1) for batched).
EXECUTION STATE
📚 .T = PyTorch transpose property. For 2D tensors: K_t.T swaps rows↔cols, same as K_t.t(). For higher dims: reverses ALL dimensions (careful — use .transpose() for batched tensors).
📚 @ (matrix multiply) = Same operator as NumPy. Q_t(2×4) @ K_t.T(4×2) = s_t(2×2).
Update running max. Unlike NumPy’s .max(axis=1) which returns just values, PyTorch’s .max(dim=1) returns a namedtuple of (values, indices). We need .values.
EXECUTION STATE
📚 torch.maximum(a, b) = Element-wise maximum of two tensors. Like np.maximum(). NOT torch.max() which finds the global max or max along a dim.
📚 s_t.max(dim=1) = Returns torch.return_types.max(values=tensor([...]), indices=tensor([...])). In NumPy, .max(axis=1) returns just values. In PyTorch, you must append .values.
⬇ arg: dim=1 = PyTorch uses ‘dim’ instead of NumPy’s ‘axis’. dim=1 = max along columns within each row. Same behavior as axis=1.
.values = Extract just the maximum values, not the indices. Without .values, m_new would be a namedtuple, not a tensor.
53e_t = torch.exp(s_t - m_new.unsqueeze(1))
Exponentiate shifted scores. .unsqueeze(1) is the PyTorch equivalent of NumPy’s [:, None].
EXECUTION STATE
📚 torch.exp() = Element-wise e^x. Identical to np.exp() but operates on PyTorch tensors.
📚 .unsqueeze(dim) = Inserts a size-1 dimension at position dim. .unsqueeze(1) on shape (2,) gives (2,1). This is PyTorch’s equivalent of NumPy’s [:, None]. Example: tensor([a, b]).unsqueeze(1) = tensor([[a],[b]]).
⬇ arg: dim=1 = Insert dimension at position 1: (2,) → (2,1). Enables broadcasting: s_t(2×2) - m_new(2,1) subtracts each row’s max from that row.
54rescale = torch.exp(m_run[i0:i1] - m_new)
Rescaling factor for previously accumulated values. Same online softmax trick as NumPy version. PyTorch handles exp(-inf) = 0.0 correctly (no NaN guard needed).
EXECUTION STATE
rescale = tensor — exp(m_old - m_new). If m unchanged: 1.0. If m increased: < 1.0. If first tile (m_old=-inf): 0.0.
→ identical to NumPy = Same values as the NumPy version — both produce mathematically identical output to standard attention.
66Q = torch.tensor([...])
Same query matrix as the NumPy version, but as a PyTorch tensor. torch.tensor() creates a tensor from a Python list.
EXECUTION STATE
📚 torch.tensor() = Creates a tensor from data (list, array, scalar). Infers dtype from data: floats become torch.float32 by default.
Q (5×4) = Same values as NumPy Q: [[1,0,1,0],[0,2,0,1],[1,1,1,0],[0,0,1,1],[1,0,0,1]]
73K = torch.tensor([...])
Same key matrix as NumPy version.
EXECUTION STATE
K (5×4) = Same values: [[0,1,0,1],[1,0,1,0],[1,1,0,0],[0,0,1,1],[1,0,0.5,0.5]]
80V = torch.tensor([...])
Same value matrix as NumPy version. Identity-like structure makes it easy to trace which token’s information flows into the output.
EXECUTION STATE
V (5×4) = Same values: [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1],[0.5,0.5,0.5,0.5]]
88fa = FlashAttention(d_model=4, tile_size=2)
Instantiate the module with d_model=4, tile_size=2. Same configuration as NumPy version.
EXECUTION STATE
fa.scale = 2.0 — √4
fa.d_model = 4
fa.tile_size = 2
90with torch.no_grad():
Context manager that disables gradient tracking. Since we’re just running inference (not training), there’s no need to build a computational graph. Saves memory and speeds up computation.
EXECUTION STATE
📚 torch.no_grad() = Disables autograd inside this block. No gradient tensors are created, no computation graph is built. Essential for inference. Without it: PyTorch tracks every operation for potential backpropagation, wasting memory.
91flash_out = fa(Q, K, V)
Calling fa() invokes fa.forward() via nn.Module’s __call__ method (which also runs hooks). NEVER call fa.forward() directly — always use fa().
EXECUTION STATE
📚 module() = nn.Module.__call__() runs: pre-hooks → forward() → post-hooks. Using fa() instead of fa.forward() ensures hooks run correctly.
flash_out (5×4) = Flash Attention output — identical to standard attention
92std_out = torch.softmax(Q @ K.T / fa.scale, dim=-1) @ V
Standard attention in one line for comparison. torch.softmax is PyTorch’s standalone softmax function (vs F.softmax from torch.nn.functional).
EXECUTION STATE
📚 torch.softmax(input, dim) = Applies softmax along the specified dimension. torch.softmax(x, dim=-1) normalizes each row to sum to 1.0. Equivalent to F.softmax(x, dim=-1).
⬇ arg: dim=-1 = Softmax along last dimension (columns). Each ROW becomes a probability distribution. This is the standard choice for attention.
std_out = Standard attention output for verification against flash_out
96print(f"Max diff: ...")
Verify Flash and standard outputs are identical. Uses method chaining: .abs() (absolute value) then .max() (global maximum).
EXECUTION STATE
📚 .abs() = Element-wise absolute value. PyTorch method equivalent to torch.abs(). (flash_out - std_out).abs() = all positive differences.
📚 .max() = Returns the single largest element. No dim specified = global max across all elements.
max diff = ~5.55e-17 (machine epsilon) — outputs are mathematically identical
Same math, different execution. Flash Attention produces mathematically identical output to standard scaled dot-product attention (differing only at floating-point precision). The N×N attention matrix is never stored — only T×T tile scores at a time.
IO-awareness is the insight. The bottleneck in standard attention is memory bandwidth (HBM reads/writes), not arithmetic. Flash Attention fuses all operations into a single kernel that keeps intermediates in SRAM.
Online softmax enables tiling. The running statistics (m,l,O) with rescaling preserve exact softmax semantics while processing keys in arbitrarily small tiles.
Memory drops from O(N2) to O(N). This enables training with 10–100× longer sequences on the same GPU hardware.
Wall-clock speedup of 3–8× despite doing the same number of arithmetic operations. The speedup comes from eliminating HBM round-trips.
Universal adoption. Flash Attention has become the de facto standard for transformer training and inference, integrated into PyTorch via scaled_dot_product_attention() and adopted across virtually all major LLM frameworks.
Exercises
Tile size impact: Modify the code to use tile_size=1 (process one key at a time). Verify the output is still identical. How does the number of tiles change? What happens to the number of rescaling operations?
Rescaling trace: Trace Flash Attention for query "cat" (row 1). At which tile does the running maximum first change from its initial tile-1 value? What is the rescaling factor α?
Memory calculation: For N=16384, d=128, float16, calculate: (a) the size of the full N×N score matrix in GB, (b) the size of Flash Attention's running statistics per head, (c) the memory savings factor.
Causal masking: How would you integrate the causal mask from Chapter 3 into Flash Attention? Hint: for tiles where the Q-row index is less than the K-column range, the entire tile can be skipped. For boundary tiles, apply −∞ masking within the tile scores.
Multi-head integration: Flash Attention operates per-head. If a model has h=32 heads, how would you parallelize Flash Attention across heads? What is the peak SRAM usage per head?
References
Dao, T., Fu, D. Y., Ermon, S., Rudra, A., & Ré, C. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. NeurIPS 2022.
Dao, T. (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv:2307.08691.
Shah, J., Bikshandi, G., Zhang, Y., Thakkar, V., Ramani, P., & Dao, T. (2024). FlashAttention-3: Fast and Accurate Attention with Asynchrony and Low-precision. arXiv:2407.08608.
Milakov, M. & Gimelshein, N. (2018). Online normalizer calculation for softmax. arXiv:1805.02867.
Rabe, M. N. & Staats, C. (2022). Self-attention Does Not Need O(n2) Memory. arXiv:2112.05682.