Beam search is the most widely used decoding strategy for machine translation. Instead of greedily selecting a single token at each step, beam search maintains multiple candidate sequences (the "beam"), exploring different paths through the output space to find higher-quality translations.
The Beam Search Algorithm
Core Concept
πtext
1Instead of keeping 1 best path (greedy), keep k best paths (beam).
2
3At each step:
41. Expand each of k current hypotheses to all possible next tokens
52. Score all k Γ vocab_size candidates
63. Keep only the k best overall
74. Repeat until all beams reach EOS or max_lengthVisual Example (Beam Width = 2)
πtext
1Step 0: Start with <bos>
2 Beam: [(<bos>, score=0.0)]
3
4Step 1: Expand <bos> β all vocab tokens
5 Candidates:
6 (<bos>, The) score=-0.5
7 (<bos>, A) score=-0.8
8 (<bos>, One) score=-1.2
9 (<bos>, My) score=-1.5
10 ...
11
12 Keep top 2:
13 Beam: [(<bos>, The, score=-0.5),
14 (<bos>, A, score=-0.8)]
15
16Step 2: Expand both hypotheses
17 From "The":
18 (<bos>, The, dog) score=-1.0
19 (<bos>, The, cat) score=-1.2
20 (<bos>, The, man) score=-1.5
21 ...
22
23 From "A":
24 (<bos>, A, dog) score=-1.3
25 (<bos>, A, cat) score=-1.4
26 ...
27
28 All candidates (2 Γ vocab_size), keep top 2:
29 Beam: [(<bos>, The, dog, score=-1.0),
30 (<bos>, The, cat, score=-1.2)]
31
32Step 3: Continue...
33 (<bos>, The, dog, runs, score=-1.8)
34 (<bos>, The, dog, barks, score=-2.0)
35 (<bos>, The, cat, sleeps, score=-1.9)
36 ...
37
38 Keep top 2:
39 Beam: [(<bos>, The, dog, runs, score=-1.8),
40 (<bos>, The, cat, sleeps, score=-1.9)]
41
42Step 4: One beam reaches EOS
43 (<bos>, The, dog, runs, <eos>, score=-2.2) β COMPLETE
44 (<bos>, The, cat, sleeps, quietly, score=-2.5)
45
46 Continue with incomplete beams until all done or max_length
47
48Final: Return highest-scoring complete sequence
49 "The dog runs"Basic Implementation
BeamSearchDecoder Class
πpython
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Optional, List, Tuple
5from dataclasses import dataclass
6
7
8@dataclass
9class BeamHypothesis:
10 """A single hypothesis in beam search."""
11 tokens: List[int] # Generated token IDs
12 score: float # Log probability sum
13 is_done: bool = False
14
15
16class BeamSearchDecoder:
17 """
18 Beam search decoder for transformer models.
19
20 Maintains multiple hypotheses and explores different paths
21 to find higher-quality sequences.
22
23 Args:
24 model: Transformer model
25 beam_size: Number of beams to maintain
26 bos_token_id: Beginning of sequence token
27 eos_token_id: End of sequence token
28 pad_token_id: Padding token
29
30 Example:
31 >>> decoder = BeamSearchDecoder(model, beam_size=5)
32 >>> output = decoder.decode(encoder_output, max_length=50)
33 """
34
35 def __init__(
36 self,
37 model: nn.Module,
38 beam_size: int = 5,
39 bos_token_id: int = 2,
40 eos_token_id: int = 3,
41 pad_token_id: int = 0
42 ):
43 self.model = model
44 self.beam_size = beam_size
45 self.bos_token_id = bos_token_id
46 self.eos_token_id = eos_token_id
47 self.pad_token_id = pad_token_id
48
49 @torch.no_grad()
50 def decode(
51 self,
52 encoder_output: torch.Tensor,
53 src_mask: Optional[torch.Tensor] = None,
54 max_length: int = 100,
55 length_penalty: float = 1.0,
56 early_stopping: bool = True
57 ) -> Tuple[torch.Tensor, torch.Tensor]:
58 """
59 Perform beam search decoding.
60
61 Args:
62 encoder_output: Encoded source [1, src_len, d_model]
63 (batch_size must be 1)
64 src_mask: Source padding mask
65 max_length: Maximum output length
66 length_penalty: Length normalization factor
67 early_stopping: Stop when beam_size hypotheses complete
68
69 Returns:
70 best_sequence: Best hypothesis tokens [1, seq_len]
71 best_score: Score of best hypothesis
72 """
73 self.model.eval()
74 device = encoder_output.device
75
76 # Expand encoder output for beam_size
77 # [1, src_len, d_model] β [beam_size, src_len, d_model]
78 encoder_output = encoder_output.expand(self.beam_size, -1, -1)
79
80 if src_mask is not None:
81 src_mask = src_mask.expand(self.beam_size, -1, -1, -1)
82
83 # Initialize beams
84 # [beam_size, 1]
85 beam_tokens = torch.full(
86 (self.beam_size, 1),
87 self.bos_token_id,
88 dtype=torch.long,
89 device=device
90 )
91
92 # Beam scores [beam_size]
93 beam_scores = torch.zeros(self.beam_size, device=device)
94 beam_scores[1:] = float('-inf') # Only first beam active initially
95
96 # Track completed hypotheses
97 completed_hypotheses: List[Tuple[torch.Tensor, float]] = []
98
99 for step in range(max_length - 1):
100 # Decode current beams
101 decoder_output = self.model.decode(
102 beam_tokens,
103 encoder_output,
104 memory_mask=src_mask
105 )
106
107 # Get logits for last position [beam_size, vocab_size]
108 logits = self.model.output_projection(decoder_output[:, -1, :])
109 log_probs = F.log_softmax(logits, dim=-1)
110
111 vocab_size = log_probs.size(-1)
112
113 # Calculate scores for all candidates
114 # [beam_size, vocab_size]
115 next_scores = beam_scores.unsqueeze(-1) + log_probs
116
117 # Reshape to [beam_size * vocab_size]
118 next_scores = next_scores.view(-1)
119
120 # Get top beam_size candidates
121 top_scores, top_indices = torch.topk(
122 next_scores, 2 * self.beam_size, sorted=True
123 )
124
125 # Convert flat indices to (beam_idx, token_idx)
126 beam_indices = top_indices // vocab_size
127 token_indices = top_indices % vocab_size
128
129 # Build new beams
130 new_beam_tokens = []
131 new_beam_scores = []
132 num_active = 0
133
134 for score, beam_idx, token_idx in zip(
135 top_scores, beam_indices, token_indices
136 ):
137 if num_active >= self.beam_size:
138 break
139
140 # Get previous sequence
141 prev_tokens = beam_tokens[beam_idx]
142
143 # Create new sequence
144 new_tokens = torch.cat([
145 prev_tokens,
146 token_idx.unsqueeze(0)
147 ])
148
149 if token_idx.item() == self.eos_token_id:
150 # Completed hypothesis
151 # Apply length penalty
152 final_score = score.item() / (len(new_tokens) ** length_penalty)
153 completed_hypotheses.append((new_tokens, final_score))
154
155 # Check early stopping
156 if early_stopping and len(completed_hypotheses) >= self.beam_size:
157 break
158 else:
159 new_beam_tokens.append(new_tokens)
160 new_beam_scores.append(score)
161 num_active += 1
162
163 # Check if we should stop
164 if early_stopping and len(completed_hypotheses) >= self.beam_size:
165 break
166
167 if num_active == 0:
168 break
169
170 # Pad to beam_size if needed
171 while len(new_beam_tokens) < self.beam_size:
172 new_beam_tokens.append(new_beam_tokens[0])
173 new_beam_scores.append(float('-inf'))
174
175 # Update beams
176 beam_tokens = torch.stack(new_beam_tokens[:self.beam_size])
177 beam_scores = torch.tensor(
178 new_beam_scores[:self.beam_size],
179 device=device
180 )
181
182 # If no completed hypotheses, use best current beam
183 if not completed_hypotheses:
184 best_idx = beam_scores.argmax()
185 return beam_tokens[best_idx].unsqueeze(0), beam_scores[best_idx]
186
187 # Return best completed hypothesis
188 best_hyp = max(completed_hypotheses, key=lambda x: x[1])
189 return best_hyp[0].unsqueeze(0), torch.tensor(best_hyp[1])Efficient Batch Beam Search
Optimized Implementation
πpython
1class EfficientBeamSearchDecoder:
2 """
3 Efficient beam search with proper batching and early stopping.
4 """
5
6 def __init__(
7 self,
8 model: nn.Module,
9 beam_size: int = 5,
10 bos_token_id: int = 2,
11 eos_token_id: int = 3,
12 pad_token_id: int = 0,
13 length_penalty: float = 1.0,
14 no_repeat_ngram_size: int = 0
15 ):
16 self.model = model
17 self.beam_size = beam_size
18 self.bos_token_id = bos_token_id
19 self.eos_token_id = eos_token_id
20 self.pad_token_id = pad_token_id
21 self.length_penalty = length_penalty
22 self.no_repeat_ngram_size = no_repeat_ngram_size
23
24 def _length_normalize(
25 self,
26 scores: torch.Tensor,
27 lengths: torch.Tensor
28 ) -> torch.Tensor:
29 """Apply length penalty to scores."""
30 return scores / (lengths.float() ** self.length_penalty)
31
32 def _block_ngrams(
33 self,
34 log_probs: torch.Tensor,
35 generated: torch.Tensor,
36 n: int
37 ) -> torch.Tensor:
38 """Block tokens that would create repeated n-grams."""
39 if n == 0 or generated.size(1) < n - 1:
40 return log_probs
41
42 batch_size = generated.size(0)
43
44 for i in range(batch_size):
45 seq = generated[i].tolist()
46 prefix = tuple(seq[-(n-1):])
47
48 # Find matching prefixes
49 for j in range(len(seq) - n + 1):
50 if tuple(seq[j:j+n-1]) == prefix:
51 blocked_token = seq[j + n - 1]
52 log_probs[i, blocked_token] = float('-inf')
53
54 return log_probsLength Penalty
Normalizing Scores by Length
πtext
1Example hypotheses:
2
3Hypotheses (without length penalty):
4 "The dog" β Log-prob: -2.0, Length: 3
5 "The big brown dog" β Log-prob: -4.5, Length: 5
6 "The very big brown fluffy dog" β Log-prob: -8.0, Length: 7
7
8Effect of Different Length Penalties:
9
10Ξ± = 0.0 (No normalization):
11 "The dog": -2.0 / 3^0.0 = -2.000 β Best (prefers shorter)
12 "The big brown dog": -4.5 / 5^0.0 = -4.500
13 "The very big brown fluffy dog": -8.0 / 7^0.0 = -8.000
14
15Ξ± = 0.6 (Mild penalty):
16 "The dog": -2.0 / 3^0.6 = -1.026
17 "The big brown dog": -4.5 / 5^0.6 = -1.554 β Best
18 "The very big brown fluffy dog": -8.0 / 7^0.6 = -2.304
19
20Ξ± = 1.0 (Linear penalty):
21 "The dog": -2.0 / 3 = -0.667 β Best (neutral to length)
22 "The big brown dog": -4.5 / 5 = -0.900
23 "The very big brown fluffy dog": -8.0 / 7 = -1.143
24
25Ξ± = 1.5 (Strong penalty):
26 "The dog": -2.0 / 3^1.5 = -0.385
27 "The big brown dog": -4.5 / 5^1.5 = -0.402
28 "The very big brown fluffy dog": -8.0 / 7^1.5 = -0.432 β Best (prefers longer)
29
30Interpretation:
31β’ Ξ± = 0.0: No normalization (prefers shorter sequences)
32β’ Ξ± = 0.6: Mild penalty (good balance, commonly used)
33β’ Ξ± = 1.0: Linear penalty (neutral to length)
34β’ Ξ± > 1.0: Strong penalty (prefers longer sequences)
35
36Typical values for translation: Ξ± β [0.6, 1.0]Diverse Beam Search
Encouraging Diversity in Beams
πtext
1Problem with Standard Beam Search:
2βββββββββββββββββββββββββββββββββββ
3Beam 1: "The dog runs quickly through the park"
4Beam 2: "The dog runs quickly through the garden"
5Beam 3: "The dog runs quickly through the forest"
6Beam 4: "The dog runs quickly through the yard"
7Beam 5: "The dog runs quickly across the park"
8
9All beams are very similar! Only differ in last 1-2 words.
10
11Diverse Beam Search Solution:
12βββββββββββββββββββββββββββββ
13Divide beams into groups. Penalize tokens chosen by earlier groups.
14
15Group 1 (beams 1-2):
16 "The dog runs quickly..."
17 "The dog runs fast..."
18
19Group 2 (beams 3-4) - penalize "dog", "runs", "quickly", "fast":
20 "A cat walks slowly..."
21 "The man walks briskly..."
22
23Group 3 (beams 5-6) - penalize all previous tokens:
24 "One animal moves silently..."
25 "My pet roams freely..."
26
27Result: More diverse outputs for applications like:
28β’ Generating multiple translation options
29β’ Creative text generation
30β’ Dialogue systemsComparison: Greedy vs Beam Search
Performance Analysis
| Aspect | Greedy | Beam Search |
|---|---|---|
| Time | O(n) | O(n Γ k Γ V) |
| Memory | O(n) | O(n Γ k) |
| Quality | Good | Better |
| Deterministic | Yes | Yes* |
| Diversity | None | Low-Medium |
| Implementation | Simple | Complex |
Note: n = sequence length, k = beam size, V = vocabulary size. *Beam search is deterministic for same beam size.
When to Use Each
GREEDY:
- Real-time applications (speed critical)
- Short outputs
- When quality is "good enough"
- Resource-constrained environments
BEAM SEARCH:
- Machine translation (quality critical)
- Longer sequences
- When you need the "best" output
- Offline batch processing
Summary
| Parameter | Description | Typical Value |
|---|---|---|
| beam_size | Number of hypotheses | 4-10 |
| length_penalty | Length normalization | 0.6-1.0 |
| no_repeat_ngram | Block repeated n-grams | 2-3 |
| early_stopping | Stop when k complete | True |
Algorithm Complexity:
πtext
1Time: O(seq_len Γ beam_size Γ vocab_size)
2Memory: O(seq_len Γ beam_size)Best Practices:
- Beam size: 4-5 is often optimal (diminishing returns beyond)
- Length penalty: Tune on validation set
- N-gram blocking: Prevents repetition
- Early stopping: Saves computation
In the next section, we'll implement sampling-based decoding strategies: temperature scaling, top-k sampling, and nucleus (top-p) samplingβmethods that introduce controlled randomness for more diverse and creative outputs.