Chapter 9
15 min read
Section 46 of 75

Greedy Decoding

Autoregressive Generation

Greedy decoding is the simplest autoregressive generation strategy: at each step, select the token with the highest probability. Despite its simplicity, greedy decoding is fast and often produces good results, making it a useful baseline.


The Greedy Algorithm

Core Concept

๐Ÿ“text
1At each step t:
2  1. Get probability distribution P(yโ‚œ | yโ‚, yโ‚‚, ..., yโ‚œโ‚‹โ‚)
3  2. Select yโ‚œ = argmax P(yโ‚œ | ...)
4  3. Append yโ‚œ to sequence
5  4. Repeat until EOS or max_length

Visual Example

๐Ÿ“text
1Source: "Der Hund lรคuft schnell"
2
3Step 1: P(yโ‚|<bos>, encoder)
4        "The": 0.45 โ† MAX
5        "A":   0.30
6        "One": 0.15
7        ...
8        Select: "The"
9
10Step 2: P(yโ‚‚|<bos>, The, encoder)
11        "dog": 0.60 โ† MAX
12        "cat": 0.25
13        "man": 0.10
14        ...
15        Select: "dog"
16
17Step 3: P(yโ‚ƒ|<bos>, The, dog, encoder)
18        "runs": 0.55 โ† MAX
19        "walks": 0.30
20        "is": 0.10
21        ...
22        Select: "runs"
23
24Step 4: P(yโ‚„|<bos>, The, dog, runs, encoder)
25        "fast": 0.50 โ† MAX
26        "quickly": 0.35
27        ".": 0.10
28        ...
29        Select: "fast"
30
31Step 5: P(yโ‚…|<bos>, The, dog, runs, fast, encoder)
32        "<eos>": 0.70 โ† MAX
33        ".": 0.20
34        ...
35        Select: "<eos>"
36
37Final: "The dog runs fast"

Implementation

Complete Greedy Decoder

๐Ÿpython
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5from typing import Optional, Tuple, List
6
7
8class GreedyDecoder:
9    """
10    Greedy decoding for transformer models.
11
12    At each step, selects the token with highest probability.
13    Fast and deterministic, but may miss globally optimal sequences.
14
15    Example:
16        >>> decoder = GreedyDecoder(model, tokenizer)
17        >>> output = decoder.decode(encoder_output, max_length=50)
18    """
19
20    def __init__(
21        self,
22        model: nn.Module,
23        bos_token_id: int = 2,
24        eos_token_id: int = 3,
25        pad_token_id: int = 0
26    ):
27        """
28        Args:
29            model: Transformer model with decode method
30            bos_token_id: Beginning of sequence token ID
31            eos_token_id: End of sequence token ID
32            pad_token_id: Padding token ID
33        """
34        self.model = model
35        self.bos_token_id = bos_token_id
36        self.eos_token_id = eos_token_id
37        self.pad_token_id = pad_token_id
38
39    @torch.no_grad()
40    def decode(
41        self,
42        encoder_output: torch.Tensor,
43        src_mask: Optional[torch.Tensor] = None,
44        max_length: int = 100,
45        return_scores: bool = False
46    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
47        """
48        Perform greedy decoding.
49
50        Args:
51            encoder_output: Encoded source [batch, src_len, d_model]
52            src_mask: Source padding mask [batch, 1, 1, src_len]
53            max_length: Maximum output length
54            return_scores: Whether to return token log-probabilities
55
56        Returns:
57            generated: Generated token IDs [batch, seq_len]
58            scores: Log-probabilities if return_scores=True [batch, seq_len]
59        """
60        self.model.eval()
61        batch_size = encoder_output.size(0)
62        device = encoder_output.device
63
64        # Initialize with BOS token
65        generated = torch.full(
66            (batch_size, 1),
67            self.bos_token_id,
68            dtype=torch.long,
69            device=device
70        )
71
72        # Track which sequences are done
73        done = torch.zeros(batch_size, dtype=torch.bool, device=device)
74
75        # Optional: track scores
76        all_scores = [] if return_scores else None
77
78        for step in range(max_length - 1):
79            # Get logits for next token
80            # decoder_output: [batch, seq_len, d_model]
81            decoder_output = self.model.decode(
82                generated,
83                encoder_output,
84                memory_mask=src_mask
85            )
86
87            # Get logits for last position
88            # [batch, vocab_size]
89            logits = self.model.output_projection(decoder_output[:, -1, :])
90
91            # Convert to log-probabilities
92            log_probs = F.log_softmax(logits, dim=-1)
93
94            # Greedy selection: argmax
95            next_tokens = log_probs.argmax(dim=-1)  # [batch]
96
97            # Get scores for selected tokens
98            if return_scores:
99                selected_scores = log_probs.gather(
100                    dim=-1,
101                    index=next_tokens.unsqueeze(-1)
102                ).squeeze(-1)
103                all_scores.append(selected_scores)
104
105            # Replace with PAD for finished sequences
106            next_tokens = next_tokens.masked_fill(done, self.pad_token_id)
107
108            # Append to generated sequence
109            generated = torch.cat([
110                generated,
111                next_tokens.unsqueeze(1)
112            ], dim=1)
113
114            # Update done status
115            done = done | (next_tokens == self.eos_token_id)
116
117            # Early stopping if all sequences are done
118            if done.all():
119                break
120
121        # Stack scores if requested
122        scores = None
123        if return_scores and all_scores:
124            scores = torch.stack(all_scores, dim=1)
125
126        return generated, scores

Optimized Implementation with Early Stopping

Batch-Aware Early Stopping

๐Ÿpython
1class GreedyDecoderOptimized:
2    """
3    Optimized greedy decoder with proper batch handling.
4    """
5
6    def __init__(
7        self,
8        model: nn.Module,
9        bos_token_id: int = 2,
10        eos_token_id: int = 3,
11        pad_token_id: int = 0
12    ):
13        self.model = model
14        self.bos_token_id = bos_token_id
15        self.eos_token_id = eos_token_id
16        self.pad_token_id = pad_token_id
17
18    @torch.no_grad()
19    def decode(
20        self,
21        encoder_output: torch.Tensor,
22        src_mask: Optional[torch.Tensor] = None,
23        max_length: int = 100,
24        min_length: int = 1
25    ) -> Tuple[torch.Tensor, torch.Tensor]:
26        """
27        Optimized greedy decoding with min/max length constraints.
28
29        Returns:
30            generated: [batch, seq_len]
31            lengths: Actual length of each sequence [batch]
32        """
33        self.model.eval()
34        batch_size = encoder_output.size(0)
35        device = encoder_output.device
36
37        # Initialize
38        generated = torch.full(
39            (batch_size, 1),
40            self.bos_token_id,
41            dtype=torch.long,
42            device=device
43        )
44
45        # Track state
46        done = torch.zeros(batch_size, dtype=torch.bool, device=device)
47        lengths = torch.ones(batch_size, dtype=torch.long, device=device)
48
49        for step in range(max_length - 1):
50            # Decode
51            decoder_output = self.model.decode(
52                generated, encoder_output, memory_mask=src_mask
53            )
54
55            # Get next token logits
56            logits = self.model.output_projection(decoder_output[:, -1, :])
57
58            # Apply minimum length constraint
59            if step < min_length - 1:
60                logits[:, self.eos_token_id] = float('-inf')
61
62            # Greedy selection
63            next_tokens = logits.argmax(dim=-1)
64
65            # Handle finished sequences
66            next_tokens = torch.where(
67                done,
68                torch.full_like(next_tokens, self.pad_token_id),
69                next_tokens
70            )
71
72            # Append
73            generated = torch.cat([generated, next_tokens.unsqueeze(1)], dim=1)
74
75            # Update lengths for non-finished sequences
76            lengths = torch.where(~done, lengths + 1, lengths)
77
78            # Update done status
79            done = done | (next_tokens == self.eos_token_id)
80
81            if done.all():
82                break
83
84        return generated, lengths

Greedy Decoding Analysis

Pros and Cons

๐Ÿ“text
1ADVANTAGES:
2โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
3โœ“ Simple to implement
4โœ“ Fast - O(n) forward passes for length n
5โœ“ Deterministic - same input โ†’ same output
6โœ“ No hyperparameters to tune
7โœ“ Memory efficient - no need to store alternatives
8
9DISADVANTAGES:
10โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
11โœ— Locally optimal, not globally optimal
12โœ— Can get stuck in repetitive loops
13โœ— No diversity in outputs
14โœ— Cannot recover from early mistakes
15โœ— May miss better sequences that require "risky" first tokens

Local vs Global Optimum

๐Ÿ“text
1Suppose we're generating a translation with these probabilities:
2
3Path A (Greedy):
4    P("The") = 0.6  โ†’ P("cat") = 0.3 โ†’ P("sleeps") = 0.2
5    Total: 0.6 ร— 0.3 ร— 0.2 = 0.036
6
7Path B (Better):
8    P("A")   = 0.3  โ†’ P("cat") = 0.8 โ†’ P("sleeps") = 0.9
9    Total: 0.3 ร— 0.8 ร— 0.9 = 0.216
10
11Greedy picks Path A because P("The") > P("A")
12But Path B has 6ร— higher total probability!
13
14This is why beam search can outperform greedy decoding.

Repetition Problem

Detecting and Handling Repetition

๐Ÿpython
1class GreedyDecoderWithRepetitionPenalty:
2    """
3    Greedy decoder with repetition penalty to avoid loops.
4    """
5
6    def __init__(
7        self,
8        model: nn.Module,
9        bos_token_id: int = 2,
10        eos_token_id: int = 3,
11        pad_token_id: int = 0
12    ):
13        self.model = model
14        self.bos_token_id = bos_token_id
15        self.eos_token_id = eos_token_id
16        self.pad_token_id = pad_token_id
17
18    def _apply_repetition_penalty(
19        self,
20        logits: torch.Tensor,
21        generated: torch.Tensor,
22        penalty: float = 1.2
23    ) -> torch.Tensor:
24        """
25        Apply repetition penalty to discourage repeated tokens.
26
27        Args:
28            logits: [batch, vocab_size]
29            generated: Previously generated tokens [batch, seq_len]
30            penalty: Penalty factor (>1 discourages repetition)
31
32        Returns:
33            penalized_logits: [batch, vocab_size]
34        """
35        # For each sequence in batch
36        for i in range(generated.size(0)):
37            # Get unique tokens already generated
38            generated_tokens = generated[i].unique()
39
40            # Penalize their logits
41            for token in generated_tokens:
42                if token.item() in [self.pad_token_id, self.bos_token_id]:
43                    continue
44                # If logit > 0, divide by penalty; if < 0, multiply by penalty
45                if logits[i, token] > 0:
46                    logits[i, token] /= penalty
47                else:
48                    logits[i, token] *= penalty
49
50        return logits
51
52    def _block_repeated_ngrams(
53        self,
54        logits: torch.Tensor,
55        generated: torch.Tensor,
56        n: int
57    ) -> torch.Tensor:
58        """
59        Block tokens that would create repeated n-grams.
60        """
61        batch_size, seq_len = generated.shape
62
63        if seq_len < n - 1:
64            return logits
65
66        for i in range(batch_size):
67            seq = generated[i].tolist()
68
69            # Get the last (n-1) tokens
70            prefix = tuple(seq[-(n-1):])
71
72            # Find all n-grams starting with this prefix
73            for j in range(len(seq) - n + 1):
74                if tuple(seq[j:j+n-1]) == prefix:
75                    # Block the token that would complete this n-gram
76                    blocked_token = seq[j + n - 1]
77                    logits[i, blocked_token] = float('-inf')
78
79        return logits

Repetition Problem Example

๐Ÿ“text
1Without repetition handling, models can get stuck in loops:
2
3Input: "Translate: Der Hund"
4
5Step 1: "The"
6Step 2: "The dog"
7Step 3: "The dog is"
8Step 4: "The dog is a"
9Step 5: "The dog is a dog"
10Step 6: "The dog is a dog is"
11Step 7: "The dog is a dog is a"
12Step 8: "The dog is a dog is a dog"
13... (continues indefinitely)
14
15With repetition penalty (1.2):
16- Each time "dog" appears, its probability is reduced
17- Eventually other tokens become more likely
18
19With n-gram blocking (n=3):
20- After "is a dog", the bigram "a dog" appeared
21- Token "dog" is blocked if "a" preceded by something
22- Prevents exact repetition of 3-grams

Complete Greedy Decoder for Translation

Production-Ready Implementation

๐Ÿpython
1class TranslationGreedyDecoder:
2    """
3    Production-ready greedy decoder for machine translation.
4    """
5
6    def __init__(
7        self,
8        model: nn.Module,
9        tokenizer,
10        device: torch.device = None
11    ):
12        self.model = model
13        self.tokenizer = tokenizer
14        self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
16        # Get special token IDs from tokenizer
17        self.bos_id = tokenizer.bos_token_id
18        self.eos_id = tokenizer.eos_token_id
19        self.pad_id = tokenizer.pad_token_id
20
21        self.model.to(self.device)
22        self.model.eval()
23
24    def translate(
25        self,
26        source_text: str,
27        max_length: int = 100,
28        repetition_penalty: float = 1.0
29    ) -> str:
30        """
31        Translate a single source sentence.
32
33        Args:
34            source_text: Source language text
35            max_length: Maximum output length
36            repetition_penalty: Penalty for repeated tokens
37
38        Returns:
39            translation: Translated text
40        """
41        # Tokenize source
42        source_ids = self.tokenizer.encode(source_text)
43        source_tensor = torch.tensor([source_ids], device=self.device)
44
45        # Encode source
46        with torch.no_grad():
47            encoder_output = self.model.encode(source_tensor)
48
49        # Generate
50        generated = self._decode(
51            encoder_output,
52            max_length=max_length,
53            repetition_penalty=repetition_penalty
54        )
55
56        # Decode to text
57        translation = self.tokenizer.decode(generated[0].tolist())
58
59        return translation
60
61    def translate_batch(
62        self,
63        source_texts: List[str],
64        max_length: int = 100,
65        batch_size: int = 32
66    ) -> List[str]:
67        """
68        Translate multiple sentences.
69        """
70        translations = []
71
72        for i in range(0, len(source_texts), batch_size):
73            batch = source_texts[i:i + batch_size]
74
75            # Tokenize and pad
76            source_ids = [self.tokenizer.encode(text) for text in batch]
77            max_src_len = max(len(ids) for ids in source_ids)
78            padded = [
79                ids + [self.pad_id] * (max_src_len - len(ids))
80                for ids in source_ids
81            ]
82            source_tensor = torch.tensor(padded, device=self.device)
83
84            # Encode
85            with torch.no_grad():
86                encoder_output = self.model.encode(source_tensor)
87
88            # Generate
89            generated = self._decode(encoder_output, max_length)
90
91            # Decode
92            for seq in generated:
93                text = self.tokenizer.decode(seq.tolist())
94                translations.append(text)
95
96        return translations

Summary

AspectDescription
AlgorithmSelect argmax at each step
ComplexityO(n) forward passes
DeterministicYes
QualityGood for short sequences

When to Use Greedy

  • Real-time translation where speed matters
  • When outputs are short and simple
  • As a baseline for comparing other methods
  • When deterministic output is required

In the next section, we'll implement beam searchโ€”a more sophisticated decoding strategy that maintains multiple hypotheses and often finds better sequences than greedy decoding.
Loading comments...