Greedy decoding is the simplest autoregressive generation strategy: at each step, select the token with the highest probability. Despite its simplicity, greedy decoding is fast and often produces good results, making it a useful baseline.
The Greedy Algorithm
Core Concept
๐text
1At each step t:
2 1. Get probability distribution P(yโ | yโ, yโ, ..., yโโโ)
3 2. Select yโ = argmax P(yโ | ...)
4 3. Append yโ to sequence
5 4. Repeat until EOS or max_lengthVisual Example
๐text
1Source: "Der Hund lรคuft schnell"
2
3Step 1: P(yโ|<bos>, encoder)
4 "The": 0.45 โ MAX
5 "A": 0.30
6 "One": 0.15
7 ...
8 Select: "The"
9
10Step 2: P(yโ|<bos>, The, encoder)
11 "dog": 0.60 โ MAX
12 "cat": 0.25
13 "man": 0.10
14 ...
15 Select: "dog"
16
17Step 3: P(yโ|<bos>, The, dog, encoder)
18 "runs": 0.55 โ MAX
19 "walks": 0.30
20 "is": 0.10
21 ...
22 Select: "runs"
23
24Step 4: P(yโ|<bos>, The, dog, runs, encoder)
25 "fast": 0.50 โ MAX
26 "quickly": 0.35
27 ".": 0.10
28 ...
29 Select: "fast"
30
31Step 5: P(yโ
|<bos>, The, dog, runs, fast, encoder)
32 "<eos>": 0.70 โ MAX
33 ".": 0.20
34 ...
35 Select: "<eos>"
36
37Final: "The dog runs fast"Implementation
Complete Greedy Decoder
๐python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5from typing import Optional, Tuple, List
6
7
8class GreedyDecoder:
9 """
10 Greedy decoding for transformer models.
11
12 At each step, selects the token with highest probability.
13 Fast and deterministic, but may miss globally optimal sequences.
14
15 Example:
16 >>> decoder = GreedyDecoder(model, tokenizer)
17 >>> output = decoder.decode(encoder_output, max_length=50)
18 """
19
20 def __init__(
21 self,
22 model: nn.Module,
23 bos_token_id: int = 2,
24 eos_token_id: int = 3,
25 pad_token_id: int = 0
26 ):
27 """
28 Args:
29 model: Transformer model with decode method
30 bos_token_id: Beginning of sequence token ID
31 eos_token_id: End of sequence token ID
32 pad_token_id: Padding token ID
33 """
34 self.model = model
35 self.bos_token_id = bos_token_id
36 self.eos_token_id = eos_token_id
37 self.pad_token_id = pad_token_id
38
39 @torch.no_grad()
40 def decode(
41 self,
42 encoder_output: torch.Tensor,
43 src_mask: Optional[torch.Tensor] = None,
44 max_length: int = 100,
45 return_scores: bool = False
46 ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
47 """
48 Perform greedy decoding.
49
50 Args:
51 encoder_output: Encoded source [batch, src_len, d_model]
52 src_mask: Source padding mask [batch, 1, 1, src_len]
53 max_length: Maximum output length
54 return_scores: Whether to return token log-probabilities
55
56 Returns:
57 generated: Generated token IDs [batch, seq_len]
58 scores: Log-probabilities if return_scores=True [batch, seq_len]
59 """
60 self.model.eval()
61 batch_size = encoder_output.size(0)
62 device = encoder_output.device
63
64 # Initialize with BOS token
65 generated = torch.full(
66 (batch_size, 1),
67 self.bos_token_id,
68 dtype=torch.long,
69 device=device
70 )
71
72 # Track which sequences are done
73 done = torch.zeros(batch_size, dtype=torch.bool, device=device)
74
75 # Optional: track scores
76 all_scores = [] if return_scores else None
77
78 for step in range(max_length - 1):
79 # Get logits for next token
80 # decoder_output: [batch, seq_len, d_model]
81 decoder_output = self.model.decode(
82 generated,
83 encoder_output,
84 memory_mask=src_mask
85 )
86
87 # Get logits for last position
88 # [batch, vocab_size]
89 logits = self.model.output_projection(decoder_output[:, -1, :])
90
91 # Convert to log-probabilities
92 log_probs = F.log_softmax(logits, dim=-1)
93
94 # Greedy selection: argmax
95 next_tokens = log_probs.argmax(dim=-1) # [batch]
96
97 # Get scores for selected tokens
98 if return_scores:
99 selected_scores = log_probs.gather(
100 dim=-1,
101 index=next_tokens.unsqueeze(-1)
102 ).squeeze(-1)
103 all_scores.append(selected_scores)
104
105 # Replace with PAD for finished sequences
106 next_tokens = next_tokens.masked_fill(done, self.pad_token_id)
107
108 # Append to generated sequence
109 generated = torch.cat([
110 generated,
111 next_tokens.unsqueeze(1)
112 ], dim=1)
113
114 # Update done status
115 done = done | (next_tokens == self.eos_token_id)
116
117 # Early stopping if all sequences are done
118 if done.all():
119 break
120
121 # Stack scores if requested
122 scores = None
123 if return_scores and all_scores:
124 scores = torch.stack(all_scores, dim=1)
125
126 return generated, scoresOptimized Implementation with Early Stopping
Batch-Aware Early Stopping
๐python
1class GreedyDecoderOptimized:
2 """
3 Optimized greedy decoder with proper batch handling.
4 """
5
6 def __init__(
7 self,
8 model: nn.Module,
9 bos_token_id: int = 2,
10 eos_token_id: int = 3,
11 pad_token_id: int = 0
12 ):
13 self.model = model
14 self.bos_token_id = bos_token_id
15 self.eos_token_id = eos_token_id
16 self.pad_token_id = pad_token_id
17
18 @torch.no_grad()
19 def decode(
20 self,
21 encoder_output: torch.Tensor,
22 src_mask: Optional[torch.Tensor] = None,
23 max_length: int = 100,
24 min_length: int = 1
25 ) -> Tuple[torch.Tensor, torch.Tensor]:
26 """
27 Optimized greedy decoding with min/max length constraints.
28
29 Returns:
30 generated: [batch, seq_len]
31 lengths: Actual length of each sequence [batch]
32 """
33 self.model.eval()
34 batch_size = encoder_output.size(0)
35 device = encoder_output.device
36
37 # Initialize
38 generated = torch.full(
39 (batch_size, 1),
40 self.bos_token_id,
41 dtype=torch.long,
42 device=device
43 )
44
45 # Track state
46 done = torch.zeros(batch_size, dtype=torch.bool, device=device)
47 lengths = torch.ones(batch_size, dtype=torch.long, device=device)
48
49 for step in range(max_length - 1):
50 # Decode
51 decoder_output = self.model.decode(
52 generated, encoder_output, memory_mask=src_mask
53 )
54
55 # Get next token logits
56 logits = self.model.output_projection(decoder_output[:, -1, :])
57
58 # Apply minimum length constraint
59 if step < min_length - 1:
60 logits[:, self.eos_token_id] = float('-inf')
61
62 # Greedy selection
63 next_tokens = logits.argmax(dim=-1)
64
65 # Handle finished sequences
66 next_tokens = torch.where(
67 done,
68 torch.full_like(next_tokens, self.pad_token_id),
69 next_tokens
70 )
71
72 # Append
73 generated = torch.cat([generated, next_tokens.unsqueeze(1)], dim=1)
74
75 # Update lengths for non-finished sequences
76 lengths = torch.where(~done, lengths + 1, lengths)
77
78 # Update done status
79 done = done | (next_tokens == self.eos_token_id)
80
81 if done.all():
82 break
83
84 return generated, lengthsGreedy Decoding Analysis
Pros and Cons
๐text
1ADVANTAGES:
2โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
3โ Simple to implement
4โ Fast - O(n) forward passes for length n
5โ Deterministic - same input โ same output
6โ No hyperparameters to tune
7โ Memory efficient - no need to store alternatives
8
9DISADVANTAGES:
10โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
11โ Locally optimal, not globally optimal
12โ Can get stuck in repetitive loops
13โ No diversity in outputs
14โ Cannot recover from early mistakes
15โ May miss better sequences that require "risky" first tokensLocal vs Global Optimum
๐text
1Suppose we're generating a translation with these probabilities:
2
3Path A (Greedy):
4 P("The") = 0.6 โ P("cat") = 0.3 โ P("sleeps") = 0.2
5 Total: 0.6 ร 0.3 ร 0.2 = 0.036
6
7Path B (Better):
8 P("A") = 0.3 โ P("cat") = 0.8 โ P("sleeps") = 0.9
9 Total: 0.3 ร 0.8 ร 0.9 = 0.216
10
11Greedy picks Path A because P("The") > P("A")
12But Path B has 6ร higher total probability!
13
14This is why beam search can outperform greedy decoding.Repetition Problem
Detecting and Handling Repetition
๐python
1class GreedyDecoderWithRepetitionPenalty:
2 """
3 Greedy decoder with repetition penalty to avoid loops.
4 """
5
6 def __init__(
7 self,
8 model: nn.Module,
9 bos_token_id: int = 2,
10 eos_token_id: int = 3,
11 pad_token_id: int = 0
12 ):
13 self.model = model
14 self.bos_token_id = bos_token_id
15 self.eos_token_id = eos_token_id
16 self.pad_token_id = pad_token_id
17
18 def _apply_repetition_penalty(
19 self,
20 logits: torch.Tensor,
21 generated: torch.Tensor,
22 penalty: float = 1.2
23 ) -> torch.Tensor:
24 """
25 Apply repetition penalty to discourage repeated tokens.
26
27 Args:
28 logits: [batch, vocab_size]
29 generated: Previously generated tokens [batch, seq_len]
30 penalty: Penalty factor (>1 discourages repetition)
31
32 Returns:
33 penalized_logits: [batch, vocab_size]
34 """
35 # For each sequence in batch
36 for i in range(generated.size(0)):
37 # Get unique tokens already generated
38 generated_tokens = generated[i].unique()
39
40 # Penalize their logits
41 for token in generated_tokens:
42 if token.item() in [self.pad_token_id, self.bos_token_id]:
43 continue
44 # If logit > 0, divide by penalty; if < 0, multiply by penalty
45 if logits[i, token] > 0:
46 logits[i, token] /= penalty
47 else:
48 logits[i, token] *= penalty
49
50 return logits
51
52 def _block_repeated_ngrams(
53 self,
54 logits: torch.Tensor,
55 generated: torch.Tensor,
56 n: int
57 ) -> torch.Tensor:
58 """
59 Block tokens that would create repeated n-grams.
60 """
61 batch_size, seq_len = generated.shape
62
63 if seq_len < n - 1:
64 return logits
65
66 for i in range(batch_size):
67 seq = generated[i].tolist()
68
69 # Get the last (n-1) tokens
70 prefix = tuple(seq[-(n-1):])
71
72 # Find all n-grams starting with this prefix
73 for j in range(len(seq) - n + 1):
74 if tuple(seq[j:j+n-1]) == prefix:
75 # Block the token that would complete this n-gram
76 blocked_token = seq[j + n - 1]
77 logits[i, blocked_token] = float('-inf')
78
79 return logitsRepetition Problem Example
๐text
1Without repetition handling, models can get stuck in loops:
2
3Input: "Translate: Der Hund"
4
5Step 1: "The"
6Step 2: "The dog"
7Step 3: "The dog is"
8Step 4: "The dog is a"
9Step 5: "The dog is a dog"
10Step 6: "The dog is a dog is"
11Step 7: "The dog is a dog is a"
12Step 8: "The dog is a dog is a dog"
13... (continues indefinitely)
14
15With repetition penalty (1.2):
16- Each time "dog" appears, its probability is reduced
17- Eventually other tokens become more likely
18
19With n-gram blocking (n=3):
20- After "is a dog", the bigram "a dog" appeared
21- Token "dog" is blocked if "a" preceded by something
22- Prevents exact repetition of 3-gramsComplete Greedy Decoder for Translation
Production-Ready Implementation
๐python
1class TranslationGreedyDecoder:
2 """
3 Production-ready greedy decoder for machine translation.
4 """
5
6 def __init__(
7 self,
8 model: nn.Module,
9 tokenizer,
10 device: torch.device = None
11 ):
12 self.model = model
13 self.tokenizer = tokenizer
14 self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
16 # Get special token IDs from tokenizer
17 self.bos_id = tokenizer.bos_token_id
18 self.eos_id = tokenizer.eos_token_id
19 self.pad_id = tokenizer.pad_token_id
20
21 self.model.to(self.device)
22 self.model.eval()
23
24 def translate(
25 self,
26 source_text: str,
27 max_length: int = 100,
28 repetition_penalty: float = 1.0
29 ) -> str:
30 """
31 Translate a single source sentence.
32
33 Args:
34 source_text: Source language text
35 max_length: Maximum output length
36 repetition_penalty: Penalty for repeated tokens
37
38 Returns:
39 translation: Translated text
40 """
41 # Tokenize source
42 source_ids = self.tokenizer.encode(source_text)
43 source_tensor = torch.tensor([source_ids], device=self.device)
44
45 # Encode source
46 with torch.no_grad():
47 encoder_output = self.model.encode(source_tensor)
48
49 # Generate
50 generated = self._decode(
51 encoder_output,
52 max_length=max_length,
53 repetition_penalty=repetition_penalty
54 )
55
56 # Decode to text
57 translation = self.tokenizer.decode(generated[0].tolist())
58
59 return translation
60
61 def translate_batch(
62 self,
63 source_texts: List[str],
64 max_length: int = 100,
65 batch_size: int = 32
66 ) -> List[str]:
67 """
68 Translate multiple sentences.
69 """
70 translations = []
71
72 for i in range(0, len(source_texts), batch_size):
73 batch = source_texts[i:i + batch_size]
74
75 # Tokenize and pad
76 source_ids = [self.tokenizer.encode(text) for text in batch]
77 max_src_len = max(len(ids) for ids in source_ids)
78 padded = [
79 ids + [self.pad_id] * (max_src_len - len(ids))
80 for ids in source_ids
81 ]
82 source_tensor = torch.tensor(padded, device=self.device)
83
84 # Encode
85 with torch.no_grad():
86 encoder_output = self.model.encode(source_tensor)
87
88 # Generate
89 generated = self._decode(encoder_output, max_length)
90
91 # Decode
92 for seq in generated:
93 text = self.tokenizer.decode(seq.tolist())
94 translations.append(text)
95
96 return translationsSummary
| Aspect | Description |
|---|---|
| Algorithm | Select argmax at each step |
| Complexity | O(n) forward passes |
| Deterministic | Yes |
| Quality | Good for short sequences |
When to Use Greedy
- Real-time translation where speed matters
- When outputs are short and simple
- As a baseline for comparing other methods
- When deterministic output is required
In the next section, we'll implement beam searchโa more sophisticated decoding strategy that maintains multiple hypotheses and often finds better sequences than greedy decoding.