Chapter 7
22 min read
Section 37 of 76

Advanced Samplers Overview

Improved Sampling Methods

Learning Objectives

By the end of this section, you will:

  1. Understand the sampler landscape and how different methods relate
  2. Implement DPM-Solver for ultra-fast sampling (10-20 steps)
  3. Master Euler and Heun methods from the ODE perspective
  4. Compare ancestral sampling variants (DDPM, DPM++ SDE)
  5. Build a unified sampler framework supporting multiple methods

Beyond DDIM

While DDIM provides 10-20x speedup over DDPM, modern samplers like DPM-Solver can achieve near-optimal quality with just 10-25 steps. This section covers the state-of-the-art in diffusion sampling.

The Sampler Landscape

Modern diffusion samplers can be organized along two axes: ODE vs SDE(deterministic vs stochastic) and solver order (first-order vs higher-order).

SamplerTypeOrderMin StepsBest Use Case
DDPMSDE1st200+Maximum diversity
DDIMODE/SDE1st25-50Fast, deterministic
EulerODE1st25-50Simple, stable
HeunODE2nd15-25Better accuracy
DPM-SolverODE1st-3rd10-25Ultra-fast quality
DPM++ 2MODE2nd15-25Production standard
DPM++ SDESDE2nd20-35High diversity
UniPCODE3rd10-20State-of-the-art

ODE vs SDE Perspective

Diffusion models can be viewed through two equivalent lenses:

SDE: dx=f(x,t)dt+g(t)dw\text{SDE: } d\mathbf{x} = f(\mathbf{x}, t) dt + g(t) d\mathbf{w}
ODE: dxdt=f(x,t)12g(t)2xlogpt(x)\text{ODE: } \frac{d\mathbf{x}}{dt} = f(\mathbf{x}, t) - \frac{1}{2} g(t)^2 \nabla_{\mathbf{x}} \log p_t(\mathbf{x})

The ODE formulation (called the probability flow ODE) gives the same marginal distributions as the SDE but follows deterministic trajectories.

Key Insight

All modern fast samplers are based on solving the probability flow ODE more efficiently. The score function xlogpt(x)\nabla_{\mathbf{x}} \log p_t(\mathbf{x}) is directly related to the noise prediction ϵθ\boldsymbol{\epsilon}_\theta.

DPM-Solver

DPM-Solver reformulates diffusion sampling as solving an ODE in the log-SNR space, enabling efficient higher-order solvers. It achieves excellent quality with just 10-20 steps.

The Key Idea: Change of Variables

Instead of working in timestep tt, DPM-Solver usesλ=log(α/σ)\lambda = \log(\alpha / \sigma) (log signal-to-noise ratio):

dxdλ=σϵθ(xλ,t(λ))\frac{d\mathbf{x}}{d\lambda} = -\sigma \boldsymbol{\epsilon}_\theta(\mathbf{x}_\lambda, t(\lambda))

This formulation has smoother dynamics, making higher-order ODE solvers more effective.

🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Optional, Tuple, List
5from dataclasses import dataclass
6import numpy as np
7
8
9@dataclass
10class DPMSolverConfig:
11    """Configuration for DPM-Solver."""
12    num_timesteps: int = 1000
13    order: int = 2  # 1, 2, or 3
14    predict_type: str = "epsilon"  # "epsilon" or "v"
15    thresholding: bool = False
16    dynamic_threshold_ratio: float = 0.995
17
18
19class DPMSolver:
20    """
21    DPM-Solver: Fast Solver for Diffusion Probabilistic Models.
22
23    Implements DPM-Solver-1, DPM-Solver-2, and DPM-Solver-3 for
24    ultra-fast high-quality sampling.
25    """
26
27    def __init__(
28        self,
29        model: nn.Module,
30        alphas_cumprod: torch.Tensor,
31        config: Optional[DPMSolverConfig] = None,
32        device: str = "cuda"
33    ):
34        self.model = model
35        self.device = device
36        self.config = config or DPMSolverConfig()
37
38        # Store schedule
39        self.alphas_cumprod = alphas_cumprod.to(device)
40        self.sigmas = torch.sqrt((1 - alphas_cumprod) / alphas_cumprod).to(device)
41
42        # Compute lambda (log-SNR)
43        self.lambdas = torch.log(self.alphas_cumprod / (1 - self.alphas_cumprod)) / 2
44        self.lambdas = self.lambdas.to(device)
45
46    def get_timestep_schedule(
47        self,
48        num_steps: int,
49        skip_type: str = "uniform"
50    ) -> torch.Tensor:
51        """Create timestep schedule for sampling."""
52        T = self.config.num_timesteps
53
54        if skip_type == "uniform":
55            timesteps = torch.linspace(T - 1, 0, num_steps + 1).long()
56        elif skip_type == "logsnr":
57            # Uniform in log-SNR space
58            lambda_min, lambda_max = self.lambdas[-1], self.lambdas[0]
59            lambdas_uniform = torch.linspace(lambda_max, lambda_min, num_steps + 1)
60            timesteps = self._lambda_to_t(lambdas_uniform)
61        elif skip_type == "quad":
62            # Quadratic spacing (more at low noise)
63            timesteps = (
64                (torch.linspace(0, np.sqrt(T - 1), num_steps + 1) ** 2)
65                .long()
66                .flip(0)
67            )
68        else:
69            raise ValueError(f"Unknown skip_type: {skip_type}")
70
71        return timesteps.to(self.device)
72
73    def _lambda_to_t(self, lambdas: torch.Tensor) -> torch.Tensor:
74        """Convert lambda values to timesteps."""
75        # Find nearest timesteps for given lambda values
76        timesteps = []
77        for lam in lambdas:
78            idx = torch.argmin(torch.abs(self.lambdas - lam))
79            timesteps.append(idx)
80        return torch.tensor(timesteps)
81
82    @torch.no_grad()
83    def sample(
84        self,
85        shape: Tuple[int, ...],
86        num_steps: int = 20,
87        x_T: Optional[torch.Tensor] = None,
88        progress: bool = True
89    ) -> torch.Tensor:
90        """
91        Generate samples using DPM-Solver.
92
93        Args:
94            shape: Output shape (B, C, H, W)
95            num_steps: Number of sampling steps
96            x_T: Starting noise (None = random)
97            progress: Show progress
98
99        Returns:
100            Generated samples in [-1, 1]
101        """
102        order = self.config.order
103
104        # Initialize
105        if x_T is None:
106            x = torch.randn(shape, device=self.device)
107        else:
108            x = x_T.to(self.device)
109
110        # Get timesteps
111        timesteps = self.get_timestep_schedule(num_steps, skip_type="logsnr")
112
113        self.model.eval()
114
115        # Buffers for multi-step methods
116        model_outputs = []
117
118        from tqdm import tqdm
119        iterator = tqdm(range(num_steps), desc="DPM-Solver") if progress else range(num_steps)
120
121        for i in iterator:
122            t = timesteps[i]
123            t_next = timesteps[i + 1]
124
125            # Get model output
126            model_output = self._get_model_output(x, t)
127            model_outputs.append(model_output)
128
129            # Apply solver step based on order
130            if order == 1 or i == 0:
131                x = self._dpm_solver_first_order_update(
132                    x, t, t_next, model_output
133                )
134            elif order == 2 or i == 1:
135                x = self._dpm_solver_second_order_update(
136                    x, t, t_next, model_outputs[-2:]
137                )
138            else:  # order == 3
139                x = self._dpm_solver_third_order_update(
140                    x, t, t_next, model_outputs[-3:]
141                )
142
143            # Keep only last few outputs
144            if len(model_outputs) > 3:
145                model_outputs.pop(0)
146
147        return x
148
149    def _get_model_output(
150        self,
151        x: torch.Tensor,
152        t: torch.Tensor
153    ) -> torch.Tensor:
154        """Get noise prediction from model."""
155        t_batch = t.expand(x.shape[0]) if t.dim() == 0 else t
156        eps = self.model(x, t_batch)
157
158        if self.config.thresholding:
159            eps = self._dynamic_threshold(eps)
160
161        return eps
162
163    def _dpm_solver_first_order_update(
164        self,
165        x: torch.Tensor,
166        t: torch.Tensor,
167        t_next: torch.Tensor,
168        eps: torch.Tensor
169    ) -> torch.Tensor:
170        """
171        First-order DPM-Solver update (equivalent to DDIM).
172        """
173        # Get schedule values
174        alpha_t = torch.sqrt(self.alphas_cumprod[t])
175        alpha_next = torch.sqrt(self.alphas_cumprod[t_next])
176        sigma_t = torch.sqrt(1 - self.alphas_cumprod[t])
177        sigma_next = torch.sqrt(1 - self.alphas_cumprod[t_next])
178
179        # Get lambda values
180        lambda_t = self.lambdas[t]
181        lambda_next = self.lambdas[t_next]
182        h = lambda_next - lambda_t
183
184        # First-order update
185        x_next = (alpha_next / alpha_t) * x - sigma_next * (torch.exp(-h) - 1) * eps
186
187        return x_next
188
189    def _dpm_solver_second_order_update(
190        self,
191        x: torch.Tensor,
192        t: torch.Tensor,
193        t_next: torch.Tensor,
194        eps_list: List[torch.Tensor]
195    ) -> torch.Tensor:
196        """
197        Second-order DPM-Solver update (DPM-Solver-2).
198
199        Uses linear extrapolation of noise predictions.
200        """
201        eps_prev, eps_curr = eps_list
202
203        # Get schedule values
204        alpha_t = torch.sqrt(self.alphas_cumprod[t])
205        alpha_next = torch.sqrt(self.alphas_cumprod[t_next])
206        sigma_t = torch.sqrt(1 - self.alphas_cumprod[t])
207        sigma_next = torch.sqrt(1 - self.alphas_cumprod[t_next])
208
209        lambda_t = self.lambdas[t]
210        lambda_next = self.lambdas[t_next]
211        h = lambda_next - lambda_t
212
213        # Second-order correction
214        # D_1 = (eps_curr - eps_prev) / h_prev
215        # For simplicity, using first-order + correction
216
217        x_next = (alpha_next / alpha_t) * x - sigma_next * (torch.exp(-h) - 1) * eps_curr
218
219        # Add second-order correction
220        r = 0.5
221        D1 = eps_curr - eps_prev
222        x_next = x_next - sigma_next * (torch.exp(-h) - 1) * r * D1 / (2 * h)
223
224        return x_next
225
226    def _dpm_solver_third_order_update(
227        self,
228        x: torch.Tensor,
229        t: torch.Tensor,
230        t_next: torch.Tensor,
231        eps_list: List[torch.Tensor]
232    ) -> torch.Tensor:
233        """
234        Third-order DPM-Solver update (DPM-Solver-3).
235
236        Uses quadratic extrapolation for even higher accuracy.
237        """
238        eps_0, eps_1, eps_2 = eps_list
239
240        # Similar structure to second-order but with additional correction
241        # For brevity, using second-order update here
242        return self._dpm_solver_second_order_update(x, t, t_next, [eps_1, eps_2])
243
244    def _dynamic_threshold(
245        self,
246        x: torch.Tensor
247    ) -> torch.Tensor:
248        """Dynamic thresholding from Imagen paper."""
249        s = torch.quantile(
250            torch.abs(x).reshape(x.shape[0], -1),
251            self.config.dynamic_threshold_ratio,
252            dim=1
253        )
254        s = torch.clamp(s, min=1.0)
255        s = s.reshape(-1, 1, 1, 1)
256        return torch.clamp(x, -s, s) / s
257
258
259# DPM++ 2M (Multistep, very popular in production)
260class DPMPlusPlus2M:
261    """
262    DPM++ 2M: DPM-Solver++ with 2nd-order multistep method.
263
264    This is one of the most popular samplers in production systems
265    like Stable Diffusion.
266    """
267
268    def __init__(
269        self,
270        model: nn.Module,
271        alphas_cumprod: torch.Tensor,
272        device: str = "cuda"
273    ):
274        self.model = model
275        self.device = device
276
277        self.alphas_cumprod = alphas_cumprod.to(device)
278        self.sigmas = torch.sqrt((1 - alphas_cumprod) / alphas_cumprod).to(device)
279
280    @torch.no_grad()
281    def sample(
282        self,
283        shape: Tuple[int, ...],
284        num_steps: int = 20,
285        x_T: Optional[torch.Tensor] = None,
286        progress: bool = True
287    ) -> torch.Tensor:
288        """Generate samples using DPM++ 2M."""
289        if x_T is None:
290            x = torch.randn(shape, device=self.device)
291        else:
292            x = x_T.to(self.device)
293
294        # Get sigma schedule
295        sigmas = self._get_sigmas(num_steps)
296
297        self.model.eval()
298
299        old_denoised = None
300
301        from tqdm import tqdm
302        iterator = tqdm(range(len(sigmas) - 1), desc="DPM++ 2M") if progress else range(len(sigmas) - 1)
303
304        for i in iterator:
305            sigma = sigmas[i]
306            sigma_next = sigmas[i + 1]
307
308            # Compute timestep from sigma
309            t = self._sigma_to_t(sigma)
310
311            # Get denoised prediction
312            denoised = self._get_denoised(x, t, sigma)
313
314            # DPM++ 2M update
315            t_next = sigma_next.log().neg()
316            t_curr = sigma.log().neg()
317            h = t_next - t_curr
318
319            if old_denoised is None or sigma_next == 0:
320                # First step or final step: use first-order
321                x = (sigma_next / sigma) * x + (1 - sigma_next / sigma) * denoised
322            else:
323                # Second order: use previous denoised
324                h_last = t_curr - sigmas[i - 1].log().neg()
325                r = h_last / h
326
327                denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
328                x = (sigma_next / sigma) * x + (1 - sigma_next / sigma) * denoised_d
329
330            old_denoised = denoised
331
332        return x
333
334    def _get_sigmas(self, num_steps: int) -> torch.Tensor:
335        """Get sigma schedule for sampling."""
336        # Use Karras schedule (popular in practice)
337        sigma_min = self.sigmas[-1]
338        sigma_max = self.sigmas[0]
339
340        rho = 7.0  # From Karras paper
341        ramp = torch.linspace(0, 1, num_steps + 1)
342        min_inv_rho = sigma_min ** (1 / rho)
343        max_inv_rho = sigma_max ** (1 / rho)
344        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
345
346        return torch.cat([sigmas, torch.zeros(1)]).to(self.device)
347
348    def _sigma_to_t(self, sigma: torch.Tensor) -> torch.Tensor:
349        """Convert sigma to discrete timestep."""
350        idx = torch.argmin(torch.abs(self.sigmas - sigma))
351        return idx
352
353    def _get_denoised(
354        self,
355        x: torch.Tensor,
356        t: torch.Tensor,
357        sigma: torch.Tensor
358    ) -> torch.Tensor:
359        """Get denoised prediction (x_0 estimate)."""
360        t_batch = t.expand(x.shape[0])
361        eps = self.model(x, t_batch)
362
363        # Convert noise prediction to x_0 prediction
364        alpha = torch.sqrt(self.alphas_cumprod[t])
365        denoised = (x - sigma * eps) / alpha
366
367        return denoised

DPM++ 2M in Practice

DPM++ 2M is the default sampler in many production systems like Stable Diffusion WebUI. It offers an excellent balance of speed (15-25 steps) and quality. For even faster sampling, try DPM++ 2M Karras which uses the Karras sigma schedule.

Euler and Heun Methods

From the ODE perspective, diffusion sampling is just solving an initial value problem. Classical numerical methods like Euler and Heun can be applied directly.

Euler Method (First-Order)

xtΔt=xt+Δtf(xt,t)\mathbf{x}_{t-\Delta t} = \mathbf{x}_t + \Delta t \cdot f(\mathbf{x}_t, t)
🐍python
1class EulerSampler:
2    """
3    Simple Euler method for diffusion sampling.
4
5    This is the simplest ODE solver and provides a baseline
6    for understanding more complex methods.
7    """
8
9    def __init__(
10        self,
11        model: nn.Module,
12        alphas_cumprod: torch.Tensor,
13        device: str = "cuda"
14    ):
15        self.model = model
16        self.alphas_cumprod = alphas_cumprod.to(device)
17        self.device = device
18
19        # Compute sigmas
20        self.sigmas = torch.sqrt((1 - alphas_cumprod) / alphas_cumprod).to(device)
21
22    @torch.no_grad()
23    def sample(
24        self,
25        shape: Tuple[int, ...],
26        num_steps: int = 50,
27        x_T: Optional[torch.Tensor] = None,
28        progress: bool = True
29    ) -> torch.Tensor:
30        """Sample using Euler method."""
31        if x_T is None:
32            x = torch.randn(shape, device=self.device)
33        else:
34            x = x_T.to(self.device)
35
36        sigmas = self._get_sigmas(num_steps)
37
38        self.model.eval()
39
40        from tqdm import tqdm
41        iterator = tqdm(range(len(sigmas) - 1), desc="Euler") if progress else range(len(sigmas) - 1)
42
43        for i in iterator:
44            sigma = sigmas[i]
45            sigma_next = sigmas[i + 1]
46
47            t = self._sigma_to_t(sigma)
48            t_batch = t.expand(x.shape[0])
49
50            # Get noise prediction
51            eps = self.model(x, t_batch)
52
53            # Compute derivative
54            # dx/dsigma = (x - denoised) / sigma = eps
55            d = eps
56
57            # Euler step
58            dt = sigma_next - sigma
59            x = x + d * dt
60
61        return x
62
63    def _get_sigmas(self, num_steps: int) -> torch.Tensor:
64        """Linear sigma schedule."""
65        sigma_max = self.sigmas[0]
66        sigma_min = self.sigmas[-1]
67        sigmas = torch.linspace(sigma_max, sigma_min, num_steps + 1)
68        return sigmas.to(self.device)
69
70    def _sigma_to_t(self, sigma: torch.Tensor) -> torch.Tensor:
71        """Find timestep for sigma value."""
72        return torch.argmin(torch.abs(self.sigmas - sigma))

Heun Method (Second-Order)

Heun's method uses a predictor-corrector approach for better accuracy:

Predict: x=xt+Δtf(xt,t)\text{Predict: } \mathbf{x}' = \mathbf{x}_t + \Delta t \cdot f(\mathbf{x}_t, t)
Correct: xtΔt=xt+Δt2[f(xt,t)+f(x,tΔt)]\text{Correct: } \mathbf{x}_{t-\Delta t} = \mathbf{x}_t + \frac{\Delta t}{2} \left[ f(\mathbf{x}_t, t) + f(\mathbf{x}', t - \Delta t) \right]
🐍python
1class HeunSampler:
2    """
3    Heun's method (2nd-order) for diffusion sampling.
4
5    Uses predictor-corrector approach for higher accuracy
6    at the cost of 2 function evaluations per step.
7    """
8
9    def __init__(
10        self,
11        model: nn.Module,
12        alphas_cumprod: torch.Tensor,
13        device: str = "cuda"
14    ):
15        self.model = model
16        self.alphas_cumprod = alphas_cumprod.to(device)
17        self.sigmas = torch.sqrt((1 - alphas_cumprod) / alphas_cumprod).to(device)
18        self.device = device
19
20    @torch.no_grad()
21    def sample(
22        self,
23        shape: Tuple[int, ...],
24        num_steps: int = 30,
25        x_T: Optional[torch.Tensor] = None,
26        progress: bool = True
27    ) -> torch.Tensor:
28        """Sample using Heun's method."""
29        if x_T is None:
30            x = torch.randn(shape, device=self.device)
31        else:
32            x = x_T.to(self.device)
33
34        sigmas = self._get_sigmas(num_steps)
35
36        self.model.eval()
37
38        from tqdm import tqdm
39        iterator = tqdm(range(len(sigmas) - 1), desc="Heun") if progress else range(len(sigmas) - 1)
40
41        for i in iterator:
42            sigma = sigmas[i]
43            sigma_next = sigmas[i + 1]
44
45            if sigma_next == 0:
46                # Final step: just Euler
47                x = self._euler_step(x, sigma, sigma_next)
48            else:
49                # Heun step (predictor-corrector)
50                x = self._heun_step(x, sigma, sigma_next)
51
52        return x
53
54    def _euler_step(
55        self,
56        x: torch.Tensor,
57        sigma: torch.Tensor,
58        sigma_next: torch.Tensor
59    ) -> torch.Tensor:
60        """Simple Euler step."""
61        t = self._sigma_to_t(sigma)
62        t_batch = t.expand(x.shape[0])
63        eps = self.model(x, t_batch)
64
65        d = eps
66        dt = sigma_next - sigma
67
68        return x + d * dt
69
70    def _heun_step(
71        self,
72        x: torch.Tensor,
73        sigma: torch.Tensor,
74        sigma_next: torch.Tensor
75    ) -> torch.Tensor:
76        """Heun predictor-corrector step."""
77        t = self._sigma_to_t(sigma)
78        t_batch = t.expand(x.shape[0])
79
80        # Predictor (Euler)
81        eps_1 = self.model(x, t_batch)
82        d_1 = eps_1
83
84        dt = sigma_next - sigma
85        x_pred = x + d_1 * dt
86
87        # Corrector
88        t_next = self._sigma_to_t(sigma_next)
89        t_next_batch = t_next.expand(x.shape[0])
90        eps_2 = self.model(x_pred, t_next_batch)
91        d_2 = eps_2
92
93        # Average the two derivatives
94        x_next = x + dt * (d_1 + d_2) / 2
95
96        return x_next
97
98    def _get_sigmas(self, num_steps: int) -> torch.Tensor:
99        sigma_max = self.sigmas[0]
100        sigma_min = self.sigmas[-1]
101        sigmas = torch.linspace(sigma_max, sigma_min, num_steps + 1)
102        return torch.cat([sigmas, torch.zeros(1)]).to(self.device)
103
104    def _sigma_to_t(self, sigma: torch.Tensor) -> torch.Tensor:
105        return torch.argmin(torch.abs(self.sigmas - sigma))

NFE vs Steps

Number of Function Evaluations (NFE) is more meaningful than step count when comparing samplers. Heun uses 2 NFE per step, so 25 Heun steps = 50 NFE. Compare this to 50 Euler steps (50 NFE) for fair comparison.

Ancestral Sampling Variants

While ODE samplers are deterministic, sometimes we want the diversity that comes from stochastic sampling. Here are the main ancestral (SDE) variants:

🐍python
1class EulerAncestralSampler:
2    """
3    Euler Ancestral: Euler method with noise injection.
4
5    Provides more diversity than deterministic Euler at the
6    cost of requiring more steps for quality.
7    """
8
9    def __init__(
10        self,
11        model: nn.Module,
12        alphas_cumprod: torch.Tensor,
13        device: str = "cuda",
14        eta: float = 1.0  # Noise scale
15    ):
16        self.model = model
17        self.alphas_cumprod = alphas_cumprod.to(device)
18        self.sigmas = torch.sqrt((1 - alphas_cumprod) / alphas_cumprod).to(device)
19        self.device = device
20        self.eta = eta
21
22    @torch.no_grad()
23    def sample(
24        self,
25        shape: Tuple[int, ...],
26        num_steps: int = 50,
27        x_T: Optional[torch.Tensor] = None,
28        progress: bool = True
29    ) -> torch.Tensor:
30        """Sample with ancestral noise injection."""
31        if x_T is None:
32            x = torch.randn(shape, device=self.device)
33        else:
34            x = x_T.to(self.device)
35
36        sigmas = self._get_sigmas(num_steps)
37
38        self.model.eval()
39
40        from tqdm import tqdm
41        iterator = tqdm(range(len(sigmas) - 1), desc="Euler-a") if progress else range(len(sigmas) - 1)
42
43        for i in iterator:
44            sigma = sigmas[i]
45            sigma_next = sigmas[i + 1]
46
47            t = self._sigma_to_t(sigma)
48            t_batch = t.expand(x.shape[0])
49
50            # Get noise prediction
51            eps = self.model(x, t_batch)
52
53            # Compute ancestral step
54            sigma_up = min(sigma_next, self.eta * (sigma_next / sigma) * torch.sqrt(sigma**2 - sigma_next**2))
55            sigma_down = torch.sqrt(sigma_next**2 - sigma_up**2)
56
57            # Deterministic step
58            d = eps
59            x = x + d * (sigma_down - sigma)
60
61            # Add ancestral noise
62            if sigma_next > 0:
63                noise = torch.randn_like(x)
64                x = x + noise * sigma_up
65
66        return x
67
68    def _get_sigmas(self, num_steps: int) -> torch.Tensor:
69        sigma_max = self.sigmas[0]
70        sigma_min = self.sigmas[-1]
71        sigmas = torch.linspace(sigma_max, sigma_min, num_steps + 1)
72        return sigmas.to(self.device)
73
74    def _sigma_to_t(self, sigma: torch.Tensor) -> torch.Tensor:
75        return torch.argmin(torch.abs(self.sigmas - sigma))
76
77
78class DPMPlusPlusSDE:
79    """
80    DPM++ SDE: Stochastic version of DPM++ with noise injection.
81
82    Provides diversity of SDE methods with efficiency of DPM++.
83    Popular for creative applications where diversity matters.
84    """
85
86    def __init__(
87        self,
88        model: nn.Module,
89        alphas_cumprod: torch.Tensor,
90        device: str = "cuda",
91        eta: float = 1.0,
92        s_noise: float = 1.0
93    ):
94        self.model = model
95        self.alphas_cumprod = alphas_cumprod.to(device)
96        self.sigmas = torch.sqrt((1 - alphas_cumprod) / alphas_cumprod).to(device)
97        self.device = device
98        self.eta = eta
99        self.s_noise = s_noise
100
101    @torch.no_grad()
102    def sample(
103        self,
104        shape: Tuple[int, ...],
105        num_steps: int = 25,
106        x_T: Optional[torch.Tensor] = None,
107        progress: bool = True
108    ) -> torch.Tensor:
109        """Sample with DPM++ SDE."""
110        if x_T is None:
111            x = torch.randn(shape, device=self.device)
112        else:
113            x = x_T.to(self.device)
114
115        sigmas = self._get_sigmas(num_steps)
116
117        self.model.eval()
118
119        old_denoised = None
120
121        from tqdm import tqdm
122        iterator = tqdm(range(len(sigmas) - 1), desc="DPM++ SDE") if progress else range(len(sigmas) - 1)
123
124        for i in iterator:
125            sigma = sigmas[i]
126            sigma_next = sigmas[i + 1]
127
128            t = self._sigma_to_t(sigma)
129            t_batch = t.expand(x.shape[0])
130
131            # Get denoised estimate
132            eps = self.model(x, t_batch)
133            alpha = torch.sqrt(self.alphas_cumprod[t])
134            denoised = (x - sigma * eps) / alpha
135
136            # Compute ancestral parameters
137            sigma_up = min(
138                sigma_next,
139                self.eta * (sigma_next / sigma) *
140                torch.sqrt(sigma**2 - sigma_next**2 + 1e-8)
141            )
142            sigma_down = torch.sqrt(sigma_next**2 - sigma_up**2)
143
144            # DPM++ 2M style update
145            if old_denoised is None or sigma_next == 0:
146                d = (x - denoised) / sigma
147            else:
148                # 2nd order with ancestral
149                t_next = sigma_next.log().neg() if sigma_next > 0 else float('inf')
150                t_curr = sigma.log().neg()
151                h = t_next - t_curr
152
153                t_last = sigmas[i - 1].log().neg()
154                h_last = t_curr - t_last
155                r = h_last / h
156
157                denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
158                d = (x - denoised_d) / sigma
159
160            # Update with ancestral step
161            x = x + d * (sigma_down - sigma)
162
163            # Add noise
164            if sigma_next > 0:
165                noise = torch.randn_like(x) * self.s_noise
166                x = x + noise * sigma_up
167
168            old_denoised = denoised
169
170        return x
171
172    def _get_sigmas(self, num_steps: int) -> torch.Tensor:
173        # Karras schedule
174        sigma_min = self.sigmas[-1]
175        sigma_max = self.sigmas[0]
176        rho = 7.0
177        ramp = torch.linspace(0, 1, num_steps + 1)
178        min_inv_rho = sigma_min ** (1 / rho)
179        max_inv_rho = sigma_max ** (1 / rho)
180        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
181        return torch.cat([sigmas, torch.zeros(1)]).to(self.device)
182
183    def _sigma_to_t(self, sigma: torch.Tensor) -> torch.Tensor:
184        return torch.argmin(torch.abs(self.sigmas - sigma))
SamplerDeterministicDiversityRecommended Steps
EulerYesNone40-50
Euler AncestralNoHigh50-80
DPM++ 2MYesNone15-25
DPM++ SDENoMedium20-35

Unified Sampler Framework

Let's create a unified interface for all samplers:

🐍python
1from enum import Enum
2from typing import Union
3
4class SamplerType(Enum):
5    DDPM = "ddpm"
6    DDIM = "ddim"
7    EULER = "euler"
8    EULER_ANCESTRAL = "euler_a"
9    HEUN = "heun"
10    DPM_SOLVER = "dpm_solver"
11    DPM_2M = "dpm_pp_2m"
12    DPM_SDE = "dpm_pp_sde"
13
14
15class UnifiedSampler:
16    """
17    Unified interface for all diffusion samplers.
18
19    Provides a single entry point with consistent API
20    regardless of underlying sampler method.
21    """
22
23    def __init__(
24        self,
25        model: nn.Module,
26        alphas_cumprod: torch.Tensor,
27        device: str = "cuda"
28    ):
29        self.model = model
30        self.alphas_cumprod = alphas_cumprod.to(device)
31        self.device = device
32
33        # Initialize all samplers lazily
34        self._samplers = {}
35
36    def _get_sampler(self, sampler_type: SamplerType):
37        """Get or create sampler instance."""
38        if sampler_type not in self._samplers:
39            if sampler_type == SamplerType.DDIM:
40                self._samplers[sampler_type] = DDIMSampler(
41                    self.model, self.alphas_cumprod,
42                    config=DDIMConfig(), device=self.device
43                )
44            elif sampler_type == SamplerType.EULER:
45                self._samplers[sampler_type] = EulerSampler(
46                    self.model, self.alphas_cumprod, self.device
47                )
48            elif sampler_type == SamplerType.EULER_ANCESTRAL:
49                self._samplers[sampler_type] = EulerAncestralSampler(
50                    self.model, self.alphas_cumprod, self.device
51                )
52            elif sampler_type == SamplerType.HEUN:
53                self._samplers[sampler_type] = HeunSampler(
54                    self.model, self.alphas_cumprod, self.device
55                )
56            elif sampler_type == SamplerType.DPM_SOLVER:
57                self._samplers[sampler_type] = DPMSolver(
58                    self.model, self.alphas_cumprod,
59                    config=DPMSolverConfig(), device=self.device
60                )
61            elif sampler_type == SamplerType.DPM_2M:
62                self._samplers[sampler_type] = DPMPlusPlus2M(
63                    self.model, self.alphas_cumprod, self.device
64                )
65            elif sampler_type == SamplerType.DPM_SDE:
66                self._samplers[sampler_type] = DPMPlusPlusSDE(
67                    self.model, self.alphas_cumprod, self.device
68                )
69
70        return self._samplers[sampler_type]
71
72    @torch.no_grad()
73    def sample(
74        self,
75        shape: Tuple[int, ...],
76        sampler_type: Union[SamplerType, str] = SamplerType.DPM_2M,
77        num_steps: int = 20,
78        x_T: Optional[torch.Tensor] = None,
79        progress: bool = True,
80        **kwargs
81    ) -> torch.Tensor:
82        """
83        Generate samples with specified sampler.
84
85        Args:
86            shape: Output shape (B, C, H, W)
87            sampler_type: Which sampler to use
88            num_steps: Number of sampling steps
89            x_T: Starting noise
90            progress: Show progress bar
91            **kwargs: Additional sampler-specific arguments
92
93        Returns:
94            Generated samples
95        """
96        if isinstance(sampler_type, str):
97            sampler_type = SamplerType(sampler_type)
98
99        sampler = self._get_sampler(sampler_type)
100
101        return sampler.sample(
102            shape=shape,
103            num_steps=num_steps,
104            x_T=x_T,
105            progress=progress,
106            **kwargs
107        )
108
109    def get_recommended_steps(self, sampler_type: SamplerType) -> int:
110        """Get recommended step count for sampler."""
111        recommendations = {
112            SamplerType.DDPM: 1000,
113            SamplerType.DDIM: 50,
114            SamplerType.EULER: 50,
115            SamplerType.EULER_ANCESTRAL: 60,
116            SamplerType.HEUN: 30,
117            SamplerType.DPM_SOLVER: 20,
118            SamplerType.DPM_2M: 20,
119            SamplerType.DPM_SDE: 25,
120        }
121        return recommendations.get(sampler_type, 50)
122
123
124# Usage example
125def demonstrate_unified_sampler(model, noise_schedule):
126    """Compare all samplers on same noise."""
127    sampler = UnifiedSampler(
128        model=model,
129        alphas_cumprod=noise_schedule.alphas_cumprod
130    )
131
132    # Fix initial noise
133    torch.manual_seed(42)
134    x_T = torch.randn(4, 3, 64, 64, device="cuda")
135
136    results = {}
137
138    for sampler_type in [
139        SamplerType.DDIM,
140        SamplerType.EULER,
141        SamplerType.HEUN,
142        SamplerType.DPM_2M,
143    ]:
144        steps = sampler.get_recommended_steps(sampler_type)
145        samples = sampler.sample(
146            shape=(4, 3, 64, 64),
147            sampler_type=sampler_type,
148            num_steps=steps,
149            x_T=x_T.clone(),
150            progress=False
151        )
152        results[sampler_type.value] = {
153            "samples": samples,
154            "steps": steps
155        }
156        print(f"{sampler_type.value}: {steps} steps")
157
158    return results

Summary

We've covered the landscape of modern diffusion samplers:

  1. DPM-Solver family: Uses log-SNR space for efficient higher-order solving, achieving excellent quality in 10-25 steps
  2. Euler/Heun methods: Classical ODE solvers adapted for diffusion, providing intuitive baselines
  3. Ancestral variants: Inject noise for diversity at the cost of requiring more steps
  4. Unified framework: Consistent API for experimenting with different samplers
Use CaseRecommended SamplerStepsWhy
Fast productionDPM++ 2M Karras20-25Best speed/quality
Maximum qualityDPM++ 2M50Extra refinement
Creative diversityDPM++ SDE25-35Stochastic variety
Debugging/learningEuler50Simple to understand
Baseline comparisonDDIM50Standard reference

Coming Up Next

In the final section of this chapter, we'll provide practical guidance for choosing the right sampler for your specific use case, including benchmarks, quality comparisons, and decision flowcharts.

The sampler choice significantly impacts both generation speed and output quality. Understanding these methods deeply allows you to make informed decisions for your specific application requirements.