Introduction
Being able to visualize and debug attention is crucial for:
- Understanding what your model has learned
- Debugging when something goes wrong
- Interpreting model predictions
- Building trust in model behavior
In this section, we'll build tools to visualize attention patterns and learn to diagnose common bugs.
6.1 Extracting Attention Weights
Modifying Attention to Return Weights
Our attention function already returns weights, but let's make sure we can capture them easily:
Extracting Attention Weights
🐍attention.py
Explanation(0)
Code(39)
39 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import matplotlib.pyplot as plt
5import numpy as np
6from typing import Optional, Tuple, List
7import math
8
9
10def scaled_dot_product_attention_with_weights(
11 query: torch.Tensor,
12 key: torch.Tensor,
13 value: torch.Tensor,
14 mask: Optional[torch.Tensor] = None
15) -> Tuple[torch.Tensor, torch.Tensor]:
16 """
17 Attention that always returns weights for visualization.
18 """
19 d_k = query.size(-1)
20 scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
21
22 if mask is not None:
23 scores = scores.masked_fill(mask == 0, float('-inf'))
24
25 attention_weights = F.softmax(scores, dim=-1)
26 attention_weights = torch.nan_to_num(attention_weights, nan=0.0)
27
28 output = torch.matmul(attention_weights, value)
29
30 return output, attention_weights
31
32
33# Example: Get attention weights
34Q = torch.randn(1, 5, 8)
35K = torch.randn(1, 5, 8)
36V = torch.randn(1, 5, 8)
37
38output, weights = scaled_dot_product_attention_with_weights(Q, K, V)
39print(f"Attention weights shape: {weights.shape}") # [1, 5, 5]6.2 Basic Attention Heatmap
Simple Heatmap Visualization
Simple Heatmap Visualization
🐍visualize.py
Explanation(0)
Code(86)
86 lines without explanation
1def plot_attention_heatmap(
2 attention_weights: torch.Tensor,
3 x_labels: Optional[List[str]] = None,
4 y_labels: Optional[List[str]] = None,
5 title: str = "Attention Weights",
6 figsize: Tuple[int, int] = (8, 6),
7 cmap: str = "Blues",
8 save_path: Optional[str] = None
9) -> None:
10 """
11 Plot a single attention head as a heatmap.
12
13 Args:
14 attention_weights: [seq_len_q, seq_len_k] attention matrix
15 x_labels: Labels for keys (columns)
16 y_labels: Labels for queries (rows)
17 title: Plot title
18 figsize: Figure size
19 cmap: Colormap name
20 save_path: Path to save figure (optional)
21 """
22 # Convert to numpy if tensor
23 if isinstance(attention_weights, torch.Tensor):
24 attention_weights = attention_weights.detach().cpu().numpy()
25
26 seq_len_q, seq_len_k = attention_weights.shape
27
28 # Default labels
29 if x_labels is None:
30 x_labels = [f"K{i}" for i in range(seq_len_k)]
31 if y_labels is None:
32 y_labels = [f"Q{i}" for i in range(seq_len_q)]
33
34 # Create figure
35 fig, ax = plt.subplots(figsize=figsize)
36
37 # Plot heatmap
38 im = ax.imshow(attention_weights, cmap=cmap, aspect='auto', vmin=0, vmax=1)
39
40 # Add colorbar
41 cbar = ax.figure.colorbar(im, ax=ax)
42 cbar.set_label("Attention Weight", rotation=270, labelpad=15)
43
44 # Set ticks and labels
45 ax.set_xticks(range(seq_len_k))
46 ax.set_yticks(range(seq_len_q))
47 ax.set_xticklabels(x_labels, rotation=45, ha='right')
48 ax.set_yticklabels(y_labels)
49
50 # Labels
51 ax.set_xlabel("Key Position")
52 ax.set_ylabel("Query Position")
53 ax.set_title(title)
54
55 # Add text annotations
56 for i in range(seq_len_q):
57 for j in range(seq_len_k):
58 value = attention_weights[i, j]
59 color = "white" if value > 0.5 else "black"
60 ax.text(j, i, f"{value:.2f}", ha="center", va="center",
61 color=color, fontsize=8)
62
63 plt.tight_layout()
64
65 if save_path:
66 plt.savefig(save_path, dpi=150, bbox_inches='tight')
67 print(f"Saved to {save_path}")
68
69 plt.close()
70
71
72# Example with real tokens
73tokens = ["The", "cat", "sat", "on", "mat"]
74Q = torch.randn(1, 5, 8)
75K = torch.randn(1, 5, 8)
76V = torch.randn(1, 5, 8)
77
78_, weights = scaled_dot_product_attention_with_weights(Q, K, V)
79
80plot_attention_heatmap(
81 weights[0], # Remove batch dimension
82 x_labels=tokens,
83 y_labels=tokens,
84 title="Self-Attention: The cat sat on mat",
85 save_path="attention_heatmap.png"
86)6.3 Multi-Head Attention Visualization
Visualizing Multiple Heads
Visualizing Multiple Heads
🐍visualize.py
Explanation(0)
Code(91)
91 lines without explanation
1def plot_multi_head_attention(
2 attention_weights: torch.Tensor,
3 tokens: List[str],
4 title: str = "Multi-Head Attention",
5 figsize: Tuple[int, int] = (16, 4),
6 save_path: Optional[str] = None
7) -> None:
8 """
9 Plot attention weights for multiple heads in a single row.
10
11 Args:
12 attention_weights: [num_heads, seq_len, seq_len]
13 tokens: List of token strings
14 title: Overall title
15 figsize: Figure size
16 save_path: Path to save figure
17 """
18 if isinstance(attention_weights, torch.Tensor):
19 attention_weights = attention_weights.detach().cpu().numpy()
20
21 num_heads = attention_weights.shape[0]
22 seq_len = attention_weights.shape[1]
23
24 # Create subplots
25 fig, axes = plt.subplots(1, num_heads, figsize=figsize)
26
27 if num_heads == 1:
28 axes = [axes]
29
30 for head_idx, ax in enumerate(axes):
31 weights = attention_weights[head_idx]
32
33 im = ax.imshow(weights, cmap='Blues', vmin=0, vmax=1)
34
35 ax.set_xticks(range(seq_len))
36 ax.set_yticks(range(seq_len))
37 ax.set_xticklabels(tokens, rotation=45, ha='right', fontsize=8)
38 ax.set_yticklabels(tokens, fontsize=8)
39 ax.set_title(f"Head {head_idx + 1}", fontsize=10)
40
41 # Add colorbar
42 fig.colorbar(im, ax=axes, shrink=0.6, label="Attention Weight")
43
44 fig.suptitle(title, fontsize=12, fontweight='bold')
45 plt.tight_layout()
46
47 if save_path:
48 plt.savefig(save_path, dpi=150, bbox_inches='tight')
49 print(f"Saved to {save_path}")
50
51 plt.close()
52
53
54# Example: Simulate 4-head attention
55num_heads = 4
56seq_len = 6
57tokens = ["The", "quick", "brown", "fox", "jumps", "over"]
58
59# Simulate different attention patterns for each head
60multi_head_weights = torch.zeros(num_heads, seq_len, seq_len)
61
62# Head 1: Diagonal (self-attention)
63multi_head_weights[0] = torch.eye(seq_len) * 0.7 + 0.3 / seq_len
64
65# Head 2: Previous token attention
66for i in range(seq_len):
67 if i > 0:
68 multi_head_weights[1, i, i-1] = 0.6
69 multi_head_weights[1, i, i] = 0.4
70
71# Head 3: Attention to first token
72multi_head_weights[2, :, 0] = 0.7
73for i in range(seq_len):
74 multi_head_weights[2, i, i] = 0.3
75
76# Head 4: Attention to nouns
77multi_head_weights[3, :, 3] = 0.5 # "fox"
78for i in range(seq_len):
79 multi_head_weights[3, i, i] = 0.3
80multi_head_weights[3] = F.softmax(multi_head_weights[3] * 5, dim=-1)
81
82# Normalize all heads
83for i in range(num_heads):
84 multi_head_weights[i] = F.softmax(multi_head_weights[i] * 3, dim=-1)
85
86plot_multi_head_attention(
87 multi_head_weights,
88 tokens,
89 title="Different Attention Patterns Across Heads",
90 save_path="multi_head_attention.png"
91)6.4 Attention Pattern Analysis
Identifying Common Patterns
Identifying Common Patterns
🐍analysis.py
Explanation(0)
Code(90)
90 lines without explanation
1def analyze_attention_pattern(
2 attention_weights: torch.Tensor,
3 threshold: float = 0.3
4) -> dict:
5 """
6 Analyze attention weights to identify patterns.
7
8 Args:
9 attention_weights: [seq_len, seq_len]
10 threshold: Weight threshold for "significant" attention
11
12 Returns:
13 Dictionary with pattern analysis
14 """
15 if isinstance(attention_weights, torch.Tensor):
16 weights = attention_weights.detach().cpu().numpy()
17 else:
18 weights = attention_weights
19
20 seq_len = weights.shape[0]
21 analysis = {}
22
23 # 1. Diagonal dominance (self-attention)
24 diagonal = np.diag(weights)
25 analysis['diagonal_mean'] = diagonal.mean()
26 analysis['is_diagonal_dominant'] = analysis['diagonal_mean'] > threshold
27
28 # 2. First token attention (like [CLS] in BERT)
29 first_col = weights[:, 0]
30 analysis['first_token_mean'] = first_col.mean()
31 analysis['attends_to_first'] = analysis['first_token_mean'] > threshold
32
33 # 3. Previous token attention
34 prev_attention = []
35 for i in range(1, seq_len):
36 prev_attention.append(weights[i, i-1])
37 analysis['previous_token_mean'] = np.mean(prev_attention) if prev_attention else 0
38 analysis['is_previous_dominant'] = analysis['previous_token_mean'] > threshold
39
40 # 4. Uniform attention (all weights similar)
41 analysis['weight_std'] = weights.std()
42 analysis['is_uniform'] = analysis['weight_std'] < 0.1
43
44 # 5. Sparse attention (few high weights)
45 above_threshold = (weights > threshold).sum()
46 analysis['sparsity'] = 1 - (above_threshold / weights.size)
47
48 # 6. Entropy (uncertainty)
49 # Avoid log(0) by adding small epsilon
50 eps = 1e-10
51 entropy = -np.sum(weights * np.log(weights + eps), axis=-1)
52 analysis['mean_entropy'] = entropy.mean()
53 analysis['max_entropy'] = np.log(seq_len) # Uniform distribution entropy
54
55 return analysis
56
57
58def print_attention_analysis(analysis: dict) -> None:
59 """Pretty print attention analysis."""
60 print("=" * 50)
61 print("ATTENTION PATTERN ANALYSIS")
62 print("=" * 50)
63
64 print(f"\n📊 Diagonal (Self) Attention:")
65 print(f" Mean diagonal weight: {analysis['diagonal_mean']:.3f}")
66 print(f" Dominant: {'✓ Yes' if analysis['is_diagonal_dominant'] else '✗ No'}")
67
68 print(f"\n📊 First Token Attention:")
69 print(f" Mean attention to first: {analysis['first_token_mean']:.3f}")
70 print(f" Attends to first: {'✓ Yes' if analysis['attends_to_first'] else '✗ No'}")
71
72 print(f"\n📊 Previous Token Attention:")
73 print(f" Mean previous attention: {analysis['previous_token_mean']:.3f}")
74 print(f" Previous dominant: {'✓ Yes' if analysis['is_previous_dominant'] else '✗ No'}")
75
76 print(f"\n📊 Distribution Properties:")
77 print(f" Weight std: {analysis['weight_std']:.3f}")
78 print(f" Is uniform: {'✓ Yes' if analysis['is_uniform'] else '✗ No'}")
79 print(f" Sparsity: {analysis['sparsity']:.1%}")
80 print(f" Mean entropy: {analysis['mean_entropy']:.3f} / {analysis['max_entropy']:.3f}")
81
82 print("=" * 50)
83
84
85# Example
86sample_weights = torch.eye(5) * 0.6 + 0.4 / 5 # Diagonal-dominant
87sample_weights = F.softmax(sample_weights * 3, dim=-1)
88
89analysis = analyze_attention_pattern(sample_weights)
90print_attention_analysis(analysis)6.5 Debugging Common Issues
Issue 1: Uniform Attention
Symptom: All attention weights are approximately equal (1/n).
Issue 1: Uniform Attention
🐍debug.py
Explanation(0)
Code(34)
34 lines without explanation
1def debug_uniform_attention(Q: torch.Tensor, K: torch.Tensor) -> None:
2 """Debug why attention might be uniform."""
3 print("🔍 Debugging Uniform Attention")
4 print("-" * 40)
5
6 # Check Q and K statistics
7 print(f"Q mean: {Q.mean():.4f}, std: {Q.std():.4f}")
8 print(f"K mean: {K.mean():.4f}, std: {K.std():.4f}")
9
10 # Check raw scores
11 scores = torch.matmul(Q, K.transpose(-2, -1))
12 print(f"Raw scores mean: {scores.mean():.4f}, std: {scores.std():.4f}")
13
14 # Check scaled scores
15 d_k = K.size(-1)
16 scaled_scores = scores / math.sqrt(d_k)
17 print(f"Scaled scores mean: {scaled_scores.mean():.4f}, std: {scaled_scores.std():.4f}")
18
19 # Potential issues
20 if Q.std() < 0.1 or K.std() < 0.1:
21 print("⚠️ ISSUE: Q or K has very low variance - check initialization")
22
23 if scaled_scores.std() < 0.1:
24 print("⚠️ ISSUE: Scaled scores have low variance - attention will be uniform")
25
26 if d_k > 512:
27 print(f"⚠️ NOTE: Large d_k ({d_k}) - make sure scaling is applied")
28
29
30# Example: Debugging
31Q_bad = torch.ones(1, 5, 64) * 0.1 # Poor initialization
32K_bad = torch.ones(1, 5, 64) * 0.1
33
34debug_uniform_attention(Q_bad, K_bad)Issue 2: Attention Collapse (All Weight on One Position)
Symptom: One position gets ~100% attention, others get ~0%.
Issue 2: Attention Collapse
🐍debug.py
Explanation(0)
Code(38)
38 lines without explanation
1def debug_attention_collapse(
2 attention_weights: torch.Tensor,
3 threshold: float = 0.95
4) -> None:
5 """Debug attention collapse issue."""
6 print("🔍 Debugging Attention Collapse")
7 print("-" * 40)
8
9 weights = attention_weights.detach()
10 max_weights = weights.max(dim=-1).values
11
12 collapsed_rows = (max_weights > threshold).sum().item()
13 total_rows = weights.shape[-2]
14
15 print(f"Rows with max weight > {threshold}: {collapsed_rows}/{total_rows}")
16
17 if collapsed_rows > total_rows * 0.5:
18 print("⚠️ ISSUE: More than 50% of queries have collapsed attention")
19 print(" Possible causes:")
20 print(" - Scores not being scaled (missing /sqrt(d_k))")
21 print(" - Learning rate too high")
22 print(" - Embedding not properly normalized")
23
24 # Check which positions are receiving all attention
25 dominant_positions = weights.argmax(dim=-1)
26 position_counts = torch.bincount(dominant_positions.flatten())
27
28 print(f"\nDominant position distribution:")
29 for pos, count in enumerate(position_counts):
30 if count > 0:
31 print(f" Position {pos}: {count} queries ({count/total_rows:.1%})")
32
33
34# Example
35collapsed_weights = torch.zeros(1, 5, 5)
36collapsed_weights[:, :, 2] = 1.0 # All attention on position 2
37
38debug_attention_collapse(collapsed_weights)Issue 3: NaN in Attention
Symptom: Attention weights or output contains NaN.
Issue 3: NaN in Attention
🐍debug.py
Explanation(0)
Code(53)
53 lines without explanation
1def debug_nan_attention(
2 Q: torch.Tensor,
3 K: torch.Tensor,
4 V: torch.Tensor,
5 mask: Optional[torch.Tensor] = None
6) -> None:
7 """Step through attention to find NaN source."""
8 print("🔍 Debugging NaN in Attention")
9 print("-" * 40)
10
11 # Check inputs
12 print(f"Q has NaN: {torch.isnan(Q).any()}")
13 print(f"K has NaN: {torch.isnan(K).any()}")
14 print(f"V has NaN: {torch.isnan(V).any()}")
15
16 # Step 1: Scores
17 d_k = K.size(-1)
18 scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
19 print(f"Scores has NaN: {torch.isnan(scores).any()}")
20 print(f"Scores has Inf: {torch.isinf(scores).any()}")
21
22 # Step 2: After masking
23 if mask is not None:
24 scores_masked = scores.masked_fill(mask == 0, float('-inf'))
25 all_masked_rows = (mask.sum(dim=-1) == 0).any()
26 print(f"Any rows fully masked: {all_masked_rows}")
27 if all_masked_rows:
28 print("⚠️ ISSUE: Some rows have all positions masked!")
29 print(" This causes softmax to produce NaN (exp(-inf)/exp(-inf))")
30 else:
31 scores_masked = scores
32
33 # Step 3: After softmax
34 weights = F.softmax(scores_masked, dim=-1)
35 print(f"Weights has NaN: {torch.isnan(weights).any()}")
36
37 # Step 4: Output
38 output = torch.matmul(weights, V)
39 print(f"Output has NaN: {torch.isnan(output).any()}")
40
41 if torch.isnan(weights).any():
42 print("\n⚠️ FIX: Add torch.nan_to_num(weights, nan=0.0) after softmax")
43
44
45# Example: Trigger NaN with all-masked row
46Q = torch.randn(1, 3, 8)
47K = torch.randn(1, 5, 8)
48V = torch.randn(1, 5, 8)
49bad_mask = torch.tensor([[[1, 1, 0, 0, 0],
50 [0, 0, 0, 0, 0], # All masked!
51 [1, 1, 1, 0, 0]]])
52
53debug_nan_attention(Q, K, V, bad_mask)6.6 Interactive Attention Explorer
Building an Interactive Tool
Building an Interactive Tool
🐍demo.py
Explanation(0)
Code(59)
59 lines without explanation
1def interactive_attention_demo():
2 """
3 Interactive demo showing how inputs affect attention.
4 """
5 print("=" * 60)
6 print("INTERACTIVE ATTENTION DEMONSTRATION")
7 print("=" * 60)
8
9 # Create simple input
10 tokens = ["I", "love", "cats", "and", "dogs"]
11 seq_len = len(tokens)
12 d_model = 4
13
14 # Create embeddings that reflect semantic similarity
15 embeddings = {
16 "I": torch.tensor([1.0, 0.0, 0.0, 0.0]),
17 "love": torch.tensor([0.0, 1.0, 0.0, 0.0]),
18 "cats": torch.tensor([0.0, 0.0, 1.0, 0.5]), # Animal
19 "and": torch.tensor([0.0, 0.0, 0.0, 0.0]),
20 "dogs": torch.tensor([0.0, 0.0, 0.8, 0.6]), # Animal (similar to cats)
21 }
22
23 X = torch.stack([embeddings[t] for t in tokens]).unsqueeze(0)
24 Q = K = V = X
25
26 print("\n1. INPUT EMBEDDINGS:")
27 for i, token in enumerate(tokens):
28 print(f" {token:8s}: {X[0, i].tolist()}")
29
30 # Compute attention
31 output, weights = scaled_dot_product_attention_with_weights(Q, K, V)
32
33 print("\n2. ATTENTION WEIGHTS (what each token attends to):")
34 weights_np = weights[0].numpy()
35 header = " " + " ".join([f"{t:8s}" for t in tokens])
36 print(header)
37 for i, token in enumerate(tokens):
38 row = f"{token:8s} " + " ".join([f"{weights_np[i, j]:8.3f}" for j in range(seq_len)])
39 print(row)
40
41 print("\n3. KEY OBSERVATIONS:")
42 # Find interesting patterns
43 for i, token in enumerate(tokens):
44 max_attn_idx = weights_np[i].argmax()
45 max_attn_val = weights_np[i, max_attn_idx]
46 print(f" '{token}' attends most to '{tokens[max_attn_idx]}' ({max_attn_val:.3f})")
47
48 # Note similarity between cats and dogs
49 cats_dogs_attn = weights_np[2, 4] # cats attending to dogs
50 dogs_cats_attn = weights_np[4, 2] # dogs attending to cats
51 print(f"\n 'cats' → 'dogs' attention: {cats_dogs_attn:.3f}")
52 print(f" 'dogs' → 'cats' attention: {dogs_cats_attn:.3f}")
53 print(" (High because they have similar embeddings - both are animals!)")
54
55 return weights
56
57
58# Run demo
59weights = interactive_attention_demo()6.7 Complete Visualization Module
Complete Visualization Module
🐍attention_visualization.py
Explanation(0)
Code(148)
148 lines without explanation
1"""
2attention_visualization.py
3==========================
4Tools for visualizing and debugging attention mechanisms.
5"""
6
7import torch
8import torch.nn.functional as F
9import matplotlib.pyplot as plt
10import numpy as np
11from typing import Optional, List, Tuple, Dict
12import math
13
14
15class AttentionVisualizer:
16 """Comprehensive attention visualization toolkit."""
17
18 def __init__(self, figsize: Tuple[int, int] = (10, 8)):
19 self.figsize = figsize
20 self.cmap = 'Blues'
21
22 def plot_heatmap(
23 self,
24 weights: torch.Tensor,
25 x_labels: Optional[List[str]] = None,
26 y_labels: Optional[List[str]] = None,
27 title: str = "Attention",
28 ax: Optional[plt.Axes] = None,
29 show_values: bool = True
30 ) -> plt.Axes:
31 """Plot single attention heatmap."""
32 if isinstance(weights, torch.Tensor):
33 weights = weights.detach().cpu().numpy()
34
35 if ax is None:
36 _, ax = plt.subplots(figsize=self.figsize)
37
38 im = ax.imshow(weights, cmap=self.cmap, vmin=0, vmax=1)
39
40 seq_q, seq_k = weights.shape
41 x_labels = x_labels or [str(i) for i in range(seq_k)]
42 y_labels = y_labels or [str(i) for i in range(seq_q)]
43
44 ax.set_xticks(range(seq_k))
45 ax.set_yticks(range(seq_q))
46 ax.set_xticklabels(x_labels, rotation=45, ha='right')
47 ax.set_yticklabels(y_labels)
48 ax.set_title(title)
49
50 if show_values and seq_q <= 10 and seq_k <= 10:
51 for i in range(seq_q):
52 for j in range(seq_k):
53 color = "white" if weights[i, j] > 0.5 else "black"
54 ax.text(j, i, f"{weights[i, j]:.2f}",
55 ha="center", va="center", color=color, fontsize=8)
56
57 return ax
58
59 def plot_multi_head(
60 self,
61 weights: torch.Tensor,
62 tokens: List[str],
63 title: str = "Multi-Head Attention",
64 save_path: Optional[str] = None
65 ) -> None:
66 """Plot all attention heads."""
67 if isinstance(weights, torch.Tensor):
68 weights = weights.detach().cpu().numpy()
69
70 num_heads = weights.shape[0]
71 cols = min(4, num_heads)
72 rows = (num_heads + cols - 1) // cols
73
74 fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
75 axes = np.array(axes).flatten()
76
77 for i in range(num_heads):
78 self.plot_heatmap(
79 weights[i], tokens, tokens,
80 f"Head {i+1}", axes[i], show_values=False
81 )
82
83 # Hide unused subplots
84 for i in range(num_heads, len(axes)):
85 axes[i].axis('off')
86
87 fig.suptitle(title, fontsize=14, fontweight='bold')
88 plt.tight_layout()
89
90 if save_path:
91 plt.savefig(save_path, dpi=150, bbox_inches='tight')
92 plt.close()
93
94 def analyze(self, weights: torch.Tensor) -> Dict:
95 """Analyze attention patterns."""
96 if isinstance(weights, torch.Tensor):
97 w = weights.detach().cpu().numpy()
98 else:
99 w = weights
100
101 return {
102 'diagonal_mean': np.diag(w).mean(),
103 'first_col_mean': w[:, 0].mean(),
104 'entropy': -np.sum(w * np.log(w + 1e-10), axis=-1).mean(),
105 'max_weight_mean': w.max(axis=-1).mean(),
106 'sparsity': (w < 0.1).mean()
107 }
108
109
110# Create global visualizer instance
111visualizer = AttentionVisualizer()
112
113
114def plot_attention(weights, tokens=None, title="Attention", save_path=None):
115 """Convenience function for quick plotting."""
116 if tokens is None:
117 tokens = [str(i) for i in range(weights.shape[-1])]
118
119 fig, ax = plt.subplots(figsize=(8, 6))
120 visualizer.plot_heatmap(weights, tokens, tokens, title, ax)
121 plt.colorbar(ax.images[0], ax=ax, label="Weight")
122 plt.tight_layout()
123
124 if save_path:
125 plt.savefig(save_path, dpi=150)
126 plt.close()
127
128
129if __name__ == "__main__":
130 # Demo
131 torch.manual_seed(42)
132
133 tokens = ["The", "cat", "sat", "on", "mat"]
134 Q = torch.randn(1, 5, 32)
135 K = torch.randn(1, 5, 32)
136 V = torch.randn(1, 5, 32)
137
138 _, weights = scaled_dot_product_attention_with_weights(Q, K, V)
139
140 # Single head visualization
141 plot_attention(weights[0], tokens, "Self-Attention Demo", "demo_attention.png")
142 print("Saved demo_attention.png")
143
144 # Analysis
145 analysis = visualizer.analyze(weights[0])
146 print("\nAttention Analysis:")
147 for key, value in analysis.items():
148 print(f" {key}: {value:.3f}")Summary
Visualization Tools Built
| Tool | Purpose |
|---|---|
| plot_attention_heatmap | Single head visualization |
| plot_multi_head_attention | All heads side by side |
| analyze_attention_pattern | Quantitative pattern analysis |
| AttentionVisualizer | Full-featured visualization class |
Debugging Checklist
When attention isn't working correctly:
- Check if attention is uniform (scores not scaled?)
- Check if attention collapsed (learning rate too high?)
- Check for NaN (all positions masked?)
- Check input embeddings (properly initialized?)
- Visualize attention patterns (do they make sense?)
Key Patterns to Recognize
| Pattern | Appearance | Meaning |
|---|---|---|
| Diagonal | High values on diagonal | Self-attention dominant |
| Column | One column highlighted | One position important |
| Previous | Diagonal shifted left | Sequential processing |
| Uniform | All values ~equal | Not learning relationships |
| Sparse | Few high values | Focused attention |
Exercises
Visualization Exercises
- Create an animated visualization that shows how attention weights change during training.
- Build a visualization that compares attention patterns before and after fine-tuning.
- Implement a "diff" visualization that highlights where two attention matrices differ.
Debugging Exercises
- Intentionally break attention in 5 different ways and use the debugging tools to identify each issue.
- Write unit tests that verify attention weights satisfy expected properties (sum to 1, no NaN, etc.).
- Create a logging callback that records attention statistics during training and plots trends.
Chapter Summary
In this chapter, you've mastered the attention mechanism:
- Intuition: Attention as weighted averaging based on similarity
- Mathematics: The scaled dot-product formula and why each component matters
- Numerical: Hand-computed examples for verification
- Implementation: Clean, tested PyTorch code
- Masking: Padding and causal masks for real-world use
- Visualization: Tools to see and debug attention
You now have a solid foundation for understanding transformers. In the next chapter, we'll extend single-head attention to multi-head attention, enabling the model to learn multiple types of relationships simultaneously.