Chapter 9
15 min read
Section 45 of 75

Understanding Autoregressive Generation

Autoregressive Generation

Autoregressive generation is how transformers produce output sequences one token at a time. Each token is predicted based on all previous tokens, creating a chain of conditional probabilities. This section explains the fundamental concepts before we implement various decoding strategies.


What is Autoregressive Generation?

The Core Idea

πŸ“text
1Autoregressive = "self-regressing" = each output depends on previous outputs
2
3P(y₁, yβ‚‚, y₃, ..., yβ‚™) = P(y₁) Γ— P(yβ‚‚|y₁) Γ— P(y₃|y₁,yβ‚‚) Γ— ... Γ— P(yβ‚™|y₁,...,yₙ₋₁)
4
5Each token is predicted conditioned on ALL previous tokens.

Visual Example: Translation

πŸ“text
1Source (German): "Der Hund lΓ€uft"
2Target (English): "The dog runs"
3
4Step 1: Predict first token
5  Input:  [<bos>]
6  Output: P(y₁ | encoder("Der Hund lΓ€uft"), <bos>)
7  Prediction: "The"
8
9Step 2: Predict second token
10  Input:  [<bos>, The]
11  Output: P(yβ‚‚ | encoder("Der Hund lΓ€uft"), <bos>, The)
12  Prediction: "dog"
13
14Step 3: Predict third token
15  Input:  [<bos>, The, dog]
16  Output: P(y₃ | encoder("Der Hund lΓ€uft"), <bos>, The, dog)
17  Prediction: "runs"
18
19Step 4: Predict fourth token
20  Input:  [<bos>, The, dog, runs]
21  Output: P(yβ‚„ | encoder("Der Hund lΓ€uft"), <bos>, The, dog, runs)
22  Prediction: "<eos>"
23
24Stop: EOS token generated!
25Final output: "The dog runs"

Training vs Inference

Training: Teacher Forcing (Parallel)

During training, we know the target sequence and use teacher forcing:

πŸ“text
1Target: "<bos> The dog runs <eos>"
2
3Input to decoder:  [<bos>, The, dog, runs]     (all but last)
4Expected output:   [The, dog, runs, <eos>]     (all but first)
5
6With causal masking, all positions are computed in ONE forward pass!
7
8Position 0: sees [<bos>]               β†’ predict "The"
9Position 1: sees [<bos>, The]          β†’ predict "dog"
10Position 2: sees [<bos>, The, dog]     β†’ predict "runs"
11Position 3: sees [<bos>, The, dog, runs] β†’ predict "<eos>"

Inference: Sequential Generation

During inference, we don't know the target:

πŸ“text
1Step 1: input [<bos>]                    β†’ predict "The"
2Step 2: input [<bos>, The]               β†’ predict "dog"
3Step 3: input [<bos>, The, dog]          β†’ predict "runs"
4Step 4: input [<bos>, The, dog, runs]    β†’ predict "<eos>"
5
6Each step requires a SEPARATE forward pass!
7(Unless we use KV-caching - covered later)

The Efficiency Gap

πŸ“text
1Training:  1 forward pass for entire sequence
2Inference: N forward passes for N tokens
3
4This is why generation is slow compared to training!

The Generation Loop

Basic Algorithm

🐍python
1import torch
2import torch.nn.functional as F
3from typing import Optional
4
5
6def generate_basic(
7    model,
8    encoder_output: torch.Tensor,
9    start_token_id: int,
10    end_token_id: int,
11    max_length: int = 100
12) -> torch.Tensor:
13    """
14    Basic autoregressive generation loop.
15
16    Args:
17        model: Transformer model
18        encoder_output: Encoded source [1, src_len, d_model]
19        start_token_id: BOS token ID
20        end_token_id: EOS token ID
21        max_length: Maximum sequence length
22
23    Returns:
24        generated: Generated token IDs [1, seq_len]
25    """
26    device = encoder_output.device
27
28    # Start with BOS token
29    generated = torch.tensor([[start_token_id]], device=device)
30
31    for _ in range(max_length):
32        # Get logits for next token
33        logits = model.decode(generated, encoder_output)  # [1, seq_len, vocab]
34
35        # Take logits for last position only
36        next_token_logits = logits[:, -1, :]  # [1, vocab]
37
38        # Get highest probability token
39        next_token = next_token_logits.argmax(dim=-1, keepdim=True)  # [1, 1]
40
41        # Append to sequence
42        generated = torch.cat([generated, next_token], dim=1)
43
44        # Stop if EOS
45        if next_token.item() == end_token_id:
46            break
47
48    return generated

Probability Distribution

From Logits to Probabilities

🐍python
1def understand_logits():
2    """
3    Understand the relationship between logits, probabilities, and predictions.
4    """
5    print("Understanding Logits and Probabilities")
6    print("=" * 60)
7
8    # Simulated vocabulary
9    vocab = ["<pad>", "<unk>", "<bos>", "<eos>", "The", "cat", "dog", "runs"]
10    vocab_size = len(vocab)
11
12    # Simulated logits from model (raw scores)
13    logits = torch.tensor([
14        -10.0,  # <pad>
15        -8.0,   # <unk>
16        -5.0,   # <bos>
17        -2.0,   # <eos>
18        3.5,    # The
19        2.1,    # cat
20        3.8,    # dog  ← highest
21        1.5     # runs
22    ])
23
24    print("\nRaw logits (unnormalized scores):")
25    for i, (word, score) in enumerate(zip(vocab, logits)):
26        print(f"  {word:8s}: {score:6.2f}")
27
28    # Convert to probabilities with softmax
29    probs = F.softmax(logits, dim=0)
30
31    print("\nProbabilities (after softmax):")
32    for word, prob in zip(vocab, probs):
33        bar = "β–ˆ" * int(prob * 50)
34        print(f"  {word:8s}: {prob:.4f} {bar}")
35
36    # Get prediction
37    prediction = logits.argmax()
38    print(f"\nPrediction (argmax): '{vocab[prediction]}' (prob={probs[prediction]:.4f})")
39
40    # Show top-3
41    top3 = torch.topk(probs, 3)
42    print("\nTop 3 candidates:")
43    for prob, idx in zip(top3.values, top3.indices):
44        print(f"  {vocab[idx]:8s}: {prob:.4f}")

Decoding Strategies Overview

Different Ways to Select Tokens

πŸ“text
1Given probability distribution over vocabulary:
2
3P("The") = 0.35
4P("dog") = 0.40  ← highest
5P("cat") = 0.15
6P("runs") = 0.08
7P("<eos>") = 0.02
8
9Different decoding strategies:
10
111. GREEDY: Always pick highest probability
12   β†’ "dog" (deterministic)
13
142. SAMPLING: Sample from distribution
15   β†’ Could be any token (random)
16
173. TOP-K SAMPLING: Sample from top K tokens only
18   β†’ Sample from {"The", "dog", "cat"} if K=3
19
204. TOP-P (NUCLEUS) SAMPLING: Sample from smallest set summing to P
21   β†’ Sample from {"The", "dog"} if P=0.75 (0.35+0.40=0.75)
22
235. BEAM SEARCH: Maintain multiple hypotheses
24   β†’ Track best sequences, not just best tokens
25
266. TEMPERATURE: Adjust distribution sharpness
27   β†’ T<1: sharper (more greedy), T>1: flatter (more random)

Implementation Framework

GenerationMixin Class

🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Optional, List, Tuple, Dict
5from dataclasses import dataclass
6
7
8@dataclass
9class GenerationConfig:
10    """Configuration for text generation."""
11    max_length: int = 100
12    min_length: int = 1
13    do_sample: bool = False
14    temperature: float = 1.0
15    top_k: int = 0  # 0 = disabled
16    top_p: float = 1.0  # 1.0 = disabled
17    num_beams: int = 1
18    repetition_penalty: float = 1.0
19    length_penalty: float = 1.0
20    early_stopping: bool = True
21    pad_token_id: int = 0
22    bos_token_id: int = 2
23    eos_token_id: int = 3
24
25
26class GenerationMixin:
27    """
28    Mixin class providing generation capabilities to transformer models.
29
30    This provides a unified interface for various decoding strategies.
31    """
32
33    def generate(
34        self,
35        encoder_output: torch.Tensor,
36        src_mask: Optional[torch.Tensor] = None,
37        config: Optional[GenerationConfig] = None,
38        **kwargs
39    ) -> torch.Tensor:
40        """
41        Generate sequences using the specified strategy.
42
43        Args:
44            encoder_output: Encoded source [batch, src_len, d_model]
45            src_mask: Source padding mask
46            config: Generation configuration
47            **kwargs: Override config parameters
48
49        Returns:
50            generated: Generated token IDs [batch, seq_len]
51        """
52        if config is None:
53            config = GenerationConfig()
54
55        # Override config with kwargs
56        for key, value in kwargs.items():
57            if hasattr(config, key):
58                setattr(config, key, value)
59
60        # Select decoding strategy
61        if config.num_beams > 1:
62            return self._beam_search(encoder_output, src_mask, config)
63        elif config.do_sample:
64            return self._sample(encoder_output, src_mask, config)
65        else:
66            return self._greedy_decode(encoder_output, src_mask, config)

Key Concepts Summary

Autoregressive Generation Properties

PropertyDescription
SequentialOne token at a time
ConditionalEach token depends on previous
Non-deterministicSampling introduces randomness
Variable lengthStops at EOS or max_length

Training vs Inference

AspectTrainingInference
Target knownYes (teacher forcing)No
Forward passes1 per batchN per token
ParallelizationFull sequenceToken-by-token
MaskingCausal maskSequential

Decoding Strategies

StrategyDeterministicQualityDiversity
GreedyYesMediumLow
SamplingNoVariableHigh
Top-KNoGoodMedium
Top-PNoGoodMedium
Beam SearchYes*HighLow

*Beam search is deterministic given the same input but explores multiple paths.


In the next section, we'll implement greedy decodingβ€”the simplest strategy that always picks the highest probability token. While simple, it's often surprisingly effective and serves as a baseline for more sophisticated methods.
Loading comments...