While greedy and beam search produce deterministic outputs, sampling-based methods introduce controlled randomness to generate more diverse and creative text. This section covers temperature scaling, top-k sampling, and nucleus (top-p) sampling.
Why Sampling?
The Diversity Problem
📝text
1Greedy/Beam Search:
2 Input: "Tell me a joke"
3 Output: "Why did the chicken cross the road? To get to the other side."
4
5 Same input → Same output (always!)
6
7Sampling:
8 Input: "Tell me a joke"
9 Output 1: "What do you call a fish without eyes? A fsh!"
10 Output 2: "Why don't scientists trust atoms? They make up everything!"
11 Output 3: "I told my wife she was drawing her eyebrows too high..."
12
13 Same input → Different outputs (creative!)When to Use Sampling
| Task | Recommended |
|---|---|
| Machine Translation | Beam search (accuracy) |
| Summarization | Beam search (faithfulness) |
| Creative Writing | Sampling (diversity) |
| Dialogue | Sampling (naturalness) |
| Code Generation | Low temperature sampling |
Temperature Scaling
The Concept
Temperature controls the "sharpness" of the probability distribution:
📝text
1Softmax with temperature:
2 P(token_i) = exp(logit_i / T) / Σ exp(logit_j / T)
3
4T < 1.0: Sharper distribution (more confident, less random)
5T = 1.0: Original distribution
6T > 1.0: Flatter distribution (less confident, more random)
7T → 0: Becomes greedy (argmax)
8T → ∞: Becomes uniform randomVisual Example
🐍python
1import torch
2import torch.nn.functional as F
3
4
5def demonstrate_temperature():
6 """
7 Demonstrate how temperature affects probability distribution.
8 """
9 print("Temperature Scaling Demonstration")
10 print("=" * 60)
11
12 # Simulated logits
13 logits = torch.tensor([3.0, 2.5, 1.5, 0.5, 0.0, -1.0, -2.0, -3.0])
14 vocab = ["the", "a", "dog", "cat", "runs", "slowly", "very", "not"]
15
16 print("\nOriginal logits:")
17 for word, logit in zip(vocab, logits):
18 print(f" {word:8s}: {logit:.2f}")
19
20 print("\n" + "-" * 60)
21 print("Probability distributions at different temperatures:")
22 print("-" * 60)
23
24 temperatures = [0.5, 1.0, 1.5, 2.0]
25
26 for temp in temperatures:
27 probs = F.softmax(logits / temp, dim=0)
28 entropy = -(probs * torch.log(probs + 1e-10)).sum()
29
30 print(f"\nT = {temp}:")
31 for word, prob in sorted(zip(vocab, probs), key=lambda x: -x[1])[:4]:
32 bar = "█" * int(prob * 30)
33 print(f" {word:8s}: {prob:.3f} {bar}")
34 print(f" Entropy: {entropy:.3f}")Implementation
🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Optional
5
6
7class TemperatureSampler:
8 """
9 Sample from probability distribution with temperature scaling.
10 """
11
12 def __init__(
13 self,
14 temperature: float = 1.0,
15 bos_token_id: int = 2,
16 eos_token_id: int = 3,
17 pad_token_id: int = 0
18 ):
19 self.temperature = temperature
20 self.bos_token_id = bos_token_id
21 self.eos_token_id = eos_token_id
22 self.pad_token_id = pad_token_id
23
24 def sample(self, logits: torch.Tensor) -> torch.Tensor:
25 """
26 Sample from logits with temperature scaling.
27
28 Args:
29 logits: [batch, vocab_size]
30
31 Returns:
32 sampled_tokens: [batch]
33 """
34 # Apply temperature
35 scaled_logits = logits / self.temperature
36
37 # Convert to probabilities
38 probs = F.softmax(scaled_logits, dim=-1)
39
40 # Sample from distribution
41 sampled = torch.multinomial(probs, num_samples=1).squeeze(-1)
42
43 return sampledTop-K Sampling
The Concept
Top-K sampling restricts the candidate set to the K most probable tokens:
📝text
1Original distribution:
2 the: 0.40, a: 0.25, dog: 0.15, cat: 0.10, runs: 0.05, ...
3
4Top-K (K=3):
5 1. Keep only top 3: {the: 0.40, a: 0.25, dog: 0.15}
6 2. Renormalize: {the: 0.50, a: 0.31, dog: 0.19}
7 3. Sample from truncated distribution
8
9Benefits:
10 - Prevents sampling very unlikely tokens
11 - Maintains some randomness within top candidates
12 - Simple to implement and tuneImplementation
🐍python
1class TopKSampler:
2 """
3 Sample from top-K most probable tokens.
4 """
5
6 def __init__(
7 self,
8 k: int = 50,
9 temperature: float = 1.0,
10 bos_token_id: int = 2,
11 eos_token_id: int = 3
12 ):
13 self.k = k
14 self.temperature = temperature
15 self.bos_token_id = bos_token_id
16 self.eos_token_id = eos_token_id
17
18 def sample(self, logits: torch.Tensor) -> torch.Tensor:
19 """
20 Sample from top-K tokens.
21
22 Args:
23 logits: [batch, vocab_size]
24
25 Returns:
26 sampled_tokens: [batch]
27 """
28 batch_size = logits.size(0)
29
30 # Apply temperature
31 scaled_logits = logits / self.temperature
32
33 # Get top-K
34 top_k_values, top_k_indices = torch.topk(
35 scaled_logits, min(self.k, logits.size(-1)), dim=-1
36 )
37
38 # Zero out non-top-k (set to -inf)
39 filtered_logits = torch.full_like(logits, float('-inf'))
40 filtered_logits.scatter_(1, top_k_indices, top_k_values)
41
42 # Convert to probabilities
43 probs = F.softmax(filtered_logits, dim=-1)
44
45 # Sample
46 sampled = torch.multinomial(probs, num_samples=1).squeeze(-1)
47
48 return sampled
49
50 def sample_efficient(self, logits: torch.Tensor) -> torch.Tensor:
51 """
52 More efficient implementation: sample only from top-K.
53 """
54 # Apply temperature
55 scaled_logits = logits / self.temperature
56
57 # Get top-K
58 top_k_values, top_k_indices = torch.topk(scaled_logits, self.k, dim=-1)
59
60 # Softmax over top-K only
61 probs = F.softmax(top_k_values, dim=-1)
62
63 # Sample index within top-K
64 sampled_idx = torch.multinomial(probs, num_samples=1).squeeze(-1)
65
66 # Map back to vocabulary indices
67 sampled_tokens = top_k_indices.gather(1, sampled_idx.unsqueeze(1)).squeeze(-1)
68
69 return sampled_tokensNucleus (Top-P) Sampling
The Concept
Top-P sampling keeps the smallest set of tokens whose cumulative probability exceeds P:
📝text
1Original distribution (sorted):
2 the: 0.35 cumsum: 0.35
3 a: 0.25 cumsum: 0.60
4 dog: 0.18 cumsum: 0.78
5 cat: 0.12 cumsum: 0.90 ← Stop here for P=0.9
6 runs: 0.06 cumsum: 0.96
7 ...
8
9Top-P (P=0.9):
10 Keep: {the, a, dog, cat} (cumsum ≤ 0.90)
11 Renormalize and sample
12
13Advantage over Top-K:
14 - Adapts to distribution shape
15 - Peaked distribution: few tokens kept
16 - Flat distribution: many tokens keptImplementation
🐍python
1class NucleusSampler:
2 """
3 Nucleus (Top-P) sampling.
4
5 Keeps smallest set of tokens with cumulative probability >= p.
6 Adapts dynamically to the shape of the distribution.
7 """
8
9 def __init__(
10 self,
11 p: float = 0.9,
12 temperature: float = 1.0,
13 bos_token_id: int = 2,
14 eos_token_id: int = 3
15 ):
16 self.p = p
17 self.temperature = temperature
18 self.bos_token_id = bos_token_id
19 self.eos_token_id = eos_token_id
20
21 def sample(self, logits: torch.Tensor) -> torch.Tensor:
22 """
23 Sample from nucleus (top-p) of the distribution.
24
25 Args:
26 logits: [batch, vocab_size]
27
28 Returns:
29 sampled_tokens: [batch]
30 """
31 batch_size, vocab_size = logits.shape
32 device = logits.device
33
34 # Apply temperature
35 scaled_logits = logits / self.temperature
36
37 # Sort by descending probability
38 sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True)
39 sorted_probs = F.softmax(sorted_logits, dim=-1)
40
41 # Cumulative probabilities
42 cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
43
44 # Find cutoff: smallest set with cumsum > p
45 # Shift right to include the token that crosses the threshold
46 sorted_mask = cumsum_probs > self.p
47 sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
48 sorted_mask[..., 0] = False
49
50 # Zero out tokens outside nucleus
51 sorted_logits[sorted_mask] = float('-inf')
52
53 # Unsort back to original order
54 original_logits = torch.zeros_like(logits)
55 original_logits.scatter_(1, sorted_indices, sorted_logits)
56
57 # Sample from filtered distribution
58 probs = F.softmax(original_logits, dim=-1)
59 sampled = torch.multinomial(probs, num_samples=1).squeeze(-1)
60
61 return sampledTop-P vs Top-K Comparison
📝text
1For P=0.9:
2
3PEAKED distribution:
4 Top-P keeps: {"the", "a"} (2 tokens)
5 Top-K=4: {"the", "a", "dog", "cat"} (fixed 4 tokens)
6
7FLAT distribution:
8 Top-P keeps: {"the", "a", "dog", "cat", "runs", "slowly"} (6 tokens)
9 Top-K=4: only 4 tokens regardless
10
11Top-P adapts to distribution shape!Combined Sampler
Production Implementation
🐍python
1class CombinedSampler:
2 """
3 Combined sampling with temperature, top-k, and top-p.
4
5 Applies in order:
6 1. Temperature scaling
7 2. Top-K filtering
8 3. Top-P (nucleus) filtering
9 4. Sample from remaining distribution
10 """
11
12 def __init__(
13 self,
14 temperature: float = 1.0,
15 top_k: int = 0, # 0 = disabled
16 top_p: float = 1.0, # 1.0 = disabled
17 repetition_penalty: float = 1.0,
18 bos_token_id: int = 2,
19 eos_token_id: int = 3,
20 pad_token_id: int = 0
21 ):
22 self.temperature = temperature
23 self.top_k = top_k
24 self.top_p = top_p
25 self.repetition_penalty = repetition_penalty
26 self.bos_token_id = bos_token_id
27 self.eos_token_id = eos_token_id
28 self.pad_token_id = pad_token_id
29
30 def _apply_repetition_penalty(
31 self,
32 logits: torch.Tensor,
33 generated: torch.Tensor
34 ) -> torch.Tensor:
35 """Apply repetition penalty to previously generated tokens."""
36 for i in range(generated.size(0)):
37 for token in generated[i].unique():
38 if token.item() in [self.pad_token_id, self.bos_token_id]:
39 continue
40 if logits[i, token] > 0:
41 logits[i, token] /= self.repetition_penalty
42 else:
43 logits[i, token] *= self.repetition_penalty
44 return logits
45
46 def _apply_top_k(self, logits: torch.Tensor) -> torch.Tensor:
47 """Apply top-K filtering."""
48 if self.top_k <= 0:
49 return logits
50
51 # Find the top-k threshold
52 top_k = min(self.top_k, logits.size(-1))
53 threshold = torch.topk(logits, top_k, dim=-1).values[:, -1:]
54
55 # Zero out tokens below threshold
56 mask = logits < threshold
57 logits = logits.masked_fill(mask, float('-inf'))
58
59 return logits
60
61 def _apply_top_p(self, logits: torch.Tensor) -> torch.Tensor:
62 """Apply top-P (nucleus) filtering."""
63 if self.top_p >= 1.0:
64 return logits
65
66 sorted_logits, sorted_indices = torch.sort(logits, descending=True)
67 sorted_probs = F.softmax(sorted_logits, dim=-1)
68 cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
69
70 # Find where cumsum exceeds p
71 sorted_mask = cumsum_probs > self.top_p
72 sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
73 sorted_mask[..., 0] = False
74
75 # Zero out
76 sorted_logits[sorted_mask] = float('-inf')
77
78 # Unsort
79 original_logits = torch.zeros_like(logits)
80 original_logits.scatter_(1, sorted_indices, sorted_logits)
81
82 return original_logits
83
84 def sample(
85 self,
86 logits: torch.Tensor,
87 generated: Optional[torch.Tensor] = None
88 ) -> torch.Tensor:
89 """
90 Sample using combined strategy.
91
92 Args:
93 logits: [batch, vocab_size]
94 generated: Previously generated tokens for repetition penalty
95
96 Returns:
97 sampled_tokens: [batch]
98 """
99 # 1. Apply repetition penalty
100 if generated is not None and self.repetition_penalty != 1.0:
101 logits = self._apply_repetition_penalty(logits.clone(), generated)
102
103 # 2. Apply temperature
104 logits = logits / self.temperature
105
106 # 3. Apply top-K
107 logits = self._apply_top_k(logits)
108
109 # 4. Apply top-P
110 logits = self._apply_top_p(logits)
111
112 # 5. Sample
113 probs = F.softmax(logits, dim=-1)
114 sampled = torch.multinomial(probs, num_samples=1).squeeze(-1)
115
116 return sampledChoosing Sampling Parameters
Guidelines
📝text
1TEMPERATURE (0.0 - 2.0)
2────────────────────────
30.0 - 0.5: Very deterministic (nearly greedy)
40.5 - 0.8: Conservative, high-quality outputs
50.8 - 1.0: Balanced randomness
61.0 - 1.3: Creative, varied outputs
71.3 - 2.0: Very random, potentially incoherent
8
9Recommended: 0.7-1.0 for most applications
10
11TOP-K (0 - vocab_size)
12────────────────────────
130: Disabled
1410-20: Very restricted (safe but potentially boring)
1530-50: Balanced (common choice)
1650-100: More variety
17100+: Very permissive
18
19Recommended: 40-50 for general use
20
21TOP-P (0.0 - 1.0)
22────────────────────────
230.5-0.8: Restricted (more focused)
240.85-0.92: Balanced (common choice)
250.92-0.98: Permissive (more variety)
260.98-1.0: Nearly unrestricted
27
28Recommended: 0.9-0.95 for general use
29
30TASK-SPECIFIC RECOMMENDATIONS
31──────────────────────────────
32Translation: temp=0.0 (greedy) or beam search
33Summarization: temp=0.3, top_p=0.9
34Dialogue: temp=0.8, top_k=50, top_p=0.9
35Creative Writing: temp=1.0-1.2, top_k=0, top_p=0.95
36Code Generation: temp=0.2-0.5, top_p=0.95Summary
| Strategy | Parameters | Pros | Cons |
|---|---|---|---|
| Temperature | T ∈ (0, ∞) | Simple, intuitive | Can be too random |
| Top-K | K ∈ [1, V] | Prevents rare tokens | Fixed size |
| Top-P | P ∈ (0, 1] | Adapts to distribution | Slightly more complex |
| Combined | T, K, P | Best of all | More hyperparameters |
When to Use Each
- Temperature only: Quick experimentation
- Top-K only: When you want consistent candidate set size
- Top-P only: When distribution shape varies
- Combined: Production systems, fine-grained control
In the next section, we'll implement KV-caching—a crucial optimization that dramatically speeds up autoregressive generation by avoiding redundant computation.