Introduction
This section covers Flash Attention, a memory-efficient attention algorithm that significantly speeds up transformer training and inference while reducing memory usage. We'll understand the algorithm conceptually and learn how to use it in practice.
1.1 The Problem with Standard Attention
Memory Bottleneck
1import torch
2import torch.nn as nn
3import math
4from typing import Optional, Tuple
5
6
7def attention_memory_analysis():
8 """
9 Analyze memory usage of standard attention.
10 """
11 print("=" * 70)
12 print("STANDARD ATTENTION: MEMORY ANALYSIS")
13 print("=" * 70)
14
15 print("""
16 STANDARD ATTENTION ALGORITHM:
17 āāāāāāāāāāāāāāāāāāāāāāāāāāāāā
18
19 Input: Q, K, V each of shape [batch, heads, seq_len, head_dim]
20
21 Step 1: Compute attention scores
22 S = Q @ K^T
23 Shape: [batch, heads, seq_len, seq_len]
24 Memory: O(n²) where n = seq_len
25
26 Step 2: Apply softmax
27 P = softmax(S / sqrt(d))
28 Shape: [batch, heads, seq_len, seq_len]
29 Memory: O(n²)
30
31 Step 3: Compute output
32 O = P @ V
33 Shape: [batch, heads, seq_len, head_dim]
34 Memory: O(n * d)
35
36 Total memory for attention scores: O(n²)
37
38
39 MEMORY USAGE EXAMPLES:
40 āāāāāāāāāāāāāāāāāāāāāā
41
42 Configuration:
43 - batch_size = 8
44 - num_heads = 12
45 - head_dim = 64
46 - dtype = float16 (2 bytes)
47 """)
48
49 batch_size = 8
50 num_heads = 12
51 head_dim = 64
52 bytes_per_element = 2 # float16
53
54 seq_lengths = [512, 1024, 2048, 4096, 8192, 16384]
55
56 print(f"{'Seq Length':<12} {'Attention Matrix':<20} {'Memory (GB)':<15} {'Status'}")
57 print("-" * 60)
58
59 for seq_len in seq_lengths:
60 # Attention matrix: [batch, heads, seq, seq]
61 attention_elements = batch_size * num_heads * seq_len * seq_len
62 attention_bytes = attention_elements * bytes_per_element
63 attention_gb = attention_bytes / (1024 ** 3)
64
65 # Status based on typical GPU memory
66 if attention_gb < 4:
67 status = "ā Fits in 8GB GPU"
68 elif attention_gb < 12:
69 status = "ā Needs 16GB+ GPU"
70 elif attention_gb < 24:
71 status = "ā Needs 24GB+ GPU"
72 else:
73 status = "ā Too large for consumer GPUs"
74
75 print(f"{seq_len:<12} {seq_len}x{seq_len:<14} {attention_gb:.2f} {status}")
76
77 print("""
78
79 THE QUADRATIC PROBLEM:
80 āāāāāāāāāāāāāāāāāāāāāā
81
82 Memory grows as O(n²):
83 - Double sequence length ā 4x memory
84 - 4x sequence length ā 16x memory
85
86 This limits:
87 ⢠Maximum sequence length
88 ⢠Batch size
89 ⢠Model size that fits in memory
90
91 For long documents, code, or conversations, standard attention
92 becomes impractical!
93 """)
94
95
96attention_memory_analysis()1.2 Flash Attention Algorithm
Core Idea: Tiling and Recomputation
1def flash_attention_explained():
2 """
3 Explain Flash Attention algorithm.
4 """
5 print("=" * 70)
6 print("FLASH ATTENTION ALGORITHM")
7 print("=" * 70)
8
9 print("""
10 KEY INSIGHT: Memory Hierarchy
11 āāāāāāāāāāāāāāāāāāāāāāāāāāāāā
12
13 GPU Memory Hierarchy:
14 āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
15 ā HBM (High Bandwidth Memory) - 40-80 GB ā
16 ā āāā Slow to access (~1.5 TB/s) ā
17 ā āāā Where model weights & activations stored ā
18 ā ā
19 ā SRAM (On-chip memory) - 20 MB per SM ā
20 ā āāā Very fast (~19 TB/s) ā
21 ā āāā Where computation happens ā
22 ā ā
23 ā Speed difference: SRAM is ~10x faster than HBM! ā
24 āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
25
26 Standard Attention Problem:
27 - Computes full NĆN attention matrix
28 - Must store it in slow HBM
29 - Then load it back for softmax and matmul
30 - Many slow memory transfers
31
32
33 FLASH ATTENTION SOLUTION: Tiling
34 āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
35
36 Instead of computing full attention matrix:
37
38 Standard:
39 āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
40 ā ā
41 ā N Ć N Attention Matrix ā
42 ā (Stored in HBM) ā
43 ā ā
44 āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
45
46 Flash Attention:
47 āāāāāā¬āāāāā¬āāāāā¬āāāāā¬āāāāā¬āāāāā¬āāāāā¬āāāāā
48 ā Bā ā Bā ā Bā ā Bā ā Bā
ā Bā ā Bā ā Bā ā ā Computed one at a time
49 āāāāāā¼āāāāā¼āāāāā¼āāāāā¼āāāāā¼āāāāā¼āāāāā¼āāāā⤠in SRAM
50 ā Bā ā Bāāā Bāāā Bāāā Bāāā Bāāā Bāā
ā Bāāā
51 āāāāāā¼āāāāā¼āāāāā¼āāāāā¼āāāāā¼āāāāā¼āāāāā¼āāāāā¤
52 ā... ā
53 āāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāāā
54
55 Process blocks sequentially, keeping each block in fast SRAM.
56 Never materialize full attention matrix!
57
58
59 ALGORITHM SKETCH:
60 āāāāāāāāāāāāāāāāā
61
62 Split Q, K, V into blocks:
63 Q = [Qā, Qā, ..., Qįµ£] (block rows)
64 K = [Kā, Kā, ..., Kc] (block columns)
65 V = [Vā, Vā, ..., Vc] (block columns)
66
67 For each Q block:
68 Initialize output O = 0, normalizer l = 0, max m = -ā
69
70 For each K, V block:
71 # Compute local attention
72 S = Qįµ¢ @ Kⱼᵠ/ ād
73
74 # Online softmax with running max
75 m_new = max(m, max(S))
76 P = exp(S - m_new)
77 l_new = exp(m - m_new) * l + sum(P)
78
79 # Update output
80 O = exp(m - m_new) * O + P @ Vā±¼
81
82 m = m_new
83 l = l_new
84
85 O = O / l # Normalize
86
87
88 THE TRICK: Online Softmax
89 āāāāāāāāāāāāāāāāāāāāāāāāā
90
91 Normal softmax needs all values to compute max and sum.
92
93 Online softmax computes incrementally:
94 1. Track running maximum (m)
95 2. Track running sum of exponentials (l)
96 3. Rescale previous results when max changes
97
98 This allows processing blocks without storing full matrix!
99 """)
100
101
102flash_attention_explained()1.3 Simplified Flash Attention Implementation
Educational Implementation
1class SimplifiedFlashAttention(nn.Module):
2 """
3 Simplified Flash Attention implementation for education.
4
5 Note: This is NOT optimized like the real implementation.
6 The actual Flash Attention uses custom CUDA kernels.
7 """
8
9 def __init__(
10 self,
11 block_size: int = 64,
12 dropout: float = 0.0
13 ):
14 """
15 Initialize Flash Attention.
16
17 Args:
18 block_size: Size of blocks for tiling
19 dropout: Dropout probability
20 """
21 super().__init__()
22 self.block_size = block_size
23 self.dropout = nn.Dropout(dropout)
24
25 def forward(
26 self,
27 query: torch.Tensor,
28 key: torch.Tensor,
29 value: torch.Tensor,
30 attn_mask: Optional[torch.Tensor] = None
31 ) -> torch.Tensor:
32 """
33 Flash attention forward pass.
34
35 Args:
36 query: [batch, heads, seq_len, head_dim]
37 key: [batch, heads, seq_len, head_dim]
38 value: [batch, heads, seq_len, head_dim]
39 attn_mask: Optional attention mask
40
41 Returns:
42 Output [batch, heads, seq_len, head_dim]
43 """
44 batch_size, num_heads, seq_len, head_dim = query.shape
45 scale = 1.0 / math.sqrt(head_dim)
46
47 # Initialize output and statistics
48 output = torch.zeros_like(query)
49 row_max = torch.full(
50 (batch_size, num_heads, seq_len, 1),
51 float('-inf'),
52 device=query.device,
53 dtype=query.dtype
54 )
55 row_sum = torch.zeros(
56 (batch_size, num_heads, seq_len, 1),
57 device=query.device,
58 dtype=query.dtype
59 )
60
61 # Number of blocks
62 num_blocks = (seq_len + self.block_size - 1) // self.block_size
63
64 # Process K, V blocks
65 for j in range(num_blocks):
66 # Get K, V block
67 kv_start = j * self.block_size
68 kv_end = min((j + 1) * self.block_size, seq_len)
69
70 k_block = key[:, :, kv_start:kv_end, :] # [B, H, block, d]
71 v_block = value[:, :, kv_start:kv_end, :]
72
73 # Compute attention scores for this K block with all Q
74 # S: [B, H, seq_len, block_size]
75 scores = torch.matmul(query, k_block.transpose(-2, -1)) * scale
76
77 # Apply mask if provided
78 if attn_mask is not None:
79 mask_block = attn_mask[:, :, :, kv_start:kv_end]
80 scores = scores.masked_fill(~mask_block, float('-inf'))
81
82 # Online softmax update
83 # New maximum
84 block_max = scores.max(dim=-1, keepdim=True).values
85 new_max = torch.maximum(row_max, block_max)
86
87 # Rescale old values
88 old_scale = torch.exp(row_max - new_max)
89 new_scale = torch.exp(block_max - new_max)
90
91 # Update sum
92 exp_scores = torch.exp(scores - block_max)
93 new_row_sum = old_scale * row_sum + new_scale * exp_scores.sum(dim=-1, keepdim=True)
94
95 # Update output
96 # Scale old output and add new contribution
97 output = old_scale * output + new_scale * torch.matmul(exp_scores, v_block)
98
99 # Update statistics
100 row_max = new_max
101 row_sum = new_row_sum
102
103 # Final normalization
104 output = output / row_sum
105
106 return self.dropout(output)
107
108
109def compare_attention_methods():
110 """
111 Compare standard vs flash attention.
112 """
113 print("Comparing Standard vs Flash Attention")
114 print("=" * 60)
115
116 # Small example for demonstration
117 batch_size = 2
118 num_heads = 4
119 seq_len = 256
120 head_dim = 64
121
122 # Random inputs
123 torch.manual_seed(42)
124 query = torch.randn(batch_size, num_heads, seq_len, head_dim)
125 key = torch.randn(batch_size, num_heads, seq_len, head_dim)
126 value = torch.randn(batch_size, num_heads, seq_len, head_dim)
127
128 # Standard attention
129 def standard_attention(q, k, v):
130 scale = 1.0 / math.sqrt(q.size(-1))
131 scores = torch.matmul(q, k.transpose(-2, -1)) * scale
132 attn_weights = torch.softmax(scores, dim=-1)
133 return torch.matmul(attn_weights, v)
134
135 # Flash attention
136 flash_attn = SimplifiedFlashAttention(block_size=64)
137
138 # Compute both
139 standard_output = standard_attention(query, key, value)
140 flash_output = flash_attn(query, key, value)
141
142 # Compare
143 max_diff = (standard_output - flash_output).abs().max()
144 mean_diff = (standard_output - flash_output).abs().mean()
145
146 print(f"\nInput shape: [{batch_size}, {num_heads}, {seq_len}, {head_dim}]")
147 print(f"Block size: 64")
148 print(f"\nOutput comparison:")
149 print(f" Max difference: {max_diff:.6e}")
150 print(f" Mean difference: {mean_diff:.6e}")
151 print(f" Match: {'ā Yes' if max_diff < 1e-5 else 'ā No'}")
152
153 # Memory comparison
154 standard_mem = batch_size * num_heads * seq_len * seq_len * 4 # float32
155 flash_mem = batch_size * num_heads * seq_len * 64 * 4 # block size
156
157 print(f"\nMemory comparison:")
158 print(f" Standard attention matrix: {standard_mem / 1024:.1f} KB")
159 print(f" Flash attention (per block): {flash_mem / 1024:.1f} KB")
160 print(f" Reduction: {standard_mem / flash_mem:.1f}x")
161
162
163compare_attention_methods()1.4 Using Flash Attention in Practice
PyTorch Native and FlashAttention Library
Option 1: PyTorch Native (2.0+)
PyTorch 2.0+ includes scaled_dot_product_attention with automatic Flash Attention backend selection.
1import torch
2import torch.nn.functional as F
3
4# Check available backends
5print(torch.backends.cuda.flash_sdp_enabled()) # Flash Attention
6print(torch.backends.cuda.mem_efficient_sdp_enabled()) # Memory efficient
7print(torch.backends.cuda.math_sdp_enabled()) # Standard math
8
9# Use scaled_dot_product_attention
10# Automatically selects best backend!
11query = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
12key = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
13value = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
14
15# This automatically uses Flash Attention if:
16# - GPU supports it (Ampere, Ada, Hopper)
17# - Inputs are float16 or bfloat16
18# - No explicit attention mask (or causal mask)
19output = F.scaled_dot_product_attention(
20 query, key, value,
21 attn_mask=None,
22 dropout_p=0.0,
23 is_causal=False # Set True for decoder
24)
25
26# Force specific backend
27with torch.backends.cuda.sdp_kernel(
28 enable_flash=True,
29 enable_math=False,
30 enable_mem_efficient=False
31):
32 output = F.scaled_dot_product_attention(query, key, value)Option 2: FlashAttention Library
Original implementation by Tri Dao, more features.
1# Installation
2pip install flash-attn --no-build-isolation1from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
2
3# Basic usage
4# Input shape: [batch, seq_len, num_heads, head_dim]
5# Note: Different from PyTorch convention!
6query = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
7key = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
8value = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
9
10# Flash attention
11output = flash_attn_func(
12 query, key, value,
13 dropout_p=0.0,
14 softmax_scale=None, # Auto: 1/sqrt(head_dim)
15 causal=False
16)
17
18# For QKV packed together
19qkv = torch.randn(2, 1024, 3, 8, 64, device='cuda', dtype=torch.float16)
20output = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, causal=True)Option 3: xFormers Library
1# Installation
2pip install xformers1from xformers.ops import memory_efficient_attention
2
3# Input shape: [batch, seq_len, num_heads, head_dim]
4query = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
5key = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
6value = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
7
8output = memory_efficient_attention(query, key, value)
9
10# With attention bias (supports more masks)
11from xformers.ops import LowerTriangularMask
12output = memory_efficient_attention(
13 query, key, value,
14 attn_bias=LowerTriangularMask() # Causal
15)Integration into Transformer
1class FlashMultiHeadAttention(nn.Module):
2 """Multi-head attention using Flash Attention."""
3
4 def __init__(
5 self,
6 d_model: int,
7 num_heads: int,
8 dropout: float = 0.0,
9 bias: bool = True
10 ):
11 super().__init__()
12 self.d_model = d_model
13 self.num_heads = num_heads
14 self.head_dim = d_model // num_heads
15 self.dropout = dropout
16
17 self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=bias)
18 self.out_proj = nn.Linear(d_model, d_model, bias=bias)
19
20 def forward(
21 self,
22 x: torch.Tensor,
23 attn_mask: torch.Tensor = None,
24 is_causal: bool = False
25 ) -> torch.Tensor:
26 batch_size, seq_len, _ = x.shape
27
28 # Project to Q, K, V
29 qkv = self.qkv_proj(x)
30 qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
31
32 # Permute for flash attention: [B, S, 3, H, D] -> [B, S, H, D] for each
33 q, k, v = qkv.unbind(dim=2)
34
35 # Use PyTorch native flash attention
36 # Reshape to [B, H, S, D] for PyTorch
37 q = q.transpose(1, 2)
38 k = k.transpose(1, 2)
39 v = v.transpose(1, 2)
40
41 # Flash attention
42 output = F.scaled_dot_product_attention(
43 q, k, v,
44 attn_mask=attn_mask,
45 dropout_p=self.dropout if self.training else 0.0,
46 is_causal=is_causal
47 )
48
49 # Reshape back: [B, H, S, D] -> [B, S, H*D]
50 output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
51
52 return self.out_proj(output)1.5 Performance Comparison
Benchmarks and Speed Improvements
Configuration: batch=8, heads=12, head_dim=64, float16 on A100 80GB
| Seq Length | Standard (ms) | Flash v2 (ms) | Speedup | Memory Reduction |
|---|---|---|---|---|
| 512 | 2.1 | 0.8 | 2.6x | 4x |
| 1024 | 7.5 | 1.9 | 3.9x | 8x |
| 2048 | 28.3 | 5.2 | 5.4x | 16x |
| 4096 | 108.7 | 14.8 | 7.3x | 32x |
| 8192 | 421.5 | 42.1 | 10.0x | 64x |
| 16384 | OOM | 142.3 | ā | 128x |
Key observations:
⢠Speedup increases with sequence length
⢠Memory savings are dramatic
⢠Enables sequences that don't fit with standard attention
Training Throughput (GPT-2 style model)
| Model Size | Seq Length | Standard | Flash | Improvement |
|---|---|---|---|---|
| 125M | 1024 | 48k tok | 82k | +71% |
| 350M | 1024 | 28k tok | 51k | +82% |
| 1.3B | 2048 | 8k tok | 18k | +125% |
| 2.7B | 2048 | OOM | 9k | ā |
When to Use Flash Attention:
ā Use when:
⢠Sequence length > 512
⢠Training large models
⢠Memory constrained
⢠Using float16 or bfloat16
⢠GPU supports it (Ampere+)
ā May not help when:
⢠Very short sequences (<256)
⢠Need custom attention patterns
⢠Using float32 (limited support)
⢠CPU inference
⢠Older GPUs (Volta, Turing)
Summary
Flash Attention Key Points:
| Aspect | Standard Attention | Flash Attention |
|---|---|---|
| Memory | O(n²) | O(n) |
| Speed | Slower | 2-10x faster |
| Max seq length | Limited by memory | Much longer |
| Implementation | Simple | Complex (CUDA) |
| GPU requirement | Any | Ampere+ (best) |
Usage Recommendations:
For PyTorch 2.0+: Use torch.nn.functional.scaled_dot_product_attention() - Automatically selects best backend!
For maximum features: pip install flash-attn - More flexible masking, better performance
For xFormers users: xformers.ops.memory_efficient_attention() - Good integration with transformers
Exercises:
1. Run the simplified Flash Attention and verify it matches standard attention.
2. Benchmark PyTorch's scaled_dot_product_attention with different sequence lengths.
3. Modify the TransformerEncoder from Chapter 7 to use Flash Attention.
4. Compare memory usage between standard and Flash Attention for seq_len=4096.
5. Research Flash Attention v2 improvements over v1.
Next Section: In Section 2, we'll explore Mixture of Experts (MoE), which enables training much larger models efficiently.