Introduction
The discrete-time formulation of diffusion models (DDPM) is actually a discretization of an underlying continuous-time process. By viewing diffusion through the lens ofstochastic differential equations (SDEs), we gain powerful theoretical tools and practical sampling algorithms that transcend specific discrete schedules.
This section develops the SDE perspective on diffusion models, showing how the forward noising process and reverse denoising process can be described by elegant continuous-time equations. This viewpoint unifies different diffusion variants, enables exact likelihood computation, and opens pathways to more efficient sampling.
Continuous-Time Diffusion
In the discrete DDPM formulation, we have timesteps with a fixed noise schedule. But what happens as and the step size approaches zero?
From Discrete to Continuous
Consider the discrete forward process:
Discrete Forward Process:As step size , this becomes a continuous stochastic process described by an SDE.
1# Intuition: Discrete to Continuous Transition
2"""
3Discrete (DDPM):
4 x_t = sqrt(alpha_t) * x_{t-1} + sqrt(1 - alpha_t) * epsilon_t
5
6Let dt = 1/T and t = k/T for k = 0, 1, ..., T
7
8As T -> infinity (dt -> 0):
9 x_{t+dt} - x_t = f(x_t, t) * dt + g(t) * sqrt(dt) * z
10
11Where z ~ N(0, I) and:
12 - f(x, t) is the drift coefficient (deterministic part)
13 - g(t) is the diffusion coefficient (stochastic part)
14
15This is an Ito SDE: dx = f(x, t)dt + g(t)dW
16where dW is the Wiener process (Brownian motion)
17"""
18
19import torch
20import numpy as np
21
22def discrete_to_continuous_params(alpha_schedule: torch.Tensor, num_steps: int) -> dict:
23 """
24 Convert discrete schedule to continuous SDE parameters.
25
26 For VP-SDE: dx = -0.5 * beta(t) * x * dt + sqrt(beta(t)) * dW
27
28 Returns beta(t) function parameters.
29 """
30 # Compute cumulative alphas
31 alpha_bar = torch.cumprod(alpha_schedule, dim=0)
32
33 # Convert to continuous time t in [0, 1]
34 t = torch.linspace(0, 1, num_steps)
35
36 # log(alpha_bar(t)) = integral of -0.5 * beta(s) ds from 0 to t
37 # So beta(t) = -2 * d/dt[log(alpha_bar(t))]
38 log_alpha_bar = torch.log(alpha_bar)
39
40 # Numerical differentiation
41 beta = -2 * torch.gradient(log_alpha_bar, spacing=(t,))[0]
42
43 return {
44 'beta': beta,
45 't': t,
46 'alpha_bar': alpha_bar
47 }The continuous formulation reveals that discrete diffusion models are numerical approximations to underlying SDEs. Different discretization choices lead to different algorithms with different trade-offs.
SDE Formulation
Stochastic differential equations describe continuous-time random processes. The general form of an SDE is:
General Ito SDE:where is the drift coefficient, is the diffusion coefficient, and is the infinitesimal Wiener process increment.
The Forward SDE
The forward diffusion process that gradually adds noise can be written as:
1class ForwardSDE:
2 """
3 Forward SDE: dx = f(x, t)dt + g(t)dW
4
5 Transforms data distribution p_data(x) into prior distribution p_T(x).
6 """
7 def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0, T: float = 1.0):
8 self.beta_min = beta_min
9 self.beta_max = beta_max
10 self.T = T
11
12 def beta(self, t: torch.Tensor) -> torch.Tensor:
13 """Time-dependent noise schedule beta(t)."""
14 # Linear schedule in continuous time
15 return self.beta_min + t * (self.beta_max - self.beta_min)
16
17 def drift_coef(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
18 """
19 Drift coefficient f(x, t) for VP-SDE.
20 f(x, t) = -0.5 * beta(t) * x
21 """
22 return -0.5 * self.beta(t).unsqueeze(-1) * x
23
24 def diffusion_coef(self, t: torch.Tensor) -> torch.Tensor:
25 """
26 Diffusion coefficient g(t) for VP-SDE.
27 g(t) = sqrt(beta(t))
28 """
29 return torch.sqrt(self.beta(t))
30
31 def marginal_prob(self, x_0: torch.Tensor, t: torch.Tensor) -> tuple:
32 """
33 Compute mean and std of p(x_t | x_0).
34
35 For VP-SDE:
36 mean = x_0 * exp(-0.5 * integral(beta(s), 0, t))
37 std = sqrt(1 - exp(-integral(beta(s), 0, t)))
38 """
39 # Integral of beta(s) from 0 to t
40 log_coef = -0.25 * t ** 2 * (self.beta_max - self.beta_min) - 0.5 * t * self.beta_min
41 mean_coef = torch.exp(log_coef)
42 std = torch.sqrt(1.0 - torch.exp(2.0 * log_coef))
43
44 mean = x_0 * mean_coef.unsqueeze(-1)
45 return mean, std
46
47 def sample_x_t(self, x_0: torch.Tensor, t: torch.Tensor) -> tuple:
48 """Sample from p(x_t | x_0)."""
49 mean, std = self.marginal_prob(x_0, t)
50 z = torch.randn_like(x_0)
51 x_t = mean + std.unsqueeze(-1) * z
52 return x_t, zThe forward SDE has the remarkable property that its marginal distributions can be computed in closed form. Given initial point , we can directly sample at any time without simulating the entire trajectory.
The Reverse SDE
The key theoretical result enabling diffusion models is Anderson's Theorem: any forward SDE has a corresponding reverse SDE that runs time backwards:
Reverse SDE (Anderson, 1982):where is the reverse-time Wiener process. The score is the only unknown quantity.
1class ReverseSDE:
2 """
3 Reverse SDE for generation.
4
5 Given a trained score model s_theta(x, t) ≈ nabla_x log p_t(x),
6 we can reverse the forward process to generate samples.
7 """
8 def __init__(self, forward_sde: ForwardSDE, score_model: nn.Module):
9 self.forward_sde = forward_sde
10 self.score_model = score_model
11
12 def drift_coef(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
13 """
14 Reverse drift: f(x, t) - g(t)^2 * score(x, t)
15
16 The score term reverses the diffusion!
17 """
18 # Forward drift
19 f = self.forward_sde.drift_coef(x, t)
20
21 # Diffusion coefficient squared
22 g_sq = self.forward_sde.beta(t).unsqueeze(-1)
23
24 # Learned score
25 score = self.score_model(x, t)
26
27 # Reverse drift
28 return f - g_sq * score
29
30 def diffusion_coef(self, t: torch.Tensor) -> torch.Tensor:
31 """Same diffusion coefficient as forward SDE."""
32 return self.forward_sde.diffusion_coef(t)
33
34 def sample(
35 self,
36 num_samples: int,
37 data_shape: tuple,
38 num_steps: int = 1000,
39 device: str = 'cuda'
40 ) -> torch.Tensor:
41 """
42 Generate samples by solving the reverse SDE.
43 Uses Euler-Maruyama discretization.
44 """
45 # Start from prior distribution (Gaussian noise)
46 x = torch.randn(num_samples, *data_shape, device=device)
47
48 # Time goes from T to 0
49 dt = self.forward_sde.T / num_steps
50 timesteps = torch.linspace(self.forward_sde.T, 0, num_steps + 1, device=device)
51
52 for i in range(num_steps):
53 t = timesteps[i]
54 t_batch = t.expand(num_samples)
55
56 # Reverse SDE step
57 drift = self.drift_coef(x, t_batch)
58 diffusion = self.diffusion_coef(t_batch)
59
60 # Euler-Maruyama: x_{t-dt} = x_t - drift * dt + diffusion * sqrt(dt) * z
61 z = torch.randn_like(x) if i < num_steps - 1 else torch.zeros_like(x)
62 x = x - drift * dt + diffusion.unsqueeze(-1) * np.sqrt(dt) * z
63
64 return xWhy the Reverse SDE Works:
- Forward SDE adds noise: drift shrinks data toward origin, diffusion adds randomness
- Score function points toward high density regions of p_t(x)
- g(t)^2 * score term counteracts the noise addition, effectively reversing it
- The reverse process generates samples by gradually removing noise
Types of Diffusion SDEs
Different choices of drift and diffusion coefficients lead to different diffusion model families, each with distinct properties.
Variance Preserving SDE (VP-SDE)
The VP-SDE, which underlies DDPM, preserves the total variance of the data during diffusion:
| Property | VP-SDE Formula |
|---|---|
| Forward SDE | dx = -0.5 * beta(t) * x * dt + sqrt(beta(t)) * dW |
| Marginal mean | E[x_t | x_0] = x_0 * exp(-0.5 * integral(beta)) |
| Marginal variance | Var[x_t | x_0] = 1 - exp(-integral(beta)) |
| Prior | N(0, I) as t -> infinity |
| Key property | Var[x_t] stays bounded (approximately 1) |
1class VPSDE:
2 """
3 Variance Preserving SDE (DDPM continuous-time limit).
4
5 dx = -0.5 * beta(t) * x * dt + sqrt(beta(t)) * dW
6
7 Properties:
8 - Data variance is preserved (approximately 1 throughout)
9 - Converges to N(0, I) prior
10 - Corresponds to DDPM with linear noise schedule
11 """
12 def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0):
13 self.beta_min = beta_min
14 self.beta_max = beta_max
15
16 def beta(self, t: torch.Tensor) -> torch.Tensor:
17 """Linear noise schedule."""
18 return self.beta_min + t * (self.beta_max - self.beta_min)
19
20 def sde(self, x: torch.Tensor, t: torch.Tensor):
21 """Return drift and diffusion coefficients."""
22 beta_t = self.beta(t).unsqueeze(-1)
23 drift = -0.5 * beta_t * x
24 diffusion = torch.sqrt(beta_t)
25 return drift, diffusion
26
27 def marginal_prob(self, x_0: torch.Tensor, t: torch.Tensor):
28 """Closed-form marginal p(x_t | x_0)."""
29 log_mean_coef = -0.25 * t ** 2 * (self.beta_max - self.beta_min) - 0.5 * t * self.beta_min
30 mean = x_0 * torch.exp(log_mean_coef).unsqueeze(-1)
31 std = torch.sqrt(1 - torch.exp(2 * log_mean_coef))
32 return mean, std
33
34 def prior_sampling(self, shape: tuple, device: str = 'cuda') -> torch.Tensor:
35 """Sample from prior N(0, I)."""
36 return torch.randn(*shape, device=device)Variance Exploding SDE (VE-SDE)
The VE-SDE, which underlies NCSN (Noise Conditional Score Networks), allows variance to grow unboundedly:
1class VESDE:
2 """
3 Variance Exploding SDE (NCSN continuous-time limit).
4
5 dx = sqrt(d[sigma^2(t)]/dt) * dW
6
7 Properties:
8 - No drift term (pure noise addition)
9 - Variance grows without bound
10 - Prior is N(0, sigma_max^2 * I)
11 - Corresponds to SMLD/NCSN noise levels
12 """
13 def __init__(self, sigma_min: float = 0.01, sigma_max: float = 50.0):
14 self.sigma_min = sigma_min
15 self.sigma_max = sigma_max
16
17 def sigma(self, t: torch.Tensor) -> torch.Tensor:
18 """Geometric schedule for noise level."""
19 return self.sigma_min * (self.sigma_max / self.sigma_min) ** t
20
21 def sde(self, x: torch.Tensor, t: torch.Tensor):
22 """Return drift (zero) and diffusion coefficients."""
23 sigma_t = self.sigma(t)
24 # d[sigma^2]/dt = 2 * sigma * d[sigma]/dt
25 # sigma(t) = sigma_min * r^t, where r = sigma_max/sigma_min
26 # d[sigma]/dt = sigma(t) * log(r)
27 drift = torch.zeros_like(x)
28 diffusion = sigma_t * np.sqrt(2 * np.log(self.sigma_max / self.sigma_min))
29 return drift, diffusion.unsqueeze(-1).expand_as(x)
30
31 def marginal_prob(self, x_0: torch.Tensor, t: torch.Tensor):
32 """Closed-form marginal p(x_t | x_0)."""
33 # Mean is just x_0 (no drift)
34 mean = x_0
35 # Std is sigma(t)
36 std = self.sigma(t)
37 return mean, std
38
39 def prior_sampling(self, shape: tuple, device: str = 'cuda') -> torch.Tensor:
40 """Sample from prior N(0, sigma_max^2 * I)."""
41 return torch.randn(*shape, device=device) * self.sigma_maxSub-VP SDE
The sub-VP SDE is a variant that maintains even tighter variance control:
1class SubVPSDE:
2 """
3 Sub-VP SDE: A variant with reduced variance growth.
4
5 dx = -0.5 * beta(t) * x * dt + sqrt(beta(t) * (1 - exp(-2 * integral(beta)))) * dW
6
7 Properties:
8 - Variance is always strictly less than 1
9 - Tighter control than VP-SDE
10 - Can improve likelihood bounds
11 """
12 def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0):
13 self.beta_min = beta_min
14 self.beta_max = beta_max
15
16 def beta(self, t: torch.Tensor) -> torch.Tensor:
17 return self.beta_min + t * (self.beta_max - self.beta_min)
18
19 def sde(self, x: torch.Tensor, t: torch.Tensor):
20 """Return drift and diffusion coefficients."""
21 beta_t = self.beta(t)
22
23 # Compute integral of beta from 0 to t
24 integral_beta = 0.5 * t ** 2 * (self.beta_max - self.beta_min) + t * self.beta_min
25
26 drift = -0.5 * beta_t.unsqueeze(-1) * x
27
28 # Sub-VP diffusion: sqrt(beta * (1 - exp(-2 * integral)))
29 discount = 1.0 - torch.exp(-2 * integral_beta)
30 diffusion = torch.sqrt(beta_t * discount)
31
32 return drift, diffusion.unsqueeze(-1).expand_as(x)
33
34 def marginal_prob(self, x_0: torch.Tensor, t: torch.Tensor):
35 """Closed-form marginal."""
36 integral_beta = 0.5 * t ** 2 * (self.beta_max - self.beta_min) + t * self.beta_min
37 mean_coef = torch.exp(-0.5 * integral_beta)
38 mean = x_0 * mean_coef.unsqueeze(-1)
39
40 # Variance is strictly less than 1 - mean_coef^2
41 std = torch.sqrt(1 - torch.exp(-integral_beta))
42 return mean, std| SDE Type | Drift f(x,t) | Diffusion g(t) | Prior |
|---|---|---|---|
| VP-SDE | -0.5 * beta(t) * x | sqrt(beta(t)) | N(0, I) |
| VE-SDE | 0 | sqrt(d[sigma^2]/dt) | N(0, sigma_max^2 * I) |
| Sub-VP | -0.5 * beta(t) * x | sqrt(beta * (1-exp(-2*int))) | N(0, I) |
SDE Trajectory Visualization
The following visualization shows how different SDE formulations create different diffusion trajectories. Observe how VP-SDE, VE-SDE, and sub-VP SDE each transform data to noise (forward) and recover data from noise (reverse) in distinct ways.
Probability Flow ODE
A remarkable property of the SDE framework is that every forward SDE has an associatedprobability flow ODE that produces the same marginal distributions but without stochastic noise:
Probability Flow ODE:This ODE is deterministic and shares marginal distributions with the SDE at all times.
Deterministic Sampling
The probability flow ODE enables deterministic sampling: given the same initial noise, you always get the same output. This is crucial for reproducibility and editing.
1class ProbabilityFlowODE:
2 """
3 Deterministic ODE with same marginals as the diffusion SDE.
4
5 dx = [f(x,t) - 0.5 * g(t)^2 * score(x,t)] dt
6
7 No stochasticity - purely deterministic flow.
8 """
9 def __init__(self, sde, score_model: nn.Module):
10 self.sde = sde
11 self.score_model = score_model
12
13 def drift(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
14 """ODE drift coefficient."""
15 # Get SDE coefficients
16 sde_drift, sde_diffusion = self.sde.sde(x, t)
17
18 # Get score
19 score = self.score_model(x, t)
20
21 # Probability flow drift
22 # f(x,t) - 0.5 * g(t)^2 * score(x,t)
23 g_sq = sde_diffusion ** 2
24 return sde_drift - 0.5 * g_sq * score
25
26 def sample_ode(
27 self,
28 x_T: torch.Tensor,
29 num_steps: int = 1000,
30 method: str = 'euler'
31 ) -> torch.Tensor:
32 """
33 Solve ODE from t=T to t=0 using specified method.
34
35 Args:
36 x_T: Initial condition (samples from prior)
37 num_steps: Number of ODE steps
38 method: 'euler' or 'rk4' (Runge-Kutta)
39 """
40 dt = -1.0 / num_steps # Negative because going backward in time
41 x = x_T.clone()
42
43 for step in range(num_steps):
44 t = 1.0 - step / num_steps
45
46 if method == 'euler':
47 x = x + self.drift(x, t) * (-dt)
48 elif method == 'rk4':
49 x = self._rk4_step(x, t, -dt)
50
51 return x
52
53 def _rk4_step(self, x: torch.Tensor, t: float, dt: float) -> torch.Tensor:
54 """Fourth-order Runge-Kutta step."""
55 k1 = self.drift(x, t)
56 k2 = self.drift(x + 0.5 * dt * k1, t + 0.5 * dt)
57 k3 = self.drift(x + 0.5 * dt * k2, t + 0.5 * dt)
58 k4 = self.drift(x + dt * k3, t + dt)
59 return x + (dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4)
60
61
62def deterministic_encode_decode(
63 ode: ProbabilityFlowODE,
64 x: torch.Tensor,
65 num_steps: int = 1000
66) -> torch.Tensor:
67 """
68 Encode data to latent space and decode back.
69 Should reconstruct original data (perfect cycle consistency).
70 """
71 # Forward: data -> noise
72 dt = 1.0 / num_steps
73 z = x.clone()
74 for step in range(num_steps):
75 t = step / num_steps
76 z = z + ode.drift(z, t) * dt
77
78 # Backward: noise -> data
79 x_recon = z.clone()
80 for step in range(num_steps):
81 t = 1.0 - step / num_steps
82 x_recon = x_recon + ode.drift(x_recon, t) * (-dt)
83
84 return x_reconExact Likelihood Computation
The ODE formulation enables exact log-likelihood computation via the change of variables formula (instantaneous change of variables):
1def compute_log_likelihood(
2 ode: ProbabilityFlowODE,
3 x: torch.Tensor,
4 num_steps: int = 1000,
5 hutchinson_samples: int = 10
6) -> torch.Tensor:
7 """
8 Compute log p(x) using the probability flow ODE.
9
10 Uses the instantaneous change of variables formula:
11 log p_0(x_0) = log p_T(x_T) + integral_0^T trace(df/dx) dt
12
13 The trace is estimated using Hutchinson's trace estimator.
14 """
15 batch_size = x.shape[0]
16 device = x.device
17
18 # Initialize
19 z = x.clone()
20 delta_log_p = torch.zeros(batch_size, device=device)
21
22 dt = 1.0 / num_steps
23
24 for step in range(num_steps):
25 t = step / num_steps
26 t_tensor = torch.full((batch_size,), t, device=device)
27
28 # Estimate trace of Jacobian using Hutchinson
29 z.requires_grad_(True)
30 drift = ode.drift(z, t_tensor)
31
32 trace_estimate = torch.zeros(batch_size, device=device)
33 for _ in range(hutchinson_samples):
34 # Random vector for Hutchinson estimator
35 v = torch.randn_like(z)
36
37 # Compute v^T * Jacobian * v
38 vjp = torch.autograd.grad(
39 drift, z,
40 grad_outputs=v,
41 create_graph=False,
42 retain_graph=True
43 )[0]
44 trace_estimate += (vjp * v).sum(dim=-1)
45
46 trace_estimate /= hutchinson_samples
47 z.requires_grad_(False)
48
49 # Update
50 z = z + drift.detach() * dt
51 delta_log_p += trace_estimate * dt
52
53 # Log probability under prior
54 log_p_prior = -0.5 * (z ** 2).sum(dim=-1) - 0.5 * z.shape[-1] * np.log(2 * np.pi)
55
56 # Total log probability
57 log_p_x = log_p_prior - delta_log_p
58
59 return log_p_xThis exact likelihood computation is a key advantage of the SDE/ODE framework over discrete diffusion models, which can only provide variational bounds.
Numerical SDE/ODE Solvers
Solving SDEs and ODEs numerically requires careful choice of discretization methods. The choice affects both sample quality and computational efficiency.
Euler-Maruyama Method
The simplest SDE solver, analogous to Euler's method for ODEs:
1def euler_maruyama_step(
2 x: torch.Tensor,
3 t: float,
4 dt: float,
5 sde,
6 score_model: nn.Module
7) -> torch.Tensor:
8 """
9 Single Euler-Maruyama step for reverse SDE.
10
11 x_{t-dt} = x_t - [f(x,t) - g(t)^2 * score(x,t)] * dt + g(t) * sqrt(dt) * z
12 """
13 batch_size = x.shape[0]
14 device = x.device
15
16 t_batch = torch.full((batch_size,), t, device=device)
17
18 # Get SDE coefficients
19 drift, diffusion = sde.sde(x, t_batch)
20
21 # Get score
22 score = score_model(x, t_batch)
23
24 # Reverse drift
25 reverse_drift = drift - diffusion ** 2 * score
26
27 # Stochastic noise
28 z = torch.randn_like(x) if t > dt else torch.zeros_like(x)
29
30 # Euler-Maruyama update (going backward in time)
31 x_new = x - reverse_drift * dt + diffusion * np.sqrt(dt) * z
32
33 return x_new
34
35
36def euler_maruyama_sampler(
37 sde,
38 score_model: nn.Module,
39 shape: tuple,
40 num_steps: int = 1000,
41 device: str = 'cuda'
42) -> torch.Tensor:
43 """Complete Euler-Maruyama sampling loop."""
44 # Initialize from prior
45 x = sde.prior_sampling(shape, device)
46
47 dt = 1.0 / num_steps
48 for step in range(num_steps):
49 t = 1.0 - step * dt
50 x = euler_maruyama_step(x, t, dt, sde, score_model)
51
52 return xPredictor-Corrector Samplers
Predictor-corrector methods combine an ODE/SDE solver (predictor) with Langevin MCMC steps (corrector) for improved sample quality:
1class PredictorCorrectorSampler:
2 """
3 Predictor-Corrector sampling for diffusion SDEs.
4
5 Alternates between:
6 1. Predictor: ODE/SDE step to move toward data distribution
7 2. Corrector: Langevin MCMC steps to refine samples
8
9 This improves sample quality by correcting numerical errors.
10 """
11 def __init__(
12 self,
13 sde,
14 score_model: nn.Module,
15 predictor: str = 'euler',
16 corrector: str = 'langevin',
17 snr: float = 0.16, # Signal-to-noise ratio for corrector
18 n_steps_corrector: int = 1
19 ):
20 self.sde = sde
21 self.score_model = score_model
22 self.predictor = predictor
23 self.corrector = corrector
24 self.snr = snr
25 self.n_steps_corrector = n_steps_corrector
26
27 def predictor_step(
28 self,
29 x: torch.Tensor,
30 t: float,
31 dt: float
32 ) -> torch.Tensor:
33 """Predictor step (reverse SDE)."""
34 batch_size = x.shape[0]
35 device = x.device
36 t_batch = torch.full((batch_size,), t, device=device)
37
38 drift, diffusion = self.sde.sde(x, t_batch)
39 score = self.score_model(x, t_batch)
40
41 reverse_drift = drift - diffusion ** 2 * score
42 z = torch.randn_like(x) if t > dt else torch.zeros_like(x)
43
44 return x - reverse_drift * dt + diffusion * np.sqrt(dt) * z
45
46 def corrector_step(
47 self,
48 x: torch.Tensor,
49 t: float
50 ) -> torch.Tensor:
51 """Langevin corrector step."""
52 batch_size = x.shape[0]
53 device = x.device
54 t_batch = torch.full((batch_size,), t, device=device)
55
56 for _ in range(self.n_steps_corrector):
57 # Get score
58 score = self.score_model(x, t_batch)
59
60 # Langevin step size based on score magnitude
61 grad_norm = torch.norm(score.view(batch_size, -1), dim=-1).mean()
62 step_size = (self.snr / grad_norm) ** 2 * 2
63
64 # Langevin update
65 z = torch.randn_like(x)
66 x = x + step_size * score + np.sqrt(2 * step_size) * z
67
68 return x
69
70 def sample(
71 self,
72 shape: tuple,
73 num_steps: int = 1000,
74 device: str = 'cuda'
75 ) -> torch.Tensor:
76 """Generate samples using predictor-corrector method."""
77 # Initialize from prior
78 x = self.sde.prior_sampling(shape, device)
79
80 dt = 1.0 / num_steps
81 for step in range(num_steps):
82 t = 1.0 - step * dt
83
84 # Corrector
85 x = self.corrector_step(x, t)
86
87 # Predictor
88 x = self.predictor_step(x, t, dt)
89
90 return xImplementation
Here's a complete SDE-based diffusion model implementation:
1class SDEDiffusionModel:
2 """
3 Complete SDE-based diffusion model.
4
5 Supports:
6 - Multiple SDE types (VP, VE, sub-VP)
7 - Multiple samplers (Euler-Maruyama, PC, ODE)
8 - Exact likelihood computation
9 """
10 def __init__(
11 self,
12 sde_type: str = 'vpsde',
13 beta_min: float = 0.1,
14 beta_max: float = 20.0,
15 sigma_min: float = 0.01,
16 sigma_max: float = 50.0
17 ):
18 # Initialize SDE
19 if sde_type == 'vpsde':
20 self.sde = VPSDE(beta_min, beta_max)
21 elif sde_type == 'vesde':
22 self.sde = VESDE(sigma_min, sigma_max)
23 elif sde_type == 'subvpsde':
24 self.sde = SubVPSDE(beta_min, beta_max)
25 else:
26 raise ValueError(f"Unknown SDE type: {sde_type}")
27
28 self.score_model = None
29
30 def set_score_model(self, model: nn.Module):
31 """Set trained score model."""
32 self.score_model = model
33
34 def train_step(
35 self,
36 x: torch.Tensor,
37 optimizer: torch.optim.Optimizer
38 ) -> float:
39 """Single training step with DSM objective."""
40 batch_size = x.shape[0]
41 device = x.device
42
43 # Sample random times uniformly
44 t = torch.rand(batch_size, device=device)
45
46 # Get noisy samples
47 mean, std = self.sde.marginal_prob(x, t)
48 z = torch.randn_like(x)
49 x_t = mean + std.unsqueeze(-1) * z
50
51 # Score target: -z / std
52 score_target = -z / std.unsqueeze(-1)
53
54 # Predict score
55 score_pred = self.score_model(x_t, t)
56
57 # DSM loss (weighted by std^2)
58 loss = 0.5 * ((score_pred - score_target) ** 2 * std.unsqueeze(-1) ** 2)
59 loss = loss.sum(dim=-1).mean()
60
61 optimizer.zero_grad()
62 loss.backward()
63 optimizer.step()
64
65 return loss.item()
66
67 @torch.no_grad()
68 def sample(
69 self,
70 shape: tuple,
71 num_steps: int = 1000,
72 sampler: str = 'euler',
73 device: str = 'cuda'
74 ) -> torch.Tensor:
75 """Generate samples."""
76 self.score_model.eval()
77
78 if sampler == 'euler':
79 return euler_maruyama_sampler(
80 self.sde, self.score_model, shape, num_steps, device
81 )
82 elif sampler == 'pc':
83 pc_sampler = PredictorCorrectorSampler(
84 self.sde, self.score_model
85 )
86 return pc_sampler.sample(shape, num_steps, device)
87 elif sampler == 'ode':
88 ode = ProbabilityFlowODE(self.sde, self.score_model)
89 x_T = self.sde.prior_sampling(shape, device)
90 return ode.sample_ode(x_T, num_steps)
91 else:
92 raise ValueError(f"Unknown sampler: {sampler}")
93
94 @torch.no_grad()
95 def log_likelihood(
96 self,
97 x: torch.Tensor,
98 num_steps: int = 1000
99 ) -> torch.Tensor:
100 """Compute exact log-likelihood."""
101 self.score_model.eval()
102 ode = ProbabilityFlowODE(self.sde, self.score_model)
103 return compute_log_likelihood(ode, x, num_steps)Summary
The SDE framework provides a powerful continuous-time perspective on diffusion models with several key advantages:
- Unified framework: VP-SDE, VE-SDE, and sub-VP SDE are all special cases of the general SDE formulation, enabling principled comparison and combination.
- Reverse SDE: Anderson's theorem guarantees that any forward diffusion can be reversed given only the score function.
- Probability flow ODE: The deterministic ODE counterpart enables exact likelihood computation and reproducible generation.
- Flexible samplers: Euler-Maruyama, predictor-corrector, and ODE solvers provide different quality-speed trade-offs.
- Theoretical foundation: Continuous-time analysis provides rigorous understanding of convergence, training dynamics, and sample quality.
In the next section, we'll synthesize the score-based and SDE perspectives into a unified framework that connects all major diffusion model variants and reveals their fundamental relationships.