Chapter 16
20 min read
Section 71 of 75

Flash Attention

Advanced Architectures

Introduction

This section covers Flash Attention, a memory-efficient attention algorithm that significantly speeds up transformer training and inference while reducing memory usage. We'll understand the algorithm conceptually and learn how to use it in practice.


1.1 The Problem with Standard Attention

Memory Bottleneck

šŸpython
1import torch
2import torch.nn as nn
3import math
4from typing import Optional, Tuple
5
6
7def attention_memory_analysis():
8    """
9    Analyze memory usage of standard attention.
10    """
11    print("=" * 70)
12    print("STANDARD ATTENTION: MEMORY ANALYSIS")
13    print("=" * 70)
14
15    print("""
16    STANDARD ATTENTION ALGORITHM:
17    ─────────────────────────────
18
19    Input: Q, K, V each of shape [batch, heads, seq_len, head_dim]
20
21    Step 1: Compute attention scores
22            S = Q @ K^T
23            Shape: [batch, heads, seq_len, seq_len]
24            Memory: O(n²) where n = seq_len
25
26    Step 2: Apply softmax
27            P = softmax(S / sqrt(d))
28            Shape: [batch, heads, seq_len, seq_len]
29            Memory: O(n²)
30
31    Step 3: Compute output
32            O = P @ V
33            Shape: [batch, heads, seq_len, head_dim]
34            Memory: O(n * d)
35
36    Total memory for attention scores: O(n²)
37
38
39    MEMORY USAGE EXAMPLES:
40    ──────────────────────
41
42    Configuration:
43    - batch_size = 8
44    - num_heads = 12
45    - head_dim = 64
46    - dtype = float16 (2 bytes)
47    """)
48
49    batch_size = 8
50    num_heads = 12
51    head_dim = 64
52    bytes_per_element = 2  # float16
53
54    seq_lengths = [512, 1024, 2048, 4096, 8192, 16384]
55
56    print(f"{'Seq Length':<12} {'Attention Matrix':<20} {'Memory (GB)':<15} {'Status'}")
57    print("-" * 60)
58
59    for seq_len in seq_lengths:
60        # Attention matrix: [batch, heads, seq, seq]
61        attention_elements = batch_size * num_heads * seq_len * seq_len
62        attention_bytes = attention_elements * bytes_per_element
63        attention_gb = attention_bytes / (1024 ** 3)
64
65        # Status based on typical GPU memory
66        if attention_gb < 4:
67            status = "āœ“ Fits in 8GB GPU"
68        elif attention_gb < 12:
69            status = "⚠ Needs 16GB+ GPU"
70        elif attention_gb < 24:
71            status = "⚠ Needs 24GB+ GPU"
72        else:
73            status = "āœ— Too large for consumer GPUs"
74
75        print(f"{seq_len:<12} {seq_len}x{seq_len:<14} {attention_gb:.2f}            {status}")
76
77    print("""
78
79    THE QUADRATIC PROBLEM:
80    ──────────────────────
81
82    Memory grows as O(n²):
83    - Double sequence length → 4x memory
84    - 4x sequence length → 16x memory
85
86    This limits:
87    • Maximum sequence length
88    • Batch size
89    • Model size that fits in memory
90
91    For long documents, code, or conversations, standard attention
92    becomes impractical!
93    """)
94
95
96attention_memory_analysis()

1.2 Flash Attention Algorithm

Core Idea: Tiling and Recomputation

šŸpython
1def flash_attention_explained():
2    """
3    Explain Flash Attention algorithm.
4    """
5    print("=" * 70)
6    print("FLASH ATTENTION ALGORITHM")
7    print("=" * 70)
8
9    print("""
10    KEY INSIGHT: Memory Hierarchy
11    ─────────────────────────────
12
13    GPU Memory Hierarchy:
14    ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”
15    │  HBM (High Bandwidth Memory) - 40-80 GB                        │
16    │  └── Slow to access (~1.5 TB/s)                                │
17    │      └── Where model weights & activations stored              │
18    │                                                                 │
19    │  SRAM (On-chip memory) - 20 MB per SM                          │
20    │  └── Very fast (~19 TB/s)                                      │
21    │      └── Where computation happens                             │
22    │                                                                 │
23    │  Speed difference: SRAM is ~10x faster than HBM!               │
24    ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜
25
26    Standard Attention Problem:
27    - Computes full NƗN attention matrix
28    - Must store it in slow HBM
29    - Then load it back for softmax and matmul
30    - Many slow memory transfers
31
32
33    FLASH ATTENTION SOLUTION: Tiling
34    ─────────────────────────────────
35
36    Instead of computing full attention matrix:
37
38    Standard:
39    ā”Œā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”
40    │                                       │
41    │         N Ɨ N Attention Matrix        │
42    │         (Stored in HBM)               │
43    │                                       │
44    ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜
45
46    Flash Attention:
47    ā”Œā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”¬ā”€ā”€ā”€ā”€ā”
48    │ B₁ │ Bā‚‚ │ Bā‚ƒ │ Bā‚„ │ Bā‚… │ B₆ │ B₇ │ Bā‚ˆ │ ← Computed one at a time
49    ā”œā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¤   in SRAM
50    │ B₉ │ B₁₀│ B₁₁│ B₁₂│ Bā‚ā‚ƒā”‚ B₁₄│ B₁₅│ B₁₆│
51    ā”œā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¼ā”€ā”€ā”€ā”€ā”¤
52    │...                                    │
53    ā””ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜
54
55    Process blocks sequentially, keeping each block in fast SRAM.
56    Never materialize full attention matrix!
57
58
59    ALGORITHM SKETCH:
60    ─────────────────
61
62    Split Q, K, V into blocks:
63    Q = [Q₁, Qā‚‚, ..., Qįµ£]  (block rows)
64    K = [K₁, Kā‚‚, ..., Kc]  (block columns)
65    V = [V₁, Vā‚‚, ..., Vc]  (block columns)
66
67    For each Q block:
68        Initialize output O = 0, normalizer l = 0, max m = -āˆž
69
70        For each K, V block:
71            # Compute local attention
72            S = Qįµ¢ @ Kⱼᵀ / √d
73
74            # Online softmax with running max
75            m_new = max(m, max(S))
76            P = exp(S - m_new)
77            l_new = exp(m - m_new) * l + sum(P)
78
79            # Update output
80            O = exp(m - m_new) * O + P @ Vā±¼
81
82            m = m_new
83            l = l_new
84
85        O = O / l  # Normalize
86
87
88    THE TRICK: Online Softmax
89    ─────────────────────────
90
91    Normal softmax needs all values to compute max and sum.
92
93    Online softmax computes incrementally:
94    1. Track running maximum (m)
95    2. Track running sum of exponentials (l)
96    3. Rescale previous results when max changes
97
98    This allows processing blocks without storing full matrix!
99    """)
100
101
102flash_attention_explained()

1.3 Simplified Flash Attention Implementation

Educational Implementation

šŸpython
1class SimplifiedFlashAttention(nn.Module):
2    """
3    Simplified Flash Attention implementation for education.
4
5    Note: This is NOT optimized like the real implementation.
6    The actual Flash Attention uses custom CUDA kernels.
7    """
8
9    def __init__(
10        self,
11        block_size: int = 64,
12        dropout: float = 0.0
13    ):
14        """
15        Initialize Flash Attention.
16
17        Args:
18            block_size: Size of blocks for tiling
19            dropout: Dropout probability
20        """
21        super().__init__()
22        self.block_size = block_size
23        self.dropout = nn.Dropout(dropout)
24
25    def forward(
26        self,
27        query: torch.Tensor,
28        key: torch.Tensor,
29        value: torch.Tensor,
30        attn_mask: Optional[torch.Tensor] = None
31    ) -> torch.Tensor:
32        """
33        Flash attention forward pass.
34
35        Args:
36            query: [batch, heads, seq_len, head_dim]
37            key: [batch, heads, seq_len, head_dim]
38            value: [batch, heads, seq_len, head_dim]
39            attn_mask: Optional attention mask
40
41        Returns:
42            Output [batch, heads, seq_len, head_dim]
43        """
44        batch_size, num_heads, seq_len, head_dim = query.shape
45        scale = 1.0 / math.sqrt(head_dim)
46
47        # Initialize output and statistics
48        output = torch.zeros_like(query)
49        row_max = torch.full(
50            (batch_size, num_heads, seq_len, 1),
51            float('-inf'),
52            device=query.device,
53            dtype=query.dtype
54        )
55        row_sum = torch.zeros(
56            (batch_size, num_heads, seq_len, 1),
57            device=query.device,
58            dtype=query.dtype
59        )
60
61        # Number of blocks
62        num_blocks = (seq_len + self.block_size - 1) // self.block_size
63
64        # Process K, V blocks
65        for j in range(num_blocks):
66            # Get K, V block
67            kv_start = j * self.block_size
68            kv_end = min((j + 1) * self.block_size, seq_len)
69
70            k_block = key[:, :, kv_start:kv_end, :]  # [B, H, block, d]
71            v_block = value[:, :, kv_start:kv_end, :]
72
73            # Compute attention scores for this K block with all Q
74            # S: [B, H, seq_len, block_size]
75            scores = torch.matmul(query, k_block.transpose(-2, -1)) * scale
76
77            # Apply mask if provided
78            if attn_mask is not None:
79                mask_block = attn_mask[:, :, :, kv_start:kv_end]
80                scores = scores.masked_fill(~mask_block, float('-inf'))
81
82            # Online softmax update
83            # New maximum
84            block_max = scores.max(dim=-1, keepdim=True).values
85            new_max = torch.maximum(row_max, block_max)
86
87            # Rescale old values
88            old_scale = torch.exp(row_max - new_max)
89            new_scale = torch.exp(block_max - new_max)
90
91            # Update sum
92            exp_scores = torch.exp(scores - block_max)
93            new_row_sum = old_scale * row_sum + new_scale * exp_scores.sum(dim=-1, keepdim=True)
94
95            # Update output
96            # Scale old output and add new contribution
97            output = old_scale * output + new_scale * torch.matmul(exp_scores, v_block)
98
99            # Update statistics
100            row_max = new_max
101            row_sum = new_row_sum
102
103        # Final normalization
104        output = output / row_sum
105
106        return self.dropout(output)
107
108
109def compare_attention_methods():
110    """
111    Compare standard vs flash attention.
112    """
113    print("Comparing Standard vs Flash Attention")
114    print("=" * 60)
115
116    # Small example for demonstration
117    batch_size = 2
118    num_heads = 4
119    seq_len = 256
120    head_dim = 64
121
122    # Random inputs
123    torch.manual_seed(42)
124    query = torch.randn(batch_size, num_heads, seq_len, head_dim)
125    key = torch.randn(batch_size, num_heads, seq_len, head_dim)
126    value = torch.randn(batch_size, num_heads, seq_len, head_dim)
127
128    # Standard attention
129    def standard_attention(q, k, v):
130        scale = 1.0 / math.sqrt(q.size(-1))
131        scores = torch.matmul(q, k.transpose(-2, -1)) * scale
132        attn_weights = torch.softmax(scores, dim=-1)
133        return torch.matmul(attn_weights, v)
134
135    # Flash attention
136    flash_attn = SimplifiedFlashAttention(block_size=64)
137
138    # Compute both
139    standard_output = standard_attention(query, key, value)
140    flash_output = flash_attn(query, key, value)
141
142    # Compare
143    max_diff = (standard_output - flash_output).abs().max()
144    mean_diff = (standard_output - flash_output).abs().mean()
145
146    print(f"\nInput shape: [{batch_size}, {num_heads}, {seq_len}, {head_dim}]")
147    print(f"Block size: 64")
148    print(f"\nOutput comparison:")
149    print(f"  Max difference: {max_diff:.6e}")
150    print(f"  Mean difference: {mean_diff:.6e}")
151    print(f"  Match: {'āœ“ Yes' if max_diff < 1e-5 else 'āœ— No'}")
152
153    # Memory comparison
154    standard_mem = batch_size * num_heads * seq_len * seq_len * 4  # float32
155    flash_mem = batch_size * num_heads * seq_len * 64 * 4  # block size
156
157    print(f"\nMemory comparison:")
158    print(f"  Standard attention matrix: {standard_mem / 1024:.1f} KB")
159    print(f"  Flash attention (per block): {flash_mem / 1024:.1f} KB")
160    print(f"  Reduction: {standard_mem / flash_mem:.1f}x")
161
162
163compare_attention_methods()

1.4 Using Flash Attention in Practice

PyTorch Native and FlashAttention Library

Option 1: PyTorch Native (2.0+)

PyTorch 2.0+ includes scaled_dot_product_attention with automatic Flash Attention backend selection.

šŸpython
1import torch
2import torch.nn.functional as F
3
4# Check available backends
5print(torch.backends.cuda.flash_sdp_enabled())  # Flash Attention
6print(torch.backends.cuda.mem_efficient_sdp_enabled())  # Memory efficient
7print(torch.backends.cuda.math_sdp_enabled())  # Standard math
8
9# Use scaled_dot_product_attention
10# Automatically selects best backend!
11query = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
12key = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
13value = torch.randn(2, 8, 1024, 64, device='cuda', dtype=torch.float16)
14
15# This automatically uses Flash Attention if:
16# - GPU supports it (Ampere, Ada, Hopper)
17# - Inputs are float16 or bfloat16
18# - No explicit attention mask (or causal mask)
19output = F.scaled_dot_product_attention(
20    query, key, value,
21    attn_mask=None,
22    dropout_p=0.0,
23    is_causal=False  # Set True for decoder
24)
25
26# Force specific backend
27with torch.backends.cuda.sdp_kernel(
28    enable_flash=True,
29    enable_math=False,
30    enable_mem_efficient=False
31):
32    output = F.scaled_dot_product_attention(query, key, value)

Option 2: FlashAttention Library

Original implementation by Tri Dao, more features.

⚔bash
1# Installation
2pip install flash-attn --no-build-isolation
šŸpython
1from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
2
3# Basic usage
4# Input shape: [batch, seq_len, num_heads, head_dim]
5# Note: Different from PyTorch convention!
6query = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
7key = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
8value = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
9
10# Flash attention
11output = flash_attn_func(
12    query, key, value,
13    dropout_p=0.0,
14    softmax_scale=None,  # Auto: 1/sqrt(head_dim)
15    causal=False
16)
17
18# For QKV packed together
19qkv = torch.randn(2, 1024, 3, 8, 64, device='cuda', dtype=torch.float16)
20output = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, causal=True)

Option 3: xFormers Library

⚔bash
1# Installation
2pip install xformers
šŸpython
1from xformers.ops import memory_efficient_attention
2
3# Input shape: [batch, seq_len, num_heads, head_dim]
4query = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
5key = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
6value = torch.randn(2, 1024, 8, 64, device='cuda', dtype=torch.float16)
7
8output = memory_efficient_attention(query, key, value)
9
10# With attention bias (supports more masks)
11from xformers.ops import LowerTriangularMask
12output = memory_efficient_attention(
13    query, key, value,
14    attn_bias=LowerTriangularMask()  # Causal
15)

Integration into Transformer

šŸpython
1class FlashMultiHeadAttention(nn.Module):
2    """Multi-head attention using Flash Attention."""
3
4    def __init__(
5        self,
6        d_model: int,
7        num_heads: int,
8        dropout: float = 0.0,
9        bias: bool = True
10    ):
11        super().__init__()
12        self.d_model = d_model
13        self.num_heads = num_heads
14        self.head_dim = d_model // num_heads
15        self.dropout = dropout
16
17        self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=bias)
18        self.out_proj = nn.Linear(d_model, d_model, bias=bias)
19
20    def forward(
21        self,
22        x: torch.Tensor,
23        attn_mask: torch.Tensor = None,
24        is_causal: bool = False
25    ) -> torch.Tensor:
26        batch_size, seq_len, _ = x.shape
27
28        # Project to Q, K, V
29        qkv = self.qkv_proj(x)
30        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
31
32        # Permute for flash attention: [B, S, 3, H, D] -> [B, S, H, D] for each
33        q, k, v = qkv.unbind(dim=2)
34
35        # Use PyTorch native flash attention
36        # Reshape to [B, H, S, D] for PyTorch
37        q = q.transpose(1, 2)
38        k = k.transpose(1, 2)
39        v = v.transpose(1, 2)
40
41        # Flash attention
42        output = F.scaled_dot_product_attention(
43            q, k, v,
44            attn_mask=attn_mask,
45            dropout_p=self.dropout if self.training else 0.0,
46            is_causal=is_causal
47        )
48
49        # Reshape back: [B, H, S, D] -> [B, S, H*D]
50        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
51
52        return self.out_proj(output)

1.5 Performance Comparison

Benchmarks and Speed Improvements

Configuration: batch=8, heads=12, head_dim=64, float16 on A100 80GB

Seq LengthStandard (ms)Flash v2 (ms)SpeedupMemory Reduction
5122.10.82.6x4x
10247.51.93.9x8x
204828.35.25.4x16x
4096108.714.87.3x32x
8192421.542.110.0x64x
16384OOM142.3āˆž128x

Key observations:

• Speedup increases with sequence length
• Memory savings are dramatic
• Enables sequences that don't fit with standard attention

Training Throughput (GPT-2 style model)

Model SizeSeq LengthStandardFlashImprovement
125M102448k tok82k+71%
350M102428k tok51k+82%
1.3B20488k tok18k+125%
2.7B2048OOM9kāˆž

When to Use Flash Attention:

āœ“ Use when:

• Sequence length > 512
• Training large models
• Memory constrained
• Using float16 or bfloat16
• GPU supports it (Ampere+)

āœ— May not help when:

• Very short sequences (<256)
• Need custom attention patterns
• Using float32 (limited support)
• CPU inference
• Older GPUs (Volta, Turing)


Summary

Flash Attention Key Points:

AspectStandard AttentionFlash Attention
MemoryO(n²)O(n)
SpeedSlower2-10x faster
Max seq lengthLimited by memoryMuch longer
ImplementationSimpleComplex (CUDA)
GPU requirementAnyAmpere+ (best)

Usage Recommendations:

For PyTorch 2.0+: Use torch.nn.functional.scaled_dot_product_attention() - Automatically selects best backend!

For maximum features: pip install flash-attn - More flexible masking, better performance

For xFormers users: xformers.ops.memory_efficient_attention() - Good integration with transformers

Exercises:

1. Run the simplified Flash Attention and verify it matches standard attention.
2. Benchmark PyTorch's scaled_dot_product_attention with different sequence lengths.
3. Modify the TransformerEncoder from Chapter 7 to use Flash Attention.
4. Compare memory usage between standard and Flash Attention for seq_len=4096.
5. Research Flash Attention v2 improvements over v1.

Next Section: In Section 2, we'll explore Mixture of Experts (MoE), which enables training much larger models efficiently.

Loading comments...