Chapter 8
18 min read
Section 40 of 75

Causal Masking for Autoregressive Generation

Transformer Decoder

Introduction

Causal masking is the key mechanism that enables autoregressive generation. It prevents the decoder from "cheating" by looking at future tokens during training, ensuring the model learns to predict each token based only on previous context.

This section covers the implementation and usage of causal masks.


The Problem: Future Information Leakage

During Training

We feed the decoder the complete target sequence:

πŸ“text
1Target: "<bos> The dog runs <eos>"
2
3Without masking, at position 2 ("dog"):
4  Model sees: "<bos>", "The", "dog", "runs", "<eos>"
5  Task: Predict "runs" (next token)
6
7  Problem: "runs" is already visible in the input!
8  Model can just copy instead of learning to predict.

The Solution: Causal Mask

πŸ“text
1With causal mask, at position 2 ("dog"):
2  Model sees: "<bos>", "The", "dog", [MASKED], [MASKED]
3  Task: Predict "runs"
4
5  Now model must actually learn to predict!

Understanding the Causal Mask

Mask Structure

A causal (look-ahead) mask is a lower triangular matrix:

πŸ“text
1Sequence: [pos0, pos1, pos2, pos3]
2
3Attention Mask (1 = attend, 0 = ignore):
4
5        pos0  pos1  pos2  pos3
6pos0 [   1     0     0     0  ]  ← pos0 can only see pos0
7pos1 [   1     1     0     0  ]  ← pos1 can see pos0, pos1
8pos2 [   1     1     1     0  ]  ← pos2 can see pos0, pos1, pos2
9pos3 [   1     1     1     1  ]  ← pos3 can see all

Visual Representation

πŸ“text
1Target: "<bos> The dog runs"
2         0     1    2    3
3
4Query position 0 ("<bos>"):
5  Attends to: [<bos>]
6
7Query position 1 ("The"):
8  Attends to: [<bos>, The]
9
10Query position 2 ("dog"):
11  Attends to: [<bos>, The, dog]
12
13Query position 3 ("runs"):
14  Attends to: [<bos>, The, dog, runs]

Implementation

Basic Causal Mask

🐍python
1import torch
2import torch.nn as nn
3
4
5def create_causal_mask(seq_len: int, device: torch.device = None) -> torch.Tensor:
6    """
7    Create a causal (look-ahead) mask for self-attention.
8
9    Args:
10        seq_len: Length of the sequence
11        device: Device to create tensor on
12
13    Returns:
14        mask: [1, 1, seq_len, seq_len] where:
15              1 = attend (valid)
16              0 = ignore (masked)
17
18    Example:
19        >>> mask = create_causal_mask(4)
20        >>> mask.squeeze()
21        tensor([[1., 0., 0., 0.],
22                [1., 1., 0., 0.],
23                [1., 1., 1., 0.],
24                [1., 1., 1., 1.]])
25    """
26    # Create lower triangular matrix (including diagonal)
27    mask = torch.tril(torch.ones(seq_len, seq_len, device=device))
28
29    # Reshape for attention: [1, 1, seq_len, seq_len]
30    # This allows broadcasting across batch and heads
31    mask = mask.unsqueeze(0).unsqueeze(0)
32
33    return mask
34
35
36# Test
37def test_causal_mask():
38    mask = create_causal_mask(5)
39
40    print("Causal Mask (seq_len=5):")
41    print(mask.squeeze())
42
43    # Verify properties
44    assert mask.shape == (1, 1, 5, 5)
45    assert mask[0, 0, 0, 1] == 0  # pos0 cannot see pos1
46    assert mask[0, 0, 1, 0] == 1  # pos1 can see pos0
47    assert mask[0, 0, 4, 4] == 1  # Last pos can see itself
48
49    print("\nβœ“ Causal mask test passed!")
50
51
52test_causal_mask()

Alternative: Using torch.triu

🐍python
1def create_causal_mask_v2(seq_len: int, device: torch.device = None) -> torch.Tensor:
2    """
3    Create causal mask using upper triangular approach.
4
5    Sets upper triangle (excluding diagonal) to -inf for softmax masking.
6    """
7    # Create mask where upper triangle is True (to be masked)
8    mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
9
10    # Convert to attention mask format (0 for masked positions)
11    mask = mask == 0  # Invert: 1 where we CAN attend
12
13    return mask.unsqueeze(0).unsqueeze(0).float()
14
15
16# Both produce same result
17mask1 = create_causal_mask(4)
18mask2 = create_causal_mask_v2(4)
19print(f"Masks equal: {torch.equal(mask1, mask2)}")

Combining Causal and Padding Masks

The Combined Mask

In real batches, we need both:

1. Causal mask: Don't look at future tokens

2. Padding mask: Don't attend to PAD tokens

🐍python
1def create_decoder_mask(
2    target_ids: torch.Tensor,
3    pad_id: int = 0
4) -> torch.Tensor:
5    """
6    Create combined causal + padding mask for decoder self-attention.
7
8    Args:
9        target_ids: [batch, tgt_len] target token IDs
10        pad_id: Padding token ID
11
12    Returns:
13        mask: [batch, 1, tgt_len, tgt_len]
14              1 = attend, 0 = ignore
15
16    Example:
17        >>> tgt = torch.tensor([[1, 2, 3, 0, 0], [1, 2, 0, 0, 0]])
18        >>> mask = create_decoder_mask(tgt)
19        # Combines causal (lower triangular) with padding (ignore PAD columns)
20    """
21    batch_size, tgt_len = target_ids.shape
22    device = target_ids.device
23
24    # 1. Causal mask: [1, 1, tgt_len, tgt_len]
25    causal_mask = create_causal_mask(tgt_len, device)
26
27    # 2. Padding mask: [batch, 1, 1, tgt_len]
28    # True where NOT padding (valid positions)
29    padding_mask = (target_ids != pad_id).unsqueeze(1).unsqueeze(2).float()
30
31    # 3. Combine: both conditions must be true to attend
32    combined_mask = causal_mask * padding_mask
33
34    return combined_mask
35
36
37# Test combined mask
38def test_combined_mask():
39    # Batch with different padding
40    target_ids = torch.tensor([
41        [2, 10, 20, 30, 3, 0, 0],   # 5 real tokens + 2 padding
42        [2, 10, 20, 3, 0, 0, 0],    # 4 real tokens + 3 padding
43    ])
44
45    mask = create_decoder_mask(target_ids, pad_id=0)
46
47    print("Combined Decoder Mask")
48    print("=" * 50)
49    print(f"Target shape: {target_ids.shape}")
50    print(f"Mask shape: {mask.shape}")
51
52    print("\nSentence 1 mask (5 real tokens):")
53    print(mask[0, 0])
54
55    print("\nSentence 2 mask (4 real tokens):")
56    print(mask[1, 0])
57
58    # Verify: padding columns should be all zeros
59    assert mask[0, 0, :, 5:].sum() == 0, "Padding should be masked"
60    assert mask[1, 0, :, 4:].sum() == 0, "Padding should be masked"
61
62    print("\nβœ“ Combined mask test passed!")
63
64
65test_combined_mask()

Output:

πŸ“text
1Combined Decoder Mask
2==================================================
3Target shape: torch.Size([2, 7])
4Mask shape: torch.Size([2, 1, 7, 7])
5
6Sentence 1 mask (5 real tokens):
7tensor([[1., 0., 0., 0., 0., 0., 0.],
8        [1., 1., 0., 0., 0., 0., 0.],
9        [1., 1., 1., 0., 0., 0., 0.],
10        [1., 1., 1., 1., 0., 0., 0.],
11        [1., 1., 1., 1., 1., 0., 0.],
12        [1., 1., 1., 1., 1., 0., 0.],  # PAD rows still causal
13        [1., 1., 1., 1., 1., 0., 0.]])
14
15Sentence 2 mask (4 real tokens):
16tensor([[1., 0., 0., 0., 0., 0., 0.],
17        [1., 1., 0., 0., 0., 0., 0.],
18        [1., 1., 1., 0., 0., 0., 0.],
19        [1., 1., 1., 1., 0., 0., 0.],
20        [1., 1., 1., 1., 0., 0., 0.],
21        [1., 1., 1., 1., 0., 0., 0.],
22        [1., 1., 1., 1., 0., 0., 0.]])
23
24βœ“ Combined mask test passed!

Applying the Mask in Attention

Mask Application

🐍python
1def masked_attention(
2    query: torch.Tensor,
3    key: torch.Tensor,
4    value: torch.Tensor,
5    mask: torch.Tensor,
6    dropout: nn.Dropout = None
7) -> torch.Tensor:
8    """
9    Apply scaled dot-product attention with masking.
10
11    Args:
12        query: [batch, heads, seq_len, d_k]
13        key: [batch, heads, seq_len, d_k]
14        value: [batch, heads, seq_len, d_k]
15        mask: [batch, 1, seq_len, seq_len] (1=attend, 0=mask)
16        dropout: Optional dropout layer
17
18    Returns:
19        output: [batch, heads, seq_len, d_k]
20    """
21    import math
22
23    d_k = query.size(-1)
24
25    # Compute attention scores
26    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
27
28    # Apply mask: set masked positions to -inf
29    if mask is not None:
30        scores = scores.masked_fill(mask == 0, float('-inf'))
31
32    # Softmax (masked positions become 0 probability)
33    attn_weights = torch.softmax(scores, dim=-1)
34
35    # Handle case where entire row is masked (nan β†’ 0)
36    attn_weights = torch.nan_to_num(attn_weights, nan=0.0)
37
38    if dropout is not None:
39        attn_weights = dropout(attn_weights)
40
41    # Apply attention to values
42    output = torch.matmul(attn_weights, value)
43
44    return output
45
46
47# Demonstrate masked attention
48def demo_masked_attention():
49    batch, heads, seq_len, d_k = 1, 1, 4, 8
50
51    Q = torch.randn(batch, heads, seq_len, d_k)
52    K = torch.randn(batch, heads, seq_len, d_k)
53    V = torch.randn(batch, heads, seq_len, d_k)
54
55    # Create causal mask
56    mask = create_causal_mask(seq_len)
57
58    # Compute attention
59    output = masked_attention(Q, K, V, mask)
60
61    print("Masked Attention Demo")
62    print(f"Q, K, V shape: [{batch}, {heads}, {seq_len}, {d_k}]")
63    print(f"Mask shape: {mask.shape}")
64    print(f"Output shape: {output.shape}")
65
66    # Show that each position only attends to allowed positions
67    scores = torch.matmul(Q, K.transpose(-2, -1))
68    scores_masked = scores.masked_fill(mask == 0, float('-inf'))
69
70    print("\nRaw scores (position 2, all keys):")
71    print(f"  {scores[0, 0, 2, :]}")
72    print("After masking:")
73    print(f"  {scores_masked[0, 0, 2, :]}")
74
75
76demo_masked_attention()

Teacher Forcing

What is Teacher Forcing?

During training, we use the ground truth target as input:

πŸ“text
1Ground truth: "<bos> The dog runs <eos>"
2
3Training with teacher forcing:
4  Input:  ["<bos>", "The",  "dog",  "runs"]
5  Output: ["The",   "dog",  "runs", "<eos>"]
6
7Each position predicts the next token, using TRUE previous tokens.

Why Teacher Forcing?

Without teacher forcing (autoregressive training):

πŸ“text
1Step 1: Input "<bos>" β†’ Predict (maybe wrong) "A"
2Step 2: Input "<bos> A" β†’ Predict (maybe wrong) "cat"
3Step 3: Input "<bos> A cat" β†’ Errors compound!
4
5Problem: Early mistakes cascade through sequence
6Training is very slow and unstable

With teacher forcing:

πŸ“text
1Step 1: Input "<bos>" β†’ Predict "The" (target: "The")
2Step 2: Input "<bos> The" β†’ Predict "dog" (using TRUE "The")
3Step 3: Input "<bos> The dog" β†’ Predict "runs" (using TRUE previous)
4
5Advantage: Each position gets correct context
6Training is fast and stable

Causal Mask Enables Parallel Teacher Forcing

🐍python
1def teacher_forcing_example():
2    """
3    Demonstrate how causal mask enables parallel teacher forcing.
4    """
5    # Full target sequence (batch=1, seq_len=5)
6    target = torch.tensor([[2, 45, 67, 89, 3]])  # <bos>, The, dog, runs, <eos>
7
8    # With causal mask, we can process all positions in parallel:
9    #
10    # Position 0: sees [<bos>]           β†’ predict "The"
11    # Position 1: sees [<bos>, The]      β†’ predict "dog"
12    # Position 2: sees [<bos>, The, dog] β†’ predict "runs"
13    # Position 3: sees [<bos>, The, dog, runs] β†’ predict "<eos>"
14    #
15    # All computed in ONE forward pass!
16
17    mask = create_causal_mask(5)
18    print("Teacher Forcing with Causal Mask")
19    print("=" * 50)
20    print("\nTarget sequence: <bos> The dog runs <eos>")
21    print("\nMask allows parallel computation:")
22
23    for pos in range(5):
24        visible = mask[0, 0, pos, :].nonzero().squeeze(-1).tolist()
25        print(f"  Position {pos}: sees positions {visible}")
26
27    print("\nAll 5 predictions computed in single forward pass!")
28
29
30teacher_forcing_example()

Cross-Attention Mask

For Encoder-Decoder Attention

Cross-attention doesn't need causal masking (decoder can see full source), but needs padding mask:

🐍python
1def create_cross_attention_mask(
2    source_ids: torch.Tensor,
3    target_len: int,
4    pad_id: int = 0
5) -> torch.Tensor:
6    """
7    Create mask for cross-attention (decoder attending to encoder).
8
9    Args:
10        source_ids: [batch, src_len] source token IDs
11        target_len: Length of target sequence
12        pad_id: Padding token ID
13
14    Returns:
15        mask: [batch, 1, target_len, src_len]
16              Each target position can attend to all non-pad source positions
17    """
18    batch_size, src_len = source_ids.shape
19
20    # Source padding mask: [batch, 1, 1, src_len]
21    src_mask = (source_ids != pad_id).unsqueeze(1).unsqueeze(2).float()
22
23    # Expand for all target positions: [batch, 1, tgt_len, src_len]
24    cross_mask = src_mask.expand(-1, -1, target_len, -1)
25
26    return cross_mask
27
28
29# Test
30def test_cross_attention_mask():
31    source = torch.tensor([
32        [10, 20, 30, 0, 0],   # 3 real + 2 pad
33        [10, 20, 30, 40, 0],  # 4 real + 1 pad
34    ])
35    target_len = 6
36
37    mask = create_cross_attention_mask(source, target_len, pad_id=0)
38
39    print("Cross-Attention Mask")
40    print(f"Source shape: {source.shape}")
41    print(f"Target length: {target_len}")
42    print(f"Mask shape: {mask.shape}")
43
44    print("\nSentence 1 (3 source tokens):")
45    print(mask[0, 0])  # All target positions see same source mask
46
47    print("\nSentence 2 (4 source tokens):")
48    print(mask[1, 0])
49
50
51test_cross_attention_mask()

Summary

Mask Types in Decoder

MaskUsed InPurpose
CausalSelf-attentionPrevent future token access
PaddingSelf & CrossIgnore PAD tokens
CombinedSelf-attentionCausal + padding
CrossCross-attentionSource padding only

Implementation Checklist

- [ ] create_causal_mask(seq_len): Lower triangular matrix

- [ ] create_decoder_mask(target_ids): Combine causal + padding

- [ ] create_cross_attention_mask(source_ids, tgt_len): Source padding only

- [ ] Apply mask with masked_fill(mask == 0, -inf)

- [ ] Handle all-masked rows with nan_to_num


Exercises

Implementation

1. Create a mask that allows attending to specific positions (e.g., only every other position).

2. Implement a sliding window causal mask (each position only sees last N tokens).

3. Add support for prefix-LM style masking (first K tokens bidirectional, rest causal).

Analysis

4. Visualize attention patterns with and without causal masking.

5. What happens to gradients if a position has all keys masked?


Next Section Preview

In the next section, we'll implement cross-attentionβ€”where decoder queries attend to encoder keys and values, enabling the translation model to access source sentence information.