Chapter 9
18 min read
Section 48 of 75

Sampling Strategies

Autoregressive Generation

While greedy and beam search produce deterministic outputs, sampling-based methods introduce controlled randomness to generate more diverse and creative text. This section covers temperature scaling, top-k sampling, and nucleus (top-p) sampling.


Why Sampling?

The Diversity Problem

📝text
1Greedy/Beam Search:
2  Input: "Tell me a joke"
3  Output: "Why did the chicken cross the road? To get to the other side."
4
5  Same input → Same output (always!)
6
7Sampling:
8  Input: "Tell me a joke"
9  Output 1: "What do you call a fish without eyes? A fsh!"
10  Output 2: "Why don't scientists trust atoms? They make up everything!"
11  Output 3: "I told my wife she was drawing her eyebrows too high..."
12
13  Same input → Different outputs (creative!)

When to Use Sampling

TaskRecommended
Machine TranslationBeam search (accuracy)
SummarizationBeam search (faithfulness)
Creative WritingSampling (diversity)
DialogueSampling (naturalness)
Code GenerationLow temperature sampling

Temperature Scaling

The Concept

Temperature controls the "sharpness" of the probability distribution:

📝text
1Softmax with temperature:
2  P(token_i) = exp(logit_i / T) / Σ exp(logit_j / T)
3
4T < 1.0: Sharper distribution (more confident, less random)
5T = 1.0: Original distribution
6T > 1.0: Flatter distribution (less confident, more random)
7T → 0:   Becomes greedy (argmax)
8T → ∞:   Becomes uniform random

Visual Example

🐍python
1import torch
2import torch.nn.functional as F
3
4
5def demonstrate_temperature():
6    """
7    Demonstrate how temperature affects probability distribution.
8    """
9    print("Temperature Scaling Demonstration")
10    print("=" * 60)
11
12    # Simulated logits
13    logits = torch.tensor([3.0, 2.5, 1.5, 0.5, 0.0, -1.0, -2.0, -3.0])
14    vocab = ["the", "a", "dog", "cat", "runs", "slowly", "very", "not"]
15
16    print("\nOriginal logits:")
17    for word, logit in zip(vocab, logits):
18        print(f"  {word:8s}: {logit:.2f}")
19
20    print("\n" + "-" * 60)
21    print("Probability distributions at different temperatures:")
22    print("-" * 60)
23
24    temperatures = [0.5, 1.0, 1.5, 2.0]
25
26    for temp in temperatures:
27        probs = F.softmax(logits / temp, dim=0)
28        entropy = -(probs * torch.log(probs + 1e-10)).sum()
29
30        print(f"\nT = {temp}:")
31        for word, prob in sorted(zip(vocab, probs), key=lambda x: -x[1])[:4]:
32            bar = "█" * int(prob * 30)
33            print(f"  {word:8s}: {prob:.3f} {bar}")
34        print(f"  Entropy: {entropy:.3f}")

Implementation

🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Optional
5
6
7class TemperatureSampler:
8    """
9    Sample from probability distribution with temperature scaling.
10    """
11
12    def __init__(
13        self,
14        temperature: float = 1.0,
15        bos_token_id: int = 2,
16        eos_token_id: int = 3,
17        pad_token_id: int = 0
18    ):
19        self.temperature = temperature
20        self.bos_token_id = bos_token_id
21        self.eos_token_id = eos_token_id
22        self.pad_token_id = pad_token_id
23
24    def sample(self, logits: torch.Tensor) -> torch.Tensor:
25        """
26        Sample from logits with temperature scaling.
27
28        Args:
29            logits: [batch, vocab_size]
30
31        Returns:
32            sampled_tokens: [batch]
33        """
34        # Apply temperature
35        scaled_logits = logits / self.temperature
36
37        # Convert to probabilities
38        probs = F.softmax(scaled_logits, dim=-1)
39
40        # Sample from distribution
41        sampled = torch.multinomial(probs, num_samples=1).squeeze(-1)
42
43        return sampled

Top-K Sampling

The Concept

Top-K sampling restricts the candidate set to the K most probable tokens:

📝text
1Original distribution:
2  the: 0.40, a: 0.25, dog: 0.15, cat: 0.10, runs: 0.05, ...
3
4Top-K (K=3):
5  1. Keep only top 3: {the: 0.40, a: 0.25, dog: 0.15}
6  2. Renormalize: {the: 0.50, a: 0.31, dog: 0.19}
7  3. Sample from truncated distribution
8
9Benefits:
10  - Prevents sampling very unlikely tokens
11  - Maintains some randomness within top candidates
12  - Simple to implement and tune

Implementation

🐍python
1class TopKSampler:
2    """
3    Sample from top-K most probable tokens.
4    """
5
6    def __init__(
7        self,
8        k: int = 50,
9        temperature: float = 1.0,
10        bos_token_id: int = 2,
11        eos_token_id: int = 3
12    ):
13        self.k = k
14        self.temperature = temperature
15        self.bos_token_id = bos_token_id
16        self.eos_token_id = eos_token_id
17
18    def sample(self, logits: torch.Tensor) -> torch.Tensor:
19        """
20        Sample from top-K tokens.
21
22        Args:
23            logits: [batch, vocab_size]
24
25        Returns:
26            sampled_tokens: [batch]
27        """
28        batch_size = logits.size(0)
29
30        # Apply temperature
31        scaled_logits = logits / self.temperature
32
33        # Get top-K
34        top_k_values, top_k_indices = torch.topk(
35            scaled_logits, min(self.k, logits.size(-1)), dim=-1
36        )
37
38        # Zero out non-top-k (set to -inf)
39        filtered_logits = torch.full_like(logits, float('-inf'))
40        filtered_logits.scatter_(1, top_k_indices, top_k_values)
41
42        # Convert to probabilities
43        probs = F.softmax(filtered_logits, dim=-1)
44
45        # Sample
46        sampled = torch.multinomial(probs, num_samples=1).squeeze(-1)
47
48        return sampled
49
50    def sample_efficient(self, logits: torch.Tensor) -> torch.Tensor:
51        """
52        More efficient implementation: sample only from top-K.
53        """
54        # Apply temperature
55        scaled_logits = logits / self.temperature
56
57        # Get top-K
58        top_k_values, top_k_indices = torch.topk(scaled_logits, self.k, dim=-1)
59
60        # Softmax over top-K only
61        probs = F.softmax(top_k_values, dim=-1)
62
63        # Sample index within top-K
64        sampled_idx = torch.multinomial(probs, num_samples=1).squeeze(-1)
65
66        # Map back to vocabulary indices
67        sampled_tokens = top_k_indices.gather(1, sampled_idx.unsqueeze(1)).squeeze(-1)
68
69        return sampled_tokens

Nucleus (Top-P) Sampling

The Concept

Top-P sampling keeps the smallest set of tokens whose cumulative probability exceeds P:

📝text
1Original distribution (sorted):
2  the: 0.35  cumsum: 0.35
3  a:   0.25  cumsum: 0.60
4  dog: 0.18  cumsum: 0.78
5  cat: 0.12  cumsum: 0.90  ← Stop here for P=0.9
6  runs: 0.06 cumsum: 0.96
7  ...
8
9Top-P (P=0.9):
10  Keep: {the, a, dog, cat} (cumsum ≤ 0.90)
11  Renormalize and sample
12
13Advantage over Top-K:
14  - Adapts to distribution shape
15  - Peaked distribution: few tokens kept
16  - Flat distribution: many tokens kept

Implementation

🐍python
1class NucleusSampler:
2    """
3    Nucleus (Top-P) sampling.
4
5    Keeps smallest set of tokens with cumulative probability >= p.
6    Adapts dynamically to the shape of the distribution.
7    """
8
9    def __init__(
10        self,
11        p: float = 0.9,
12        temperature: float = 1.0,
13        bos_token_id: int = 2,
14        eos_token_id: int = 3
15    ):
16        self.p = p
17        self.temperature = temperature
18        self.bos_token_id = bos_token_id
19        self.eos_token_id = eos_token_id
20
21    def sample(self, logits: torch.Tensor) -> torch.Tensor:
22        """
23        Sample from nucleus (top-p) of the distribution.
24
25        Args:
26            logits: [batch, vocab_size]
27
28        Returns:
29            sampled_tokens: [batch]
30        """
31        batch_size, vocab_size = logits.shape
32        device = logits.device
33
34        # Apply temperature
35        scaled_logits = logits / self.temperature
36
37        # Sort by descending probability
38        sorted_logits, sorted_indices = torch.sort(scaled_logits, descending=True)
39        sorted_probs = F.softmax(sorted_logits, dim=-1)
40
41        # Cumulative probabilities
42        cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
43
44        # Find cutoff: smallest set with cumsum > p
45        # Shift right to include the token that crosses the threshold
46        sorted_mask = cumsum_probs > self.p
47        sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
48        sorted_mask[..., 0] = False
49
50        # Zero out tokens outside nucleus
51        sorted_logits[sorted_mask] = float('-inf')
52
53        # Unsort back to original order
54        original_logits = torch.zeros_like(logits)
55        original_logits.scatter_(1, sorted_indices, sorted_logits)
56
57        # Sample from filtered distribution
58        probs = F.softmax(original_logits, dim=-1)
59        sampled = torch.multinomial(probs, num_samples=1).squeeze(-1)
60
61        return sampled

Top-P vs Top-K Comparison

📝text
1For P=0.9:
2
3PEAKED distribution:
4  Top-P keeps: {"the", "a"} (2 tokens)
5  Top-K=4: {"the", "a", "dog", "cat"} (fixed 4 tokens)
6
7FLAT distribution:
8  Top-P keeps: {"the", "a", "dog", "cat", "runs", "slowly"} (6 tokens)
9  Top-K=4: only 4 tokens regardless
10
11Top-P adapts to distribution shape!

Combined Sampler

Production Implementation

🐍python
1class CombinedSampler:
2    """
3    Combined sampling with temperature, top-k, and top-p.
4
5    Applies in order:
6    1. Temperature scaling
7    2. Top-K filtering
8    3. Top-P (nucleus) filtering
9    4. Sample from remaining distribution
10    """
11
12    def __init__(
13        self,
14        temperature: float = 1.0,
15        top_k: int = 0,  # 0 = disabled
16        top_p: float = 1.0,  # 1.0 = disabled
17        repetition_penalty: float = 1.0,
18        bos_token_id: int = 2,
19        eos_token_id: int = 3,
20        pad_token_id: int = 0
21    ):
22        self.temperature = temperature
23        self.top_k = top_k
24        self.top_p = top_p
25        self.repetition_penalty = repetition_penalty
26        self.bos_token_id = bos_token_id
27        self.eos_token_id = eos_token_id
28        self.pad_token_id = pad_token_id
29
30    def _apply_repetition_penalty(
31        self,
32        logits: torch.Tensor,
33        generated: torch.Tensor
34    ) -> torch.Tensor:
35        """Apply repetition penalty to previously generated tokens."""
36        for i in range(generated.size(0)):
37            for token in generated[i].unique():
38                if token.item() in [self.pad_token_id, self.bos_token_id]:
39                    continue
40                if logits[i, token] > 0:
41                    logits[i, token] /= self.repetition_penalty
42                else:
43                    logits[i, token] *= self.repetition_penalty
44        return logits
45
46    def _apply_top_k(self, logits: torch.Tensor) -> torch.Tensor:
47        """Apply top-K filtering."""
48        if self.top_k <= 0:
49            return logits
50
51        # Find the top-k threshold
52        top_k = min(self.top_k, logits.size(-1))
53        threshold = torch.topk(logits, top_k, dim=-1).values[:, -1:]
54
55        # Zero out tokens below threshold
56        mask = logits < threshold
57        logits = logits.masked_fill(mask, float('-inf'))
58
59        return logits
60
61    def _apply_top_p(self, logits: torch.Tensor) -> torch.Tensor:
62        """Apply top-P (nucleus) filtering."""
63        if self.top_p >= 1.0:
64            return logits
65
66        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
67        sorted_probs = F.softmax(sorted_logits, dim=-1)
68        cumsum_probs = torch.cumsum(sorted_probs, dim=-1)
69
70        # Find where cumsum exceeds p
71        sorted_mask = cumsum_probs > self.top_p
72        sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
73        sorted_mask[..., 0] = False
74
75        # Zero out
76        sorted_logits[sorted_mask] = float('-inf')
77
78        # Unsort
79        original_logits = torch.zeros_like(logits)
80        original_logits.scatter_(1, sorted_indices, sorted_logits)
81
82        return original_logits
83
84    def sample(
85        self,
86        logits: torch.Tensor,
87        generated: Optional[torch.Tensor] = None
88    ) -> torch.Tensor:
89        """
90        Sample using combined strategy.
91
92        Args:
93            logits: [batch, vocab_size]
94            generated: Previously generated tokens for repetition penalty
95
96        Returns:
97            sampled_tokens: [batch]
98        """
99        # 1. Apply repetition penalty
100        if generated is not None and self.repetition_penalty != 1.0:
101            logits = self._apply_repetition_penalty(logits.clone(), generated)
102
103        # 2. Apply temperature
104        logits = logits / self.temperature
105
106        # 3. Apply top-K
107        logits = self._apply_top_k(logits)
108
109        # 4. Apply top-P
110        logits = self._apply_top_p(logits)
111
112        # 5. Sample
113        probs = F.softmax(logits, dim=-1)
114        sampled = torch.multinomial(probs, num_samples=1).squeeze(-1)
115
116        return sampled

Choosing Sampling Parameters

Guidelines

📝text
1TEMPERATURE (0.0 - 2.0)
2────────────────────────
30.0 - 0.5:  Very deterministic (nearly greedy)
40.5 - 0.8:  Conservative, high-quality outputs
50.8 - 1.0:  Balanced randomness
61.0 - 1.3:  Creative, varied outputs
71.3 - 2.0:  Very random, potentially incoherent
8
9Recommended: 0.7-1.0 for most applications
10
11TOP-K (0 - vocab_size)
12────────────────────────
130:          Disabled
1410-20:      Very restricted (safe but potentially boring)
1530-50:      Balanced (common choice)
1650-100:     More variety
17100+:       Very permissive
18
19Recommended: 40-50 for general use
20
21TOP-P (0.0 - 1.0)
22────────────────────────
230.5-0.8:    Restricted (more focused)
240.85-0.92:  Balanced (common choice)
250.92-0.98:  Permissive (more variety)
260.98-1.0:   Nearly unrestricted
27
28Recommended: 0.9-0.95 for general use
29
30TASK-SPECIFIC RECOMMENDATIONS
31──────────────────────────────
32Translation:     temp=0.0 (greedy) or beam search
33Summarization:   temp=0.3, top_p=0.9
34Dialogue:        temp=0.8, top_k=50, top_p=0.9
35Creative Writing: temp=1.0-1.2, top_k=0, top_p=0.95
36Code Generation: temp=0.2-0.5, top_p=0.95

Summary

StrategyParametersProsCons
TemperatureT ∈ (0, ∞)Simple, intuitiveCan be too random
Top-KK ∈ [1, V]Prevents rare tokensFixed size
Top-PP ∈ (0, 1]Adapts to distributionSlightly more complex
CombinedT, K, PBest of allMore hyperparameters

When to Use Each

  • Temperature only: Quick experimentation
  • Top-K only: When you want consistent candidate set size
  • Top-P only: When distribution shape varies
  • Combined: Production systems, fine-grained control

In the next section, we'll implement KV-caching—a crucial optimization that dramatically speeds up autoregressive generation by avoiding redundant computation.
Loading comments...