Chapter 15
25 min read
Section 69 of 76

Unified Framework

Score-Based and SDE Perspective

Introduction

Throughout this book, we've encountered multiple perspectives on diffusion models: DDPM's discrete Markov chains, score-based models with denoising score matching, and SDE/ODE formulations. This section synthesizes these viewpoints into a unified framework that reveals their fundamental equivalences and guides practical choices.

Understanding this unified view enables informed decisions about model design, training objectives, and sampling algorithms. We'll see that seemingly different approaches are often mathematically equivalent, differing only in parameterization or emphasis.


The Big Picture

All diffusion-based generative models share a common structure: they learn to reverse a corruption process that transforms data into noise. The differences lie in how this process is parameterized, discretized, and trained.

Equivalence Map

The following table maps concepts across different formulations:

ConceptDDPMScore-BasedSDE
Forward processq(x_t | x_{t-1})Adding noise levelsForward SDE: dx = f dt + g dW
Reverse processp(x_{t-1} | x_t)Annealed LangevinReverse SDE
What model learnsepsilon(x_t, t)score s(x, sigma)score nabla log p_t
Training objectiveMSE on noiseDSMDSM / Score matching
Noise schedulebeta_t sequencesigma levelsbeta(t) or sigma(t)
PriorN(0, I)N(0, sigma_max^2 I)Depends on SDE type
Central Insight: All formulations learn the same fundamental quantity: the score function xlogpt(x)\nabla_x \log p_t(x). They differ only in how this score is parameterized, scaled, and used for sampling.

Parameterization Choices

The neural network can predict different quantities, all of which are equivalent up to scaling. The choice affects training dynamics and numerical stability.

Epsilon Prediction

The standard DDPM parameterization predicts the noise that was added:

🐍python
1def epsilon_prediction_loss(
2    model: nn.Module,
3    x_0: torch.Tensor,
4    noise_schedule: NoiseSchedule
5) -> torch.Tensor:
6    """
7    Standard DDPM epsilon-prediction objective.
8
9    Model predicts: epsilon_theta(x_t, t) ≈ epsilon
10
11    Relationship to score:
12        score = -epsilon / sqrt(1 - alpha_bar_t)
13    """
14    # Sample timesteps
15    t = torch.randint(0, noise_schedule.T, (x_0.shape[0],))
16
17    # Sample noise
18    epsilon = torch.randn_like(x_0)
19
20    # Noise the data
21    alpha_bar_t = noise_schedule.alpha_bar[t].view(-1, 1, 1, 1)
22    x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * epsilon
23
24    # Predict noise
25    epsilon_pred = model(x_t, t)
26
27    # Simple MSE loss
28    loss = F.mse_loss(epsilon_pred, epsilon)
29
30    return loss
31
32
33def epsilon_to_score(epsilon: torch.Tensor, alpha_bar_t: torch.Tensor) -> torch.Tensor:
34    """Convert epsilon prediction to score."""
35    return -epsilon / torch.sqrt(1 - alpha_bar_t)
36
37
38def score_to_epsilon(score: torch.Tensor, alpha_bar_t: torch.Tensor) -> torch.Tensor:
39    """Convert score to epsilon prediction."""
40    return -score * torch.sqrt(1 - alpha_bar_t)

Score Prediction

Direct score prediction, as in NCSN, predicts the gradient of the log-density:

🐍python
1def score_prediction_loss(
2    model: nn.Module,
3    x_0: torch.Tensor,
4    sigmas: torch.Tensor
5) -> torch.Tensor:
6    """
7    Score-based objective (DSM).
8
9    Model predicts: s_theta(x_t, sigma) ≈ nabla_x log p(x_t | x_0)
10
11    The true conditional score is: -epsilon / sigma
12    """
13    # Sample noise levels
14    sigma_idx = torch.randint(0, len(sigmas), (x_0.shape[0],))
15    sigma = sigmas[sigma_idx].view(-1, 1, 1, 1)
16
17    # Add noise
18    epsilon = torch.randn_like(x_0)
19    x_t = x_0 + sigma * epsilon
20
21    # Predict score
22    score_pred = model(x_t, sigma_idx)
23
24    # True score
25    score_true = -epsilon / sigma
26
27    # DSM loss (weighted by sigma^2 for stability)
28    loss = 0.5 * ((score_pred - score_true) ** 2 * sigma ** 2).mean()
29
30    return loss

V-Prediction

V-prediction, introduced in progressive distillation, predicts a velocity that balances signal and noise:

🐍python
1def v_prediction_loss(
2    model: nn.Module,
3    x_0: torch.Tensor,
4    noise_schedule: NoiseSchedule
5) -> torch.Tensor:
6    """
7    V-prediction objective (Salimans & Ho, 2022).
8
9    Model predicts: v_theta(x_t, t) = sqrt(alpha_bar) * epsilon - sqrt(1 - alpha_bar) * x_0
10
11    Benefits:
12    - More stable gradients
13    - Better for distillation
14    - Symmetric treatment of signal and noise
15    """
16    t = torch.randint(0, noise_schedule.T, (x_0.shape[0],))
17    epsilon = torch.randn_like(x_0)
18
19    alpha_bar_t = noise_schedule.alpha_bar[t].view(-1, 1, 1, 1)
20    sqrt_alpha = torch.sqrt(alpha_bar_t)
21    sqrt_one_minus_alpha = torch.sqrt(1 - alpha_bar_t)
22
23    # Noisy sample
24    x_t = sqrt_alpha * x_0 + sqrt_one_minus_alpha * epsilon
25
26    # Target v
27    v_target = sqrt_alpha * epsilon - sqrt_one_minus_alpha * x_0
28
29    # Predict v
30    v_pred = model(x_t, t)
31
32    loss = F.mse_loss(v_pred, v_target)
33
34    return loss
35
36
37def v_to_epsilon(v: torch.Tensor, x_t: torch.Tensor, alpha_bar_t: torch.Tensor) -> torch.Tensor:
38    """Convert v prediction to epsilon."""
39    sqrt_alpha = torch.sqrt(alpha_bar_t)
40    sqrt_one_minus_alpha = torch.sqrt(1 - alpha_bar_t)
41    return sqrt_alpha * v + sqrt_one_minus_alpha * x_t / sqrt_alpha
42
43
44def v_to_x0(v: torch.Tensor, x_t: torch.Tensor, alpha_bar_t: torch.Tensor) -> torch.Tensor:
45    """Convert v prediction to x0."""
46    sqrt_alpha = torch.sqrt(alpha_bar_t)
47    sqrt_one_minus_alpha = torch.sqrt(1 - alpha_bar_t)
48    return sqrt_alpha * x_t - sqrt_one_minus_alpha * v

X0 Prediction

Direct prediction of the clean data, useful when interpretability matters:

🐍python
1def x0_prediction_loss(
2    model: nn.Module,
3    x_0: torch.Tensor,
4    noise_schedule: NoiseSchedule
5) -> torch.Tensor:
6    """
7    X0-prediction objective.
8
9    Model predicts: x0_theta(x_t, t) ≈ x_0
10
11    Benefits:
12    - Interpretable output (the denoised image)
13    - Natural for image-to-image tasks
14
15    Drawbacks:
16    - Can be numerically unstable at high noise levels
17    - Gradients can be noisy
18    """
19    t = torch.randint(0, noise_schedule.T, (x_0.shape[0],))
20    epsilon = torch.randn_like(x_0)
21
22    alpha_bar_t = noise_schedule.alpha_bar[t].view(-1, 1, 1, 1)
23    x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * epsilon
24
25    # Predict x0
26    x0_pred = model(x_t, t)
27
28    # Simple MSE (may need weighting for stability)
29    loss = F.mse_loss(x0_pred, x_0)
30
31    return loss
32
33
34def x0_to_epsilon(x0_pred: torch.Tensor, x_t: torch.Tensor, alpha_bar_t: torch.Tensor) -> torch.Tensor:
35    """Convert x0 prediction to epsilon."""
36    sqrt_alpha = torch.sqrt(alpha_bar_t)
37    sqrt_one_minus_alpha = torch.sqrt(1 - alpha_bar_t)
38    return (x_t - sqrt_alpha * x0_pred) / sqrt_one_minus_alpha
PredictionTargetScore RelationBest For
EpsilonAdded noises = -eps / sqrt(1-alpha_bar)Standard training
Scorenabla log pDirectTheory, analysis
VVelocity blendComplex relationDistillation
X0Clean datas = (x_t - sqrt(a)*x0) / (1-a)Interpretability

Loss Weighting Strategies

Not all timesteps contribute equally to sample quality. Weighting the loss across timesteps can significantly improve results.

SNR-Based Weighting

The signal-to-noise ratio (SNR) provides a principled basis for weighting:

🐍python
1def snr(alpha_bar_t: torch.Tensor) -> torch.Tensor:
2    """
3    Signal-to-noise ratio at timestep t.
4
5    SNR(t) = alpha_bar_t / (1 - alpha_bar_t)
6
7    High SNR = more signal, less noise (early timesteps)
8    Low SNR = less signal, more noise (late timesteps)
9    """
10    return alpha_bar_t / (1 - alpha_bar_t)
11
12
13def snr_weighted_loss(
14    model: nn.Module,
15    x_0: torch.Tensor,
16    noise_schedule: NoiseSchedule,
17    weighting: str = 'uniform'
18) -> torch.Tensor:
19    """
20    Loss with SNR-based weighting.
21
22    Weighting options:
23    - 'uniform': Standard unweighted loss
24    - 'snr': Weight by SNR (emphasize high-noise)
25    - 'truncated_snr': Clipped SNR weighting
26    """
27    t = torch.randint(0, noise_schedule.T, (x_0.shape[0],))
28    epsilon = torch.randn_like(x_0)
29
30    alpha_bar_t = noise_schedule.alpha_bar[t].view(-1, 1, 1, 1)
31    x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * epsilon
32
33    epsilon_pred = model(x_t, t)
34
35    # Per-sample MSE
36    mse = ((epsilon_pred - epsilon) ** 2).mean(dim=[1, 2, 3])
37
38    # Compute weights
39    snr_t = snr(alpha_bar_t.squeeze())
40
41    if weighting == 'uniform':
42        weights = torch.ones_like(snr_t)
43    elif weighting == 'snr':
44        weights = snr_t
45    elif weighting == 'truncated_snr':
46        weights = snr_t.clamp(max=5.0)  # Truncate to avoid explosion
47
48    # Weighted loss
49    loss = (weights * mse).mean()
50
51    return loss

Min-SNR Weighting

Min-SNR weighting, introduced by Hang et al., provides a principled approach that improves sample quality:

🐍python
1def min_snr_gamma_weighting(
2    model: nn.Module,
3    x_0: torch.Tensor,
4    noise_schedule: NoiseSchedule,
5    gamma: float = 5.0
6) -> torch.Tensor:
7    """
8    Min-SNR-gamma weighting (Hang et al., 2023).
9
10    Weight = min(SNR(t), gamma) / SNR(t)
11
12    This:
13    - Downweights early timesteps (high SNR)
14    - Keeps late timesteps (low SNR) at full weight
15    - Prevents loss from being dominated by easy cases
16
17    gamma=5 is recommended default.
18    """
19    t = torch.randint(0, noise_schedule.T, (x_0.shape[0],))
20    epsilon = torch.randn_like(x_0)
21
22    alpha_bar_t = noise_schedule.alpha_bar[t].view(-1, 1, 1, 1)
23    x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * epsilon
24
25    epsilon_pred = model(x_t, t)
26
27    # Per-sample MSE
28    mse = ((epsilon_pred - epsilon) ** 2).mean(dim=[1, 2, 3])
29
30    # Min-SNR weighting
31    snr_t = snr(alpha_bar_t.squeeze())
32    min_snr_weight = torch.minimum(snr_t, gamma * torch.ones_like(snr_t)) / snr_t
33
34    loss = (min_snr_weight * mse).mean()
35
36    return loss
37
38
39class AdaptiveWeighting:
40    """
41    Learnable loss weighting that adapts during training.
42    """
43    def __init__(self, num_timesteps: int):
44        self.log_weights = nn.Parameter(torch.zeros(num_timesteps))
45
46    def get_weights(self, t: torch.Tensor) -> torch.Tensor:
47        """Get normalized weights for timesteps."""
48        weights = F.softmax(self.log_weights, dim=0)
49        return weights[t] * len(self.log_weights)  # Scale to sum to T
50
51    def weighted_loss(
52        self,
53        mse: torch.Tensor,
54        t: torch.Tensor
55    ) -> torch.Tensor:
56        """Apply learned weighting."""
57        weights = self.get_weights(t)
58        return (weights * mse).mean()

Connecting All Frameworks

Now we can explicitly show how different frameworks are mathematically equivalent.

DDPM as Discretized VP-SDE

DDPM is precisely a discrete approximation of the VP-SDE:

🐍python
1def show_ddpm_vpsde_equivalence():
2    """
3    DDPM forward process:
4        q(x_t | x_{t-1}) = N(x_t; sqrt(1-beta_t) * x_{t-1}, beta_t * I)
5
6    VP-SDE forward:
7        dx = -0.5 * beta(t) * x * dt + sqrt(beta(t)) * dW
8
9    Connection:
10        With dt = 1/T and beta_t = beta(t/T) * dt, the DDPM transition
11        is exactly the Euler-Maruyama discretization of the VP-SDE.
12    """
13    # DDPM parameters
14    T = 1000
15    beta_min = 0.0001
16    beta_max = 0.02
17
18    # Linear schedule
19    betas_ddpm = torch.linspace(beta_min, beta_max, T)
20
21    # Equivalent continuous parameters
22    # beta_continuous(t) such that beta_ddpm[k] ≈ beta_continuous(k/T) * (1/T)
23    beta_min_continuous = beta_min * T
24    beta_max_continuous = beta_max * T
25
26    # Verify equivalence
27    for t in [0, 100, 500, 999]:
28        t_continuous = t / T
29
30        # DDPM cumulative product
31        alpha_bar_ddpm = torch.prod(1 - betas_ddpm[:t+1])
32
33        # VP-SDE marginal (closed form)
34        integral_beta = 0.5 * t_continuous**2 * (beta_max_continuous - beta_min_continuous) + \
35                       t_continuous * beta_min_continuous
36        alpha_bar_vpsde = torch.exp(-integral_beta)
37
38        print(f"t={t}: DDPM alpha_bar={alpha_bar_ddpm:.4f}, VP-SDE={alpha_bar_vpsde:.4f}")
39
40
41class UnifiedDiffusionModel:
42    """
43    Unified view: same model works with DDPM, Score, or SDE formulation.
44    """
45    def __init__(self, model: nn.Module, noise_schedule: NoiseSchedule):
46        self.model = model
47        self.schedule = noise_schedule
48
49    def predict_epsilon(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
50        """Direct epsilon prediction."""
51        return self.model(x_t, t)
52
53    def predict_score(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
54        """Convert epsilon to score."""
55        epsilon = self.predict_epsilon(x_t, t)
56        alpha_bar_t = self.schedule.alpha_bar[t].view(-1, 1, 1, 1)
57        return -epsilon / torch.sqrt(1 - alpha_bar_t)
58
59    def predict_x0(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
60        """Convert epsilon to x0 prediction."""
61        epsilon = self.predict_epsilon(x_t, t)
62        alpha_bar_t = self.schedule.alpha_bar[t].view(-1, 1, 1, 1)
63        sqrt_alpha = torch.sqrt(alpha_bar_t)
64        sqrt_one_minus_alpha = torch.sqrt(1 - alpha_bar_t)
65        return (x_t - sqrt_one_minus_alpha * epsilon) / sqrt_alpha

NCSN as Discretized VE-SDE

Similarly, NCSN (Noise Conditional Score Networks) corresponds to VE-SDE:

🐍python
1def show_ncsn_vesde_equivalence():
2    """
3    NCSN noise levels:
4        sigma_i = sigma_min * (sigma_max / sigma_min)^(i / (L-1))
5
6    VE-SDE forward:
7        dx = sqrt(d[sigma^2(t)]/dt) * dW
8
9    where sigma(t) = sigma_min * (sigma_max / sigma_min)^t
10
11    The geometric noise schedule in NCSN is the discrete version
12    of the exponential schedule in VE-SDE.
13    """
14    # NCSN parameters
15    L = 10  # Number of noise levels
16    sigma_min = 0.01
17    sigma_max = 50.0
18
19    # NCSN sigmas
20    sigmas_ncsn = sigma_min * (sigma_max / sigma_min) ** (torch.arange(L) / (L - 1))
21
22    # VE-SDE continuous
23    t_continuous = torch.linspace(0, 1, L)
24    sigmas_vesde = sigma_min * (sigma_max / sigma_min) ** t_continuous
25
26    print("NCSN sigmas:", sigmas_ncsn)
27    print("VE-SDE sigmas:", sigmas_vesde)
28    # They are identical!
29
30    # The key difference:
31    # - NCSN uses discrete annealed Langevin with fixed sigma levels
32    # - VE-SDE uses continuous reverse SDE
33    # Both learn the same score function

Connection to Flow Matching

Recent work on flow matching provides yet another perspective that generalizes diffusion. Instead of focusing on scores, flow matching directly learns a velocity field that transports noise to data.

Conditional Flow Matching

Flow matching learns an ODE that transports samples:

🐍python
1class FlowMatchingModel:
2    """
3    Flow matching: learn velocity field v(x, t) that transports
4    noise distribution to data distribution.
5
6    dx/dt = v(x, t)
7
8    Starting from x_0 ~ N(0, I) at t=0, we want x_1 ~ p_data at t=1.
9    """
10    def __init__(self, velocity_model: nn.Module):
11        self.velocity_model = velocity_model
12
13    def conditional_flow_matching_loss(
14        self,
15        x_1: torch.Tensor  # Data samples
16    ) -> torch.Tensor:
17        """
18        Conditional Flow Matching (CFM) objective.
19
20        Key insight: define a simple conditional path between
21        x_0 (noise) and x_1 (data), then learn to match its velocity.
22
23        Linear interpolation path:
24            x_t = (1 - t) * x_0 + t * x_1
25
26        Conditional velocity:
27            u_t(x | x_0, x_1) = x_1 - x_0
28        """
29        batch_size = x_1.shape[0]
30
31        # Sample initial noise
32        x_0 = torch.randn_like(x_1)
33
34        # Sample time uniformly
35        t = torch.rand(batch_size, 1, 1, 1, device=x_1.device)
36
37        # Linear interpolation
38        x_t = (1 - t) * x_0 + t * x_1
39
40        # Conditional velocity (constant for linear path)
41        velocity_target = x_1 - x_0
42
43        # Predict velocity
44        velocity_pred = self.velocity_model(x_t, t.squeeze())
45
46        # MSE loss
47        loss = F.mse_loss(velocity_pred, velocity_target)
48
49        return loss
50
51    @torch.no_grad()
52    def sample(
53        self,
54        num_samples: int,
55        data_shape: tuple,
56        num_steps: int = 100,
57        device: str = 'cuda'
58    ) -> torch.Tensor:
59        """Generate samples by integrating the learned ODE."""
60        # Start from noise
61        x = torch.randn(num_samples, *data_shape, device=device)
62
63        dt = 1.0 / num_steps
64        for step in range(num_steps):
65            t = step / num_steps
66            t_batch = torch.full((num_samples,), t, device=device)
67
68            velocity = self.velocity_model(x, t_batch)
69            x = x + velocity * dt
70
71        return x

Optimal Transport Paths

Flow matching can use optimal transport to define more efficient paths:

🐍python
1class OptimalTransportFlowMatching:
2    """
3    OT-CFM: Use optimal transport to pair noise and data samples,
4    potentially creating more efficient flow paths.
5    """
6    def __init__(self, velocity_model: nn.Module):
7        self.velocity_model = velocity_model
8
9    def ot_cfm_loss(
10        self,
11        x_1: torch.Tensor,  # Data samples
12        use_minibatch_ot: bool = True
13    ) -> torch.Tensor:
14        """
15        Optimal Transport Conditional Flow Matching.
16
17        Instead of pairing each data point with independent noise,
18        use OT to find better pairings within the minibatch.
19        """
20        batch_size = x_1.shape[0]
21        device = x_1.device
22
23        # Sample noise
24        x_0 = torch.randn_like(x_1)
25
26        if use_minibatch_ot:
27            # Compute OT coupling within minibatch
28            x_0 = self._ot_reorder(x_0, x_1)
29
30        # Rest is same as standard CFM
31        t = torch.rand(batch_size, 1, 1, 1, device=device)
32        x_t = (1 - t) * x_0 + t * x_1
33        velocity_target = x_1 - x_0
34
35        velocity_pred = self.velocity_model(x_t, t.squeeze())
36        loss = F.mse_loss(velocity_pred, velocity_target)
37
38        return loss
39
40    def _ot_reorder(
41        self,
42        x_0: torch.Tensor,
43        x_1: torch.Tensor
44    ) -> torch.Tensor:
45        """
46        Reorder x_0 samples to better match x_1 using OT.
47        Uses POT (Python Optimal Transport) library.
48        """
49        import ot
50
51        # Flatten for distance computation
52        x_0_flat = x_0.view(x_0.shape[0], -1).cpu().numpy()
53        x_1_flat = x_1.view(x_1.shape[0], -1).cpu().numpy()
54
55        # Compute cost matrix (squared Euclidean)
56        M = ot.dist(x_0_flat, x_1_flat)
57
58        # Solve OT
59        coupling = ot.emd(
60            np.ones(len(x_0)) / len(x_0),
61            np.ones(len(x_1)) / len(x_1),
62            M
63        )
64
65        # Get permutation from coupling
66        perm = coupling.argmax(axis=0)
67
68        return x_0[perm]
Diffusion vs Flow Matching:
  • Diffusion: Learn score, sample via reverse SDE/ODE with noise
  • Flow Matching: Learn velocity, sample via ODE (no stochasticity)
  • Both can achieve similar quality; flow matching may be simpler to train

Unified Implementation

Here's a unified implementation that supports all major formulations:

🐍python
1class UnifiedGenerativeModel:
2    """
3    Unified framework supporting all diffusion/flow formulations.
4
5    Supports:
6    - DDPM (discrete, epsilon prediction)
7    - Score-based (continuous, score prediction)
8    - Flow matching (ODE, velocity prediction)
9    - Multiple noise schedules (VP, VE, linear)
10    """
11    def __init__(
12        self,
13        model: nn.Module,
14        framework: str = 'ddpm',
15        schedule: str = 'linear',
16        prediction: str = 'epsilon'
17    ):
18        self.model = model
19        self.framework = framework
20        self.prediction = prediction
21
22        # Initialize schedule
23        if schedule == 'linear':
24            self.schedule = LinearSchedule()
25        elif schedule == 'cosine':
26            self.schedule = CosineSchedule()
27        elif schedule == 'vp':
28            self.schedule = VPSchedule()
29
30    def get_loss(self, x_0: torch.Tensor) -> torch.Tensor:
31        """Compute training loss based on framework."""
32        if self.framework == 'ddpm':
33            return self._ddpm_loss(x_0)
34        elif self.framework == 'score':
35            return self._score_loss(x_0)
36        elif self.framework == 'flow':
37            return self._flow_loss(x_0)
38
39    def _ddpm_loss(self, x_0: torch.Tensor) -> torch.Tensor:
40        """DDPM discrete loss."""
41        t = torch.randint(0, self.schedule.T, (x_0.shape[0],))
42        epsilon = torch.randn_like(x_0)
43
44        alpha_bar = self.schedule.alpha_bar[t].view(-1, 1, 1, 1)
45        x_t = torch.sqrt(alpha_bar) * x_0 + torch.sqrt(1 - alpha_bar) * epsilon
46
47        if self.prediction == 'epsilon':
48            target = epsilon
49        elif self.prediction == 'x0':
50            target = x_0
51        elif self.prediction == 'v':
52            target = torch.sqrt(alpha_bar) * epsilon - torch.sqrt(1 - alpha_bar) * x_0
53
54        pred = self.model(x_t, t)
55        return F.mse_loss(pred, target)
56
57    def _score_loss(self, x_0: torch.Tensor) -> torch.Tensor:
58        """Score-based continuous loss."""
59        t = torch.rand(x_0.shape[0], device=x_0.device)
60        sigma = self.schedule.sigma(t).view(-1, 1, 1, 1)
61        epsilon = torch.randn_like(x_0)
62        x_t = x_0 + sigma * epsilon
63
64        score_target = -epsilon / sigma
65        score_pred = self.model(x_t, t)
66
67        # Weighted by sigma^2
68        loss = 0.5 * ((score_pred - score_target) ** 2 * sigma ** 2).mean()
69        return loss
70
71    def _flow_loss(self, x_0: torch.Tensor) -> torch.Tensor:
72        """Flow matching loss."""
73        noise = torch.randn_like(x_0)
74        t = torch.rand(x_0.shape[0], 1, 1, 1, device=x_0.device)
75        x_t = (1 - t) * noise + t * x_0
76
77        velocity_target = x_0 - noise
78        velocity_pred = self.model(x_t, t.squeeze())
79
80        return F.mse_loss(velocity_pred, velocity_target)
81
82    @torch.no_grad()
83    def sample(
84        self,
85        shape: tuple,
86        num_steps: int = 50,
87        device: str = 'cuda'
88    ) -> torch.Tensor:
89        """Generate samples using appropriate method."""
90        if self.framework in ['ddpm', 'score']:
91            return self._sample_diffusion(shape, num_steps, device)
92        elif self.framework == 'flow':
93            return self._sample_flow(shape, num_steps, device)
94
95    def _sample_diffusion(self, shape, num_steps, device):
96        """DDPM/Score-based sampling."""
97        x = torch.randn(*shape, device=device)
98        timesteps = torch.linspace(self.schedule.T - 1, 0, num_steps, dtype=torch.long)
99
100        for t in timesteps:
101            t_batch = t.expand(shape[0]).to(device)
102            x = self._diffusion_step(x, t_batch)
103
104        return x
105
106    def _sample_flow(self, shape, num_steps, device):
107        """Flow matching sampling (ODE)."""
108        x = torch.randn(*shape, device=device)
109        dt = 1.0 / num_steps
110
111        for step in range(num_steps):
112            t = step / num_steps
113            t_batch = torch.full((shape[0],), t, device=device)
114            velocity = self.model(x, t_batch)
115            x = x + velocity * dt
116
117        return x

Choosing the Right Framework

Given the equivalence of these frameworks, how should you choose? Consider these factors:

FactorDDPMScore/SDEFlow Matching
ImplementationSimplestModerateSimple
TheoryVariationalRigorousElegant
Sampling speedSlowerFlexibleFaster
LikelihoodELBO onlyExact (ODE)Exact
ConditioningWell-studiedWell-studiedEmerging
DistillationGoodGoodExcellent
Practical Recommendations:
  1. For standard image generation: DDPM with epsilon prediction and cosine schedule
  2. For fast sampling or distillation: Flow matching or v-prediction
  3. For theoretical analysis or likelihood: SDE/ODE framework
  4. For maximum flexibility: Unified framework that supports all options

Summary

This section has unified the major perspectives on diffusion models into a coherent framework. Key takeaways:

  • All frameworks learn the same thing: the score functionxlogpt(x)\nabla_x \log p_t(x), just with different parameterizations.
  • DDPM is discretized VP-SDE: the familiar discrete Markov chain is an Euler-Maruyama approximation to a continuous SDE.
  • Prediction targets are equivalent: epsilon, score, v, and x0 predictions can all be converted to each other.
  • Loss weighting matters: Min-SNR and similar strategies improve sample quality by balancing contributions across timesteps.
  • Flow matching generalizes diffusion: by learning velocity fields directly, it provides a simpler alternative that often trains faster.

This unified understanding empowers you to make informed choices about model design, combine ideas from different frameworks, and push the boundaries of generative modeling. The field continues to evolve, but these fundamental connections will remain relevant as new techniques emerge.