Chapter 2
15 min read
Section 13 of 75

Attention Visualization and Debugging

Attention Mechanism From Scratch

Introduction

Being able to visualize and debug attention is crucial for:

  1. Understanding what your model has learned
  2. Debugging when something goes wrong
  3. Interpreting model predictions
  4. Building trust in model behavior

In this section, we'll build tools to visualize attention patterns and learn to diagnose common bugs.


6.1 Extracting Attention Weights

Modifying Attention to Return Weights

Our attention function already returns weights, but let's make sure we can capture them easily:

Extracting Attention Weights
🐍attention.py
39 lines without explanation
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import matplotlib.pyplot as plt
5import numpy as np
6from typing import Optional, Tuple, List
7import math
8
9
10def scaled_dot_product_attention_with_weights(
11    query: torch.Tensor,
12    key: torch.Tensor,
13    value: torch.Tensor,
14    mask: Optional[torch.Tensor] = None
15) -> Tuple[torch.Tensor, torch.Tensor]:
16    """
17    Attention that always returns weights for visualization.
18    """
19    d_k = query.size(-1)
20    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
21
22    if mask is not None:
23        scores = scores.masked_fill(mask == 0, float('-inf'))
24
25    attention_weights = F.softmax(scores, dim=-1)
26    attention_weights = torch.nan_to_num(attention_weights, nan=0.0)
27
28    output = torch.matmul(attention_weights, value)
29
30    return output, attention_weights
31
32
33# Example: Get attention weights
34Q = torch.randn(1, 5, 8)
35K = torch.randn(1, 5, 8)
36V = torch.randn(1, 5, 8)
37
38output, weights = scaled_dot_product_attention_with_weights(Q, K, V)
39print(f"Attention weights shape: {weights.shape}")  # [1, 5, 5]

6.2 Basic Attention Heatmap

Simple Heatmap Visualization

Simple Heatmap Visualization
🐍visualize.py
86 lines without explanation
1def plot_attention_heatmap(
2    attention_weights: torch.Tensor,
3    x_labels: Optional[List[str]] = None,
4    y_labels: Optional[List[str]] = None,
5    title: str = "Attention Weights",
6    figsize: Tuple[int, int] = (8, 6),
7    cmap: str = "Blues",
8    save_path: Optional[str] = None
9) -> None:
10    """
11    Plot a single attention head as a heatmap.
12
13    Args:
14        attention_weights: [seq_len_q, seq_len_k] attention matrix
15        x_labels: Labels for keys (columns)
16        y_labels: Labels for queries (rows)
17        title: Plot title
18        figsize: Figure size
19        cmap: Colormap name
20        save_path: Path to save figure (optional)
21    """
22    # Convert to numpy if tensor
23    if isinstance(attention_weights, torch.Tensor):
24        attention_weights = attention_weights.detach().cpu().numpy()
25
26    seq_len_q, seq_len_k = attention_weights.shape
27
28    # Default labels
29    if x_labels is None:
30        x_labels = [f"K{i}" for i in range(seq_len_k)]
31    if y_labels is None:
32        y_labels = [f"Q{i}" for i in range(seq_len_q)]
33
34    # Create figure
35    fig, ax = plt.subplots(figsize=figsize)
36
37    # Plot heatmap
38    im = ax.imshow(attention_weights, cmap=cmap, aspect='auto', vmin=0, vmax=1)
39
40    # Add colorbar
41    cbar = ax.figure.colorbar(im, ax=ax)
42    cbar.set_label("Attention Weight", rotation=270, labelpad=15)
43
44    # Set ticks and labels
45    ax.set_xticks(range(seq_len_k))
46    ax.set_yticks(range(seq_len_q))
47    ax.set_xticklabels(x_labels, rotation=45, ha='right')
48    ax.set_yticklabels(y_labels)
49
50    # Labels
51    ax.set_xlabel("Key Position")
52    ax.set_ylabel("Query Position")
53    ax.set_title(title)
54
55    # Add text annotations
56    for i in range(seq_len_q):
57        for j in range(seq_len_k):
58            value = attention_weights[i, j]
59            color = "white" if value > 0.5 else "black"
60            ax.text(j, i, f"{value:.2f}", ha="center", va="center",
61                   color=color, fontsize=8)
62
63    plt.tight_layout()
64
65    if save_path:
66        plt.savefig(save_path, dpi=150, bbox_inches='tight')
67        print(f"Saved to {save_path}")
68
69    plt.close()
70
71
72# Example with real tokens
73tokens = ["The", "cat", "sat", "on", "mat"]
74Q = torch.randn(1, 5, 8)
75K = torch.randn(1, 5, 8)
76V = torch.randn(1, 5, 8)
77
78_, weights = scaled_dot_product_attention_with_weights(Q, K, V)
79
80plot_attention_heatmap(
81    weights[0],  # Remove batch dimension
82    x_labels=tokens,
83    y_labels=tokens,
84    title="Self-Attention: The cat sat on mat",
85    save_path="attention_heatmap.png"
86)

6.3 Multi-Head Attention Visualization

Visualizing Multiple Heads

Visualizing Multiple Heads
🐍visualize.py
91 lines without explanation
1def plot_multi_head_attention(
2    attention_weights: torch.Tensor,
3    tokens: List[str],
4    title: str = "Multi-Head Attention",
5    figsize: Tuple[int, int] = (16, 4),
6    save_path: Optional[str] = None
7) -> None:
8    """
9    Plot attention weights for multiple heads in a single row.
10
11    Args:
12        attention_weights: [num_heads, seq_len, seq_len]
13        tokens: List of token strings
14        title: Overall title
15        figsize: Figure size
16        save_path: Path to save figure
17    """
18    if isinstance(attention_weights, torch.Tensor):
19        attention_weights = attention_weights.detach().cpu().numpy()
20
21    num_heads = attention_weights.shape[0]
22    seq_len = attention_weights.shape[1]
23
24    # Create subplots
25    fig, axes = plt.subplots(1, num_heads, figsize=figsize)
26
27    if num_heads == 1:
28        axes = [axes]
29
30    for head_idx, ax in enumerate(axes):
31        weights = attention_weights[head_idx]
32
33        im = ax.imshow(weights, cmap='Blues', vmin=0, vmax=1)
34
35        ax.set_xticks(range(seq_len))
36        ax.set_yticks(range(seq_len))
37        ax.set_xticklabels(tokens, rotation=45, ha='right', fontsize=8)
38        ax.set_yticklabels(tokens, fontsize=8)
39        ax.set_title(f"Head {head_idx + 1}", fontsize=10)
40
41    # Add colorbar
42    fig.colorbar(im, ax=axes, shrink=0.6, label="Attention Weight")
43
44    fig.suptitle(title, fontsize=12, fontweight='bold')
45    plt.tight_layout()
46
47    if save_path:
48        plt.savefig(save_path, dpi=150, bbox_inches='tight')
49        print(f"Saved to {save_path}")
50
51    plt.close()
52
53
54# Example: Simulate 4-head attention
55num_heads = 4
56seq_len = 6
57tokens = ["The", "quick", "brown", "fox", "jumps", "over"]
58
59# Simulate different attention patterns for each head
60multi_head_weights = torch.zeros(num_heads, seq_len, seq_len)
61
62# Head 1: Diagonal (self-attention)
63multi_head_weights[0] = torch.eye(seq_len) * 0.7 + 0.3 / seq_len
64
65# Head 2: Previous token attention
66for i in range(seq_len):
67    if i > 0:
68        multi_head_weights[1, i, i-1] = 0.6
69    multi_head_weights[1, i, i] = 0.4
70
71# Head 3: Attention to first token
72multi_head_weights[2, :, 0] = 0.7
73for i in range(seq_len):
74    multi_head_weights[2, i, i] = 0.3
75
76# Head 4: Attention to nouns
77multi_head_weights[3, :, 3] = 0.5  # "fox"
78for i in range(seq_len):
79    multi_head_weights[3, i, i] = 0.3
80multi_head_weights[3] = F.softmax(multi_head_weights[3] * 5, dim=-1)
81
82# Normalize all heads
83for i in range(num_heads):
84    multi_head_weights[i] = F.softmax(multi_head_weights[i] * 3, dim=-1)
85
86plot_multi_head_attention(
87    multi_head_weights,
88    tokens,
89    title="Different Attention Patterns Across Heads",
90    save_path="multi_head_attention.png"
91)

6.4 Attention Pattern Analysis

Identifying Common Patterns

Identifying Common Patterns
🐍analysis.py
90 lines without explanation
1def analyze_attention_pattern(
2    attention_weights: torch.Tensor,
3    threshold: float = 0.3
4) -> dict:
5    """
6    Analyze attention weights to identify patterns.
7
8    Args:
9        attention_weights: [seq_len, seq_len]
10        threshold: Weight threshold for "significant" attention
11
12    Returns:
13        Dictionary with pattern analysis
14    """
15    if isinstance(attention_weights, torch.Tensor):
16        weights = attention_weights.detach().cpu().numpy()
17    else:
18        weights = attention_weights
19
20    seq_len = weights.shape[0]
21    analysis = {}
22
23    # 1. Diagonal dominance (self-attention)
24    diagonal = np.diag(weights)
25    analysis['diagonal_mean'] = diagonal.mean()
26    analysis['is_diagonal_dominant'] = analysis['diagonal_mean'] > threshold
27
28    # 2. First token attention (like [CLS] in BERT)
29    first_col = weights[:, 0]
30    analysis['first_token_mean'] = first_col.mean()
31    analysis['attends_to_first'] = analysis['first_token_mean'] > threshold
32
33    # 3. Previous token attention
34    prev_attention = []
35    for i in range(1, seq_len):
36        prev_attention.append(weights[i, i-1])
37    analysis['previous_token_mean'] = np.mean(prev_attention) if prev_attention else 0
38    analysis['is_previous_dominant'] = analysis['previous_token_mean'] > threshold
39
40    # 4. Uniform attention (all weights similar)
41    analysis['weight_std'] = weights.std()
42    analysis['is_uniform'] = analysis['weight_std'] < 0.1
43
44    # 5. Sparse attention (few high weights)
45    above_threshold = (weights > threshold).sum()
46    analysis['sparsity'] = 1 - (above_threshold / weights.size)
47
48    # 6. Entropy (uncertainty)
49    # Avoid log(0) by adding small epsilon
50    eps = 1e-10
51    entropy = -np.sum(weights * np.log(weights + eps), axis=-1)
52    analysis['mean_entropy'] = entropy.mean()
53    analysis['max_entropy'] = np.log(seq_len)  # Uniform distribution entropy
54
55    return analysis
56
57
58def print_attention_analysis(analysis: dict) -> None:
59    """Pretty print attention analysis."""
60    print("=" * 50)
61    print("ATTENTION PATTERN ANALYSIS")
62    print("=" * 50)
63
64    print(f"\n📊 Diagonal (Self) Attention:")
65    print(f"   Mean diagonal weight: {analysis['diagonal_mean']:.3f}")
66    print(f"   Dominant: {'✓ Yes' if analysis['is_diagonal_dominant'] else '✗ No'}")
67
68    print(f"\n📊 First Token Attention:")
69    print(f"   Mean attention to first: {analysis['first_token_mean']:.3f}")
70    print(f"   Attends to first: {'✓ Yes' if analysis['attends_to_first'] else '✗ No'}")
71
72    print(f"\n📊 Previous Token Attention:")
73    print(f"   Mean previous attention: {analysis['previous_token_mean']:.3f}")
74    print(f"   Previous dominant: {'✓ Yes' if analysis['is_previous_dominant'] else '✗ No'}")
75
76    print(f"\n📊 Distribution Properties:")
77    print(f"   Weight std: {analysis['weight_std']:.3f}")
78    print(f"   Is uniform: {'✓ Yes' if analysis['is_uniform'] else '✗ No'}")
79    print(f"   Sparsity: {analysis['sparsity']:.1%}")
80    print(f"   Mean entropy: {analysis['mean_entropy']:.3f} / {analysis['max_entropy']:.3f}")
81
82    print("=" * 50)
83
84
85# Example
86sample_weights = torch.eye(5) * 0.6 + 0.4 / 5  # Diagonal-dominant
87sample_weights = F.softmax(sample_weights * 3, dim=-1)
88
89analysis = analyze_attention_pattern(sample_weights)
90print_attention_analysis(analysis)

6.5 Debugging Common Issues

Issue 1: Uniform Attention

Symptom: All attention weights are approximately equal (1/n).

Issue 1: Uniform Attention
🐍debug.py
34 lines without explanation
1def debug_uniform_attention(Q: torch.Tensor, K: torch.Tensor) -> None:
2    """Debug why attention might be uniform."""
3    print("🔍 Debugging Uniform Attention")
4    print("-" * 40)
5
6    # Check Q and K statistics
7    print(f"Q mean: {Q.mean():.4f}, std: {Q.std():.4f}")
8    print(f"K mean: {K.mean():.4f}, std: {K.std():.4f}")
9
10    # Check raw scores
11    scores = torch.matmul(Q, K.transpose(-2, -1))
12    print(f"Raw scores mean: {scores.mean():.4f}, std: {scores.std():.4f}")
13
14    # Check scaled scores
15    d_k = K.size(-1)
16    scaled_scores = scores / math.sqrt(d_k)
17    print(f"Scaled scores mean: {scaled_scores.mean():.4f}, std: {scaled_scores.std():.4f}")
18
19    # Potential issues
20    if Q.std() < 0.1 or K.std() < 0.1:
21        print("⚠️ ISSUE: Q or K has very low variance - check initialization")
22
23    if scaled_scores.std() < 0.1:
24        print("⚠️ ISSUE: Scaled scores have low variance - attention will be uniform")
25
26    if d_k > 512:
27        print(f"⚠️ NOTE: Large d_k ({d_k}) - make sure scaling is applied")
28
29
30# Example: Debugging
31Q_bad = torch.ones(1, 5, 64) * 0.1  # Poor initialization
32K_bad = torch.ones(1, 5, 64) * 0.1
33
34debug_uniform_attention(Q_bad, K_bad)

Issue 2: Attention Collapse (All Weight on One Position)

Symptom: One position gets ~100% attention, others get ~0%.

Issue 2: Attention Collapse
🐍debug.py
38 lines without explanation
1def debug_attention_collapse(
2    attention_weights: torch.Tensor,
3    threshold: float = 0.95
4) -> None:
5    """Debug attention collapse issue."""
6    print("🔍 Debugging Attention Collapse")
7    print("-" * 40)
8
9    weights = attention_weights.detach()
10    max_weights = weights.max(dim=-1).values
11
12    collapsed_rows = (max_weights > threshold).sum().item()
13    total_rows = weights.shape[-2]
14
15    print(f"Rows with max weight > {threshold}: {collapsed_rows}/{total_rows}")
16
17    if collapsed_rows > total_rows * 0.5:
18        print("⚠️ ISSUE: More than 50% of queries have collapsed attention")
19        print("   Possible causes:")
20        print("   - Scores not being scaled (missing /sqrt(d_k))")
21        print("   - Learning rate too high")
22        print("   - Embedding not properly normalized")
23
24    # Check which positions are receiving all attention
25    dominant_positions = weights.argmax(dim=-1)
26    position_counts = torch.bincount(dominant_positions.flatten())
27
28    print(f"\nDominant position distribution:")
29    for pos, count in enumerate(position_counts):
30        if count > 0:
31            print(f"   Position {pos}: {count} queries ({count/total_rows:.1%})")
32
33
34# Example
35collapsed_weights = torch.zeros(1, 5, 5)
36collapsed_weights[:, :, 2] = 1.0  # All attention on position 2
37
38debug_attention_collapse(collapsed_weights)

Issue 3: NaN in Attention

Symptom: Attention weights or output contains NaN.

Issue 3: NaN in Attention
🐍debug.py
53 lines without explanation
1def debug_nan_attention(
2    Q: torch.Tensor,
3    K: torch.Tensor,
4    V: torch.Tensor,
5    mask: Optional[torch.Tensor] = None
6) -> None:
7    """Step through attention to find NaN source."""
8    print("🔍 Debugging NaN in Attention")
9    print("-" * 40)
10
11    # Check inputs
12    print(f"Q has NaN: {torch.isnan(Q).any()}")
13    print(f"K has NaN: {torch.isnan(K).any()}")
14    print(f"V has NaN: {torch.isnan(V).any()}")
15
16    # Step 1: Scores
17    d_k = K.size(-1)
18    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
19    print(f"Scores has NaN: {torch.isnan(scores).any()}")
20    print(f"Scores has Inf: {torch.isinf(scores).any()}")
21
22    # Step 2: After masking
23    if mask is not None:
24        scores_masked = scores.masked_fill(mask == 0, float('-inf'))
25        all_masked_rows = (mask.sum(dim=-1) == 0).any()
26        print(f"Any rows fully masked: {all_masked_rows}")
27        if all_masked_rows:
28            print("⚠️ ISSUE: Some rows have all positions masked!")
29            print("   This causes softmax to produce NaN (exp(-inf)/exp(-inf))")
30    else:
31        scores_masked = scores
32
33    # Step 3: After softmax
34    weights = F.softmax(scores_masked, dim=-1)
35    print(f"Weights has NaN: {torch.isnan(weights).any()}")
36
37    # Step 4: Output
38    output = torch.matmul(weights, V)
39    print(f"Output has NaN: {torch.isnan(output).any()}")
40
41    if torch.isnan(weights).any():
42        print("\n⚠️ FIX: Add torch.nan_to_num(weights, nan=0.0) after softmax")
43
44
45# Example: Trigger NaN with all-masked row
46Q = torch.randn(1, 3, 8)
47K = torch.randn(1, 5, 8)
48V = torch.randn(1, 5, 8)
49bad_mask = torch.tensor([[[1, 1, 0, 0, 0],
50                          [0, 0, 0, 0, 0],  # All masked!
51                          [1, 1, 1, 0, 0]]])
52
53debug_nan_attention(Q, K, V, bad_mask)

6.6 Interactive Attention Explorer

Building an Interactive Tool

Building an Interactive Tool
🐍demo.py
59 lines without explanation
1def interactive_attention_demo():
2    """
3    Interactive demo showing how inputs affect attention.
4    """
5    print("=" * 60)
6    print("INTERACTIVE ATTENTION DEMONSTRATION")
7    print("=" * 60)
8
9    # Create simple input
10    tokens = ["I", "love", "cats", "and", "dogs"]
11    seq_len = len(tokens)
12    d_model = 4
13
14    # Create embeddings that reflect semantic similarity
15    embeddings = {
16        "I":     torch.tensor([1.0, 0.0, 0.0, 0.0]),
17        "love":  torch.tensor([0.0, 1.0, 0.0, 0.0]),
18        "cats":  torch.tensor([0.0, 0.0, 1.0, 0.5]),  # Animal
19        "and":   torch.tensor([0.0, 0.0, 0.0, 0.0]),
20        "dogs":  torch.tensor([0.0, 0.0, 0.8, 0.6]),  # Animal (similar to cats)
21    }
22
23    X = torch.stack([embeddings[t] for t in tokens]).unsqueeze(0)
24    Q = K = V = X
25
26    print("\n1. INPUT EMBEDDINGS:")
27    for i, token in enumerate(tokens):
28        print(f"   {token:8s}: {X[0, i].tolist()}")
29
30    # Compute attention
31    output, weights = scaled_dot_product_attention_with_weights(Q, K, V)
32
33    print("\n2. ATTENTION WEIGHTS (what each token attends to):")
34    weights_np = weights[0].numpy()
35    header = "         " + " ".join([f"{t:8s}" for t in tokens])
36    print(header)
37    for i, token in enumerate(tokens):
38        row = f"{token:8s} " + " ".join([f"{weights_np[i, j]:8.3f}" for j in range(seq_len)])
39        print(row)
40
41    print("\n3. KEY OBSERVATIONS:")
42    # Find interesting patterns
43    for i, token in enumerate(tokens):
44        max_attn_idx = weights_np[i].argmax()
45        max_attn_val = weights_np[i, max_attn_idx]
46        print(f"   '{token}' attends most to '{tokens[max_attn_idx]}' ({max_attn_val:.3f})")
47
48    # Note similarity between cats and dogs
49    cats_dogs_attn = weights_np[2, 4]  # cats attending to dogs
50    dogs_cats_attn = weights_np[4, 2]  # dogs attending to cats
51    print(f"\n   'cats' → 'dogs' attention: {cats_dogs_attn:.3f}")
52    print(f"   'dogs' → 'cats' attention: {dogs_cats_attn:.3f}")
53    print("   (High because they have similar embeddings - both are animals!)")
54
55    return weights
56
57
58# Run demo
59weights = interactive_attention_demo()

6.7 Complete Visualization Module

Complete Visualization Module
🐍attention_visualization.py
148 lines without explanation
1"""
2attention_visualization.py
3==========================
4Tools for visualizing and debugging attention mechanisms.
5"""
6
7import torch
8import torch.nn.functional as F
9import matplotlib.pyplot as plt
10import numpy as np
11from typing import Optional, List, Tuple, Dict
12import math
13
14
15class AttentionVisualizer:
16    """Comprehensive attention visualization toolkit."""
17
18    def __init__(self, figsize: Tuple[int, int] = (10, 8)):
19        self.figsize = figsize
20        self.cmap = 'Blues'
21
22    def plot_heatmap(
23        self,
24        weights: torch.Tensor,
25        x_labels: Optional[List[str]] = None,
26        y_labels: Optional[List[str]] = None,
27        title: str = "Attention",
28        ax: Optional[plt.Axes] = None,
29        show_values: bool = True
30    ) -> plt.Axes:
31        """Plot single attention heatmap."""
32        if isinstance(weights, torch.Tensor):
33            weights = weights.detach().cpu().numpy()
34
35        if ax is None:
36            _, ax = plt.subplots(figsize=self.figsize)
37
38        im = ax.imshow(weights, cmap=self.cmap, vmin=0, vmax=1)
39
40        seq_q, seq_k = weights.shape
41        x_labels = x_labels or [str(i) for i in range(seq_k)]
42        y_labels = y_labels or [str(i) for i in range(seq_q)]
43
44        ax.set_xticks(range(seq_k))
45        ax.set_yticks(range(seq_q))
46        ax.set_xticklabels(x_labels, rotation=45, ha='right')
47        ax.set_yticklabels(y_labels)
48        ax.set_title(title)
49
50        if show_values and seq_q <= 10 and seq_k <= 10:
51            for i in range(seq_q):
52                for j in range(seq_k):
53                    color = "white" if weights[i, j] > 0.5 else "black"
54                    ax.text(j, i, f"{weights[i, j]:.2f}",
55                           ha="center", va="center", color=color, fontsize=8)
56
57        return ax
58
59    def plot_multi_head(
60        self,
61        weights: torch.Tensor,
62        tokens: List[str],
63        title: str = "Multi-Head Attention",
64        save_path: Optional[str] = None
65    ) -> None:
66        """Plot all attention heads."""
67        if isinstance(weights, torch.Tensor):
68            weights = weights.detach().cpu().numpy()
69
70        num_heads = weights.shape[0]
71        cols = min(4, num_heads)
72        rows = (num_heads + cols - 1) // cols
73
74        fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4))
75        axes = np.array(axes).flatten()
76
77        for i in range(num_heads):
78            self.plot_heatmap(
79                weights[i], tokens, tokens,
80                f"Head {i+1}", axes[i], show_values=False
81            )
82
83        # Hide unused subplots
84        for i in range(num_heads, len(axes)):
85            axes[i].axis('off')
86
87        fig.suptitle(title, fontsize=14, fontweight='bold')
88        plt.tight_layout()
89
90        if save_path:
91            plt.savefig(save_path, dpi=150, bbox_inches='tight')
92        plt.close()
93
94    def analyze(self, weights: torch.Tensor) -> Dict:
95        """Analyze attention patterns."""
96        if isinstance(weights, torch.Tensor):
97            w = weights.detach().cpu().numpy()
98        else:
99            w = weights
100
101        return {
102            'diagonal_mean': np.diag(w).mean(),
103            'first_col_mean': w[:, 0].mean(),
104            'entropy': -np.sum(w * np.log(w + 1e-10), axis=-1).mean(),
105            'max_weight_mean': w.max(axis=-1).mean(),
106            'sparsity': (w < 0.1).mean()
107        }
108
109
110# Create global visualizer instance
111visualizer = AttentionVisualizer()
112
113
114def plot_attention(weights, tokens=None, title="Attention", save_path=None):
115    """Convenience function for quick plotting."""
116    if tokens is None:
117        tokens = [str(i) for i in range(weights.shape[-1])]
118
119    fig, ax = plt.subplots(figsize=(8, 6))
120    visualizer.plot_heatmap(weights, tokens, tokens, title, ax)
121    plt.colorbar(ax.images[0], ax=ax, label="Weight")
122    plt.tight_layout()
123
124    if save_path:
125        plt.savefig(save_path, dpi=150)
126    plt.close()
127
128
129if __name__ == "__main__":
130    # Demo
131    torch.manual_seed(42)
132
133    tokens = ["The", "cat", "sat", "on", "mat"]
134    Q = torch.randn(1, 5, 32)
135    K = torch.randn(1, 5, 32)
136    V = torch.randn(1, 5, 32)
137
138    _, weights = scaled_dot_product_attention_with_weights(Q, K, V)
139
140    # Single head visualization
141    plot_attention(weights[0], tokens, "Self-Attention Demo", "demo_attention.png")
142    print("Saved demo_attention.png")
143
144    # Analysis
145    analysis = visualizer.analyze(weights[0])
146    print("\nAttention Analysis:")
147    for key, value in analysis.items():
148        print(f"  {key}: {value:.3f}")

Summary

Visualization Tools Built

ToolPurpose
plot_attention_heatmapSingle head visualization
plot_multi_head_attentionAll heads side by side
analyze_attention_patternQuantitative pattern analysis
AttentionVisualizerFull-featured visualization class

Debugging Checklist

When attention isn't working correctly:

  • Check if attention is uniform (scores not scaled?)
  • Check if attention collapsed (learning rate too high?)
  • Check for NaN (all positions masked?)
  • Check input embeddings (properly initialized?)
  • Visualize attention patterns (do they make sense?)

Key Patterns to Recognize

PatternAppearanceMeaning
DiagonalHigh values on diagonalSelf-attention dominant
ColumnOne column highlightedOne position important
PreviousDiagonal shifted leftSequential processing
UniformAll values ~equalNot learning relationships
SparseFew high valuesFocused attention

Exercises

Visualization Exercises

  1. Create an animated visualization that shows how attention weights change during training.
  2. Build a visualization that compares attention patterns before and after fine-tuning.
  3. Implement a "diff" visualization that highlights where two attention matrices differ.

Debugging Exercises

  1. Intentionally break attention in 5 different ways and use the debugging tools to identify each issue.
  2. Write unit tests that verify attention weights satisfy expected properties (sum to 1, no NaN, etc.).
  3. Create a logging callback that records attention statistics during training and plots trends.

Chapter Summary

In this chapter, you've mastered the attention mechanism:

  1. Intuition: Attention as weighted averaging based on similarity
  2. Mathematics: The scaled dot-product formula and why each component matters
  3. Numerical: Hand-computed examples for verification
  4. Implementation: Clean, tested PyTorch code
  5. Masking: Padding and causal masks for real-world use
  6. Visualization: Tools to see and debug attention

You now have a solid foundation for understanding transformers. In the next chapter, we'll extend single-head attention to multi-head attention, enabling the model to learn multiple types of relationships simultaneously.