Autoregressive generation is how transformers produce output sequences one token at a time. Each token is predicted based on all previous tokens, creating a chain of conditional probabilities. This section explains the fundamental concepts before we implement various decoding strategies.
What is Autoregressive Generation?
The Core Idea
πtext
1Autoregressive = "self-regressing" = each output depends on previous outputs
2
3P(yβ, yβ, yβ, ..., yβ) = P(yβ) Γ P(yβ|yβ) Γ P(yβ|yβ,yβ) Γ ... Γ P(yβ|yβ,...,yβββ)
4
5Each token is predicted conditioned on ALL previous tokens.Visual Example: Translation
πtext
1Source (German): "Der Hund lΓ€uft"
2Target (English): "The dog runs"
3
4Step 1: Predict first token
5 Input: [<bos>]
6 Output: P(yβ | encoder("Der Hund lΓ€uft"), <bos>)
7 Prediction: "The"
8
9Step 2: Predict second token
10 Input: [<bos>, The]
11 Output: P(yβ | encoder("Der Hund lΓ€uft"), <bos>, The)
12 Prediction: "dog"
13
14Step 3: Predict third token
15 Input: [<bos>, The, dog]
16 Output: P(yβ | encoder("Der Hund lΓ€uft"), <bos>, The, dog)
17 Prediction: "runs"
18
19Step 4: Predict fourth token
20 Input: [<bos>, The, dog, runs]
21 Output: P(yβ | encoder("Der Hund lΓ€uft"), <bos>, The, dog, runs)
22 Prediction: "<eos>"
23
24Stop: EOS token generated!
25Final output: "The dog runs"Training vs Inference
Training: Teacher Forcing (Parallel)
During training, we know the target sequence and use teacher forcing:
πtext
1Target: "<bos> The dog runs <eos>"
2
3Input to decoder: [<bos>, The, dog, runs] (all but last)
4Expected output: [The, dog, runs, <eos>] (all but first)
5
6With causal masking, all positions are computed in ONE forward pass!
7
8Position 0: sees [<bos>] β predict "The"
9Position 1: sees [<bos>, The] β predict "dog"
10Position 2: sees [<bos>, The, dog] β predict "runs"
11Position 3: sees [<bos>, The, dog, runs] β predict "<eos>"Inference: Sequential Generation
During inference, we don't know the target:
πtext
1Step 1: input [<bos>] β predict "The"
2Step 2: input [<bos>, The] β predict "dog"
3Step 3: input [<bos>, The, dog] β predict "runs"
4Step 4: input [<bos>, The, dog, runs] β predict "<eos>"
5
6Each step requires a SEPARATE forward pass!
7(Unless we use KV-caching - covered later)The Efficiency Gap
πtext
1Training: 1 forward pass for entire sequence
2Inference: N forward passes for N tokens
3
4This is why generation is slow compared to training!The Generation Loop
Basic Algorithm
πpython
1import torch
2import torch.nn.functional as F
3from typing import Optional
4
5
6def generate_basic(
7 model,
8 encoder_output: torch.Tensor,
9 start_token_id: int,
10 end_token_id: int,
11 max_length: int = 100
12) -> torch.Tensor:
13 """
14 Basic autoregressive generation loop.
15
16 Args:
17 model: Transformer model
18 encoder_output: Encoded source [1, src_len, d_model]
19 start_token_id: BOS token ID
20 end_token_id: EOS token ID
21 max_length: Maximum sequence length
22
23 Returns:
24 generated: Generated token IDs [1, seq_len]
25 """
26 device = encoder_output.device
27
28 # Start with BOS token
29 generated = torch.tensor([[start_token_id]], device=device)
30
31 for _ in range(max_length):
32 # Get logits for next token
33 logits = model.decode(generated, encoder_output) # [1, seq_len, vocab]
34
35 # Take logits for last position only
36 next_token_logits = logits[:, -1, :] # [1, vocab]
37
38 # Get highest probability token
39 next_token = next_token_logits.argmax(dim=-1, keepdim=True) # [1, 1]
40
41 # Append to sequence
42 generated = torch.cat([generated, next_token], dim=1)
43
44 # Stop if EOS
45 if next_token.item() == end_token_id:
46 break
47
48 return generatedProbability Distribution
From Logits to Probabilities
πpython
1def understand_logits():
2 """
3 Understand the relationship between logits, probabilities, and predictions.
4 """
5 print("Understanding Logits and Probabilities")
6 print("=" * 60)
7
8 # Simulated vocabulary
9 vocab = ["<pad>", "<unk>", "<bos>", "<eos>", "The", "cat", "dog", "runs"]
10 vocab_size = len(vocab)
11
12 # Simulated logits from model (raw scores)
13 logits = torch.tensor([
14 -10.0, # <pad>
15 -8.0, # <unk>
16 -5.0, # <bos>
17 -2.0, # <eos>
18 3.5, # The
19 2.1, # cat
20 3.8, # dog β highest
21 1.5 # runs
22 ])
23
24 print("\nRaw logits (unnormalized scores):")
25 for i, (word, score) in enumerate(zip(vocab, logits)):
26 print(f" {word:8s}: {score:6.2f}")
27
28 # Convert to probabilities with softmax
29 probs = F.softmax(logits, dim=0)
30
31 print("\nProbabilities (after softmax):")
32 for word, prob in zip(vocab, probs):
33 bar = "β" * int(prob * 50)
34 print(f" {word:8s}: {prob:.4f} {bar}")
35
36 # Get prediction
37 prediction = logits.argmax()
38 print(f"\nPrediction (argmax): '{vocab[prediction]}' (prob={probs[prediction]:.4f})")
39
40 # Show top-3
41 top3 = torch.topk(probs, 3)
42 print("\nTop 3 candidates:")
43 for prob, idx in zip(top3.values, top3.indices):
44 print(f" {vocab[idx]:8s}: {prob:.4f}")Decoding Strategies Overview
Different Ways to Select Tokens
πtext
1Given probability distribution over vocabulary:
2
3P("The") = 0.35
4P("dog") = 0.40 β highest
5P("cat") = 0.15
6P("runs") = 0.08
7P("<eos>") = 0.02
8
9Different decoding strategies:
10
111. GREEDY: Always pick highest probability
12 β "dog" (deterministic)
13
142. SAMPLING: Sample from distribution
15 β Could be any token (random)
16
173. TOP-K SAMPLING: Sample from top K tokens only
18 β Sample from {"The", "dog", "cat"} if K=3
19
204. TOP-P (NUCLEUS) SAMPLING: Sample from smallest set summing to P
21 β Sample from {"The", "dog"} if P=0.75 (0.35+0.40=0.75)
22
235. BEAM SEARCH: Maintain multiple hypotheses
24 β Track best sequences, not just best tokens
25
266. TEMPERATURE: Adjust distribution sharpness
27 β T<1: sharper (more greedy), T>1: flatter (more random)Implementation Framework
GenerationMixin Class
πpython
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Optional, List, Tuple, Dict
5from dataclasses import dataclass
6
7
8@dataclass
9class GenerationConfig:
10 """Configuration for text generation."""
11 max_length: int = 100
12 min_length: int = 1
13 do_sample: bool = False
14 temperature: float = 1.0
15 top_k: int = 0 # 0 = disabled
16 top_p: float = 1.0 # 1.0 = disabled
17 num_beams: int = 1
18 repetition_penalty: float = 1.0
19 length_penalty: float = 1.0
20 early_stopping: bool = True
21 pad_token_id: int = 0
22 bos_token_id: int = 2
23 eos_token_id: int = 3
24
25
26class GenerationMixin:
27 """
28 Mixin class providing generation capabilities to transformer models.
29
30 This provides a unified interface for various decoding strategies.
31 """
32
33 def generate(
34 self,
35 encoder_output: torch.Tensor,
36 src_mask: Optional[torch.Tensor] = None,
37 config: Optional[GenerationConfig] = None,
38 **kwargs
39 ) -> torch.Tensor:
40 """
41 Generate sequences using the specified strategy.
42
43 Args:
44 encoder_output: Encoded source [batch, src_len, d_model]
45 src_mask: Source padding mask
46 config: Generation configuration
47 **kwargs: Override config parameters
48
49 Returns:
50 generated: Generated token IDs [batch, seq_len]
51 """
52 if config is None:
53 config = GenerationConfig()
54
55 # Override config with kwargs
56 for key, value in kwargs.items():
57 if hasattr(config, key):
58 setattr(config, key, value)
59
60 # Select decoding strategy
61 if config.num_beams > 1:
62 return self._beam_search(encoder_output, src_mask, config)
63 elif config.do_sample:
64 return self._sample(encoder_output, src_mask, config)
65 else:
66 return self._greedy_decode(encoder_output, src_mask, config)Key Concepts Summary
Autoregressive Generation Properties
| Property | Description |
|---|---|
| Sequential | One token at a time |
| Conditional | Each token depends on previous |
| Non-deterministic | Sampling introduces randomness |
| Variable length | Stops at EOS or max_length |
Training vs Inference
| Aspect | Training | Inference |
|---|---|---|
| Target known | Yes (teacher forcing) | No |
| Forward passes | 1 per batch | N per token |
| Parallelization | Full sequence | Token-by-token |
| Masking | Causal mask | Sequential |
Decoding Strategies
| Strategy | Deterministic | Quality | Diversity |
|---|---|---|---|
| Greedy | Yes | Medium | Low |
| Sampling | No | Variable | High |
| Top-K | No | Good | Medium |
| Top-P | No | Good | Medium |
| Beam Search | Yes* | High | Low |
*Beam search is deterministic given the same input but explores multiple paths.
In the next section, we'll implement greedy decodingβthe simplest strategy that always picks the highest probability token. While simple, it's often surprisingly effective and serves as a baseline for more sophisticated methods.