Masking is essential for practical attention mechanisms. Without masks, attention would:
Attend to padding tokens in variable-length sequences
Allow the decoder to "cheat" by looking at future tokens
In this section, we'll understand why masks are necessary, how they work mathematically, and implement both padding masks and causal (look-ahead) masks.
5.1 Why Masking is Necessary
Problem 1: Variable-Length Sequences
Real-world sequences have different lengths. To batch them efficiently, we pad shorter sequences:
Row 2 (position 2): Can attend to positions 0, 1, 2
etc.
Alternative: Using torch.triu
Alternative: Using torch.triu
🐍masks.py
Explanation(0)
Code(11)
11 lines without explanation
1defcreate_causal_mask_v2(seq_len:int)-> torch.Tensor:2"""Alternative implementation using triu (upper triangular)."""3# Create matrix of ones, then mask upper triangle (excluding diagonal)4 mask = torch.ones(seq_len, seq_len)5 mask = torch.triu(mask, diagonal=1)# Upper triangle, offset by 16 mask =(mask ==0)# Invert: True where we CAN attend7return mask
89# Verify both methods produce same result10assert torch.equal(create_causal_mask(5), create_causal_mask_v2(5))11print("Both methods produce identical masks ✓")
Visualizing Causal Attention
Visualizing Causal Attention
🐍visualize.py
Explanation(0)
Code(24)
24 lines without explanation
1import matplotlib.pyplot as plt
23defvisualize_causal_mask(seq_len:int):4"""Visualize what each position can attend to."""5 mask = create_causal_mask(seq_len)67 fig, ax = plt.subplots(figsize=(6,6))8 ax.imshow(mask.float(), cmap='Blues')910 ax.set_xlabel('Key Position (attending TO)')11 ax.set_ylabel('Query Position (attending FROM)')12 ax.set_title('Causal Mask: White = Cannot Attend')1314# Add grid15 ax.set_xticks(range(seq_len))16 ax.set_yticks(range(seq_len))17 ax.grid(True, linewidth=0.5)1819 plt.tight_layout()20 plt.savefig('causal_mask.png', dpi=150)21 plt.close()22print("Saved causal_mask.png")2324visualize_causal_mask(8)
5.5 Combining Masks
Combining Padding and Causal Masks
For a decoder, we need both:
Don't attend to padding
Don't attend to future positions
Combining Padding and Causal Masks
🐍masks.py
Explanation(0)
Code(39)
39 lines without explanation
1defcreate_combined_mask(2 seq: torch.Tensor,3 pad_token_id:int=04)-> torch.Tensor:5"""
6 Create a combined padding + causal mask for decoder self-attention.
78 Args:
9 seq: Token IDs of shape [batch, seq_len]
10 pad_token_id: Padding token ID
1112 Returns:
13 mask: [batch, seq_len, seq_len]
14 True = attend, False = ignore
15 """16 batch_size, seq_len = seq.shape
1718# Padding mask: [batch, 1, seq_len]19 padding_mask = create_padding_mask(seq, pad_token_id)2021# Causal mask: [seq_len, seq_len] -> [1, seq_len, seq_len]22 causal_mask = create_causal_mask(seq_len).unsqueeze(0)2324# Combine: Both conditions must be True to attend25# [batch, 1, seq_len] AND [1, seq_len, seq_len] -> [batch, seq_len, seq_len]26 combined_mask = padding_mask & causal_mask
2728return combined_mask
293031# Example32seq = torch.tensor([33[5,3,2,0,0],# 3 real tokens + 2 padding34[7,4,8,6,0],# 4 real tokens + 1 padding35])3637combined = create_combined_mask(seq)38print("Combined mask for sequence 0:")39print(combined[0].int())
1# ❌ WRONG: 1 = masked, 0 = attend (inverted)2mask =(seq == pad_token_id)# True where padding3scores.masked_fill(mask,-inf)# Masks non-padding!45# ✅ CORRECT: 1 = attend, 0 = masked6mask =(seq != pad_token_id)# True where NOT padding7scores.masked_fill(~mask,-inf)# Or equivalently:8scores.masked_fill(mask ==0,-inf)
Mistake 4: Forgetting to Handle NaN
Mistake 4: Forgetting to Handle NaN
🐍python
Explanation(0)
Code(6)
6 lines without explanation
1# When ALL positions are masked, softmax produces NaN2scores = torch.tensor([[-inf,-inf,-inf]])3weights = F.softmax(scores, dim=-1)# [nan, nan, nan]!45# ✅ Handle with nan_to_num6weights = torch.nan_to_num(weights, nan=0.0)# [0, 0, 0]
Summary
Mask Types Reference
Mask
Purpose
Shape
Used In
Padding
Ignore PAD tokens
[batch, 1, seq_k]
All attention
Causal
Block future
[seq_q, seq_k]
Decoder self-attn
Combined
Both
[batch, seq_q, seq_k]
Decoder self-attn
Cross
Source padding
[batch, 1, src_len]
Cross-attention
Key Implementation Points
Use -∞ for masking: masked_fill(mask == 0, float('-inf'))
Broadcasting: Add dimensions with .unsqueeze() for proper broadcasting
Polarity: True/1 = attend, False/0 = ignore
NaN handling: Use torch.nan_to_num() for all-masked cases
Exercises
Implementation Exercises
Create a "prefix" mask that allows attending only to the first N tokens.
Implement a "window" mask that allows each position to attend only to positions within ±k.
Create a "block diagonal" mask for efficient attention on very long sequences.
Debugging Exercises
Given attention weights that are all uniform (1/n everywhere), what might be wrong with the masking?
If you see NaN in your outputs, trace through the forward pass to find where they originate.
Write a test that verifies masked positions receive exactly zero attention weight.
Next Section Preview
In the next section, we'll build visualization tools for attention. Being able to see what your attention mechanism is doing is invaluable for debugging and understanding model behavior.