Chapter 7
15 min read
Section 38 of 76

Choosing a Sampler

Improved Sampling Methods

Learning Objectives

By the end of this section, you will:

  1. Apply a decision framework for selecting the right sampler
  2. Interpret benchmark results to make informed choices
  3. Match samplers to use cases for optimal results
  4. Optimize sampler performance for your specific setup
  5. Troubleshoot common issues with sampling quality

Practical Guidance

This section synthesizes everything we've learned about samplers into actionable guidance. After reading this, you'll be able to confidently select and configure the best sampler for any diffusion model application.

Decision Framework

When choosing a sampler, ask yourself these questions in order:

Question 1: Speed or Quality Priority?

PriorityRecommended PathTypical Steps
Speed (real-time)DPM++ 2M or DDIM10-25
Quality (offline)DPM++ 2M or Euler30-50
Maximum qualityDDPM or DPM++ at high steps100-200+

Question 2: Determinism Required?

RequirementDeterministic SamplersStochastic Samplers
Reproducibility neededDDIM, Euler, DPM++ 2M-
Image editing/inversionDDIM (eta=0)-
Maximum diversity-DDPM, DPM++ SDE, Euler-a
Creative exploration-DPM++ SDE, Euler-a

Question 3: Guidance Scale?

When using classifier-free guidance (CFG), some samplers behave differently:

CFG ScaleRecommended SamplerNotes
Low (1-3)Any sampler worksMinimal impact
Medium (5-7)DPM++ 2M, EulerStandard range
High (10+)Euler, DDIMBetter stability
Very high (15+)Euler with clippingRisk of artifacts
🐍python
1def choose_sampler(
2    speed_priority: str = "balanced",  # "fast", "balanced", "quality"
3    deterministic: bool = True,
4    cfg_scale: float = 7.5,
5    use_case: str = "general"  # "general", "editing", "creative"
6) -> str:
7    """
8    Decision tree for sampler selection.
9
10    Returns recommended sampler name.
11    """
12    if use_case == "editing":
13        # Editing requires deterministic for inversion
14        return "ddim"
15
16    if speed_priority == "fast":
17        if deterministic:
18            return "dpm_pp_2m"  # 15-20 steps
19        else:
20            return "euler_a"  # 25-30 steps
21
22    elif speed_priority == "balanced":
23        if deterministic:
24            if cfg_scale > 12:
25                return "euler"  # More stable at high CFG
26            else:
27                return "dpm_pp_2m"  # Best speed/quality
28        else:
29            return "dpm_pp_sde"  # Good diversity
30
31    else:  # quality
32        if deterministic:
33            return "euler"  # 50+ steps for quality
34        else:
35            return "ddpm"  # 200+ steps, maximum diversity
36
37# Examples
38print(choose_sampler("fast", True, 7.5))       # -> dpm_pp_2m
39print(choose_sampler("balanced", False, 7.5))  # -> dpm_pp_sde
40print(choose_sampler("quality", True, 15.0))   # -> euler
41print(choose_sampler("balanced", True, 7.5, "editing"))  # -> ddim

Sampler Benchmarks

Here are comprehensive benchmarks comparing samplers on standard datasets:

CIFAR-10 (32x32) Results

SamplerStepsFIDTime/ImageNFE
DDPM10003.175.2s1000
DDIM504.160.26s50
DDIM1003.450.52s100
Euler504.820.26s50
Heun254.210.26s50
DPM-Solver203.890.11s20
DPM++ 2M203.420.11s20

ImageNet (256x256) Results

SamplerStepsFIDTime/ImageNFE
DDPM10004.5245s1000
DDIM505.832.3s50
DDIM2504.7111.5s250
DPM++ 2M255.121.2s25
DPM++ 2M504.682.3s50
UniPC205.050.92s20
🐍python
1import time
2from typing import Dict, List
3import torch
4import torch.nn.functional as F
5
6class SamplerBenchmark:
7    """
8    Benchmark samplers on quality and speed.
9    """
10
11    def __init__(
12        self,
13        model,
14        noise_schedule,
15        test_images: torch.Tensor,  # Real images for FID
16        device: str = "cuda"
17    ):
18        self.model = model
19        self.ns = noise_schedule
20        self.test_images = test_images.to(device)
21        self.device = device
22
23        # Initialize samplers
24        from unified_sampler import UnifiedSampler, SamplerType
25        self.unified = UnifiedSampler(model, noise_schedule.alphas_cumprod, device)
26
27    def benchmark_sampler(
28        self,
29        sampler_type: str,
30        step_counts: List[int],
31        num_samples: int = 1000,
32        batch_size: int = 50
33    ) -> Dict:
34        """
35        Benchmark a single sampler at various step counts.
36        """
37        results = {}
38
39        for steps in step_counts:
40            samples = []
41            total_time = 0
42
43            for i in range(0, num_samples, batch_size):
44                batch = min(batch_size, num_samples - i)
45
46                start = time.time()
47                batch_samples = self.unified.sample(
48                    shape=(batch, 3, 64, 64),
49                    sampler_type=sampler_type,
50                    num_steps=steps,
51                    progress=False
52                )
53                total_time += time.time() - start
54
55                samples.append(batch_samples)
56
57            samples = torch.cat(samples)
58
59            # Compute FID (simplified - real implementation would use Inception features)
60            fid = self._compute_fid_approximation(samples)
61
62            results[steps] = {
63                "fid": fid,
64                "total_time": total_time,
65                "time_per_image": total_time / num_samples,
66                "samples_per_second": num_samples / total_time
67            }
68
69            print(f"{sampler_type} @ {steps} steps: FID={fid:.2f}, {total_time:.2f}s total")
70
71        return results
72
73    def _compute_fid_approximation(self, samples: torch.Tensor) -> float:
74        """
75        Simplified FID approximation.
76        Real implementation should use Inception network.
77        """
78        # Compute statistics of generated samples
79        gen_mean = samples.mean(dim=[0, 2, 3])
80        gen_var = samples.var(dim=[0, 2, 3])
81
82        # Compare with real images
83        real_mean = self.test_images.mean(dim=[0, 2, 3])
84        real_var = self.test_images.var(dim=[0, 2, 3])
85
86        # Simple approximation (not real FID, just for illustration)
87        mean_diff = (gen_mean - real_mean).pow(2).sum()
88        var_diff = (gen_var.sqrt() - real_var.sqrt()).pow(2).sum()
89
90        return (mean_diff + var_diff).item()
91
92    def compare_all_samplers(
93        self,
94        target_time: float = 1.0,  # seconds per image
95        num_samples: int = 500
96    ) -> Dict:
97        """
98        Compare all samplers at similar computational budget.
99        """
100        # Determine appropriate step counts for similar time
101        step_configs = {
102            "ddim": 50,
103            "euler": 50,
104            "heun": 25,  # 2x NFE per step
105            "dpm_pp_2m": 20,
106            "dpm_pp_sde": 25,
107            "euler_a": 50,
108        }
109
110        results = {}
111
112        for sampler, steps in step_configs.items():
113            print(f"\nBenchmarking {sampler}...")
114
115            results[sampler] = self.benchmark_sampler(
116                sampler_type=sampler,
117                step_counts=[steps],
118                num_samples=num_samples
119            )
120
121        # Print summary
122        print("\n" + "=" * 60)
123        print("SUMMARY (similar time budget)")
124        print("=" * 60)
125        print(f"{'Sampler':<15} {'Steps':<8} {'FID':<10} {'Time/img':<10}")
126        print("-" * 60)
127
128        for sampler, data in sorted(results.items(), key=lambda x: list(x[1].values())[0]["fid"]):
129            info = list(data.values())[0]
130            steps = list(data.keys())[0]
131            print(f"{sampler:<15} {steps:<8} {info['fid']:<10.2f} {info['time_per_image']:<10.3f}s")
132
133        return results

Fair Comparison

When comparing samplers, always use NFE (Number of Function Evaluations)rather than steps. Heun uses 2 NFE per step, so 25 Heun steps should be compared with 50 Euler steps for fair evaluation.

Sampler Selection by Use Case

1. Production Image Generation API

RequirementChoiceRationale
SamplerDPM++ 2M KarrasBest speed/quality trade-off
Steps20-25Sub-second generation
CFG Scale7.5Standard value
ScheduleKarrasBetter for fine details
🐍python
1# Production configuration
2class ProductionConfig:
3    sampler = "dpm_pp_2m"
4    steps = 22
5    cfg_scale = 7.5
6    schedule = "karras"
7    batch_size = 4
8
9def generate_production(prompts, model, noise_schedule):
10    """Production-ready generation function."""
11    sampler = DPMPlusPlus2M(
12        model=model,
13        alphas_cumprod=noise_schedule.alphas_cumprod
14    )
15
16    images = sampler.sample(
17        shape=(len(prompts), 3, 512, 512),
18        num_steps=ProductionConfig.steps,
19        progress=False
20    )
21
22    return images

2. Image Editing Application

RequirementChoiceRationale
SamplerDDIM (eta=0)Required for inversion
Inversion Steps100-200High accuracy reconstruction
Sampling Steps50Quality regeneration
DeterministicYesReproducible edits
🐍python
1# Image editing configuration
2class EditingConfig:
3    inversion_sampler = "ddim"
4    inversion_steps = 100
5    sampling_steps = 50
6    eta = 0.0  # Must be deterministic
7
8def edit_image(image, edit_direction, strength, model, noise_schedule):
9    """Edit an image using DDIM inversion."""
10    # Invert
11    inverter = DDIMInverter(model, noise_schedule.alphas_cumprod)
12    x_T, _ = inverter.invert(image, num_steps=EditingConfig.inversion_steps)
13
14    # Apply edit
15    x_T_edited = x_T + strength * edit_direction
16
17    # Regenerate
18    sampler = DDIMSampler(
19        model, noise_schedule.alphas_cumprod,
20        config=DDIMConfig(eta=0.0)
21    )
22    edited = sampler.sample(
23        shape=x_T_edited.shape,
24        num_steps=EditingConfig.sampling_steps,
25        x_T=x_T_edited
26    )
27
28    return edited

3. Creative Exploration

RequirementChoiceRationale
SamplerDPM++ SDE or Euler-aDiversity through stochasticity
Steps25-35Balance speed and diversity
Noise Scale1.0Full stochastic effect
CFG Scale7-9Creative range

4. High-Quality Renders

RequirementChoiceRationale
SamplerDPM++ 2MHigh quality output
Steps50-100Maximum refinement
CFG Scale7-8Balanced guidance
ScheduleKarrasBetter fine details

Optimization Tips

Speed Optimizations

  1. Use torch.compile: 1.3-2x speedup on PyTorch 2.0+
  2. Enable mixed precision: FP16 inference with minimal quality loss
  3. Batch effectively: Maximize GPU utilization
  4. Use flash attention: Significant speedup for attention layers
🐍python
1import torch
2from torch.amp import autocast
3
4class OptimizedPipeline:
5    """
6    Optimized diffusion pipeline with all speedups.
7    """
8
9    def __init__(self, model, noise_schedule, device="cuda"):
10        self.device = device
11
12        # Compile model for speed
13        self.model = torch.compile(
14            model,
15            mode="reduce-overhead",
16            fullgraph=True
17        )
18
19        # Initialize sampler with compiled model
20        self.sampler = DPMPlusPlus2M(
21            model=self.model,
22            alphas_cumprod=noise_schedule.alphas_cumprod,
23            device=device
24        )
25
26    @torch.inference_mode()
27    def generate(
28        self,
29        batch_size: int = 4,
30        num_steps: int = 20,
31        use_amp: bool = True
32    ):
33        """Generate with all optimizations."""
34        shape = (batch_size, 3, 512, 512)
35
36        if use_amp:
37            with autocast('cuda', dtype=torch.float16):
38                samples = self.sampler.sample(
39                    shape=shape,
40                    num_steps=num_steps,
41                    progress=False
42                )
43        else:
44            samples = self.sampler.sample(
45                shape=shape,
46                num_steps=num_steps,
47                progress=False
48            )
49
50        return samples.float()  # Convert back to FP32 if needed
51
52
53# Benchmark optimization impact
54def benchmark_optimizations(model, noise_schedule):
55    """Compare speed with and without optimizations."""
56    import time
57
58    # Baseline
59    baseline_sampler = DPMPlusPlus2M(
60        model=model,
61        alphas_cumprod=noise_schedule.alphas_cumprod
62    )
63
64    start = time.time()
65    for _ in range(10):
66        baseline_sampler.sample((4, 3, 64, 64), num_steps=20, progress=False)
67    baseline_time = time.time() - start
68
69    # Optimized
70    optimized = OptimizedPipeline(model, noise_schedule)
71
72    # Warm up compile
73    optimized.generate(4, 20)
74
75    start = time.time()
76    for _ in range(10):
77        optimized.generate(4, 20)
78    optimized_time = time.time() - start
79
80    print(f"Baseline: {baseline_time:.2f}s")
81    print(f"Optimized: {optimized_time:.2f}s")
82    print(f"Speedup: {baseline_time / optimized_time:.2f}x")

Quality Optimizations

  • Use Karras sigmas: Better for fine details at low step counts
  • Enable x0 clipping: Prevents color saturation artifacts
  • Dynamic thresholding: From Imagen paper, helps at high CFG
  • EMA model: Always use EMA weights for sampling

Common Issues and Solutions

Issue 1: Blurry or Noisy Outputs

SymptomLikely CauseSolution
Very blurryToo few stepsIncrease steps to 25-50
Noisy/grainyWrong sigma scheduleUse Karras schedule
Washed out colorsMissing x0 clippingEnable clipping to [-1, 1]

Issue 2: Color Saturation or Artifacts

SymptomLikely CauseSolution
Oversaturated colorsCFG too highReduce to 7-8
Color bandingFP16 precision lossUse FP32 or AMP carefully
Repeated patternsPoor sampler at low stepsUse DPM++ 2M

Issue 3: Inconsistent Quality

🐍python
1def diagnose_sampling_issues(
2    model,
3    noise_schedule,
4    num_samples: int = 10
5) -> dict:
6    """
7    Diagnose common sampling issues.
8    """
9    issues = []
10
11    # Test 1: Check if model outputs are in expected range
12    x_t = torch.randn(1, 3, 64, 64, device="cuda")
13    t = torch.tensor([500], device="cuda")
14    eps = model(x_t, t)
15
16    eps_std = eps.std().item()
17    if eps_std < 0.5:
18        issues.append("Model outputs have low variance - check training")
19    if eps_std > 2.0:
20        issues.append("Model outputs have high variance - may cause instability")
21
22    # Test 2: Check sampler consistency
23    torch.manual_seed(42)
24    x_T = torch.randn(1, 3, 64, 64, device="cuda")
25
26    sampler = DDIMSampler(model, noise_schedule.alphas_cumprod, DDIMConfig(eta=0.0))
27
28    samples = []
29    for _ in range(3):
30        sample = sampler.sample(
31            shape=(1, 3, 64, 64),
32            num_steps=50,
33            x_T=x_T.clone(),
34            progress=False
35        )
36        samples.append(sample)
37
38    # Check if deterministic sampler is actually deterministic
39    var = torch.stack(samples).var(dim=0).mean().item()
40    if var > 1e-6:
41        issues.append("DDIM (eta=0) should be deterministic but isn't")
42
43    # Test 3: Check for NaN/Inf
44    test_sample = sampler.sample(
45        shape=(1, 3, 64, 64),
46        num_steps=50,
47        progress=False
48    )
49    if torch.isnan(test_sample).any():
50        issues.append("NaN values in output - check model and schedule")
51    if torch.isinf(test_sample).any():
52        issues.append("Inf values in output - numerical instability")
53
54    # Test 4: Check output range
55    if test_sample.min() < -3 or test_sample.max() > 3:
56        issues.append("Output range outside expected [-3, 3] - check x0 clipping")
57
58    return {
59        "issues": issues,
60        "eps_std": eps_std,
61        "output_range": (test_sample.min().item(), test_sample.max().item()),
62        "determinism_var": var
63    }

Chapter Summary

In this chapter, we've comprehensively covered improved sampling methods for diffusion models:

Key Takeaways

  1. DDPM limitations: Ancestral sampling requires 1000 steps, is stochastic (non-reproducible), and prevents latent space manipulation
  2. DDIM solution: Non-Markovian formulation enables deterministic sampling with arbitrary step counts using the same trained model
  3. DDIM applications: Inversion for encoding, semantic interpolation, and image editing become possible with deterministic sampling
  4. Advanced samplers: DPM-Solver, Euler/Heun, and their variants achieve even faster sampling (10-25 steps) through better ODE solving
  5. Sampler selection: Choose based on speed/quality trade-off, determinism requirements, and specific use case
Use CaseSamplerStepsKey Setting
Production APIDPM++ 2M Karras20-25eta=0, x0_clip=True
Image EditingDDIM50-100 (inv), 50 (gen)eta=0
Creative AppsDPM++ SDE25-35eta=1, s_noise=1
Maximum QualityEuler or DPM++ 2M50-100Karras schedule
DebuggingEuler50Simple, predictable

Part III Complete

With this chapter, we've completed Part III: Architecture and Implementation. You now have all the tools to build, train, and efficiently sample from diffusion models. Part IV will cover advanced topics including conditional generation, latent diffusion, and state-of-the-art applications.

The choice of sampler can make a 10-100x difference in generation speed with minimal impact on quality. By understanding the principles behind each method, you can make informed decisions that optimize for your specific requirements - whether that's real-time generation, maximum quality, or creative exploration.