Introduction
Causal masking is the key mechanism that enables autoregressive generation. It prevents the decoder from "cheating" by looking at future tokens during training, ensuring the model learns to predict each token based only on previous context.
This section covers the implementation and usage of causal masks.
The Problem: Future Information Leakage
During Training
We feed the decoder the complete target sequence:
1Target: "<bos> The dog runs <eos>"
2
3Without masking, at position 2 ("dog"):
4 Model sees: "<bos>", "The", "dog", "runs", "<eos>"
5 Task: Predict "runs" (next token)
6
7 Problem: "runs" is already visible in the input!
8 Model can just copy instead of learning to predict.The Solution: Causal Mask
1With causal mask, at position 2 ("dog"):
2 Model sees: "<bos>", "The", "dog", [MASKED], [MASKED]
3 Task: Predict "runs"
4
5 Now model must actually learn to predict!Understanding the Causal Mask
Mask Structure
A causal (look-ahead) mask is a lower triangular matrix:
1Sequence: [pos0, pos1, pos2, pos3]
2
3Attention Mask (1 = attend, 0 = ignore):
4
5 pos0 pos1 pos2 pos3
6pos0 [ 1 0 0 0 ] β pos0 can only see pos0
7pos1 [ 1 1 0 0 ] β pos1 can see pos0, pos1
8pos2 [ 1 1 1 0 ] β pos2 can see pos0, pos1, pos2
9pos3 [ 1 1 1 1 ] β pos3 can see allVisual Representation
1Target: "<bos> The dog runs"
2 0 1 2 3
3
4Query position 0 ("<bos>"):
5 Attends to: [<bos>]
6
7Query position 1 ("The"):
8 Attends to: [<bos>, The]
9
10Query position 2 ("dog"):
11 Attends to: [<bos>, The, dog]
12
13Query position 3 ("runs"):
14 Attends to: [<bos>, The, dog, runs]Implementation
Basic Causal Mask
1import torch
2import torch.nn as nn
3
4
5def create_causal_mask(seq_len: int, device: torch.device = None) -> torch.Tensor:
6 """
7 Create a causal (look-ahead) mask for self-attention.
8
9 Args:
10 seq_len: Length of the sequence
11 device: Device to create tensor on
12
13 Returns:
14 mask: [1, 1, seq_len, seq_len] where:
15 1 = attend (valid)
16 0 = ignore (masked)
17
18 Example:
19 >>> mask = create_causal_mask(4)
20 >>> mask.squeeze()
21 tensor([[1., 0., 0., 0.],
22 [1., 1., 0., 0.],
23 [1., 1., 1., 0.],
24 [1., 1., 1., 1.]])
25 """
26 # Create lower triangular matrix (including diagonal)
27 mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
28
29 # Reshape for attention: [1, 1, seq_len, seq_len]
30 # This allows broadcasting across batch and heads
31 mask = mask.unsqueeze(0).unsqueeze(0)
32
33 return mask
34
35
36# Test
37def test_causal_mask():
38 mask = create_causal_mask(5)
39
40 print("Causal Mask (seq_len=5):")
41 print(mask.squeeze())
42
43 # Verify properties
44 assert mask.shape == (1, 1, 5, 5)
45 assert mask[0, 0, 0, 1] == 0 # pos0 cannot see pos1
46 assert mask[0, 0, 1, 0] == 1 # pos1 can see pos0
47 assert mask[0, 0, 4, 4] == 1 # Last pos can see itself
48
49 print("\nβ Causal mask test passed!")
50
51
52test_causal_mask()Alternative: Using torch.triu
1def create_causal_mask_v2(seq_len: int, device: torch.device = None) -> torch.Tensor:
2 """
3 Create causal mask using upper triangular approach.
4
5 Sets upper triangle (excluding diagonal) to -inf for softmax masking.
6 """
7 # Create mask where upper triangle is True (to be masked)
8 mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
9
10 # Convert to attention mask format (0 for masked positions)
11 mask = mask == 0 # Invert: 1 where we CAN attend
12
13 return mask.unsqueeze(0).unsqueeze(0).float()
14
15
16# Both produce same result
17mask1 = create_causal_mask(4)
18mask2 = create_causal_mask_v2(4)
19print(f"Masks equal: {torch.equal(mask1, mask2)}")Combining Causal and Padding Masks
The Combined Mask
In real batches, we need both:
1. Causal mask: Don't look at future tokens
2. Padding mask: Don't attend to PAD tokens
1def create_decoder_mask(
2 target_ids: torch.Tensor,
3 pad_id: int = 0
4) -> torch.Tensor:
5 """
6 Create combined causal + padding mask for decoder self-attention.
7
8 Args:
9 target_ids: [batch, tgt_len] target token IDs
10 pad_id: Padding token ID
11
12 Returns:
13 mask: [batch, 1, tgt_len, tgt_len]
14 1 = attend, 0 = ignore
15
16 Example:
17 >>> tgt = torch.tensor([[1, 2, 3, 0, 0], [1, 2, 0, 0, 0]])
18 >>> mask = create_decoder_mask(tgt)
19 # Combines causal (lower triangular) with padding (ignore PAD columns)
20 """
21 batch_size, tgt_len = target_ids.shape
22 device = target_ids.device
23
24 # 1. Causal mask: [1, 1, tgt_len, tgt_len]
25 causal_mask = create_causal_mask(tgt_len, device)
26
27 # 2. Padding mask: [batch, 1, 1, tgt_len]
28 # True where NOT padding (valid positions)
29 padding_mask = (target_ids != pad_id).unsqueeze(1).unsqueeze(2).float()
30
31 # 3. Combine: both conditions must be true to attend
32 combined_mask = causal_mask * padding_mask
33
34 return combined_mask
35
36
37# Test combined mask
38def test_combined_mask():
39 # Batch with different padding
40 target_ids = torch.tensor([
41 [2, 10, 20, 30, 3, 0, 0], # 5 real tokens + 2 padding
42 [2, 10, 20, 3, 0, 0, 0], # 4 real tokens + 3 padding
43 ])
44
45 mask = create_decoder_mask(target_ids, pad_id=0)
46
47 print("Combined Decoder Mask")
48 print("=" * 50)
49 print(f"Target shape: {target_ids.shape}")
50 print(f"Mask shape: {mask.shape}")
51
52 print("\nSentence 1 mask (5 real tokens):")
53 print(mask[0, 0])
54
55 print("\nSentence 2 mask (4 real tokens):")
56 print(mask[1, 0])
57
58 # Verify: padding columns should be all zeros
59 assert mask[0, 0, :, 5:].sum() == 0, "Padding should be masked"
60 assert mask[1, 0, :, 4:].sum() == 0, "Padding should be masked"
61
62 print("\nβ Combined mask test passed!")
63
64
65test_combined_mask()Output:
1Combined Decoder Mask
2==================================================
3Target shape: torch.Size([2, 7])
4Mask shape: torch.Size([2, 1, 7, 7])
5
6Sentence 1 mask (5 real tokens):
7tensor([[1., 0., 0., 0., 0., 0., 0.],
8 [1., 1., 0., 0., 0., 0., 0.],
9 [1., 1., 1., 0., 0., 0., 0.],
10 [1., 1., 1., 1., 0., 0., 0.],
11 [1., 1., 1., 1., 1., 0., 0.],
12 [1., 1., 1., 1., 1., 0., 0.], # PAD rows still causal
13 [1., 1., 1., 1., 1., 0., 0.]])
14
15Sentence 2 mask (4 real tokens):
16tensor([[1., 0., 0., 0., 0., 0., 0.],
17 [1., 1., 0., 0., 0., 0., 0.],
18 [1., 1., 1., 0., 0., 0., 0.],
19 [1., 1., 1., 1., 0., 0., 0.],
20 [1., 1., 1., 1., 0., 0., 0.],
21 [1., 1., 1., 1., 0., 0., 0.],
22 [1., 1., 1., 1., 0., 0., 0.]])
23
24β Combined mask test passed!Applying the Mask in Attention
Mask Application
1def masked_attention(
2 query: torch.Tensor,
3 key: torch.Tensor,
4 value: torch.Tensor,
5 mask: torch.Tensor,
6 dropout: nn.Dropout = None
7) -> torch.Tensor:
8 """
9 Apply scaled dot-product attention with masking.
10
11 Args:
12 query: [batch, heads, seq_len, d_k]
13 key: [batch, heads, seq_len, d_k]
14 value: [batch, heads, seq_len, d_k]
15 mask: [batch, 1, seq_len, seq_len] (1=attend, 0=mask)
16 dropout: Optional dropout layer
17
18 Returns:
19 output: [batch, heads, seq_len, d_k]
20 """
21 import math
22
23 d_k = query.size(-1)
24
25 # Compute attention scores
26 scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
27
28 # Apply mask: set masked positions to -inf
29 if mask is not None:
30 scores = scores.masked_fill(mask == 0, float('-inf'))
31
32 # Softmax (masked positions become 0 probability)
33 attn_weights = torch.softmax(scores, dim=-1)
34
35 # Handle case where entire row is masked (nan β 0)
36 attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
37
38 if dropout is not None:
39 attn_weights = dropout(attn_weights)
40
41 # Apply attention to values
42 output = torch.matmul(attn_weights, value)
43
44 return output
45
46
47# Demonstrate masked attention
48def demo_masked_attention():
49 batch, heads, seq_len, d_k = 1, 1, 4, 8
50
51 Q = torch.randn(batch, heads, seq_len, d_k)
52 K = torch.randn(batch, heads, seq_len, d_k)
53 V = torch.randn(batch, heads, seq_len, d_k)
54
55 # Create causal mask
56 mask = create_causal_mask(seq_len)
57
58 # Compute attention
59 output = masked_attention(Q, K, V, mask)
60
61 print("Masked Attention Demo")
62 print(f"Q, K, V shape: [{batch}, {heads}, {seq_len}, {d_k}]")
63 print(f"Mask shape: {mask.shape}")
64 print(f"Output shape: {output.shape}")
65
66 # Show that each position only attends to allowed positions
67 scores = torch.matmul(Q, K.transpose(-2, -1))
68 scores_masked = scores.masked_fill(mask == 0, float('-inf'))
69
70 print("\nRaw scores (position 2, all keys):")
71 print(f" {scores[0, 0, 2, :]}")
72 print("After masking:")
73 print(f" {scores_masked[0, 0, 2, :]}")
74
75
76demo_masked_attention()Teacher Forcing
What is Teacher Forcing?
During training, we use the ground truth target as input:
1Ground truth: "<bos> The dog runs <eos>"
2
3Training with teacher forcing:
4 Input: ["<bos>", "The", "dog", "runs"]
5 Output: ["The", "dog", "runs", "<eos>"]
6
7Each position predicts the next token, using TRUE previous tokens.Why Teacher Forcing?
Without teacher forcing (autoregressive training):
1Step 1: Input "<bos>" β Predict (maybe wrong) "A"
2Step 2: Input "<bos> A" β Predict (maybe wrong) "cat"
3Step 3: Input "<bos> A cat" β Errors compound!
4
5Problem: Early mistakes cascade through sequence
6Training is very slow and unstableWith teacher forcing:
1Step 1: Input "<bos>" β Predict "The" (target: "The")
2Step 2: Input "<bos> The" β Predict "dog" (using TRUE "The")
3Step 3: Input "<bos> The dog" β Predict "runs" (using TRUE previous)
4
5Advantage: Each position gets correct context
6Training is fast and stableCausal Mask Enables Parallel Teacher Forcing
1def teacher_forcing_example():
2 """
3 Demonstrate how causal mask enables parallel teacher forcing.
4 """
5 # Full target sequence (batch=1, seq_len=5)
6 target = torch.tensor([[2, 45, 67, 89, 3]]) # <bos>, The, dog, runs, <eos>
7
8 # With causal mask, we can process all positions in parallel:
9 #
10 # Position 0: sees [<bos>] β predict "The"
11 # Position 1: sees [<bos>, The] β predict "dog"
12 # Position 2: sees [<bos>, The, dog] β predict "runs"
13 # Position 3: sees [<bos>, The, dog, runs] β predict "<eos>"
14 #
15 # All computed in ONE forward pass!
16
17 mask = create_causal_mask(5)
18 print("Teacher Forcing with Causal Mask")
19 print("=" * 50)
20 print("\nTarget sequence: <bos> The dog runs <eos>")
21 print("\nMask allows parallel computation:")
22
23 for pos in range(5):
24 visible = mask[0, 0, pos, :].nonzero().squeeze(-1).tolist()
25 print(f" Position {pos}: sees positions {visible}")
26
27 print("\nAll 5 predictions computed in single forward pass!")
28
29
30teacher_forcing_example()Cross-Attention Mask
For Encoder-Decoder Attention
Cross-attention doesn't need causal masking (decoder can see full source), but needs padding mask:
1def create_cross_attention_mask(
2 source_ids: torch.Tensor,
3 target_len: int,
4 pad_id: int = 0
5) -> torch.Tensor:
6 """
7 Create mask for cross-attention (decoder attending to encoder).
8
9 Args:
10 source_ids: [batch, src_len] source token IDs
11 target_len: Length of target sequence
12 pad_id: Padding token ID
13
14 Returns:
15 mask: [batch, 1, target_len, src_len]
16 Each target position can attend to all non-pad source positions
17 """
18 batch_size, src_len = source_ids.shape
19
20 # Source padding mask: [batch, 1, 1, src_len]
21 src_mask = (source_ids != pad_id).unsqueeze(1).unsqueeze(2).float()
22
23 # Expand for all target positions: [batch, 1, tgt_len, src_len]
24 cross_mask = src_mask.expand(-1, -1, target_len, -1)
25
26 return cross_mask
27
28
29# Test
30def test_cross_attention_mask():
31 source = torch.tensor([
32 [10, 20, 30, 0, 0], # 3 real + 2 pad
33 [10, 20, 30, 40, 0], # 4 real + 1 pad
34 ])
35 target_len = 6
36
37 mask = create_cross_attention_mask(source, target_len, pad_id=0)
38
39 print("Cross-Attention Mask")
40 print(f"Source shape: {source.shape}")
41 print(f"Target length: {target_len}")
42 print(f"Mask shape: {mask.shape}")
43
44 print("\nSentence 1 (3 source tokens):")
45 print(mask[0, 0]) # All target positions see same source mask
46
47 print("\nSentence 2 (4 source tokens):")
48 print(mask[1, 0])
49
50
51test_cross_attention_mask()Summary
Mask Types in Decoder
| Mask | Used In | Purpose |
|---|---|---|
| Causal | Self-attention | Prevent future token access |
| Padding | Self & Cross | Ignore PAD tokens |
| Combined | Self-attention | Causal + padding |
| Cross | Cross-attention | Source padding only |
Implementation Checklist
- [ ] create_causal_mask(seq_len): Lower triangular matrix
- [ ] create_decoder_mask(target_ids): Combine causal + padding
- [ ] create_cross_attention_mask(source_ids, tgt_len): Source padding only
- [ ] Apply mask with masked_fill(mask == 0, -inf)
- [ ] Handle all-masked rows with nan_to_num
Exercises
Implementation
1. Create a mask that allows attending to specific positions (e.g., only every other position).
2. Implement a sliding window causal mask (each position only sees last N tokens).
3. Add support for prefix-LM style masking (first K tokens bidirectional, rest causal).
Analysis
4. Visualize attention patterns with and without causal masking.
5. What happens to gradients if a position has all keys masked?
Next Section Preview
In the next section, we'll implement cross-attentionβwhere decoder queries attend to encoder keys and values, enabling the translation model to access source sentence information.