Chapter 15
25 min read
Section 68 of 76

Diffusion as SDEs

Score-Based and SDE Perspective

Introduction

The discrete-time formulation of diffusion models (DDPM) is actually a discretization of an underlying continuous-time process. By viewing diffusion through the lens ofstochastic differential equations (SDEs), we gain powerful theoretical tools and practical sampling algorithms that transcend specific discrete schedules.

This section develops the SDE perspective on diffusion models, showing how the forward noising process and reverse denoising process can be described by elegant continuous-time equations. This viewpoint unifies different diffusion variants, enables exact likelihood computation, and opens pathways to more efficient sampling.


Continuous-Time Diffusion

In the discrete DDPM formulation, we have TT timesteps with a fixed noise schedule. But what happens as TT \to \infty and the step size approaches zero?

From Discrete to Continuous

Consider the discrete forward process:

Discrete Forward Process:xt=αtxt1+1αtϵtx_t = \sqrt{\alpha_t} x_{t-1} + \sqrt{1 - \alpha_t} \epsilon_t

As step size Δt0\Delta t \to 0, this becomes a continuous stochastic process described by an SDE.

🐍python
1# Intuition: Discrete to Continuous Transition
2"""
3Discrete (DDPM):
4    x_t = sqrt(alpha_t) * x_{t-1} + sqrt(1 - alpha_t) * epsilon_t
5
6Let dt = 1/T and t = k/T for k = 0, 1, ..., T
7
8As T -> infinity (dt -> 0):
9    x_{t+dt} - x_t = f(x_t, t) * dt + g(t) * sqrt(dt) * z
10
11Where z ~ N(0, I) and:
12    - f(x, t) is the drift coefficient (deterministic part)
13    - g(t) is the diffusion coefficient (stochastic part)
14
15This is an Ito SDE: dx = f(x, t)dt + g(t)dW
16where dW is the Wiener process (Brownian motion)
17"""
18
19import torch
20import numpy as np
21
22def discrete_to_continuous_params(alpha_schedule: torch.Tensor, num_steps: int) -> dict:
23    """
24    Convert discrete schedule to continuous SDE parameters.
25
26    For VP-SDE: dx = -0.5 * beta(t) * x * dt + sqrt(beta(t)) * dW
27
28    Returns beta(t) function parameters.
29    """
30    # Compute cumulative alphas
31    alpha_bar = torch.cumprod(alpha_schedule, dim=0)
32
33    # Convert to continuous time t in [0, 1]
34    t = torch.linspace(0, 1, num_steps)
35
36    # log(alpha_bar(t)) = integral of -0.5 * beta(s) ds from 0 to t
37    # So beta(t) = -2 * d/dt[log(alpha_bar(t))]
38    log_alpha_bar = torch.log(alpha_bar)
39
40    # Numerical differentiation
41    beta = -2 * torch.gradient(log_alpha_bar, spacing=(t,))[0]
42
43    return {
44        'beta': beta,
45        't': t,
46        'alpha_bar': alpha_bar
47    }

The continuous formulation reveals that discrete diffusion models are numerical approximations to underlying SDEs. Different discretization choices lead to different algorithms with different trade-offs.


SDE Formulation

Stochastic differential equations describe continuous-time random processes. The general form of an SDE is:

General Ito SDE:dx=f(x,t)dt+g(t)dWdx = f(x, t) dt + g(t) dW

where f(x,t)f(x, t) is the drift coefficient,g(t)g(t) is the diffusion coefficient, anddWdW is the infinitesimal Wiener process increment.

The Forward SDE

The forward diffusion process that gradually adds noise can be written as:

🐍python
1class ForwardSDE:
2    """
3    Forward SDE: dx = f(x, t)dt + g(t)dW
4
5    Transforms data distribution p_data(x) into prior distribution p_T(x).
6    """
7    def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0, T: float = 1.0):
8        self.beta_min = beta_min
9        self.beta_max = beta_max
10        self.T = T
11
12    def beta(self, t: torch.Tensor) -> torch.Tensor:
13        """Time-dependent noise schedule beta(t)."""
14        # Linear schedule in continuous time
15        return self.beta_min + t * (self.beta_max - self.beta_min)
16
17    def drift_coef(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
18        """
19        Drift coefficient f(x, t) for VP-SDE.
20        f(x, t) = -0.5 * beta(t) * x
21        """
22        return -0.5 * self.beta(t).unsqueeze(-1) * x
23
24    def diffusion_coef(self, t: torch.Tensor) -> torch.Tensor:
25        """
26        Diffusion coefficient g(t) for VP-SDE.
27        g(t) = sqrt(beta(t))
28        """
29        return torch.sqrt(self.beta(t))
30
31    def marginal_prob(self, x_0: torch.Tensor, t: torch.Tensor) -> tuple:
32        """
33        Compute mean and std of p(x_t | x_0).
34
35        For VP-SDE:
36            mean = x_0 * exp(-0.5 * integral(beta(s), 0, t))
37            std = sqrt(1 - exp(-integral(beta(s), 0, t)))
38        """
39        # Integral of beta(s) from 0 to t
40        log_coef = -0.25 * t ** 2 * (self.beta_max - self.beta_min) - 0.5 * t * self.beta_min
41        mean_coef = torch.exp(log_coef)
42        std = torch.sqrt(1.0 - torch.exp(2.0 * log_coef))
43
44        mean = x_0 * mean_coef.unsqueeze(-1)
45        return mean, std
46
47    def sample_x_t(self, x_0: torch.Tensor, t: torch.Tensor) -> tuple:
48        """Sample from p(x_t | x_0)."""
49        mean, std = self.marginal_prob(x_0, t)
50        z = torch.randn_like(x_0)
51        x_t = mean + std.unsqueeze(-1) * z
52        return x_t, z

The forward SDE has the remarkable property that its marginal distributions can be computed in closed form. Given initial point x0x_0, we can directly sample xtx_t at any time ttwithout simulating the entire trajectory.

The Reverse SDE

The key theoretical result enabling diffusion models is Anderson's Theorem: any forward SDE has a corresponding reverse SDE that runs time backwards:

Reverse SDE (Anderson, 1982):dx=[f(x,t)g(t)2xlogpt(x)]dt+g(t)dWˉdx = [f(x, t) - g(t)^2 \nabla_x \log p_t(x)] dt + g(t) d\bar{W}

where dWˉd\bar{W} is the reverse-time Wiener process. The scorexlogpt(x)\nabla_x \log p_t(x) is the only unknown quantity.

🐍python
1class ReverseSDE:
2    """
3    Reverse SDE for generation.
4
5    Given a trained score model s_theta(x, t) ≈ nabla_x log p_t(x),
6    we can reverse the forward process to generate samples.
7    """
8    def __init__(self, forward_sde: ForwardSDE, score_model: nn.Module):
9        self.forward_sde = forward_sde
10        self.score_model = score_model
11
12    def drift_coef(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
13        """
14        Reverse drift: f(x, t) - g(t)^2 * score(x, t)
15
16        The score term reverses the diffusion!
17        """
18        # Forward drift
19        f = self.forward_sde.drift_coef(x, t)
20
21        # Diffusion coefficient squared
22        g_sq = self.forward_sde.beta(t).unsqueeze(-1)
23
24        # Learned score
25        score = self.score_model(x, t)
26
27        # Reverse drift
28        return f - g_sq * score
29
30    def diffusion_coef(self, t: torch.Tensor) -> torch.Tensor:
31        """Same diffusion coefficient as forward SDE."""
32        return self.forward_sde.diffusion_coef(t)
33
34    def sample(
35        self,
36        num_samples: int,
37        data_shape: tuple,
38        num_steps: int = 1000,
39        device: str = 'cuda'
40    ) -> torch.Tensor:
41        """
42        Generate samples by solving the reverse SDE.
43        Uses Euler-Maruyama discretization.
44        """
45        # Start from prior distribution (Gaussian noise)
46        x = torch.randn(num_samples, *data_shape, device=device)
47
48        # Time goes from T to 0
49        dt = self.forward_sde.T / num_steps
50        timesteps = torch.linspace(self.forward_sde.T, 0, num_steps + 1, device=device)
51
52        for i in range(num_steps):
53            t = timesteps[i]
54            t_batch = t.expand(num_samples)
55
56            # Reverse SDE step
57            drift = self.drift_coef(x, t_batch)
58            diffusion = self.diffusion_coef(t_batch)
59
60            # Euler-Maruyama: x_{t-dt} = x_t - drift * dt + diffusion * sqrt(dt) * z
61            z = torch.randn_like(x) if i < num_steps - 1 else torch.zeros_like(x)
62            x = x - drift * dt + diffusion.unsqueeze(-1) * np.sqrt(dt) * z
63
64        return x
Why the Reverse SDE Works:
  • Forward SDE adds noise: drift shrinks data toward origin, diffusion adds randomness
  • Score function points toward high density regions of p_t(x)
  • g(t)^2 * score term counteracts the noise addition, effectively reversing it
  • The reverse process generates samples by gradually removing noise

Types of Diffusion SDEs

Different choices of drift and diffusion coefficients lead to different diffusion model families, each with distinct properties.

Variance Preserving SDE (VP-SDE)

The VP-SDE, which underlies DDPM, preserves the total variance of the data during diffusion:

PropertyVP-SDE Formula
Forward SDEdx = -0.5 * beta(t) * x * dt + sqrt(beta(t)) * dW
Marginal meanE[x_t | x_0] = x_0 * exp(-0.5 * integral(beta))
Marginal varianceVar[x_t | x_0] = 1 - exp(-integral(beta))
PriorN(0, I) as t -> infinity
Key propertyVar[x_t] stays bounded (approximately 1)
🐍python
1class VPSDE:
2    """
3    Variance Preserving SDE (DDPM continuous-time limit).
4
5    dx = -0.5 * beta(t) * x * dt + sqrt(beta(t)) * dW
6
7    Properties:
8    - Data variance is preserved (approximately 1 throughout)
9    - Converges to N(0, I) prior
10    - Corresponds to DDPM with linear noise schedule
11    """
12    def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0):
13        self.beta_min = beta_min
14        self.beta_max = beta_max
15
16    def beta(self, t: torch.Tensor) -> torch.Tensor:
17        """Linear noise schedule."""
18        return self.beta_min + t * (self.beta_max - self.beta_min)
19
20    def sde(self, x: torch.Tensor, t: torch.Tensor):
21        """Return drift and diffusion coefficients."""
22        beta_t = self.beta(t).unsqueeze(-1)
23        drift = -0.5 * beta_t * x
24        diffusion = torch.sqrt(beta_t)
25        return drift, diffusion
26
27    def marginal_prob(self, x_0: torch.Tensor, t: torch.Tensor):
28        """Closed-form marginal p(x_t | x_0)."""
29        log_mean_coef = -0.25 * t ** 2 * (self.beta_max - self.beta_min) - 0.5 * t * self.beta_min
30        mean = x_0 * torch.exp(log_mean_coef).unsqueeze(-1)
31        std = torch.sqrt(1 - torch.exp(2 * log_mean_coef))
32        return mean, std
33
34    def prior_sampling(self, shape: tuple, device: str = 'cuda') -> torch.Tensor:
35        """Sample from prior N(0, I)."""
36        return torch.randn(*shape, device=device)

Variance Exploding SDE (VE-SDE)

The VE-SDE, which underlies NCSN (Noise Conditional Score Networks), allows variance to grow unboundedly:

🐍python
1class VESDE:
2    """
3    Variance Exploding SDE (NCSN continuous-time limit).
4
5    dx = sqrt(d[sigma^2(t)]/dt) * dW
6
7    Properties:
8    - No drift term (pure noise addition)
9    - Variance grows without bound
10    - Prior is N(0, sigma_max^2 * I)
11    - Corresponds to SMLD/NCSN noise levels
12    """
13    def __init__(self, sigma_min: float = 0.01, sigma_max: float = 50.0):
14        self.sigma_min = sigma_min
15        self.sigma_max = sigma_max
16
17    def sigma(self, t: torch.Tensor) -> torch.Tensor:
18        """Geometric schedule for noise level."""
19        return self.sigma_min * (self.sigma_max / self.sigma_min) ** t
20
21    def sde(self, x: torch.Tensor, t: torch.Tensor):
22        """Return drift (zero) and diffusion coefficients."""
23        sigma_t = self.sigma(t)
24        # d[sigma^2]/dt = 2 * sigma * d[sigma]/dt
25        # sigma(t) = sigma_min * r^t, where r = sigma_max/sigma_min
26        # d[sigma]/dt = sigma(t) * log(r)
27        drift = torch.zeros_like(x)
28        diffusion = sigma_t * np.sqrt(2 * np.log(self.sigma_max / self.sigma_min))
29        return drift, diffusion.unsqueeze(-1).expand_as(x)
30
31    def marginal_prob(self, x_0: torch.Tensor, t: torch.Tensor):
32        """Closed-form marginal p(x_t | x_0)."""
33        # Mean is just x_0 (no drift)
34        mean = x_0
35        # Std is sigma(t)
36        std = self.sigma(t)
37        return mean, std
38
39    def prior_sampling(self, shape: tuple, device: str = 'cuda') -> torch.Tensor:
40        """Sample from prior N(0, sigma_max^2 * I)."""
41        return torch.randn(*shape, device=device) * self.sigma_max

Sub-VP SDE

The sub-VP SDE is a variant that maintains even tighter variance control:

🐍python
1class SubVPSDE:
2    """
3    Sub-VP SDE: A variant with reduced variance growth.
4
5    dx = -0.5 * beta(t) * x * dt + sqrt(beta(t) * (1 - exp(-2 * integral(beta)))) * dW
6
7    Properties:
8    - Variance is always strictly less than 1
9    - Tighter control than VP-SDE
10    - Can improve likelihood bounds
11    """
12    def __init__(self, beta_min: float = 0.1, beta_max: float = 20.0):
13        self.beta_min = beta_min
14        self.beta_max = beta_max
15
16    def beta(self, t: torch.Tensor) -> torch.Tensor:
17        return self.beta_min + t * (self.beta_max - self.beta_min)
18
19    def sde(self, x: torch.Tensor, t: torch.Tensor):
20        """Return drift and diffusion coefficients."""
21        beta_t = self.beta(t)
22
23        # Compute integral of beta from 0 to t
24        integral_beta = 0.5 * t ** 2 * (self.beta_max - self.beta_min) + t * self.beta_min
25
26        drift = -0.5 * beta_t.unsqueeze(-1) * x
27
28        # Sub-VP diffusion: sqrt(beta * (1 - exp(-2 * integral)))
29        discount = 1.0 - torch.exp(-2 * integral_beta)
30        diffusion = torch.sqrt(beta_t * discount)
31
32        return drift, diffusion.unsqueeze(-1).expand_as(x)
33
34    def marginal_prob(self, x_0: torch.Tensor, t: torch.Tensor):
35        """Closed-form marginal."""
36        integral_beta = 0.5 * t ** 2 * (self.beta_max - self.beta_min) + t * self.beta_min
37        mean_coef = torch.exp(-0.5 * integral_beta)
38        mean = x_0 * mean_coef.unsqueeze(-1)
39
40        # Variance is strictly less than 1 - mean_coef^2
41        std = torch.sqrt(1 - torch.exp(-integral_beta))
42        return mean, std
SDE TypeDrift f(x,t)Diffusion g(t)Prior
VP-SDE-0.5 * beta(t) * xsqrt(beta(t))N(0, I)
VE-SDE0sqrt(d[sigma^2]/dt)N(0, sigma_max^2 * I)
Sub-VP-0.5 * beta(t) * xsqrt(beta * (1-exp(-2*int)))N(0, I)

SDE Trajectory Visualization

The following visualization shows how different SDE formulations create different diffusion trajectories. Observe how VP-SDE, VE-SDE, and sub-VP SDE each transform data to noise (forward) and recover data from noise (reverse) in distinct ways.

Loading visualization...

Probability Flow ODE

A remarkable property of the SDE framework is that every forward SDE has an associatedprobability flow ODE that produces the same marginal distributions but without stochastic noise:

Probability Flow ODE:dx=[f(x,t)12g(t)2xlogpt(x)]dtdx = \left[ f(x, t) - \frac{1}{2} g(t)^2 \nabla_x \log p_t(x) \right] dt

This ODE is deterministic and shares marginal distributions with the SDE at all times.

Deterministic Sampling

The probability flow ODE enables deterministic sampling: given the same initial noise, you always get the same output. This is crucial for reproducibility and editing.

🐍python
1class ProbabilityFlowODE:
2    """
3    Deterministic ODE with same marginals as the diffusion SDE.
4
5    dx = [f(x,t) - 0.5 * g(t)^2 * score(x,t)] dt
6
7    No stochasticity - purely deterministic flow.
8    """
9    def __init__(self, sde, score_model: nn.Module):
10        self.sde = sde
11        self.score_model = score_model
12
13    def drift(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
14        """ODE drift coefficient."""
15        # Get SDE coefficients
16        sde_drift, sde_diffusion = self.sde.sde(x, t)
17
18        # Get score
19        score = self.score_model(x, t)
20
21        # Probability flow drift
22        # f(x,t) - 0.5 * g(t)^2 * score(x,t)
23        g_sq = sde_diffusion ** 2
24        return sde_drift - 0.5 * g_sq * score
25
26    def sample_ode(
27        self,
28        x_T: torch.Tensor,
29        num_steps: int = 1000,
30        method: str = 'euler'
31    ) -> torch.Tensor:
32        """
33        Solve ODE from t=T to t=0 using specified method.
34
35        Args:
36            x_T: Initial condition (samples from prior)
37            num_steps: Number of ODE steps
38            method: 'euler' or 'rk4' (Runge-Kutta)
39        """
40        dt = -1.0 / num_steps  # Negative because going backward in time
41        x = x_T.clone()
42
43        for step in range(num_steps):
44            t = 1.0 - step / num_steps
45
46            if method == 'euler':
47                x = x + self.drift(x, t) * (-dt)
48            elif method == 'rk4':
49                x = self._rk4_step(x, t, -dt)
50
51        return x
52
53    def _rk4_step(self, x: torch.Tensor, t: float, dt: float) -> torch.Tensor:
54        """Fourth-order Runge-Kutta step."""
55        k1 = self.drift(x, t)
56        k2 = self.drift(x + 0.5 * dt * k1, t + 0.5 * dt)
57        k3 = self.drift(x + 0.5 * dt * k2, t + 0.5 * dt)
58        k4 = self.drift(x + dt * k3, t + dt)
59        return x + (dt / 6) * (k1 + 2 * k2 + 2 * k3 + k4)
60
61
62def deterministic_encode_decode(
63    ode: ProbabilityFlowODE,
64    x: torch.Tensor,
65    num_steps: int = 1000
66) -> torch.Tensor:
67    """
68    Encode data to latent space and decode back.
69    Should reconstruct original data (perfect cycle consistency).
70    """
71    # Forward: data -> noise
72    dt = 1.0 / num_steps
73    z = x.clone()
74    for step in range(num_steps):
75        t = step / num_steps
76        z = z + ode.drift(z, t) * dt
77
78    # Backward: noise -> data
79    x_recon = z.clone()
80    for step in range(num_steps):
81        t = 1.0 - step / num_steps
82        x_recon = x_recon + ode.drift(x_recon, t) * (-dt)
83
84    return x_recon

Exact Likelihood Computation

The ODE formulation enables exact log-likelihood computation via the change of variables formula (instantaneous change of variables):

🐍python
1def compute_log_likelihood(
2    ode: ProbabilityFlowODE,
3    x: torch.Tensor,
4    num_steps: int = 1000,
5    hutchinson_samples: int = 10
6) -> torch.Tensor:
7    """
8    Compute log p(x) using the probability flow ODE.
9
10    Uses the instantaneous change of variables formula:
11    log p_0(x_0) = log p_T(x_T) + integral_0^T trace(df/dx) dt
12
13    The trace is estimated using Hutchinson's trace estimator.
14    """
15    batch_size = x.shape[0]
16    device = x.device
17
18    # Initialize
19    z = x.clone()
20    delta_log_p = torch.zeros(batch_size, device=device)
21
22    dt = 1.0 / num_steps
23
24    for step in range(num_steps):
25        t = step / num_steps
26        t_tensor = torch.full((batch_size,), t, device=device)
27
28        # Estimate trace of Jacobian using Hutchinson
29        z.requires_grad_(True)
30        drift = ode.drift(z, t_tensor)
31
32        trace_estimate = torch.zeros(batch_size, device=device)
33        for _ in range(hutchinson_samples):
34            # Random vector for Hutchinson estimator
35            v = torch.randn_like(z)
36
37            # Compute v^T * Jacobian * v
38            vjp = torch.autograd.grad(
39                drift, z,
40                grad_outputs=v,
41                create_graph=False,
42                retain_graph=True
43            )[0]
44            trace_estimate += (vjp * v).sum(dim=-1)
45
46        trace_estimate /= hutchinson_samples
47        z.requires_grad_(False)
48
49        # Update
50        z = z + drift.detach() * dt
51        delta_log_p += trace_estimate * dt
52
53    # Log probability under prior
54    log_p_prior = -0.5 * (z ** 2).sum(dim=-1) - 0.5 * z.shape[-1] * np.log(2 * np.pi)
55
56    # Total log probability
57    log_p_x = log_p_prior - delta_log_p
58
59    return log_p_x

This exact likelihood computation is a key advantage of the SDE/ODE framework over discrete diffusion models, which can only provide variational bounds.


Numerical SDE/ODE Solvers

Solving SDEs and ODEs numerically requires careful choice of discretization methods. The choice affects both sample quality and computational efficiency.

Euler-Maruyama Method

The simplest SDE solver, analogous to Euler's method for ODEs:

🐍python
1def euler_maruyama_step(
2    x: torch.Tensor,
3    t: float,
4    dt: float,
5    sde,
6    score_model: nn.Module
7) -> torch.Tensor:
8    """
9    Single Euler-Maruyama step for reverse SDE.
10
11    x_{t-dt} = x_t - [f(x,t) - g(t)^2 * score(x,t)] * dt + g(t) * sqrt(dt) * z
12    """
13    batch_size = x.shape[0]
14    device = x.device
15
16    t_batch = torch.full((batch_size,), t, device=device)
17
18    # Get SDE coefficients
19    drift, diffusion = sde.sde(x, t_batch)
20
21    # Get score
22    score = score_model(x, t_batch)
23
24    # Reverse drift
25    reverse_drift = drift - diffusion ** 2 * score
26
27    # Stochastic noise
28    z = torch.randn_like(x) if t > dt else torch.zeros_like(x)
29
30    # Euler-Maruyama update (going backward in time)
31    x_new = x - reverse_drift * dt + diffusion * np.sqrt(dt) * z
32
33    return x_new
34
35
36def euler_maruyama_sampler(
37    sde,
38    score_model: nn.Module,
39    shape: tuple,
40    num_steps: int = 1000,
41    device: str = 'cuda'
42) -> torch.Tensor:
43    """Complete Euler-Maruyama sampling loop."""
44    # Initialize from prior
45    x = sde.prior_sampling(shape, device)
46
47    dt = 1.0 / num_steps
48    for step in range(num_steps):
49        t = 1.0 - step * dt
50        x = euler_maruyama_step(x, t, dt, sde, score_model)
51
52    return x

Predictor-Corrector Samplers

Predictor-corrector methods combine an ODE/SDE solver (predictor) with Langevin MCMC steps (corrector) for improved sample quality:

🐍python
1class PredictorCorrectorSampler:
2    """
3    Predictor-Corrector sampling for diffusion SDEs.
4
5    Alternates between:
6    1. Predictor: ODE/SDE step to move toward data distribution
7    2. Corrector: Langevin MCMC steps to refine samples
8
9    This improves sample quality by correcting numerical errors.
10    """
11    def __init__(
12        self,
13        sde,
14        score_model: nn.Module,
15        predictor: str = 'euler',
16        corrector: str = 'langevin',
17        snr: float = 0.16,  # Signal-to-noise ratio for corrector
18        n_steps_corrector: int = 1
19    ):
20        self.sde = sde
21        self.score_model = score_model
22        self.predictor = predictor
23        self.corrector = corrector
24        self.snr = snr
25        self.n_steps_corrector = n_steps_corrector
26
27    def predictor_step(
28        self,
29        x: torch.Tensor,
30        t: float,
31        dt: float
32    ) -> torch.Tensor:
33        """Predictor step (reverse SDE)."""
34        batch_size = x.shape[0]
35        device = x.device
36        t_batch = torch.full((batch_size,), t, device=device)
37
38        drift, diffusion = self.sde.sde(x, t_batch)
39        score = self.score_model(x, t_batch)
40
41        reverse_drift = drift - diffusion ** 2 * score
42        z = torch.randn_like(x) if t > dt else torch.zeros_like(x)
43
44        return x - reverse_drift * dt + diffusion * np.sqrt(dt) * z
45
46    def corrector_step(
47        self,
48        x: torch.Tensor,
49        t: float
50    ) -> torch.Tensor:
51        """Langevin corrector step."""
52        batch_size = x.shape[0]
53        device = x.device
54        t_batch = torch.full((batch_size,), t, device=device)
55
56        for _ in range(self.n_steps_corrector):
57            # Get score
58            score = self.score_model(x, t_batch)
59
60            # Langevin step size based on score magnitude
61            grad_norm = torch.norm(score.view(batch_size, -1), dim=-1).mean()
62            step_size = (self.snr / grad_norm) ** 2 * 2
63
64            # Langevin update
65            z = torch.randn_like(x)
66            x = x + step_size * score + np.sqrt(2 * step_size) * z
67
68        return x
69
70    def sample(
71        self,
72        shape: tuple,
73        num_steps: int = 1000,
74        device: str = 'cuda'
75    ) -> torch.Tensor:
76        """Generate samples using predictor-corrector method."""
77        # Initialize from prior
78        x = self.sde.prior_sampling(shape, device)
79
80        dt = 1.0 / num_steps
81        for step in range(num_steps):
82            t = 1.0 - step * dt
83
84            # Corrector
85            x = self.corrector_step(x, t)
86
87            # Predictor
88            x = self.predictor_step(x, t, dt)
89
90        return x

Implementation

Here's a complete SDE-based diffusion model implementation:

🐍python
1class SDEDiffusionModel:
2    """
3    Complete SDE-based diffusion model.
4
5    Supports:
6    - Multiple SDE types (VP, VE, sub-VP)
7    - Multiple samplers (Euler-Maruyama, PC, ODE)
8    - Exact likelihood computation
9    """
10    def __init__(
11        self,
12        sde_type: str = 'vpsde',
13        beta_min: float = 0.1,
14        beta_max: float = 20.0,
15        sigma_min: float = 0.01,
16        sigma_max: float = 50.0
17    ):
18        # Initialize SDE
19        if sde_type == 'vpsde':
20            self.sde = VPSDE(beta_min, beta_max)
21        elif sde_type == 'vesde':
22            self.sde = VESDE(sigma_min, sigma_max)
23        elif sde_type == 'subvpsde':
24            self.sde = SubVPSDE(beta_min, beta_max)
25        else:
26            raise ValueError(f"Unknown SDE type: {sde_type}")
27
28        self.score_model = None
29
30    def set_score_model(self, model: nn.Module):
31        """Set trained score model."""
32        self.score_model = model
33
34    def train_step(
35        self,
36        x: torch.Tensor,
37        optimizer: torch.optim.Optimizer
38    ) -> float:
39        """Single training step with DSM objective."""
40        batch_size = x.shape[0]
41        device = x.device
42
43        # Sample random times uniformly
44        t = torch.rand(batch_size, device=device)
45
46        # Get noisy samples
47        mean, std = self.sde.marginal_prob(x, t)
48        z = torch.randn_like(x)
49        x_t = mean + std.unsqueeze(-1) * z
50
51        # Score target: -z / std
52        score_target = -z / std.unsqueeze(-1)
53
54        # Predict score
55        score_pred = self.score_model(x_t, t)
56
57        # DSM loss (weighted by std^2)
58        loss = 0.5 * ((score_pred - score_target) ** 2 * std.unsqueeze(-1) ** 2)
59        loss = loss.sum(dim=-1).mean()
60
61        optimizer.zero_grad()
62        loss.backward()
63        optimizer.step()
64
65        return loss.item()
66
67    @torch.no_grad()
68    def sample(
69        self,
70        shape: tuple,
71        num_steps: int = 1000,
72        sampler: str = 'euler',
73        device: str = 'cuda'
74    ) -> torch.Tensor:
75        """Generate samples."""
76        self.score_model.eval()
77
78        if sampler == 'euler':
79            return euler_maruyama_sampler(
80                self.sde, self.score_model, shape, num_steps, device
81            )
82        elif sampler == 'pc':
83            pc_sampler = PredictorCorrectorSampler(
84                self.sde, self.score_model
85            )
86            return pc_sampler.sample(shape, num_steps, device)
87        elif sampler == 'ode':
88            ode = ProbabilityFlowODE(self.sde, self.score_model)
89            x_T = self.sde.prior_sampling(shape, device)
90            return ode.sample_ode(x_T, num_steps)
91        else:
92            raise ValueError(f"Unknown sampler: {sampler}")
93
94    @torch.no_grad()
95    def log_likelihood(
96        self,
97        x: torch.Tensor,
98        num_steps: int = 1000
99    ) -> torch.Tensor:
100        """Compute exact log-likelihood."""
101        self.score_model.eval()
102        ode = ProbabilityFlowODE(self.sde, self.score_model)
103        return compute_log_likelihood(ode, x, num_steps)

Summary

The SDE framework provides a powerful continuous-time perspective on diffusion models with several key advantages:

  • Unified framework: VP-SDE, VE-SDE, and sub-VP SDE are all special cases of the general SDE formulation, enabling principled comparison and combination.
  • Reverse SDE: Anderson's theorem guarantees that any forward diffusion can be reversed given only the score function.
  • Probability flow ODE: The deterministic ODE counterpart enables exact likelihood computation and reproducible generation.
  • Flexible samplers: Euler-Maruyama, predictor-corrector, and ODE solvers provide different quality-speed trade-offs.
  • Theoretical foundation: Continuous-time analysis provides rigorous understanding of convergence, training dynamics, and sample quality.

In the next section, we'll synthesize the score-based and SDE perspectives into a unified framework that connects all major diffusion model variants and reveals their fundamental relationships.