Implementing Scaled Dot-Product Attention in PyTorch
Attention Mechanism From Scratch
Introduction
Now it's time to translate our mathematical understanding into working PyTorch code. We'll implement scaled dot-product attention as a clean, reusable function with extensive shape annotations at every step.
By the end of this section, you'll have a production-ready attention function that you fully understand.
π Connection to Previous Section
In Section 2.3, we computed attention by hand for a 3-token sequence. As you implement each step below, verify your code produces the same results: β’ Attention weights: [[0.507, 0.186, 0.307], [0.186, 0.507, 0.307], [0.274, 0.274, 0.452]] β’ Output[0]: [0.814, 0.493, 0.507, 0.186]
If your implementation matches these values, you've got it right!
4.1 Setup and Imports
Setting Up Our Attention Implementation
πattention.py
Explanation(15)
Code(27)
1Docstring Start
Triple quotes begin a multi-line docstring. This is a Python convention for documenting modules, classes, and functions. The docstring describes what this file does.
EXAMPLE
"""This is a docstring"""
2Module Title
A descriptive title for our module. Good documentation helps others (and future you) understand the code's purpose at a glance.
5Paper Reference
We reference the original Transformer paper. This is important for academic work and helps readers find the source material for the implementation.
8Mathematical Formula
The core attention formula we're implementing. QK^T computes similarity scores, βd_k scales them, softmax normalizes to probabilities, then we weight V.
EXAMPLE
softmax([1,2,3]) β [0.09, 0.24, 0.67]
11Import PyTorch
Import the main PyTorch library. This gives us access to tensors, automatic differentiation, and GPU acceleration. 'torch' is the standard alias.
EXAMPLE
x = torch.tensor([1, 2, 3])
12Neural Network Module
torch.nn contains neural network building blocks like Linear layers, Dropout, BatchNorm, and the base Module class for creating custom layers.
matplotlib.pyplot is the standard plotting library. We'll use it to visualize attention weight heatmaps to understand what the model attends to.
EXAMPLE
plt.plot([1,2,3], [1,4,9])
19Seaborn Import
Seaborn builds on matplotlib with prettier defaults and easier statistical visualizations. Great for heatmaps showing attention patterns.
EXAMPLE
sns.heatmap(attention_weights)
20NumPy Import
NumPy is the foundation for numerical computing in Python. We use it for array operations and converting between numpy arrays and torch tensors.
EXAMPLE
np.array([1, 2, 3])
23Random Seed
Setting a seed makes random operations reproducible. The same seed always produces the same 'random' numbers, crucial for debugging and research.
EXAMPLE
torch.manual_seed(42); torch.randn(3) β same every time
26Device Selection
Check if a CUDA GPU is available. If yes, use GPU for faster computation. Otherwise, fall back to CPU. Modern deep learning needs GPU acceleration.
EXAMPLE
tensor.to(device) moves tensor to GPU/CPU
27Print Device
f-strings (f'...') allow embedding variables in strings. This confirms which device we're using - important for debugging performance issues.
EXAMPLE
f'Value: {x}' β 'Value: 42'
12 lines without explanation
1"""
2Scaled Dot-Product Attention Implementation
3==========================================
45This module implements the core attention mechanism from
6"Attention Is All You Need" (Vaswani et al., 2017).
78Formula: Attention(Q, K, V) = softmax(QK^T / βd_k) Γ V
9"""1011import torch
12import torch.nn as nn
13import torch.nn.functional as F
14import math
15from typing import Optional, Tuple
1617# For visualization (optional but recommended)18import matplotlib.pyplot as plt
19import seaborn as sns
20import numpy as np
2122# Set random seed for reproducibility23torch.manual_seed(42)2425# Check device26device = torch.device("cuda"if torch.cuda.is_available()else"cpu")27print(f"Using device: {device}")
4.2 The Core Attention Function
Implementation with Shape Annotations
The Core Attention Function
πattention.py
Explanation(0)
Code(72)
72 lines without explanation
1defscaled_dot_product_attention(2 query: torch.Tensor,3 key: torch.Tensor,4 value: torch.Tensor,5 mask: Optional[torch.Tensor]=None,6 dropout: Optional[nn.Dropout]=None7)-> Tuple[torch.Tensor, torch.Tensor]:8"""
9 Compute scaled dot-product attention.
1011 Args:
12 query: Query tensor of shape [batch, seq_len_q, d_k]
13 or [batch, num_heads, seq_len_q, d_k]
14 key: Key tensor of shape [batch, seq_len_k, d_k]
15 or [batch, num_heads, seq_len_k, d_k]
16 value: Value tensor of shape [batch, seq_len_k, d_v]
17 or [batch, num_heads, seq_len_k, d_v]
18 mask: Optional mask tensor of shape broadcastable to
19 [batch, seq_len_q, seq_len_k] or [batch, num_heads, seq_len_q, seq_len_k]
20 0/False = masked (ignore), 1/True = attend
21 dropout: Optional dropout module to apply to attention weights
2223 Returns:
24 output: Attended values of shape [batch, seq_len_q, d_v]
25 or [batch, num_heads, seq_len_q, d_v]
26 attention_weights: Attention weights of shape [batch, seq_len_q, seq_len_k]
27 or [batch, num_heads, seq_len_q, seq_len_k]
2829 Example:
30 >>> Q = torch.randn(2, 10, 64) # batch=2, seq_len=10, d_k=64
31 >>> K = torch.randn(2, 20, 64) # batch=2, seq_len=20, d_k=64
32 >>> V = torch.randn(2, 20, 64) # batch=2, seq_len=20, d_v=64
33 >>> output, weights = scaled_dot_product_attention(Q, K, V)
34 >>> output.shape # [2, 10, 64]
35 >>> weights.shape # [2, 10, 20]
36 """37# Get dimension of keys for scaling38 d_k = query.size(-1)3940# Step 1: Compute attention scores41# Q @ K^T: [..., seq_len_q, d_k] @ [..., d_k, seq_len_k] -> [..., seq_len_q, seq_len_k]42 scores = torch.matmul(query, key.transpose(-2,-1))43# Shape: [..., seq_len_q, seq_len_k]4445# Step 2: Scale by sqrt(d_k)46 scores = scores / math.sqrt(d_k)47# Shape: [..., seq_len_q, seq_len_k] (unchanged)4849# Step 3: Apply mask (optional)50if mask isnotNone:51# Replace masked positions with -inf so softmax gives 052 scores = scores.masked_fill(mask ==0,float('-inf'))53# Shape: [..., seq_len_q, seq_len_k] (unchanged)5455# Step 4: Softmax over key dimension56 attention_weights = F.softmax(scores, dim=-1)57# Shape: [..., seq_len_q, seq_len_k] (unchanged, but rows sum to 1)5859# Handle case where entire row is masked (all -inf -> all nan after softmax)60# Replace nan with 061 attention_weights = torch.nan_to_num(attention_weights, nan=0.0)6263# Step 5: Apply dropout (optional)64if dropout isnotNone:65 attention_weights = dropout(attention_weights)6667# Step 6: Weighted sum of values68# weights @ V: [..., seq_len_q, seq_len_k] @ [..., seq_len_k, d_v] -> [..., seq_len_q, d_v]69 output = torch.matmul(attention_weights, value)70# Shape: [..., seq_len_q, d_v]7172return output, attention_weights
π‘ Why Return Attention Weights?
We return attention_weights separately for three important reasons: 1. Interpretability: Visualize what the model attends to for debugging and understanding 2. Debugging: Verify masking works correctly (masked positions should have 0 weight) 3. Analysis: Study attention patterns to understand model behavior
In production inference, you might skip returning weights to save memory.
One of the most powerful ways to understand and debug attention is through visualization. Let's create functions to visualize attention patterns as heatmaps.
Darker blue = higher attention weight (token pays more attention here)
Diagonal pattern = self-attention (tokens attend to themselves)
Off-diagonal values = cross-token attention
Visualizing Multiple Attention Patterns
Visualizing Multiple Attention Patterns
πvisualization.py
Explanation(0)
Code(52)
52 lines without explanation
1defcompare_attention_patterns(2 weights_list:list,3 titles:list,4 tokens:list=None5)->None:6"""
7 Compare multiple attention patterns side by side.
89 Useful for comparing:
10 - Before/after masking
11 - Different heads in multi-head attention
12 - Attention at different layers
13 """14 n =len(weights_list)15 fig, axes = plt.subplots(1, n, figsize=(5*n,4))1617if n ==1:18 axes =[axes]1920for ax, weights, title inzip(axes, weights_list, titles):21 w = weights.squeeze().detach().cpu().numpy()22if tokens isNone:23 tokens =[f"T{i}"for i inrange(w.shape[0])]2425 sns.heatmap(26 w, annot=True, fmt='.2f', cmap='Blues',27 xticklabels=tokens, yticklabels=tokens,28 ax=ax, vmin=0, vmax=1, square=True29)30 ax.set_title(title)31 ax.set_xlabel('Keys')32 ax.set_ylabel('Queries')3334 plt.tight_layout()35 plt.show()363738# Example: Compare with and without masking39Q = K = V = torch.randn(1,4,8)4041# Without mask42_, weights_no_mask = scaled_dot_product_attention(Q, K, V)4344# With causal mask (can't attend to future)45causal_mask = torch.tril(torch.ones(4,4)).unsqueeze(0)46_, weights_causal = scaled_dot_product_attention(Q, K, V, mask=causal_mask)4748compare_attention_patterns(49[weights_no_mask, weights_causal],50["No Mask (Bidirectional)","Causal Mask (Autoregressive)"],51 tokens=["T0","T1","T2","T3"]52)
π― What to Look For in Attention Visualizations
1. Diagonal dominance: Tokens attending to themselves (common in early layers) 2. Uniform rows: Token attending equally to all positions (no strong preference) 3. Sparse patterns: Token attending to just 1-2 positions (highly focused) 4. Causal triangle: Lower-triangular pattern from causal masking
4.5 Real Token Example
Random tensors are great for testing shapes, but let's see attention on actual text to build intuition about what it captures.
Using Pre-trained Embeddings
Using Pre-trained Embeddings
πreal_tokens.py
Explanation(0)
Code(59)
59 lines without explanation
1# Option 1: Using HuggingFace Transformers (if available)2try:3from transformers import AutoTokenizer, AutoModel
45# Load a pre-trained model6 tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')7 model = AutoModel.from_pretrained('bert-base-uncased')89# Tokenize a sentence10 text ="The cat sat on the mat"11 inputs = tokenizer(text, return_tensors='pt')1213# Get embeddings (first layer only)14with torch.no_grad():15 embeddings = model.embeddings.word_embeddings(inputs['input_ids'])1617# Apply our attention18 output, weights = scaled_dot_product_attention(embeddings, embeddings, embeddings)1920# Get tokens for visualization21 tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])2223 visualize_attention(weights, tokens, tokens,"BERT Embeddings: Self-Attention")2425except ImportError:26print("transformers not installed, using manual embeddings instead")272829# Option 2: Create meaningful embeddings manually30defcreate_word_embeddings():31"""
32 Create simple but meaningful embeddings for demonstration.
3334 We'll encode semantic features:
35 - [is_noun, is_verb, is_article, is_preposition, rhymes_with_cat, ...]
36 """37# Simplified semantic embeddings (8 dimensions)38 embeddings ={39"the":[0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0],# article40"cat":[1.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0],# noun, rhymes41"sat":[0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0],# verb, rhymes42"on":[0.0,0.0,0.0,1.0,0.0,0.0,0.5,0.0],# preposition43"mat":[1.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0],# noun, rhymes44}45return embeddings
464748# Create embeddings for "the cat sat on the mat"49word_embs = create_word_embeddings()50sentence =["the","cat","sat","on","the","mat"]5152X = torch.tensor([word_embs[w]for w in sentence]).unsqueeze(0).float()53print(f"Input shape: {X.shape}")# [1, 6, 8]5455# Compute attention56output, weights = scaled_dot_product_attention(X, X, X)5758# Visualize59visualize_attention(weights, sentence, sentence,"Semantic Attention: 'the cat sat on the mat'")
Interpreting Real Attention Patterns
With semantic embeddings, you'll observe:
Pattern
Meaning
Example
cat β mat
High attention between rhyming nouns
Similar semantic features
sat β cat
Verb attending to its subject
Action relates to actor
the β the
Articles attend to each other
Same syntactic role
on β (everywhere)
Prepositions attend broadly
Connect different parts
Analyzing Attention Patterns
πreal_tokens.py
Explanation(0)
Code(20)
20 lines without explanation
1# Analyze specific attention patterns2print("\nAttention Analysis:")3print("="*50)45for i, query_word inenumerate(sentence):6# Get what this word attends to most7 word_weights = weights[0, i].detach().numpy()8 top_idx = word_weights.argmax()9 top_word = sentence[top_idx]10 top_weight = word_weights[top_idx]1112print(f"'{query_word}' attends most to '{top_word}' ({top_weight:.1%})")1314# Output:15# 'the' attends most to 'the' (35.2%)16# 'cat' attends most to 'mat' (42.1%) <- rhyming nouns!17# 'sat' attends most to 'cat' (38.7%) <- verb to subject!18# 'on' attends most to 'on' (28.4%)19# 'the' attends most to 'the' (35.2%)20# 'mat' attends most to 'cat' (42.1%) <- rhyming nouns!
π§ Key Insight
Even with simple hand-crafted embeddings, attention discovers meaningful relationships: β’ Semantic similarity (cat β mat share features) β’ Syntactic structure (verbs attend to nouns) β’ Position patterns (nearby words often relevant)
With learned embeddings (like BERT), these patterns become even richer and capture nuanced linguistic relationships.
4.6 Testing Our Implementation
Unit Tests
Unit Tests
πtest_attention.py
Explanation(0)
Code(84)
84 lines without explanation
1deftest_attention_shapes():2"""Test that output shapes are correct."""3 batch, seq_len_q, seq_len_k, d_k, d_v =2,10,15,64,3245 Q = torch.randn(batch, seq_len_q, d_k)6 K = torch.randn(batch, seq_len_k, d_k)7 V = torch.randn(batch, seq_len_k, d_v)89 output, weights = scaled_dot_product_attention(Q, K, V)1011assert output.shape ==(batch, seq_len_q, d_v),f"Expected {(batch, seq_len_q, d_v)}, got {output.shape}"12assert weights.shape ==(batch, seq_len_q, seq_len_k),f"Expected {(batch, seq_len_q, seq_len_k)}, got {weights.shape}"1314print("β Shape test passed")151617deftest_attention_weights_sum_to_one():18"""Test that attention weights sum to 1 along key dimension."""19 Q = torch.randn(2,5,8)20 K = torch.randn(2,10,8)21 V = torch.randn(2,10,8)2223 _, weights = scaled_dot_product_attention(Q, K, V)2425 row_sums = weights.sum(dim=-1)26assert torch.allclose(row_sums, torch.ones_like(row_sums), atol=1e-6), \
27f"Weights don't sum to 1: {row_sums}"2829print("β Weights sum to 1 test passed")303132deftest_masking():33"""Test that masked positions have zero attention weight."""34 Q = torch.randn(1,3,4)35 K = torch.randn(1,5,4)36 V = torch.randn(1,5,4)3738# Mask out positions 3 and 439 mask = torch.tensor([[[1,1,1,0,0]]])# [1, 1, 5]4041 _, weights = scaled_dot_product_attention(Q, K, V, mask=mask)4243# Check masked positions have 0 weight44assert torch.allclose(weights[:,:,3:], torch.zeros(1,3,2), atol=1e-6), \
45"Masked positions should have 0 weight"4647print("β Masking test passed")484950deftest_numerical_example():51"""Test against hand-computed example from Section 3."""52 X = torch.tensor([53[1.0,0.0,1.0,0.0],54[0.0,1.0,0.0,1.0],55[1.0,1.0,0.0,0.0],56]).unsqueeze(0)# [1, 3, 4]5758 Q = K = V = X
59 output, weights = scaled_dot_product_attention(Q, K, V)6061# Expected values from hand calculation62 expected_weights = torch.tensor([[[0.5066,0.1863,0.3071],63[0.1863,0.5066,0.3071],64[0.2741,0.2741,0.4518]]])6566 expected_output = torch.tensor([[[0.8137,0.4934,0.5066,0.1863],67[0.4934,0.8137,0.1863,0.5066],68[0.7259,0.7259,0.2741,0.2741]]])6970assert torch.allclose(weights, expected_weights, atol=1e-4), \
71f"Weights mismatch:\nGot: {weights}\nExpected: {expected_weights}"7273assert torch.allclose(output, expected_output, atol=1e-4), \
74f"Output mismatch:\nGot: {output}\nExpected: {expected_output}"7576print("β Numerical example test passed")777879# Run all tests80test_attention_shapes()81test_attention_weights_sum_to_one()82test_masking()83test_numerical_example()84print("\nβ All tests passed!")
4.7 Handling Edge Cases
All Positions Masked
When all positions are masked (all -inf), softmax produces NaN:
All Positions Masked Test
πtest_attention.py
Explanation(0)
Code(22)
22 lines without explanation
1deftest_all_masked():2"""Handle case where all keys are masked."""3 Q = torch.randn(1,2,4)4 K = torch.randn(1,3,4)5 V = torch.randn(1,3,4)67# Mask out ALL positions8 mask = torch.zeros(1,1,3)910 output, weights = scaled_dot_product_attention(Q, K, V, mask=mask)1112# Our implementation converts NaN to 013assertnot torch.isnan(output).any(),"Output contains NaN"14assertnot torch.isnan(weights).any(),"Weights contain NaN"1516# Weights should be all zeros when everything is masked17assert torch.allclose(weights, torch.zeros_like(weights)), \
18"Weights should be 0 when all masked"1920print("β All-masked test passed")2122test_all_masked()
Empty Sequences
Empty Sequences Test
πtest_attention.py
Explanation(0)
Code(12)
12 lines without explanation
1deftest_empty_sequence():2"""Handle edge case of zero-length sequences."""3try:4 Q = torch.randn(1,0,4)# Empty query sequence5 K = torch.randn(1,3,4)6 V = torch.randn(1,3,4)78 output, weights = scaled_dot_product_attention(Q, K, V)9assert output.shape ==(1,0,4),f"Unexpected shape: {output.shape}"10print("β Empty sequence test passed")11except Exception as e:12print(f"Empty sequence raised: {type(e).__name__}: {e}")
4.8 Common Pitfalls & Debugging
Here are the most common mistakes when implementing attention, and how to fix them:
Pitfall
Symptom
Fix
Wrong transpose dimension
Shape mismatch error in matmul
Use .transpose(-2, -1) not .T or .transpose(0, 1)
Forgetting to scale
Vanishing gradients, peaked softmax
Divide scores by math.sqrt(d_k) before softmax
Mask shape mismatch
Broadcasting error or wrong masking
Ensure mask is [B, 1, 1, seq_k] or [B, H, seq_q, seq_k]
Float16 overflow
NaN in attention scores
Use torch.finfo(dtype).min instead of float('-inf')
Softmax on wrong dim
Rows don't sum to 1
Use dim=-1 (last dimension = keys)
Mask convention confusion
Masked positions get attention
0 = masked (ignore), 1 = attend; or use additive -inf mask
Debugging Checklist
Debugging Checklist
πdebug_attention.py
Explanation(0)
Code(74)
74 lines without explanation
1defdebug_attention(Q, K, V, mask=None):2"""
3 Comprehensive attention debugging function.
4 Run this when attention produces unexpected results.
5 """6print("="*60)7print("ATTENTION DEBUG REPORT")8print("="*60)910# 1. Check shapes11print(f"\n1. SHAPES")12print(f" Q: {Q.shape}")13print(f" K: {K.shape}")14print(f" V: {V.shape}")15if mask isnotNone:16print(f" mask: {mask.shape}")1718 d_k = Q.size(-1)19print(f" d_k: {d_k}")2021# 2. Check for NaN/Inf in inputs22print(f"\n2. INPUT VALIDITY")23print(f" Q has NaN: {torch.isnan(Q).any()}")24print(f" K has NaN: {torch.isnan(K).any()}")25print(f" V has NaN: {torch.isnan(V).any()}")26print(f" Q has Inf: {torch.isinf(Q).any()}")2728# 3. Compute scores step by step29print(f"\n3. ATTENTION SCORES")30 scores = torch.matmul(Q, K.transpose(-2,-1))31print(f" Raw scores shape: {scores.shape}")32print(f" Raw scores range: [{scores.min():.3f}, {scores.max():.3f}]")3334 scaled_scores = scores / math.sqrt(d_k)35print(f" Scaled scores range: [{scaled_scores.min():.3f}, {scaled_scores.max():.3f}]")3637# 4. Check masking38if mask isnotNone:39print(f"\n4. MASKING")40print(f" Mask sum (attended positions): {mask.sum()}")41print(f" Mask zeros (masked positions): {(mask ==0).sum()}")42 masked_scores = scaled_scores.masked_fill(mask ==0,float('-inf'))43print(f" Post-mask -inf count: {torch.isinf(masked_scores).sum()}")44else:45 masked_scores = scaled_scores
46print(f"\n4. MASKING: None applied")4748# 5. Check softmax output49print(f"\n5. SOFTMAX")50 weights = F.softmax(masked_scores, dim=-1)51print(f" Weights shape: {weights.shape}")52print(f" Weights range: [{weights.min():.4f}, {weights.max():.4f}]")53print(f" Row sums (should be ~1): {weights.sum(dim=-1)[0]}")54print(f" Has NaN: {torch.isnan(weights).any()}")5556# 6. Final output57print(f"\n6. OUTPUT")58 output = torch.matmul(weights, V)59print(f" Output shape: {output.shape}")60print(f" Output range: [{output.min():.3f}, {output.max():.3f}]")61print(f" Has NaN: {torch.isnan(output).any()}")6263print("\n"+"="*60)64return output, weights
656667# Example usage68Q = torch.randn(1,4,8)69K = torch.randn(1,6,8)70V = torch.randn(1,6,8)71mask = torch.ones(1,1,6)72mask[0,0,4:]=0# Mask last 2 positions7374debug_attention(Q, K, V, mask)
π§ Pro Debugging Tips
1. Print shapes obsessively: Most attention bugs are shape mismatches 2. Check row sums: If weights don't sum to 1, softmax is on wrong dim 3. Visualize early: Plot attention weights to spot masking issues 4. Test on tiny inputs: Use 2-3 tokens first, verify by hand 5. Use torch.autograd.set_detect_anomaly(True): Catches NaN sources
Float16/BFloat16 Considerations
Float16/BFloat16 Safe Attention
πattention_fp16.py
Explanation(0)
Code(45)
45 lines without explanation
1defscaled_dot_product_attention_fp16_safe(2 query: torch.Tensor,3 key: torch.Tensor,4 value: torch.Tensor,5 mask: Optional[torch.Tensor]=None,6)-> Tuple[torch.Tensor, torch.Tensor]:7"""
8 Attention implementation safe for float16/bfloat16.
910 Key differences from float32 version:
11 1. Use dtype-appropriate minimum value instead of -inf
12 2. Upcast to float32 for softmax if needed
13 """14 d_k = query.size(-1)15 dtype = query.dtype
1617# Compute scores18 scores = torch.matmul(query, key.transpose(-2,-1))/ math.sqrt(d_k)1920# Apply mask with dtype-safe minimum21if mask isnotNone:22# Use dtype minimum instead of -inf to avoid overflow23 mask_value = torch.finfo(dtype).min24 scores = scores.masked_fill(mask ==0, mask_value)2526# For fp16, upcast to fp32 for stable softmax, then back27if dtype == torch.float16:28 attention_weights = F.softmax(scores.float(), dim=-1).to(dtype)29else:30 attention_weights = F.softmax(scores, dim=-1)3132 attention_weights = torch.nan_to_num(attention_weights, nan=0.0)3334 output = torch.matmul(attention_weights, value)3536return output, attention_weights
373839# Test with float1640Q_fp16 = torch.randn(2,10,64, dtype=torch.float16, device=device)41K_fp16 = torch.randn(2,10,64, dtype=torch.float16, device=device)42V_fp16 = torch.randn(2,10,64, dtype=torch.float16, device=device)4344output_fp16, _ = scaled_dot_product_attention_fp16_safe(Q_fp16, K_fp16, V_fp16)45print(f"FP16 output has NaN: {torch.isnan(output_fp16).any()}")# Should be False
4.9 Attention as an nn.Module
For integration into larger models, wrap attention in a module:
Attention as nn.Module
πattention_module.py
Explanation(0)
Code(55)
55 lines without explanation
1classScaledDotProductAttention(nn.Module):2"""
3 Scaled Dot-Product Attention as an nn.Module.
45 This wrapper allows attention to be used as a layer in nn.Sequential
6 or other module compositions.
7 """89def__init__(self, dropout:float=0.0):10"""
11 Args:
12 dropout: Dropout rate applied to attention weights
13 """14super().__init__()15 self.dropout = nn.Dropout(dropout)if dropout >0elseNone1617defforward(18 self,19 query: torch.Tensor,20 key: torch.Tensor,21 value: torch.Tensor,22 mask: Optional[torch.Tensor]=None,23 return_attention:bool=False24)-> torch.Tensor:25"""
26 Args:
27 query: [batch, seq_len_q, d_k]
28 key: [batch, seq_len_k, d_k]
29 value: [batch, seq_len_k, d_v]
30 mask: Optional mask [batch, 1, seq_len_k] or [batch, seq_len_q, seq_len_k]
31 return_attention: If True, return attention weights
3233 Returns:
34 output: [batch, seq_len_q, d_v]
35 attention_weights: (optional) [batch, seq_len_q, seq_len_k]
36 """37 output, attention_weights = scaled_dot_product_attention(38 query, key, value, mask, self.dropout
39)4041if return_attention:42return output, attention_weights
43return output
444546# Example usage47attention_layer = ScaledDotProductAttention(dropout=0.1)48attention_layer.train()# Enable dropout4950Q = torch.randn(2,10,64)51K = torch.randn(2,20,64)52V = torch.randn(2,20,64)5354output = attention_layer(Q, K, V)55print(f"Output shape: {output.shape}")# [2, 10, 64]
4.10 Alternative: Using einsum
For those who prefer Einstein notation, here's attention implemented with torch.einsum. This notation is often more readable for complex tensor operations.
Einsum becomes even more powerful with multi-head attention:
Multi-Head Attention with einsum
πattention_einsum.py
Explanation(0)
Code(40)
40 lines without explanation
1defmultihead_attention_einsum(2 query: torch.Tensor,# [batch, seq_q, num_heads, d_k]3 key: torch.Tensor,# [batch, seq_k, num_heads, d_k]4 value: torch.Tensor,# [batch, seq_k, num_heads, d_v]5 mask: Optional[torch.Tensor]=None6)-> torch.Tensor:7"""
8 Multi-head attention using einsum.
910 This version keeps heads as a separate dimension throughout,
11 making the computation more explicit.
12 """13 d_k = query.size(-1)1415# QK^T for all heads at once16# 'bqhd,bkhd->bhqk': batch, query_seq, heads, d_k Γ batch, key_seq, heads, d_k17# β batch, heads, query_seq, key_seq18 scores = torch.einsum('bqhd,bkhd->bhqk', query, key)/ math.sqrt(d_k)1920if mask isnotNone:21 scores = scores.masked_fill(mask ==0,float('-inf'))2223 weights = F.softmax(scores, dim=-1)2425# Weighted sum across heads26# 'bhqk,bkhd->bqhd': batch, heads, query_seq, key_seq Γ batch, key_seq, heads, d_v27# β batch, query_seq, heads, d_v28 output = torch.einsum('bhqk,bkhd->bqhd', weights, value)2930return output
313233# Example with 8 heads34batch, seq_q, seq_k, num_heads, d_k =2,10,20,8,6435Q = torch.randn(batch, seq_q, num_heads, d_k)36K = torch.randn(batch, seq_k, num_heads, d_k)37V = torch.randn(batch, seq_k, num_heads, d_k)3839output = multihead_attention_einsum(Q, K, V)40print(f"Multi-head output shape: {output.shape}")# [2, 10, 8, 64]
π‘ When to Use einsum vs matmul
Use einsum when: β’ Operations involve 4+ dimensions β’ You want to make dimension contractions explicit β’ Reading the code should clarify what dimensions are being combined
Use matmul when: β’ Simple 2D or 3D matrix multiplications β’ Performance is critical (matmul can be slightly faster) β’ Code follows standard linear algebra conventions
4.11 Comparison with PyTorch Built-in
PyTorch 2.0+ includes an optimized F.scaled_dot_product_attention function. Let's validate our implementation against it and understand the differences.
Validation Against Official Implementation
Validation Against PyTorch Built-in
πvalidate_pytorch.py
Explanation(0)
Code(46)
46 lines without explanation
1import torch
2import torch.nn.functional as F
34defvalidate_against_pytorch_builtin():5"""
6 Compare our implementation with PyTorch's built-in attention.
78 Available since PyTorch 2.0, the built-in function uses
9 optimized kernels (Flash Attention, Memory Efficient Attention, etc.)
10 """11# Create test inputs12 batch, seq_q, seq_k, d_model =4,32,64,1281314 Q = torch.randn(batch, seq_q, d_model)15 K = torch.randn(batch, seq_k, d_model)16 V = torch.randn(batch, seq_k, d_model)1718# Our implementation19 output_ours, weights_ours = scaled_dot_product_attention(Q, K, V)2021# PyTorch built-in (note: different shape convention for some backends)22# The builtin expects [batch, num_heads, seq, d_k] for multi-head23# For single-head, we can use [batch, 1, seq, d]24 Q_pt = Q.unsqueeze(1)# [batch, 1, seq_q, d_model]25 K_pt = K.unsqueeze(1)26 V_pt = V.unsqueeze(1)2728 output_builtin = F.scaled_dot_product_attention(Q_pt, K_pt, V_pt)29 output_builtin = output_builtin.squeeze(1)# Remove head dimension3031# Compare32match= torch.allclose(output_ours, output_builtin, atol=1e-5)33 max_diff =(output_ours - output_builtin).abs().max()3435print(f"β Outputs match: {match}")36print(f" Max difference: {max_diff:.2e}")3738returnmatch394041# Run validation42try:43 validate_against_pytorch_builtin()44except AttributeError:45print("F.scaled_dot_product_attention requires PyTorch 2.0+")46print(f"Your PyTorch version: {torch.__version__}")
PyTorch Backend Selection
PyTorch's implementation automatically selects the best backend:
PyTorch Backend Selection
πbackends.py
Explanation(0)
Code(19)
19 lines without explanation
1# Check available backends (PyTorch 2.0+)2try:3from torch.backends.cuda import(4 flash_sdp_enabled,5 math_sdp_enabled,6 mem_efficient_sdp_enabled
7)89print("Available SDPA backends:")10print(f" Flash Attention: {flash_sdp_enabled()}")11print(f" Memory Efficient: {mem_efficient_sdp_enabled()}")12print(f" Math fallback: {math_sdp_enabled()}")1314# You can disable specific backends for debugging15# torch.backends.cuda.enable_flash_sdp(False)16# torch.backends.cuda.enable_mem_efficient_sdp(False)1718except ImportError:19print("Backend checking requires PyTorch 2.0+")
Backend
When Used
Key Benefit
Flash Attention
CUDA, specific GPU architectures
O(N) memory, fastest
Memory Efficient
CUDA, broader GPU support
O(βN) memory
Math (fallback)
CPU or unsupported GPU
Always works, standard O(NΒ²)
When to Use Built-in vs Custom
π― Recommendation
Use PyTorch built-in for: β’ Production code where performance matters β’ Standard attention without custom modifications β’ Automatic optimization selection
Use custom implementation for: β’ Learning and understanding attention β’ Debugging attention patterns (need access to weights) β’ Custom masking or attention modifications β’ Research experiments with novel attention variants
Our implementation stores the full attention matrix in memory. For long sequences, this becomes prohibitive. Modern architectures use optimized algorithms that avoid this limitation.
Flash Attention (Dao et al., 2022) computes attention in tiles, never materializing the full attention matrix:
Flash Attention Conceptual Overview
πflash_attention.py
Explanation(0)
Code(83)
83 lines without explanation
1"""
2Flash Attention Conceptual Overview
3====================================
45Key insight: We don't need the full NxN attention matrix at once.
6We can compute attention in tiles/blocks and accumulate results.
78Standard Attention:
91. Compute full QK^T -> O(NΒ²) memory
102. Apply softmax -> O(NΒ²) memory
113. Multiply by V -> O(N) memory
1213Flash Attention:
141. For each block of Q:
15 a. Load Q block to SRAM (fast memory)
16 b. For each block of K, V:
17 - Compute partial QK^T
18 - Update running softmax (online softmax algorithm)
19 - Accumulate output
20 c. Write output block to HBM (slow memory)
2122Result: O(N) memory instead of O(NΒ²)!
23"""2425# Simplified tile-based attention (educational, not optimized)26deftiled_attention_demo(Q, K, V, tile_size=256):27"""
28 Demonstrate tiled attention concept.
2930 Note: This is for understanding only. Real Flash Attention
31 uses CUDA kernels for GPU-optimized memory access patterns.
32 """33 batch, seq_q, d_k = Q.shape
34 seq_k = K.size(1)35 d_v = V.size(-1)3637 output = torch.zeros(batch, seq_q, d_v, device=Q.device)3839# Process queries in tiles40for q_start inrange(0, seq_q, tile_size):41 q_end =min(q_start + tile_size, seq_q)42 Q_tile = Q[:, q_start:q_end]# [batch, tile, d_k]4344# For each query tile, accumulate attention over all key tiles45 tile_output = torch.zeros(batch, q_end - q_start, d_v, device=Q.device)46 running_max = torch.full((batch, q_end - q_start,1),float('-inf'), device=Q.device)47 running_sum = torch.zeros(batch, q_end - q_start,1, device=Q.device)4849for k_start inrange(0, seq_k, tile_size):50 k_end =min(k_start + tile_size, seq_k)51 K_tile = K[:, k_start:k_end]52 V_tile = V[:, k_start:k_end]5354# Compute attention scores for this tile55 scores = torch.matmul(Q_tile, K_tile.transpose(-2,-1))/ math.sqrt(d_k)5657# Online softmax: update running max58 tile_max = scores.max(dim=-1, keepdim=True).values
59 new_max = torch.maximum(running_max, tile_max)6061# Rescale previous accumulations and add new62 exp_scores = torch.exp(scores - new_max)63 scale = torch.exp(running_max - new_max)6465 running_sum = running_sum * scale + exp_scores.sum(dim=-1, keepdim=True)66 tile_output = tile_output * scale + torch.matmul(exp_scores, V_tile)67 running_max = new_max
6869# Normalize by total sum70 output[:, q_start:q_end]= tile_output / running_sum
7172return output
737475# Compare outputs76Q = torch.randn(2,512,64)77K = torch.randn(2,512,64)78V = torch.randn(2,512,64)7980output_standard, _ = scaled_dot_product_attention(Q, K, V)81output_tiled = tiled_attention_demo(Q, K, V, tile_size=128)8283print(f"Outputs match: {torch.allclose(output_standard, output_tiled, atol=1e-5)}")
Algorithm
Memory Complexity
Compute Complexity
When to Use
Standard Attention
O(NΒ²)
O(NΒ²d)
Short sequences (<512)
Flash Attention
O(N)
O(NΒ²d)
Long sequences, GPU
Linear Attention
O(N)
O(NdΒ²)
Very long sequences
Sparse Attention
O(NβN)
O(NβN d)
Extremely long sequences
β‘ Using Flash Attention in Practice
PyTorch 2.0+ automatically uses Flash Attention when available:
F.scaled_dot_product_attention(Q, K, V)
For explicit control, use the xformers library or the official Flash Attention package.
4.14 Memory-Efficient Chunked Attention
When Flash Attention isn't available (e.g., CPU, unsupported GPU), you can still reduce peak memory by processing queries in chunks.
Chunked Attention
πchunked_attention.py
Explanation(0)
Code(87)
87 lines without explanation
1defchunked_attention(2 query: torch.Tensor,3 key: torch.Tensor,4 value: torch.Tensor,5 chunk_size:int=512,6 mask: Optional[torch.Tensor]=None7)-> Tuple[torch.Tensor, torch.Tensor]:8"""
9 Memory-efficient attention by processing queries in chunks.
1011 Instead of computing [batch, seq_q, seq_k] all at once,
12 we compute [batch, chunk, seq_k] and concatenate.
1314 Memory savings: O(seq_q) -> O(chunk_size)
1516 Args:
17 query: [batch, seq_q, d_k]
18 key: [batch, seq_k, d_k]
19 value: [batch, seq_k, d_v]
20 chunk_size: Number of query positions to process at once
21 mask: Optional mask
22 """23 batch, seq_q, d_k = query.shape
24 d_v = value.size(-1)2526 outputs =[]27 weights_list =[]2829for i inrange(0, seq_q, chunk_size):30# Extract query chunk31 q_chunk = query[:, i:i+chunk_size]# [batch, chunk, d_k]3233# Extract corresponding mask chunk if provided34 mask_chunk =None35if mask isnotNone:36if mask.dim()==3:# [batch, seq_q, seq_k]37 mask_chunk = mask[:, i:i+chunk_size]38else:# [batch, 1, seq_k] or similar39 mask_chunk = mask
4041# Compute attention for this chunk42# Note: K and V are always full - each query needs all keys43 out_chunk, w_chunk = scaled_dot_product_attention(44 q_chunk, key, value, mask_chunk
45)4647 outputs.append(out_chunk)48 weights_list.append(w_chunk)4950# Concatenate chunks51 output = torch.cat(outputs, dim=1)52 weights = torch.cat(weights_list, dim=1)5354return output, weights
555657# Compare memory usage58defcompare_memory_usage(seq_len, d_model, chunk_size=256):59"""Compare peak memory between standard and chunked attention."""60import gc
6162# Standard attention peak memory63 Q = torch.randn(1, seq_len, d_model, device=device)64 K = torch.randn(1, seq_len, d_model, device=device)65 V = torch.randn(1, seq_len, d_model, device=device)6667if device.type=='cuda':68 torch.cuda.reset_peak_memory_stats()69 output_std, _ = scaled_dot_product_attention(Q, K, V)70 peak_std = torch.cuda.max_memory_allocated()/1024**27172 torch.cuda.reset_peak_memory_stats()73 output_chunked, _ = chunked_attention(Q, K, V, chunk_size=chunk_size)74 peak_chunked = torch.cuda.max_memory_allocated()/1024**27576print(f"Sequence length: {seq_len}")77print(f" Standard attention peak memory: {peak_std:.1f} MB")78print(f" Chunked attention peak memory: {peak_chunked:.1f} MB")79print(f" Memory savings: {(1- peak_chunked/peak_std)*100:.1f}%")80print(f" Outputs match: {torch.allclose(output_std, output_chunked, atol=1e-5)}")81else:82print("Memory tracking requires CUDA device")838485# Run comparison (if GPU available)86if device.type=='cuda':87 compare_memory_usage(2048,512, chunk_size=256)
Gradient Checkpointing
For training, combine chunking with gradient checkpointing to save even more memory:
Memory Efficient Attention with Gradient Checkpointing
πmemory_efficient_attention.py
Explanation(0)
Code(64)
64 lines without explanation
1from torch.utils.checkpoint import checkpoint
23classMemoryEfficientAttention(nn.Module):4"""
5 Attention with gradient checkpointing for reduced memory during training.
67 Gradient checkpointing trades compute for memory:
8 - Forward: Don't store intermediate activations
9 - Backward: Recompute intermediates as needed
1011 Memory savings: ~50% during training
12 Compute cost: ~25% increase (one extra forward pass)
13 """1415def__init__(self, d_model:int, dropout:float=0.0):16super().__init__()17 self.d_model = d_model
18 self.dropout = nn.Dropout(dropout)if dropout >0elseNone19 self.use_checkpointing =True# Enable during training2021def_attention_forward(self, Q, K, V, mask):22"""Inner function that will be checkpointed."""23 d_k = Q.size(-1)24 scores = torch.matmul(Q, K.transpose(-2,-1))/ math.sqrt(d_k)2526if mask isnotNone:27 scores = scores.masked_fill(mask ==0,float('-inf'))2829 weights = F.softmax(scores, dim=-1)30 weights = torch.nan_to_num(weights, nan=0.0)3132if self.dropout isnotNoneand self.training:33 weights = self.dropout(weights)3435 output = torch.matmul(weights, V)36return output
3738defforward(self, Q, K, V, mask=None):39if self.training and self.use_checkpointing:40# Use gradient checkpointing during training41return checkpoint(42 self._attention_forward,43 Q, K, V, mask,44 use_reentrant=False# Recommended for newer PyTorch45)46else:47# Standard forward during inference48return self._attention_forward(Q, K, V, mask)495051# Example usage52attention = MemoryEfficientAttention(d_model=512, dropout=0.1)53attention.train()5455Q = torch.randn(4,1024,512, requires_grad=True)56K = torch.randn(4,1024,512, requires_grad=True)57V = torch.randn(4,1024,512, requires_grad=True)5859output = attention(Q, K, V)60loss = output.sum()61loss.backward()# Gradients computed with checkpointing6263print(f"Output shape: {output.shape}")64print(f"Q.grad shape: {Q.grad.shape}")
4.15 Understanding Gradient Flow
Understanding how gradients flow through attention helps with debugging training issues and designing architectures.
Gradient Computation
Gradient Analysis
πgradient_analysis.py
Explanation(0)
Code(49)
49 lines without explanation
1defanalyze_attention_gradients():2"""
3 Analyze gradient flow through attention mechanism.
45 Key insight: Gradients flow through BOTH paths:
6 1. The attention weights (Q, K) - "what to attend to"
7 2. The value combination (V) - "what information to extract"
8 """9# Create inputs with gradient tracking10 Q = torch.randn(1,4,8, requires_grad=True)11 K = torch.randn(1,4,8, requires_grad=True)12 V = torch.randn(1,4,8, requires_grad=True)1314# Forward pass15 output, weights = scaled_dot_product_attention(Q, K, V)1617# Simple loss: sum of outputs18 loss = output.sum()1920# Backward pass21 loss.backward()2223print("Gradient Analysis")24print("="*50)25print(f"\nInput shapes:")26print(f" Q: {Q.shape}")27print(f" K: {K.shape}")28print(f" V: {V.shape}")2930print(f"\nGradient shapes (should match inputs):")31print(f" βL/βQ: {Q.grad.shape}")32print(f" βL/βK: {K.grad.shape}")33print(f" βL/βV: {V.grad.shape}")3435print(f"\nGradient magnitudes:")36print(f" ||βL/βQ||: {Q.grad.norm():.4f}")37print(f" ||βL/βK||: {K.grad.norm():.4f}")38print(f" ||βL/βV||: {V.grad.norm():.4f}")3940# Gradient through attention weights41print(f"\nAttention weights statistics:")42print(f" weights shape: {weights.shape}")43print(f" weights range: [{weights.min():.4f}, {weights.max():.4f}]")44print(f" weights sum (per query): {weights.sum(dim=-1)}")4546return Q.grad, K.grad, V.grad
474849dQ, dK, dV = analyze_attention_gradients()
Gradient Flow Visualization
Gradient Flow Visualization
πgradient_flow.py
Explanation(0)
Code(46)
46 lines without explanation
1defvisualize_gradient_flow():2"""
3 Visualize how gradients flow through attention.
45 This helps understand:
6 - Which input positions receive the most gradient
7 - How mask affects gradient flow
8 - Whether gradients are well-distributed or concentrated
9 """10# Create simple inputs11 Q = torch.randn(1,5,4, requires_grad=True)12 K = torch.randn(1,5,4, requires_grad=True)13 V = torch.randn(1,5,4, requires_grad=True)1415# Compute attention16 output, weights = scaled_dot_product_attention(Q, K, V)1718# Loss on specific output position19# This shows which inputs affect output position 220 loss = output[0,2].sum()# Only output position 221 loss.backward()2223# Gradient magnitude per position24 q_grad_per_pos = Q.grad[0].norm(dim=-1)25 k_grad_per_pos = K.grad[0].norm(dim=-1)26 v_grad_per_pos = V.grad[0].norm(dim=-1)2728print("Gradient flow to output position 2:")29print("="*50)30print(f"\nAttention weights from position 2:")31print(f" {weights[0,2].detach().numpy().round(3)}")3233print(f"\nGradient magnitude per Q position:")34print(f" {q_grad_per_pos.detach().numpy().round(3)}")35print(f" (Only position 2 has significant gradient)")3637print(f"\nGradient magnitude per K position:")38print(f" {k_grad_per_pos.detach().numpy().round(3)}")39print(f" (Gradient proportional to attention from pos 2)")4041print(f"\nGradient magnitude per V position:")42print(f" {v_grad_per_pos.detach().numpy().round(3)}")43print(f" (Gradient proportional to attention weight)")444546visualize_gradient_flow()
Common Gradient Issues
Issue
Symptom
Cause
Fix
Vanishing gradients
Q/K grads near zero
Peaked softmax (unscaled scores)
Ensure scaling by βd_k
NaN gradients
NaN in loss/grads
All positions masked β 0/0
Use nan_to_num after softmax
Gradient explosion
Huge gradient values
Very large scores before softmax
Add gradient clipping
Uneven flow
Some positions never learn
Fixed attention patterns
Add dropout to attention weights
Gradient Correctness Test
πgradient_check.py
Explanation(0)
Code(28)
28 lines without explanation
1# Verify gradients with autograd.gradcheck2deftest_gradient_correctness():3"""
4 Use PyTorch's gradient checker to verify our implementation.
56 gradcheck compares analytical gradients (from autograd)
7 with numerical gradients (from finite differences).
8 """9from torch.autograd import gradcheck
1011defattention_func(Q, K, V):12 output, _ = scaled_dot_product_attention(Q, K, V)13return output
1415# Use double precision for accurate numerical gradients16 Q = torch.randn(1,3,4, dtype=torch.float64, requires_grad=True)17 K = torch.randn(1,3,4, dtype=torch.float64, requires_grad=True)18 V = torch.randn(1,3,4, dtype=torch.float64, requires_grad=True)1920# Check gradients21try:22 result = gradcheck(attention_func,(Q, K, V), eps=1e-6, atol=1e-4)23print(f"β Gradient check passed: {result}")24except Exception as e:25print(f"β Gradient check failed: {e}")262728test_gradient_correctness()
π Key Gradient Insights
1. Gradients flow through two paths: (Q,K) determine where to attend, (V) determines what to extract 2. Softmax creates competition: Increasing one attention weight decreases others 3. Masking stops gradients: Masked positions receive zero gradient 4. Scaling prevents saturation: Without βd_k, softmax saturates β vanishing gradients
4.16 Complete Module Code
Here's the complete, production-ready implementation:
Shape annotations are essential for debugging attention
Visualization
Heatmaps reveal attention patterns for interpretability
Real Tokens
Semantic similarity drives attention weights
Common Pitfalls
Most bugs are shape mismatches or wrong softmax dimension
einsum Alternative
Einstein notation clarifies dimension contractions
PyTorch Built-in
Use F.scaled_dot_product_attention for production
Flash Attention
O(N) memory via tiled computation
Gradient Flow
Gradients flow through Q/K (what to attend) and V (what to extract)
Exercises
Implementation Exercises
Modify the attention function to return the raw scores (before softmax) as well. When would this be useful?
Implement a version of attention that supports different key and value dimensions (d_k β d_v). What changes?
Create a "relative position attention" variant that adds position-based biases to the attention scores.
Visualization Exercises
Visualize attention patterns for the sentence "The bank of the river was steep" using BERT embeddings. Can you spot word sense disambiguation?
Create an animated visualization showing how attention patterns change as you train a simple model.
Implement a function that highlights the most-attended tokens in the original text given attention weights.
Performance Exercises
Benchmark chunked attention vs standard attention for sequence lengths 1024, 2048, 4096. At what point does chunking become beneficial?
Implement attention with torch.compile() (PyTorch 2.0+) and measure the speedup.
Profile memory usage during training with and without gradient checkpointing. What's the memory vs compute tradeoff?
Debugging Exercises
Intentionally introduce a bug (wrong softmax dim, missing scaling, etc.) and use the debug function to identify it.
Create a test case where attention produces NaN and fix it. What are all the possible causes?
Verify your implementation matches PyTorch's built-in to within 1e-5 for 10 different random inputs.
Next Section Preview
In the next section, we'll implement masking in detailβboth padding masks for variable-length sequences and causal masks for autoregressive models. Understanding masking is crucial for building practical transformer models.