Chapter 9
18 min read
Section 47 of 75

Beam Search

Autoregressive Generation

Beam search is the most widely used decoding strategy for machine translation. Instead of greedily selecting a single token at each step, beam search maintains multiple candidate sequences (the "beam"), exploring different paths through the output space to find higher-quality translations.


The Beam Search Algorithm

Core Concept

πŸ“text
1Instead of keeping 1 best path (greedy), keep k best paths (beam).
2
3At each step:
41. Expand each of k current hypotheses to all possible next tokens
52. Score all k Γ— vocab_size candidates
63. Keep only the k best overall
74. Repeat until all beams reach EOS or max_length

Visual Example (Beam Width = 2)

πŸ“text
1Step 0: Start with <bos>
2        Beam: [(<bos>, score=0.0)]
3
4Step 1: Expand <bos> β†’ all vocab tokens
5        Candidates:
6          (<bos>, The)  score=-0.5
7          (<bos>, A)    score=-0.8
8          (<bos>, One)  score=-1.2
9          (<bos>, My)   score=-1.5
10          ...
11
12        Keep top 2:
13        Beam: [(<bos>, The, score=-0.5),
14               (<bos>, A, score=-0.8)]
15
16Step 2: Expand both hypotheses
17        From "The":
18          (<bos>, The, dog)   score=-1.0
19          (<bos>, The, cat)   score=-1.2
20          (<bos>, The, man)   score=-1.5
21          ...
22
23        From "A":
24          (<bos>, A, dog)     score=-1.3
25          (<bos>, A, cat)     score=-1.4
26          ...
27
28        All candidates (2 Γ— vocab_size), keep top 2:
29        Beam: [(<bos>, The, dog, score=-1.0),
30               (<bos>, The, cat, score=-1.2)]
31
32Step 3: Continue...
33        (<bos>, The, dog, runs, score=-1.8)
34        (<bos>, The, dog, barks, score=-2.0)
35        (<bos>, The, cat, sleeps, score=-1.9)
36        ...
37
38        Keep top 2:
39        Beam: [(<bos>, The, dog, runs, score=-1.8),
40               (<bos>, The, cat, sleeps, score=-1.9)]
41
42Step 4: One beam reaches EOS
43        (<bos>, The, dog, runs, <eos>, score=-2.2) ← COMPLETE
44        (<bos>, The, cat, sleeps, quietly, score=-2.5)
45
46        Continue with incomplete beams until all done or max_length
47
48Final: Return highest-scoring complete sequence
49       "The dog runs"

Basic Implementation

BeamSearchDecoder Class

🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Optional, List, Tuple
5from dataclasses import dataclass
6
7
8@dataclass
9class BeamHypothesis:
10    """A single hypothesis in beam search."""
11    tokens: List[int]  # Generated token IDs
12    score: float       # Log probability sum
13    is_done: bool = False
14
15
16class BeamSearchDecoder:
17    """
18    Beam search decoder for transformer models.
19
20    Maintains multiple hypotheses and explores different paths
21    to find higher-quality sequences.
22
23    Args:
24        model: Transformer model
25        beam_size: Number of beams to maintain
26        bos_token_id: Beginning of sequence token
27        eos_token_id: End of sequence token
28        pad_token_id: Padding token
29
30    Example:
31        >>> decoder = BeamSearchDecoder(model, beam_size=5)
32        >>> output = decoder.decode(encoder_output, max_length=50)
33    """
34
35    def __init__(
36        self,
37        model: nn.Module,
38        beam_size: int = 5,
39        bos_token_id: int = 2,
40        eos_token_id: int = 3,
41        pad_token_id: int = 0
42    ):
43        self.model = model
44        self.beam_size = beam_size
45        self.bos_token_id = bos_token_id
46        self.eos_token_id = eos_token_id
47        self.pad_token_id = pad_token_id
48
49    @torch.no_grad()
50    def decode(
51        self,
52        encoder_output: torch.Tensor,
53        src_mask: Optional[torch.Tensor] = None,
54        max_length: int = 100,
55        length_penalty: float = 1.0,
56        early_stopping: bool = True
57    ) -> Tuple[torch.Tensor, torch.Tensor]:
58        """
59        Perform beam search decoding.
60
61        Args:
62            encoder_output: Encoded source [1, src_len, d_model]
63                           (batch_size must be 1)
64            src_mask: Source padding mask
65            max_length: Maximum output length
66            length_penalty: Length normalization factor
67            early_stopping: Stop when beam_size hypotheses complete
68
69        Returns:
70            best_sequence: Best hypothesis tokens [1, seq_len]
71            best_score: Score of best hypothesis
72        """
73        self.model.eval()
74        device = encoder_output.device
75
76        # Expand encoder output for beam_size
77        # [1, src_len, d_model] β†’ [beam_size, src_len, d_model]
78        encoder_output = encoder_output.expand(self.beam_size, -1, -1)
79
80        if src_mask is not None:
81            src_mask = src_mask.expand(self.beam_size, -1, -1, -1)
82
83        # Initialize beams
84        # [beam_size, 1]
85        beam_tokens = torch.full(
86            (self.beam_size, 1),
87            self.bos_token_id,
88            dtype=torch.long,
89            device=device
90        )
91
92        # Beam scores [beam_size]
93        beam_scores = torch.zeros(self.beam_size, device=device)
94        beam_scores[1:] = float('-inf')  # Only first beam active initially
95
96        # Track completed hypotheses
97        completed_hypotheses: List[Tuple[torch.Tensor, float]] = []
98
99        for step in range(max_length - 1):
100            # Decode current beams
101            decoder_output = self.model.decode(
102                beam_tokens,
103                encoder_output,
104                memory_mask=src_mask
105            )
106
107            # Get logits for last position [beam_size, vocab_size]
108            logits = self.model.output_projection(decoder_output[:, -1, :])
109            log_probs = F.log_softmax(logits, dim=-1)
110
111            vocab_size = log_probs.size(-1)
112
113            # Calculate scores for all candidates
114            # [beam_size, vocab_size]
115            next_scores = beam_scores.unsqueeze(-1) + log_probs
116
117            # Reshape to [beam_size * vocab_size]
118            next_scores = next_scores.view(-1)
119
120            # Get top beam_size candidates
121            top_scores, top_indices = torch.topk(
122                next_scores, 2 * self.beam_size, sorted=True
123            )
124
125            # Convert flat indices to (beam_idx, token_idx)
126            beam_indices = top_indices // vocab_size
127            token_indices = top_indices % vocab_size
128
129            # Build new beams
130            new_beam_tokens = []
131            new_beam_scores = []
132            num_active = 0
133
134            for score, beam_idx, token_idx in zip(
135                top_scores, beam_indices, token_indices
136            ):
137                if num_active >= self.beam_size:
138                    break
139
140                # Get previous sequence
141                prev_tokens = beam_tokens[beam_idx]
142
143                # Create new sequence
144                new_tokens = torch.cat([
145                    prev_tokens,
146                    token_idx.unsqueeze(0)
147                ])
148
149                if token_idx.item() == self.eos_token_id:
150                    # Completed hypothesis
151                    # Apply length penalty
152                    final_score = score.item() / (len(new_tokens) ** length_penalty)
153                    completed_hypotheses.append((new_tokens, final_score))
154
155                    # Check early stopping
156                    if early_stopping and len(completed_hypotheses) >= self.beam_size:
157                        break
158                else:
159                    new_beam_tokens.append(new_tokens)
160                    new_beam_scores.append(score)
161                    num_active += 1
162
163            # Check if we should stop
164            if early_stopping and len(completed_hypotheses) >= self.beam_size:
165                break
166
167            if num_active == 0:
168                break
169
170            # Pad to beam_size if needed
171            while len(new_beam_tokens) < self.beam_size:
172                new_beam_tokens.append(new_beam_tokens[0])
173                new_beam_scores.append(float('-inf'))
174
175            # Update beams
176            beam_tokens = torch.stack(new_beam_tokens[:self.beam_size])
177            beam_scores = torch.tensor(
178                new_beam_scores[:self.beam_size],
179                device=device
180            )
181
182        # If no completed hypotheses, use best current beam
183        if not completed_hypotheses:
184            best_idx = beam_scores.argmax()
185            return beam_tokens[best_idx].unsqueeze(0), beam_scores[best_idx]
186
187        # Return best completed hypothesis
188        best_hyp = max(completed_hypotheses, key=lambda x: x[1])
189        return best_hyp[0].unsqueeze(0), torch.tensor(best_hyp[1])

Optimized Implementation

🐍python
1class EfficientBeamSearchDecoder:
2    """
3    Efficient beam search with proper batching and early stopping.
4    """
5
6    def __init__(
7        self,
8        model: nn.Module,
9        beam_size: int = 5,
10        bos_token_id: int = 2,
11        eos_token_id: int = 3,
12        pad_token_id: int = 0,
13        length_penalty: float = 1.0,
14        no_repeat_ngram_size: int = 0
15    ):
16        self.model = model
17        self.beam_size = beam_size
18        self.bos_token_id = bos_token_id
19        self.eos_token_id = eos_token_id
20        self.pad_token_id = pad_token_id
21        self.length_penalty = length_penalty
22        self.no_repeat_ngram_size = no_repeat_ngram_size
23
24    def _length_normalize(
25        self,
26        scores: torch.Tensor,
27        lengths: torch.Tensor
28    ) -> torch.Tensor:
29        """Apply length penalty to scores."""
30        return scores / (lengths.float() ** self.length_penalty)
31
32    def _block_ngrams(
33        self,
34        log_probs: torch.Tensor,
35        generated: torch.Tensor,
36        n: int
37    ) -> torch.Tensor:
38        """Block tokens that would create repeated n-grams."""
39        if n == 0 or generated.size(1) < n - 1:
40            return log_probs
41
42        batch_size = generated.size(0)
43
44        for i in range(batch_size):
45            seq = generated[i].tolist()
46            prefix = tuple(seq[-(n-1):])
47
48            # Find matching prefixes
49            for j in range(len(seq) - n + 1):
50                if tuple(seq[j:j+n-1]) == prefix:
51                    blocked_token = seq[j + n - 1]
52                    log_probs[i, blocked_token] = float('-inf')
53
54        return log_probs

Length Penalty

Normalizing Scores by Length

πŸ“text
1Example hypotheses:
2
3Hypotheses (without length penalty):
4  "The dog"                          β†’ Log-prob: -2.0, Length: 3
5  "The big brown dog"                β†’ Log-prob: -4.5, Length: 5
6  "The very big brown fluffy dog"    β†’ Log-prob: -8.0, Length: 7
7
8Effect of Different Length Penalties:
9
10Ξ± = 0.0 (No normalization):
11  "The dog": -2.0 / 3^0.0 = -2.000  ← Best (prefers shorter)
12  "The big brown dog": -4.5 / 5^0.0 = -4.500
13  "The very big brown fluffy dog": -8.0 / 7^0.0 = -8.000
14
15Ξ± = 0.6 (Mild penalty):
16  "The dog": -2.0 / 3^0.6 = -1.026
17  "The big brown dog": -4.5 / 5^0.6 = -1.554  ← Best
18  "The very big brown fluffy dog": -8.0 / 7^0.6 = -2.304
19
20Ξ± = 1.0 (Linear penalty):
21  "The dog": -2.0 / 3 = -0.667  ← Best (neutral to length)
22  "The big brown dog": -4.5 / 5 = -0.900
23  "The very big brown fluffy dog": -8.0 / 7 = -1.143
24
25Ξ± = 1.5 (Strong penalty):
26  "The dog": -2.0 / 3^1.5 = -0.385
27  "The big brown dog": -4.5 / 5^1.5 = -0.402
28  "The very big brown fluffy dog": -8.0 / 7^1.5 = -0.432  ← Best (prefers longer)
29
30Interpretation:
31β€’ Ξ± = 0.0: No normalization (prefers shorter sequences)
32β€’ Ξ± = 0.6: Mild penalty (good balance, commonly used)
33β€’ Ξ± = 1.0: Linear penalty (neutral to length)
34β€’ Ξ± > 1.0: Strong penalty (prefers longer sequences)
35
36Typical values for translation: α ∈ [0.6, 1.0]

Encouraging Diversity in Beams

πŸ“text
1Problem with Standard Beam Search:
2───────────────────────────────────
3Beam 1: "The dog runs quickly through the park"
4Beam 2: "The dog runs quickly through the garden"
5Beam 3: "The dog runs quickly through the forest"
6Beam 4: "The dog runs quickly through the yard"
7Beam 5: "The dog runs quickly across the park"
8
9All beams are very similar! Only differ in last 1-2 words.
10
11Diverse Beam Search Solution:
12─────────────────────────────
13Divide beams into groups. Penalize tokens chosen by earlier groups.
14
15Group 1 (beams 1-2):
16  "The dog runs quickly..."
17  "The dog runs fast..."
18
19Group 2 (beams 3-4) - penalize "dog", "runs", "quickly", "fast":
20  "A cat walks slowly..."
21  "The man walks briskly..."
22
23Group 3 (beams 5-6) - penalize all previous tokens:
24  "One animal moves silently..."
25  "My pet roams freely..."
26
27Result: More diverse outputs for applications like:
28β€’ Generating multiple translation options
29β€’ Creative text generation
30β€’ Dialogue systems

Comparison: Greedy vs Beam Search

Performance Analysis

AspectGreedyBeam Search
TimeO(n)O(n Γ— k Γ— V)
MemoryO(n)O(n Γ— k)
QualityGoodBetter
DeterministicYesYes*
DiversityNoneLow-Medium
ImplementationSimpleComplex

Note: n = sequence length, k = beam size, V = vocabulary size. *Beam search is deterministic for same beam size.

When to Use Each

GREEDY:

  • Real-time applications (speed critical)
  • Short outputs
  • When quality is "good enough"
  • Resource-constrained environments

BEAM SEARCH:

  • Machine translation (quality critical)
  • Longer sequences
  • When you need the "best" output
  • Offline batch processing

Summary

ParameterDescriptionTypical Value
beam_sizeNumber of hypotheses4-10
length_penaltyLength normalization0.6-1.0
no_repeat_ngramBlock repeated n-grams2-3
early_stoppingStop when k completeTrue

Algorithm Complexity:

πŸ“text
1Time: O(seq_len Γ— beam_size Γ— vocab_size)
2Memory: O(seq_len Γ— beam_size)

Best Practices:

  1. Beam size: 4-5 is often optimal (diminishing returns beyond)
  2. Length penalty: Tune on validation set
  3. N-gram blocking: Prevents repetition
  4. Early stopping: Saves computation

In the next section, we'll implement sampling-based decoding strategies: temperature scaling, top-k sampling, and nucleus (top-p) samplingβ€”methods that introduce controlled randomness for more diverse and creative outputs.
Loading comments...