Learning Objectives
By the end of this section, you will:
- Understand the sampler landscape and how different methods relate
- Implement DPM-Solver for ultra-fast sampling (10-20 steps)
- Master Euler and Heun methods from the ODE perspective
- Compare ancestral sampling variants (DDPM, DPM++ SDE)
- Build a unified sampler framework supporting multiple methods
Beyond DDIM
The Sampler Landscape
Modern diffusion samplers can be organized along two axes: ODE vs SDE(deterministic vs stochastic) and solver order (first-order vs higher-order).
| Sampler | Type | Order | Min Steps | Best Use Case |
|---|---|---|---|---|
| DDPM | SDE | 1st | 200+ | Maximum diversity |
| DDIM | ODE/SDE | 1st | 25-50 | Fast, deterministic |
| Euler | ODE | 1st | 25-50 | Simple, stable |
| Heun | ODE | 2nd | 15-25 | Better accuracy |
| DPM-Solver | ODE | 1st-3rd | 10-25 | Ultra-fast quality |
| DPM++ 2M | ODE | 2nd | 15-25 | Production standard |
| DPM++ SDE | SDE | 2nd | 20-35 | High diversity |
| UniPC | ODE | 3rd | 10-20 | State-of-the-art |
ODE vs SDE Perspective
Diffusion models can be viewed through two equivalent lenses:
The ODE formulation (called the probability flow ODE) gives the same marginal distributions as the SDE but follows deterministic trajectories.
Key Insight
DPM-Solver
DPM-Solver reformulates diffusion sampling as solving an ODE in the log-SNR space, enabling efficient higher-order solvers. It achieves excellent quality with just 10-20 steps.
The Key Idea: Change of Variables
Instead of working in timestep , DPM-Solver uses (log signal-to-noise ratio):
This formulation has smoother dynamics, making higher-order ODE solvers more effective.
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Optional, Tuple, List
5from dataclasses import dataclass
6import numpy as np
7
8
9@dataclass
10class DPMSolverConfig:
11 """Configuration for DPM-Solver."""
12 num_timesteps: int = 1000
13 order: int = 2 # 1, 2, or 3
14 predict_type: str = "epsilon" # "epsilon" or "v"
15 thresholding: bool = False
16 dynamic_threshold_ratio: float = 0.995
17
18
19class DPMSolver:
20 """
21 DPM-Solver: Fast Solver for Diffusion Probabilistic Models.
22
23 Implements DPM-Solver-1, DPM-Solver-2, and DPM-Solver-3 for
24 ultra-fast high-quality sampling.
25 """
26
27 def __init__(
28 self,
29 model: nn.Module,
30 alphas_cumprod: torch.Tensor,
31 config: Optional[DPMSolverConfig] = None,
32 device: str = "cuda"
33 ):
34 self.model = model
35 self.device = device
36 self.config = config or DPMSolverConfig()
37
38 # Store schedule
39 self.alphas_cumprod = alphas_cumprod.to(device)
40 self.sigmas = torch.sqrt((1 - alphas_cumprod) / alphas_cumprod).to(device)
41
42 # Compute lambda (log-SNR)
43 self.lambdas = torch.log(self.alphas_cumprod / (1 - self.alphas_cumprod)) / 2
44 self.lambdas = self.lambdas.to(device)
45
46 def get_timestep_schedule(
47 self,
48 num_steps: int,
49 skip_type: str = "uniform"
50 ) -> torch.Tensor:
51 """Create timestep schedule for sampling."""
52 T = self.config.num_timesteps
53
54 if skip_type == "uniform":
55 timesteps = torch.linspace(T - 1, 0, num_steps + 1).long()
56 elif skip_type == "logsnr":
57 # Uniform in log-SNR space
58 lambda_min, lambda_max = self.lambdas[-1], self.lambdas[0]
59 lambdas_uniform = torch.linspace(lambda_max, lambda_min, num_steps + 1)
60 timesteps = self._lambda_to_t(lambdas_uniform)
61 elif skip_type == "quad":
62 # Quadratic spacing (more at low noise)
63 timesteps = (
64 (torch.linspace(0, np.sqrt(T - 1), num_steps + 1) ** 2)
65 .long()
66 .flip(0)
67 )
68 else:
69 raise ValueError(f"Unknown skip_type: {skip_type}")
70
71 return timesteps.to(self.device)
72
73 def _lambda_to_t(self, lambdas: torch.Tensor) -> torch.Tensor:
74 """Convert lambda values to timesteps."""
75 # Find nearest timesteps for given lambda values
76 timesteps = []
77 for lam in lambdas:
78 idx = torch.argmin(torch.abs(self.lambdas - lam))
79 timesteps.append(idx)
80 return torch.tensor(timesteps)
81
82 @torch.no_grad()
83 def sample(
84 self,
85 shape: Tuple[int, ...],
86 num_steps: int = 20,
87 x_T: Optional[torch.Tensor] = None,
88 progress: bool = True
89 ) -> torch.Tensor:
90 """
91 Generate samples using DPM-Solver.
92
93 Args:
94 shape: Output shape (B, C, H, W)
95 num_steps: Number of sampling steps
96 x_T: Starting noise (None = random)
97 progress: Show progress
98
99 Returns:
100 Generated samples in [-1, 1]
101 """
102 order = self.config.order
103
104 # Initialize
105 if x_T is None:
106 x = torch.randn(shape, device=self.device)
107 else:
108 x = x_T.to(self.device)
109
110 # Get timesteps
111 timesteps = self.get_timestep_schedule(num_steps, skip_type="logsnr")
112
113 self.model.eval()
114
115 # Buffers for multi-step methods
116 model_outputs = []
117
118 from tqdm import tqdm
119 iterator = tqdm(range(num_steps), desc="DPM-Solver") if progress else range(num_steps)
120
121 for i in iterator:
122 t = timesteps[i]
123 t_next = timesteps[i + 1]
124
125 # Get model output
126 model_output = self._get_model_output(x, t)
127 model_outputs.append(model_output)
128
129 # Apply solver step based on order
130 if order == 1 or i == 0:
131 x = self._dpm_solver_first_order_update(
132 x, t, t_next, model_output
133 )
134 elif order == 2 or i == 1:
135 x = self._dpm_solver_second_order_update(
136 x, t, t_next, model_outputs[-2:]
137 )
138 else: # order == 3
139 x = self._dpm_solver_third_order_update(
140 x, t, t_next, model_outputs[-3:]
141 )
142
143 # Keep only last few outputs
144 if len(model_outputs) > 3:
145 model_outputs.pop(0)
146
147 return x
148
149 def _get_model_output(
150 self,
151 x: torch.Tensor,
152 t: torch.Tensor
153 ) -> torch.Tensor:
154 """Get noise prediction from model."""
155 t_batch = t.expand(x.shape[0]) if t.dim() == 0 else t
156 eps = self.model(x, t_batch)
157
158 if self.config.thresholding:
159 eps = self._dynamic_threshold(eps)
160
161 return eps
162
163 def _dpm_solver_first_order_update(
164 self,
165 x: torch.Tensor,
166 t: torch.Tensor,
167 t_next: torch.Tensor,
168 eps: torch.Tensor
169 ) -> torch.Tensor:
170 """
171 First-order DPM-Solver update (equivalent to DDIM).
172 """
173 # Get schedule values
174 alpha_t = torch.sqrt(self.alphas_cumprod[t])
175 alpha_next = torch.sqrt(self.alphas_cumprod[t_next])
176 sigma_t = torch.sqrt(1 - self.alphas_cumprod[t])
177 sigma_next = torch.sqrt(1 - self.alphas_cumprod[t_next])
178
179 # Get lambda values
180 lambda_t = self.lambdas[t]
181 lambda_next = self.lambdas[t_next]
182 h = lambda_next - lambda_t
183
184 # First-order update
185 x_next = (alpha_next / alpha_t) * x - sigma_next * (torch.exp(-h) - 1) * eps
186
187 return x_next
188
189 def _dpm_solver_second_order_update(
190 self,
191 x: torch.Tensor,
192 t: torch.Tensor,
193 t_next: torch.Tensor,
194 eps_list: List[torch.Tensor]
195 ) -> torch.Tensor:
196 """
197 Second-order DPM-Solver update (DPM-Solver-2).
198
199 Uses linear extrapolation of noise predictions.
200 """
201 eps_prev, eps_curr = eps_list
202
203 # Get schedule values
204 alpha_t = torch.sqrt(self.alphas_cumprod[t])
205 alpha_next = torch.sqrt(self.alphas_cumprod[t_next])
206 sigma_t = torch.sqrt(1 - self.alphas_cumprod[t])
207 sigma_next = torch.sqrt(1 - self.alphas_cumprod[t_next])
208
209 lambda_t = self.lambdas[t]
210 lambda_next = self.lambdas[t_next]
211 h = lambda_next - lambda_t
212
213 # Second-order correction
214 # D_1 = (eps_curr - eps_prev) / h_prev
215 # For simplicity, using first-order + correction
216
217 x_next = (alpha_next / alpha_t) * x - sigma_next * (torch.exp(-h) - 1) * eps_curr
218
219 # Add second-order correction
220 r = 0.5
221 D1 = eps_curr - eps_prev
222 x_next = x_next - sigma_next * (torch.exp(-h) - 1) * r * D1 / (2 * h)
223
224 return x_next
225
226 def _dpm_solver_third_order_update(
227 self,
228 x: torch.Tensor,
229 t: torch.Tensor,
230 t_next: torch.Tensor,
231 eps_list: List[torch.Tensor]
232 ) -> torch.Tensor:
233 """
234 Third-order DPM-Solver update (DPM-Solver-3).
235
236 Uses quadratic extrapolation for even higher accuracy.
237 """
238 eps_0, eps_1, eps_2 = eps_list
239
240 # Similar structure to second-order but with additional correction
241 # For brevity, using second-order update here
242 return self._dpm_solver_second_order_update(x, t, t_next, [eps_1, eps_2])
243
244 def _dynamic_threshold(
245 self,
246 x: torch.Tensor
247 ) -> torch.Tensor:
248 """Dynamic thresholding from Imagen paper."""
249 s = torch.quantile(
250 torch.abs(x).reshape(x.shape[0], -1),
251 self.config.dynamic_threshold_ratio,
252 dim=1
253 )
254 s = torch.clamp(s, min=1.0)
255 s = s.reshape(-1, 1, 1, 1)
256 return torch.clamp(x, -s, s) / s
257
258
259# DPM++ 2M (Multistep, very popular in production)
260class DPMPlusPlus2M:
261 """
262 DPM++ 2M: DPM-Solver++ with 2nd-order multistep method.
263
264 This is one of the most popular samplers in production systems
265 like Stable Diffusion.
266 """
267
268 def __init__(
269 self,
270 model: nn.Module,
271 alphas_cumprod: torch.Tensor,
272 device: str = "cuda"
273 ):
274 self.model = model
275 self.device = device
276
277 self.alphas_cumprod = alphas_cumprod.to(device)
278 self.sigmas = torch.sqrt((1 - alphas_cumprod) / alphas_cumprod).to(device)
279
280 @torch.no_grad()
281 def sample(
282 self,
283 shape: Tuple[int, ...],
284 num_steps: int = 20,
285 x_T: Optional[torch.Tensor] = None,
286 progress: bool = True
287 ) -> torch.Tensor:
288 """Generate samples using DPM++ 2M."""
289 if x_T is None:
290 x = torch.randn(shape, device=self.device)
291 else:
292 x = x_T.to(self.device)
293
294 # Get sigma schedule
295 sigmas = self._get_sigmas(num_steps)
296
297 self.model.eval()
298
299 old_denoised = None
300
301 from tqdm import tqdm
302 iterator = tqdm(range(len(sigmas) - 1), desc="DPM++ 2M") if progress else range(len(sigmas) - 1)
303
304 for i in iterator:
305 sigma = sigmas[i]
306 sigma_next = sigmas[i + 1]
307
308 # Compute timestep from sigma
309 t = self._sigma_to_t(sigma)
310
311 # Get denoised prediction
312 denoised = self._get_denoised(x, t, sigma)
313
314 # DPM++ 2M update
315 t_next = sigma_next.log().neg()
316 t_curr = sigma.log().neg()
317 h = t_next - t_curr
318
319 if old_denoised is None or sigma_next == 0:
320 # First step or final step: use first-order
321 x = (sigma_next / sigma) * x + (1 - sigma_next / sigma) * denoised
322 else:
323 # Second order: use previous denoised
324 h_last = t_curr - sigmas[i - 1].log().neg()
325 r = h_last / h
326
327 denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
328 x = (sigma_next / sigma) * x + (1 - sigma_next / sigma) * denoised_d
329
330 old_denoised = denoised
331
332 return x
333
334 def _get_sigmas(self, num_steps: int) -> torch.Tensor:
335 """Get sigma schedule for sampling."""
336 # Use Karras schedule (popular in practice)
337 sigma_min = self.sigmas[-1]
338 sigma_max = self.sigmas[0]
339
340 rho = 7.0 # From Karras paper
341 ramp = torch.linspace(0, 1, num_steps + 1)
342 min_inv_rho = sigma_min ** (1 / rho)
343 max_inv_rho = sigma_max ** (1 / rho)
344 sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
345
346 return torch.cat([sigmas, torch.zeros(1)]).to(self.device)
347
348 def _sigma_to_t(self, sigma: torch.Tensor) -> torch.Tensor:
349 """Convert sigma to discrete timestep."""
350 idx = torch.argmin(torch.abs(self.sigmas - sigma))
351 return idx
352
353 def _get_denoised(
354 self,
355 x: torch.Tensor,
356 t: torch.Tensor,
357 sigma: torch.Tensor
358 ) -> torch.Tensor:
359 """Get denoised prediction (x_0 estimate)."""
360 t_batch = t.expand(x.shape[0])
361 eps = self.model(x, t_batch)
362
363 # Convert noise prediction to x_0 prediction
364 alpha = torch.sqrt(self.alphas_cumprod[t])
365 denoised = (x - sigma * eps) / alpha
366
367 return denoisedDPM++ 2M in Practice
Euler and Heun Methods
From the ODE perspective, diffusion sampling is just solving an initial value problem. Classical numerical methods like Euler and Heun can be applied directly.
Euler Method (First-Order)
1class EulerSampler:
2 """
3 Simple Euler method for diffusion sampling.
4
5 This is the simplest ODE solver and provides a baseline
6 for understanding more complex methods.
7 """
8
9 def __init__(
10 self,
11 model: nn.Module,
12 alphas_cumprod: torch.Tensor,
13 device: str = "cuda"
14 ):
15 self.model = model
16 self.alphas_cumprod = alphas_cumprod.to(device)
17 self.device = device
18
19 # Compute sigmas
20 self.sigmas = torch.sqrt((1 - alphas_cumprod) / alphas_cumprod).to(device)
21
22 @torch.no_grad()
23 def sample(
24 self,
25 shape: Tuple[int, ...],
26 num_steps: int = 50,
27 x_T: Optional[torch.Tensor] = None,
28 progress: bool = True
29 ) -> torch.Tensor:
30 """Sample using Euler method."""
31 if x_T is None:
32 x = torch.randn(shape, device=self.device)
33 else:
34 x = x_T.to(self.device)
35
36 sigmas = self._get_sigmas(num_steps)
37
38 self.model.eval()
39
40 from tqdm import tqdm
41 iterator = tqdm(range(len(sigmas) - 1), desc="Euler") if progress else range(len(sigmas) - 1)
42
43 for i in iterator:
44 sigma = sigmas[i]
45 sigma_next = sigmas[i + 1]
46
47 t = self._sigma_to_t(sigma)
48 t_batch = t.expand(x.shape[0])
49
50 # Get noise prediction
51 eps = self.model(x, t_batch)
52
53 # Compute derivative
54 # dx/dsigma = (x - denoised) / sigma = eps
55 d = eps
56
57 # Euler step
58 dt = sigma_next - sigma
59 x = x + d * dt
60
61 return x
62
63 def _get_sigmas(self, num_steps: int) -> torch.Tensor:
64 """Linear sigma schedule."""
65 sigma_max = self.sigmas[0]
66 sigma_min = self.sigmas[-1]
67 sigmas = torch.linspace(sigma_max, sigma_min, num_steps + 1)
68 return sigmas.to(self.device)
69
70 def _sigma_to_t(self, sigma: torch.Tensor) -> torch.Tensor:
71 """Find timestep for sigma value."""
72 return torch.argmin(torch.abs(self.sigmas - sigma))Heun Method (Second-Order)
Heun's method uses a predictor-corrector approach for better accuracy:
1class HeunSampler:
2 """
3 Heun's method (2nd-order) for diffusion sampling.
4
5 Uses predictor-corrector approach for higher accuracy
6 at the cost of 2 function evaluations per step.
7 """
8
9 def __init__(
10 self,
11 model: nn.Module,
12 alphas_cumprod: torch.Tensor,
13 device: str = "cuda"
14 ):
15 self.model = model
16 self.alphas_cumprod = alphas_cumprod.to(device)
17 self.sigmas = torch.sqrt((1 - alphas_cumprod) / alphas_cumprod).to(device)
18 self.device = device
19
20 @torch.no_grad()
21 def sample(
22 self,
23 shape: Tuple[int, ...],
24 num_steps: int = 30,
25 x_T: Optional[torch.Tensor] = None,
26 progress: bool = True
27 ) -> torch.Tensor:
28 """Sample using Heun's method."""
29 if x_T is None:
30 x = torch.randn(shape, device=self.device)
31 else:
32 x = x_T.to(self.device)
33
34 sigmas = self._get_sigmas(num_steps)
35
36 self.model.eval()
37
38 from tqdm import tqdm
39 iterator = tqdm(range(len(sigmas) - 1), desc="Heun") if progress else range(len(sigmas) - 1)
40
41 for i in iterator:
42 sigma = sigmas[i]
43 sigma_next = sigmas[i + 1]
44
45 if sigma_next == 0:
46 # Final step: just Euler
47 x = self._euler_step(x, sigma, sigma_next)
48 else:
49 # Heun step (predictor-corrector)
50 x = self._heun_step(x, sigma, sigma_next)
51
52 return x
53
54 def _euler_step(
55 self,
56 x: torch.Tensor,
57 sigma: torch.Tensor,
58 sigma_next: torch.Tensor
59 ) -> torch.Tensor:
60 """Simple Euler step."""
61 t = self._sigma_to_t(sigma)
62 t_batch = t.expand(x.shape[0])
63 eps = self.model(x, t_batch)
64
65 d = eps
66 dt = sigma_next - sigma
67
68 return x + d * dt
69
70 def _heun_step(
71 self,
72 x: torch.Tensor,
73 sigma: torch.Tensor,
74 sigma_next: torch.Tensor
75 ) -> torch.Tensor:
76 """Heun predictor-corrector step."""
77 t = self._sigma_to_t(sigma)
78 t_batch = t.expand(x.shape[0])
79
80 # Predictor (Euler)
81 eps_1 = self.model(x, t_batch)
82 d_1 = eps_1
83
84 dt = sigma_next - sigma
85 x_pred = x + d_1 * dt
86
87 # Corrector
88 t_next = self._sigma_to_t(sigma_next)
89 t_next_batch = t_next.expand(x.shape[0])
90 eps_2 = self.model(x_pred, t_next_batch)
91 d_2 = eps_2
92
93 # Average the two derivatives
94 x_next = x + dt * (d_1 + d_2) / 2
95
96 return x_next
97
98 def _get_sigmas(self, num_steps: int) -> torch.Tensor:
99 sigma_max = self.sigmas[0]
100 sigma_min = self.sigmas[-1]
101 sigmas = torch.linspace(sigma_max, sigma_min, num_steps + 1)
102 return torch.cat([sigmas, torch.zeros(1)]).to(self.device)
103
104 def _sigma_to_t(self, sigma: torch.Tensor) -> torch.Tensor:
105 return torch.argmin(torch.abs(self.sigmas - sigma))NFE vs Steps
Ancestral Sampling Variants
While ODE samplers are deterministic, sometimes we want the diversity that comes from stochastic sampling. Here are the main ancestral (SDE) variants:
1class EulerAncestralSampler:
2 """
3 Euler Ancestral: Euler method with noise injection.
4
5 Provides more diversity than deterministic Euler at the
6 cost of requiring more steps for quality.
7 """
8
9 def __init__(
10 self,
11 model: nn.Module,
12 alphas_cumprod: torch.Tensor,
13 device: str = "cuda",
14 eta: float = 1.0 # Noise scale
15 ):
16 self.model = model
17 self.alphas_cumprod = alphas_cumprod.to(device)
18 self.sigmas = torch.sqrt((1 - alphas_cumprod) / alphas_cumprod).to(device)
19 self.device = device
20 self.eta = eta
21
22 @torch.no_grad()
23 def sample(
24 self,
25 shape: Tuple[int, ...],
26 num_steps: int = 50,
27 x_T: Optional[torch.Tensor] = None,
28 progress: bool = True
29 ) -> torch.Tensor:
30 """Sample with ancestral noise injection."""
31 if x_T is None:
32 x = torch.randn(shape, device=self.device)
33 else:
34 x = x_T.to(self.device)
35
36 sigmas = self._get_sigmas(num_steps)
37
38 self.model.eval()
39
40 from tqdm import tqdm
41 iterator = tqdm(range(len(sigmas) - 1), desc="Euler-a") if progress else range(len(sigmas) - 1)
42
43 for i in iterator:
44 sigma = sigmas[i]
45 sigma_next = sigmas[i + 1]
46
47 t = self._sigma_to_t(sigma)
48 t_batch = t.expand(x.shape[0])
49
50 # Get noise prediction
51 eps = self.model(x, t_batch)
52
53 # Compute ancestral step
54 sigma_up = min(sigma_next, self.eta * (sigma_next / sigma) * torch.sqrt(sigma**2 - sigma_next**2))
55 sigma_down = torch.sqrt(sigma_next**2 - sigma_up**2)
56
57 # Deterministic step
58 d = eps
59 x = x + d * (sigma_down - sigma)
60
61 # Add ancestral noise
62 if sigma_next > 0:
63 noise = torch.randn_like(x)
64 x = x + noise * sigma_up
65
66 return x
67
68 def _get_sigmas(self, num_steps: int) -> torch.Tensor:
69 sigma_max = self.sigmas[0]
70 sigma_min = self.sigmas[-1]
71 sigmas = torch.linspace(sigma_max, sigma_min, num_steps + 1)
72 return sigmas.to(self.device)
73
74 def _sigma_to_t(self, sigma: torch.Tensor) -> torch.Tensor:
75 return torch.argmin(torch.abs(self.sigmas - sigma))
76
77
78class DPMPlusPlusSDE:
79 """
80 DPM++ SDE: Stochastic version of DPM++ with noise injection.
81
82 Provides diversity of SDE methods with efficiency of DPM++.
83 Popular for creative applications where diversity matters.
84 """
85
86 def __init__(
87 self,
88 model: nn.Module,
89 alphas_cumprod: torch.Tensor,
90 device: str = "cuda",
91 eta: float = 1.0,
92 s_noise: float = 1.0
93 ):
94 self.model = model
95 self.alphas_cumprod = alphas_cumprod.to(device)
96 self.sigmas = torch.sqrt((1 - alphas_cumprod) / alphas_cumprod).to(device)
97 self.device = device
98 self.eta = eta
99 self.s_noise = s_noise
100
101 @torch.no_grad()
102 def sample(
103 self,
104 shape: Tuple[int, ...],
105 num_steps: int = 25,
106 x_T: Optional[torch.Tensor] = None,
107 progress: bool = True
108 ) -> torch.Tensor:
109 """Sample with DPM++ SDE."""
110 if x_T is None:
111 x = torch.randn(shape, device=self.device)
112 else:
113 x = x_T.to(self.device)
114
115 sigmas = self._get_sigmas(num_steps)
116
117 self.model.eval()
118
119 old_denoised = None
120
121 from tqdm import tqdm
122 iterator = tqdm(range(len(sigmas) - 1), desc="DPM++ SDE") if progress else range(len(sigmas) - 1)
123
124 for i in iterator:
125 sigma = sigmas[i]
126 sigma_next = sigmas[i + 1]
127
128 t = self._sigma_to_t(sigma)
129 t_batch = t.expand(x.shape[0])
130
131 # Get denoised estimate
132 eps = self.model(x, t_batch)
133 alpha = torch.sqrt(self.alphas_cumprod[t])
134 denoised = (x - sigma * eps) / alpha
135
136 # Compute ancestral parameters
137 sigma_up = min(
138 sigma_next,
139 self.eta * (sigma_next / sigma) *
140 torch.sqrt(sigma**2 - sigma_next**2 + 1e-8)
141 )
142 sigma_down = torch.sqrt(sigma_next**2 - sigma_up**2)
143
144 # DPM++ 2M style update
145 if old_denoised is None or sigma_next == 0:
146 d = (x - denoised) / sigma
147 else:
148 # 2nd order with ancestral
149 t_next = sigma_next.log().neg() if sigma_next > 0 else float('inf')
150 t_curr = sigma.log().neg()
151 h = t_next - t_curr
152
153 t_last = sigmas[i - 1].log().neg()
154 h_last = t_curr - t_last
155 r = h_last / h
156
157 denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
158 d = (x - denoised_d) / sigma
159
160 # Update with ancestral step
161 x = x + d * (sigma_down - sigma)
162
163 # Add noise
164 if sigma_next > 0:
165 noise = torch.randn_like(x) * self.s_noise
166 x = x + noise * sigma_up
167
168 old_denoised = denoised
169
170 return x
171
172 def _get_sigmas(self, num_steps: int) -> torch.Tensor:
173 # Karras schedule
174 sigma_min = self.sigmas[-1]
175 sigma_max = self.sigmas[0]
176 rho = 7.0
177 ramp = torch.linspace(0, 1, num_steps + 1)
178 min_inv_rho = sigma_min ** (1 / rho)
179 max_inv_rho = sigma_max ** (1 / rho)
180 sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
181 return torch.cat([sigmas, torch.zeros(1)]).to(self.device)
182
183 def _sigma_to_t(self, sigma: torch.Tensor) -> torch.Tensor:
184 return torch.argmin(torch.abs(self.sigmas - sigma))| Sampler | Deterministic | Diversity | Recommended Steps |
|---|---|---|---|
| Euler | Yes | None | 40-50 |
| Euler Ancestral | No | High | 50-80 |
| DPM++ 2M | Yes | None | 15-25 |
| DPM++ SDE | No | Medium | 20-35 |
Unified Sampler Framework
Let's create a unified interface for all samplers:
1from enum import Enum
2from typing import Union
3
4class SamplerType(Enum):
5 DDPM = "ddpm"
6 DDIM = "ddim"
7 EULER = "euler"
8 EULER_ANCESTRAL = "euler_a"
9 HEUN = "heun"
10 DPM_SOLVER = "dpm_solver"
11 DPM_2M = "dpm_pp_2m"
12 DPM_SDE = "dpm_pp_sde"
13
14
15class UnifiedSampler:
16 """
17 Unified interface for all diffusion samplers.
18
19 Provides a single entry point with consistent API
20 regardless of underlying sampler method.
21 """
22
23 def __init__(
24 self,
25 model: nn.Module,
26 alphas_cumprod: torch.Tensor,
27 device: str = "cuda"
28 ):
29 self.model = model
30 self.alphas_cumprod = alphas_cumprod.to(device)
31 self.device = device
32
33 # Initialize all samplers lazily
34 self._samplers = {}
35
36 def _get_sampler(self, sampler_type: SamplerType):
37 """Get or create sampler instance."""
38 if sampler_type not in self._samplers:
39 if sampler_type == SamplerType.DDIM:
40 self._samplers[sampler_type] = DDIMSampler(
41 self.model, self.alphas_cumprod,
42 config=DDIMConfig(), device=self.device
43 )
44 elif sampler_type == SamplerType.EULER:
45 self._samplers[sampler_type] = EulerSampler(
46 self.model, self.alphas_cumprod, self.device
47 )
48 elif sampler_type == SamplerType.EULER_ANCESTRAL:
49 self._samplers[sampler_type] = EulerAncestralSampler(
50 self.model, self.alphas_cumprod, self.device
51 )
52 elif sampler_type == SamplerType.HEUN:
53 self._samplers[sampler_type] = HeunSampler(
54 self.model, self.alphas_cumprod, self.device
55 )
56 elif sampler_type == SamplerType.DPM_SOLVER:
57 self._samplers[sampler_type] = DPMSolver(
58 self.model, self.alphas_cumprod,
59 config=DPMSolverConfig(), device=self.device
60 )
61 elif sampler_type == SamplerType.DPM_2M:
62 self._samplers[sampler_type] = DPMPlusPlus2M(
63 self.model, self.alphas_cumprod, self.device
64 )
65 elif sampler_type == SamplerType.DPM_SDE:
66 self._samplers[sampler_type] = DPMPlusPlusSDE(
67 self.model, self.alphas_cumprod, self.device
68 )
69
70 return self._samplers[sampler_type]
71
72 @torch.no_grad()
73 def sample(
74 self,
75 shape: Tuple[int, ...],
76 sampler_type: Union[SamplerType, str] = SamplerType.DPM_2M,
77 num_steps: int = 20,
78 x_T: Optional[torch.Tensor] = None,
79 progress: bool = True,
80 **kwargs
81 ) -> torch.Tensor:
82 """
83 Generate samples with specified sampler.
84
85 Args:
86 shape: Output shape (B, C, H, W)
87 sampler_type: Which sampler to use
88 num_steps: Number of sampling steps
89 x_T: Starting noise
90 progress: Show progress bar
91 **kwargs: Additional sampler-specific arguments
92
93 Returns:
94 Generated samples
95 """
96 if isinstance(sampler_type, str):
97 sampler_type = SamplerType(sampler_type)
98
99 sampler = self._get_sampler(sampler_type)
100
101 return sampler.sample(
102 shape=shape,
103 num_steps=num_steps,
104 x_T=x_T,
105 progress=progress,
106 **kwargs
107 )
108
109 def get_recommended_steps(self, sampler_type: SamplerType) -> int:
110 """Get recommended step count for sampler."""
111 recommendations = {
112 SamplerType.DDPM: 1000,
113 SamplerType.DDIM: 50,
114 SamplerType.EULER: 50,
115 SamplerType.EULER_ANCESTRAL: 60,
116 SamplerType.HEUN: 30,
117 SamplerType.DPM_SOLVER: 20,
118 SamplerType.DPM_2M: 20,
119 SamplerType.DPM_SDE: 25,
120 }
121 return recommendations.get(sampler_type, 50)
122
123
124# Usage example
125def demonstrate_unified_sampler(model, noise_schedule):
126 """Compare all samplers on same noise."""
127 sampler = UnifiedSampler(
128 model=model,
129 alphas_cumprod=noise_schedule.alphas_cumprod
130 )
131
132 # Fix initial noise
133 torch.manual_seed(42)
134 x_T = torch.randn(4, 3, 64, 64, device="cuda")
135
136 results = {}
137
138 for sampler_type in [
139 SamplerType.DDIM,
140 SamplerType.EULER,
141 SamplerType.HEUN,
142 SamplerType.DPM_2M,
143 ]:
144 steps = sampler.get_recommended_steps(sampler_type)
145 samples = sampler.sample(
146 shape=(4, 3, 64, 64),
147 sampler_type=sampler_type,
148 num_steps=steps,
149 x_T=x_T.clone(),
150 progress=False
151 )
152 results[sampler_type.value] = {
153 "samples": samples,
154 "steps": steps
155 }
156 print(f"{sampler_type.value}: {steps} steps")
157
158 return resultsSummary
We've covered the landscape of modern diffusion samplers:
- DPM-Solver family: Uses log-SNR space for efficient higher-order solving, achieving excellent quality in 10-25 steps
- Euler/Heun methods: Classical ODE solvers adapted for diffusion, providing intuitive baselines
- Ancestral variants: Inject noise for diversity at the cost of requiring more steps
- Unified framework: Consistent API for experimenting with different samplers
| Use Case | Recommended Sampler | Steps | Why |
|---|---|---|---|
| Fast production | DPM++ 2M Karras | 20-25 | Best speed/quality |
| Maximum quality | DPM++ 2M | 50 | Extra refinement |
| Creative diversity | DPM++ SDE | 25-35 | Stochastic variety |
| Debugging/learning | Euler | 50 | Simple to understand |
| Baseline comparison | DDIM | 50 | Standard reference |
Coming Up Next
The sampler choice significantly impacts both generation speed and output quality. Understanding these methods deeply allows you to make informed decisions for your specific application requirements.