Chapter 2
15 min read
Section 12 of 75

Understanding and Implementing Masks

Attention Mechanism From Scratch

Introduction

Masking is essential for practical attention mechanisms. Without masks, attention would:

  1. Attend to padding tokens in variable-length sequences
  2. Allow the decoder to "cheat" by looking at future tokens

In this section, we'll understand why masks are necessary, how they work mathematically, and implement both padding masks and causal (look-ahead) masks.


5.1 Why Masking is Necessary

Problem 1: Variable-Length Sequences

Real-world sequences have different lengths. To batch them efficiently, we pad shorter sequences:

Variable-Length Sequences Example
📝text
9 lines without explanation
1Original sequences:
2  "cat sat"       → [cat, sat]           (length 2)
3  "dog ran fast"  → [dog, ran, fast]     (length 3)
4  "bird"          → [bird]               (length 1)
5
6After padding (max_len=3):
7  [cat, sat, PAD]
8  [dog, ran, fast]
9  [bird, PAD, PAD]

The Problem: Without masking, tokens would attend to PAD tokens, incorporating meaningless information.

Problem 2: Autoregressive Generation

In language models and decoders, we generate tokens one at a time. During training, we must prevent the model from seeing future tokens:

Autoregressive Generation
📝text
5 lines without explanation
1Target: "The cat sat"
2
3Position 0: Predicting "The"  → Can see: nothing
4Position 1: Predicting "cat"  → Can see: "The"
5Position 2: Predicting "sat"  → Can see: "The cat"

The Problem: Without causal masking, the model could cheat by looking at what it's supposed to predict.


5.2 How Masking Works

The Mechanism

Masks work by setting attention scores to -∞ before softmax:

How Masking Works
📝text
3 lines without explanation
1softmax([1.0, 2.0, -∞]) = [0.27, 0.73, 0.00]
23                          Masked position gets zero weight

Since exp()=0\exp(-\infty) = 0, masked positions contribute nothing to the weighted sum.

Mask Types

Mask TypePurposeShape
Padding maskIgnore PAD tokens[batch, 1, seq_len]
Causal maskPrevent future attention[seq_len, seq_len]
CombinedBoth effects[batch, seq_len, seq_len]

5.3 Padding Masks

Creating a Padding Mask

Creating a Padding Mask
🐍masks.py
45 lines without explanation
1import torch
2import torch.nn.functional as F
3
4def create_padding_mask(
5    seq: torch.Tensor,
6    pad_token_id: int = 0
7) -> torch.Tensor:
8    """
9    Create a padding mask for attention.
10
11    Args:
12        seq: Token IDs of shape [batch, seq_len]
13        pad_token_id: The ID used for padding (default: 0)
14
15    Returns:
16        mask: Boolean tensor of shape [batch, 1, seq_len]
17              True = attend, False = ignore
18
19    Example:
20        >>> seq = torch.tensor([[1, 2, 3, 0, 0],
21        ...                     [4, 5, 0, 0, 0]])
22        >>> mask = create_padding_mask(seq, pad_token_id=0)
23        >>> mask
24        tensor([[[ True,  True,  True, False, False]],
25                [[ True,  True, False, False, False]]])
26    """
27    # [batch, seq_len] -> [batch, 1, seq_len]
28    # The extra dimension allows broadcasting with attention scores
29    mask = (seq != pad_token_id).unsqueeze(1)
30    return mask
31
32
33# Example
34sequences = torch.tensor([
35    [5, 3, 2, 0, 0],  # 3 real tokens, 2 padding
36    [7, 4, 8, 6, 0],  # 4 real tokens, 1 padding
37    [1, 0, 0, 0, 0],  # 1 real token, 4 padding
38])
39
40padding_mask = create_padding_mask(sequences, pad_token_id=0)
41print("Sequences:")
42print(sequences)
43print("\nPadding mask:")
44print(padding_mask)
45print(f"\nShape: {padding_mask.shape}")

Output:

Padding Mask Output
📝text
11 lines without explanation
1Sequences:
2tensor([[5, 3, 2, 0, 0],
3        [7, 4, 8, 6, 0],
4        [1, 0, 0, 0, 0]])
5
6Padding mask:
7tensor([[[ True,  True,  True, False, False]],
8        [[ True,  True,  True,  True, False]],
9        [[ True, False, False, False, False]]])
10
11Shape: torch.Size([3, 1, 5])

Applying Padding Mask

Applying Padding Mask
🐍attention.py
46 lines without explanation
1def attention_with_padding_mask(Q, K, V, padding_mask):
2    """
3    Attention that ignores padding tokens in keys/values.
4
5    Args:
6        Q: [batch, seq_len_q, d_k]
7        K: [batch, seq_len_k, d_k]
8        V: [batch, seq_len_k, d_v]
9        padding_mask: [batch, 1, seq_len_k]
10
11    Returns:
12        output: [batch, seq_len_q, d_v]
13    """
14    d_k = Q.size(-1)
15
16    # Compute scores: [batch, seq_len_q, seq_len_k]
17    scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
18
19    # Apply mask: [batch, 1, seq_len_k] broadcasts to [batch, seq_len_q, seq_len_k]
20    scores = scores.masked_fill(~padding_mask, float('-inf'))
21
22    # Softmax and output
23    weights = F.softmax(scores, dim=-1)
24    weights = torch.nan_to_num(weights, nan=0.0)
25    output = torch.matmul(weights, V)
26
27    return output, weights
28
29
30# Example
31batch, seq_len, d_model = 2, 5, 4
32Q = torch.randn(batch, seq_len, d_model)
33K = torch.randn(batch, seq_len, d_model)
34V = torch.randn(batch, seq_len, d_model)
35
36# Sequences with different padding
37seq = torch.tensor([
38    [1, 2, 3, 0, 0],  # 3 valid tokens
39    [1, 2, 3, 4, 0],  # 4 valid tokens
40])
41padding_mask = create_padding_mask(seq)
42
43output, weights = attention_with_padding_mask(Q, K, V, padding_mask)
44
45print("Attention weights (batch 0, first query):")
46print(weights[0, 0])  # Should have 0 weights on positions 3, 4

5.4 Causal (Look-Ahead) Masks

Creating a Causal Mask

Creating a Causal Mask
🐍masks.py
31 lines without explanation
1def create_causal_mask(seq_len: int) -> torch.Tensor:
2    """
3    Create a causal (look-ahead) mask for autoregressive models.
4
5    Position i can only attend to positions j where j <= i.
6
7    Args:
8        seq_len: Length of the sequence
9
10    Returns:
11        mask: Boolean tensor of shape [seq_len, seq_len]
12              True = attend, False = ignore
13
14    Example:
15        >>> mask = create_causal_mask(4)
16        >>> mask
17        tensor([[ True, False, False, False],
18                [ True,  True, False, False],
19                [ True,  True,  True, False],
20                [ True,  True,  True,  True]])
21    """
22    # Create lower triangular matrix
23    # torch.tril keeps lower triangle (including diagonal)
24    mask = torch.tril(torch.ones(seq_len, seq_len)).bool()
25    return mask
26
27
28# Example
29causal_mask = create_causal_mask(5)
30print("Causal mask (5x5):")
31print(causal_mask.int())  # Print as integers for clarity

Output:

Causal Mask Output
📝text
6 lines without explanation
1Causal mask (5x5):
2tensor([[1, 0, 0, 0, 0],
3        [1, 1, 0, 0, 0],
4        [1, 1, 1, 0, 0],
5        [1, 1, 1, 1, 0],
6        [1, 1, 1, 1, 1]])

Interpretation:

  • Row 0 (position 0): Can only attend to position 0
  • Row 1 (position 1): Can attend to positions 0, 1
  • Row 2 (position 2): Can attend to positions 0, 1, 2
  • etc.

Alternative: Using torch.triu

Alternative: Using torch.triu
🐍masks.py
11 lines without explanation
1def create_causal_mask_v2(seq_len: int) -> torch.Tensor:
2    """Alternative implementation using triu (upper triangular)."""
3    # Create matrix of ones, then mask upper triangle (excluding diagonal)
4    mask = torch.ones(seq_len, seq_len)
5    mask = torch.triu(mask, diagonal=1)  # Upper triangle, offset by 1
6    mask = (mask == 0)  # Invert: True where we CAN attend
7    return mask
8
9# Verify both methods produce same result
10assert torch.equal(create_causal_mask(5), create_causal_mask_v2(5))
11print("Both methods produce identical masks ✓")

Visualizing Causal Attention

Visualizing Causal Attention
🐍visualize.py
24 lines without explanation
1import matplotlib.pyplot as plt
2
3def visualize_causal_mask(seq_len: int):
4    """Visualize what each position can attend to."""
5    mask = create_causal_mask(seq_len)
6
7    fig, ax = plt.subplots(figsize=(6, 6))
8    ax.imshow(mask.float(), cmap='Blues')
9
10    ax.set_xlabel('Key Position (attending TO)')
11    ax.set_ylabel('Query Position (attending FROM)')
12    ax.set_title('Causal Mask: White = Cannot Attend')
13
14    # Add grid
15    ax.set_xticks(range(seq_len))
16    ax.set_yticks(range(seq_len))
17    ax.grid(True, linewidth=0.5)
18
19    plt.tight_layout()
20    plt.savefig('causal_mask.png', dpi=150)
21    plt.close()
22    print("Saved causal_mask.png")
23
24visualize_causal_mask(8)

5.5 Combining Masks

Combining Padding and Causal Masks

For a decoder, we need both:

  1. Don't attend to padding
  2. Don't attend to future positions
Combining Padding and Causal Masks
🐍masks.py
39 lines without explanation
1def create_combined_mask(
2    seq: torch.Tensor,
3    pad_token_id: int = 0
4) -> torch.Tensor:
5    """
6    Create a combined padding + causal mask for decoder self-attention.
7
8    Args:
9        seq: Token IDs of shape [batch, seq_len]
10        pad_token_id: Padding token ID
11
12    Returns:
13        mask: [batch, seq_len, seq_len]
14              True = attend, False = ignore
15    """
16    batch_size, seq_len = seq.shape
17
18    # Padding mask: [batch, 1, seq_len]
19    padding_mask = create_padding_mask(seq, pad_token_id)
20
21    # Causal mask: [seq_len, seq_len] -> [1, seq_len, seq_len]
22    causal_mask = create_causal_mask(seq_len).unsqueeze(0)
23
24    # Combine: Both conditions must be True to attend
25    # [batch, 1, seq_len] AND [1, seq_len, seq_len] -> [batch, seq_len, seq_len]
26    combined_mask = padding_mask & causal_mask
27
28    return combined_mask
29
30
31# Example
32seq = torch.tensor([
33    [5, 3, 2, 0, 0],  # 3 real tokens + 2 padding
34    [7, 4, 8, 6, 0],  # 4 real tokens + 1 padding
35])
36
37combined = create_combined_mask(seq)
38print("Combined mask for sequence 0:")
39print(combined[0].int())

Output:

Combined Mask Output
📝text
6 lines without explanation
1Combined mask for sequence 0:
2tensor([[1, 0, 0, 0, 0],
3        [1, 1, 0, 0, 0],
4        [1, 1, 1, 0, 0],
5        [0, 0, 0, 0, 0],
6        [0, 0, 0, 0, 0]])

Notice:

  • Rows 3, 4 are all zeros (padding positions don't attend to anything)
  • Columns 3, 4 are all zeros (nothing attends to padding)
  • Upper triangle is masked (causal constraint)

5.6 Source-Target Masks for Encoder-Decoder

Cross-Attention Mask

In encoder-decoder attention:

  • Query comes from decoder (may have padding)
  • Key/Value come from encoder (may have padding)
  • No causal constraint (decoder can see all encoder positions)
Cross-Attention Mask
🐍masks.py
43 lines without explanation
1def create_cross_attention_mask(
2    src_seq: torch.Tensor,
3    tgt_seq: torch.Tensor,
4    pad_token_id: int = 0
5) -> torch.Tensor:
6    """
7    Create mask for encoder-decoder cross-attention.
8
9    Args:
10        src_seq: Source (encoder) token IDs [batch, src_len]
11        tgt_seq: Target (decoder) token IDs [batch, tgt_len]
12        pad_token_id: Padding token ID
13
14    Returns:
15        mask: [batch, tgt_len, src_len]
16              Decoder positions can attend to non-padding encoder positions
17    """
18    # We only need to mask padding in the source (encoder) sequence
19    # [batch, 1, src_len]
20    src_mask = create_padding_mask(src_seq, pad_token_id)
21
22    # Broadcast to [batch, tgt_len, src_len]
23    # All target positions see the same source mask
24    batch_size, tgt_len = tgt_seq.shape
25    cross_mask = src_mask.expand(batch_size, tgt_len, -1)
26
27    return cross_mask
28
29
30# Example
31src = torch.tensor([
32    [1, 2, 3, 0, 0],  # German: "Der Hund ist [PAD] [PAD]"
33    [4, 5, 6, 7, 0],  # German: "Die Katze sitzt hier [PAD]"
34])
35
36tgt = torch.tensor([
37    [10, 11, 12, 0],  # English: "The dog is [PAD]"
38    [13, 14, 15, 16], # English: "The cat sits here"
39])
40
41cross_mask = create_cross_attention_mask(src, tgt)
42print("Cross-attention mask for batch 0:")
43print(cross_mask[0].int())

Output:

Cross-Attention Mask Output
📝text
5 lines without explanation
1Cross-attention mask for batch 0:
2tensor([[1, 1, 1, 0, 0],
3        [1, 1, 1, 0, 0],
4        [1, 1, 1, 0, 0],
5        [1, 1, 1, 0, 0]])

All decoder positions can attend to encoder positions 0, 1, 2 (non-padding).


5.7 Complete Mask Utilities Module

Complete Mask Utilities Module
🐍masks.py
119 lines without explanation
1"""
2masks.py
3========
4Mask utilities for Transformer attention mechanisms.
5"""
6
7import torch
8from typing import Optional
9
10
11def create_padding_mask(
12    seq: torch.Tensor,
13    pad_token_id: int = 0
14) -> torch.Tensor:
15    """
16    Create padding mask: [batch, 1, seq_len]
17    True = attend, False = ignore (padding)
18    """
19    return (seq != pad_token_id).unsqueeze(1)
20
21
22def create_causal_mask(
23    seq_len: int,
24    device: Optional[torch.device] = None
25) -> torch.Tensor:
26    """
27    Create causal (look-ahead) mask: [seq_len, seq_len]
28    True = attend, False = ignore (future)
29    """
30    mask = torch.tril(torch.ones(seq_len, seq_len, device=device)).bool()
31    return mask
32
33
34def create_decoder_mask(
35    seq: torch.Tensor,
36    pad_token_id: int = 0
37) -> torch.Tensor:
38    """
39    Create combined causal + padding mask for decoder self-attention.
40    Returns: [batch, seq_len, seq_len]
41    """
42    batch_size, seq_len = seq.shape
43    device = seq.device
44
45    padding_mask = create_padding_mask(seq, pad_token_id)  # [batch, 1, seq_len]
46    causal_mask = create_causal_mask(seq_len, device)      # [seq_len, seq_len]
47
48    # Combine: [batch, 1, seq_len] & [seq_len, seq_len] -> [batch, seq_len, seq_len]
49    combined = padding_mask & causal_mask.unsqueeze(0)
50    return combined
51
52
53def create_encoder_mask(
54    seq: torch.Tensor,
55    pad_token_id: int = 0
56) -> torch.Tensor:
57    """
58    Create padding mask for encoder self-attention.
59    Returns: [batch, 1, seq_len]
60    """
61    return create_padding_mask(seq, pad_token_id)
62
63
64def create_cross_attention_mask(
65    src_seq: torch.Tensor,
66    tgt_len: int,
67    pad_token_id: int = 0
68) -> torch.Tensor:
69    """
70    Create mask for encoder-decoder cross-attention.
71    Returns: [batch, 1, src_len] (broadcasts to [batch, tgt_len, src_len])
72    """
73    return create_padding_mask(src_seq, pad_token_id)
74
75
76# Comprehensive test
77def test_masks():
78    """Test all mask creation functions."""
79    batch, seq_len = 2, 5
80
81    # Test sequences with padding
82    seq = torch.tensor([
83        [1, 2, 3, 0, 0],
84        [4, 5, 6, 7, 0]
85    ])
86
87    # Padding mask
88    pad_mask = create_padding_mask(seq)
89    assert pad_mask.shape == (2, 1, 5)
90    assert pad_mask[0, 0, :3].all()  # First 3 are True
91    assert not pad_mask[0, 0, 3:].any()  # Last 2 are False
92    print("✓ Padding mask test passed")
93
94    # Causal mask
95    causal = create_causal_mask(5)
96    assert causal.shape == (5, 5)
97    assert causal[0, 0] and not causal[0, 1]  # Position 0 sees only itself
98    assert causal[4, :].all()  # Position 4 sees all
99    print("✓ Causal mask test passed")
100
101    # Decoder mask (combined)
102    dec_mask = create_decoder_mask(seq)
103    assert dec_mask.shape == (2, 5, 5)
104    # Padding rows should be all False
105    assert not dec_mask[0, 3, :].any()
106    assert not dec_mask[0, 4, :].any()
107    print("✓ Decoder mask test passed")
108
109    # Cross-attention mask
110    src = torch.tensor([[1, 2, 0, 0]])
111    cross_mask = create_cross_attention_mask(src, tgt_len=3)
112    assert cross_mask.shape == (1, 1, 4)
113    print("✓ Cross-attention mask test passed")
114
115    print("\n✓ All mask tests passed!")
116
117
118if __name__ == "__main__":
119    test_masks()

5.8 Common Masking Mistakes

Mistake 1: Wrong Mask Value

Mistake 1: Wrong Mask Value
🐍python
7 lines without explanation
1# ❌ WRONG: Using 0 to mask
2scores = scores.masked_fill(mask == 0, 0)  # Doesn't work!
3# softmax([1, 2, 0]) = [0.09, 0.24, 0.67]  # Still contributes!
4
5# ✅ CORRECT: Using -inf to mask
6scores = scores.masked_fill(mask == 0, float('-inf'))
7# softmax([1, 2, -inf]) = [0.27, 0.73, 0.00]  # Properly masked

Mistake 2: Wrong Broadcasting

Mistake 2: Wrong Broadcasting
🐍python
8 lines without explanation
1# ❌ WRONG: Mask doesn't broadcast correctly
2mask = torch.ones(batch, seq_len)  # [batch, seq_len]
3scores = torch.randn(batch, seq_len, seq_len)  # [batch, seq_q, seq_k]
4# scores.masked_fill(mask, ...)  # Broadcasting error!
5
6# ✅ CORRECT: Add dimension for broadcasting
7mask = mask.unsqueeze(1)  # [batch, 1, seq_len]
8# Now broadcasts: [batch, 1, seq_len] -> [batch, seq_q, seq_len]

Mistake 3: Mask Polarity

Mistake 3: Mask Polarity
🐍python
8 lines without explanation
1# ❌ WRONG: 1 = masked, 0 = attend (inverted)
2mask = (seq == pad_token_id)  # True where padding
3scores.masked_fill(mask, -inf)  # Masks non-padding!
4
5# ✅ CORRECT: 1 = attend, 0 = masked
6mask = (seq != pad_token_id)  # True where NOT padding
7scores.masked_fill(~mask, -inf)  # Or equivalently:
8scores.masked_fill(mask == 0, -inf)

Mistake 4: Forgetting to Handle NaN

Mistake 4: Forgetting to Handle NaN
🐍python
6 lines without explanation
1# When ALL positions are masked, softmax produces NaN
2scores = torch.tensor([[-inf, -inf, -inf]])
3weights = F.softmax(scores, dim=-1)  # [nan, nan, nan]!
4
5# ✅ Handle with nan_to_num
6weights = torch.nan_to_num(weights, nan=0.0)  # [0, 0, 0]

Summary

Mask Types Reference

MaskPurposeShapeUsed In
PaddingIgnore PAD tokens[batch, 1, seq_k]All attention
CausalBlock future[seq_q, seq_k]Decoder self-attn
CombinedBoth[batch, seq_q, seq_k]Decoder self-attn
CrossSource padding[batch, 1, src_len]Cross-attention

Key Implementation Points

  1. Use -∞ for masking: masked_fill(mask == 0, float('-inf'))
  2. Broadcasting: Add dimensions with .unsqueeze() for proper broadcasting
  3. Polarity: True/1 = attend, False/0 = ignore
  4. NaN handling: Use torch.nan_to_num() for all-masked cases

Exercises

Implementation Exercises

  1. Create a "prefix" mask that allows attending only to the first N tokens.
  2. Implement a "window" mask that allows each position to attend only to positions within ±k.
  3. Create a "block diagonal" mask for efficient attention on very long sequences.

Debugging Exercises

  1. Given attention weights that are all uniform (1/n everywhere), what might be wrong with the masking?
  2. If you see NaN in your outputs, trace through the forward pass to find where they originate.
  3. Write a test that verifies masked positions receive exactly zero attention weight.

Next Section Preview

In the next section, we'll build visualization tools for attention. Being able to see what your attention mechanism is doing is invaluable for debugging and understanding model behavior.