Introduction
Throughout this book, we've encountered multiple perspectives on diffusion models: DDPM's discrete Markov chains, score-based models with denoising score matching, and SDE/ODE formulations. This section synthesizes these viewpoints into a unified framework that reveals their fundamental equivalences and guides practical choices.
Understanding this unified view enables informed decisions about model design, training objectives, and sampling algorithms. We'll see that seemingly different approaches are often mathematically equivalent, differing only in parameterization or emphasis.
The Big Picture
All diffusion-based generative models share a common structure: they learn to reverse a corruption process that transforms data into noise. The differences lie in how this process is parameterized, discretized, and trained.
Equivalence Map
The following table maps concepts across different formulations:
| Concept | DDPM | Score-Based | SDE |
|---|---|---|---|
| Forward process | q(x_t | x_{t-1}) | Adding noise levels | Forward SDE: dx = f dt + g dW |
| Reverse process | p(x_{t-1} | x_t) | Annealed Langevin | Reverse SDE |
| What model learns | epsilon(x_t, t) | score s(x, sigma) | score nabla log p_t |
| Training objective | MSE on noise | DSM | DSM / Score matching |
| Noise schedule | beta_t sequence | sigma levels | beta(t) or sigma(t) |
| Prior | N(0, I) | N(0, sigma_max^2 I) | Depends on SDE type |
Central Insight: All formulations learn the same fundamental quantity: the score function . They differ only in how this score is parameterized, scaled, and used for sampling.
Parameterization Choices
The neural network can predict different quantities, all of which are equivalent up to scaling. The choice affects training dynamics and numerical stability.
Epsilon Prediction
The standard DDPM parameterization predicts the noise that was added:
1def epsilon_prediction_loss(
2 model: nn.Module,
3 x_0: torch.Tensor,
4 noise_schedule: NoiseSchedule
5) -> torch.Tensor:
6 """
7 Standard DDPM epsilon-prediction objective.
8
9 Model predicts: epsilon_theta(x_t, t) ≈ epsilon
10
11 Relationship to score:
12 score = -epsilon / sqrt(1 - alpha_bar_t)
13 """
14 # Sample timesteps
15 t = torch.randint(0, noise_schedule.T, (x_0.shape[0],))
16
17 # Sample noise
18 epsilon = torch.randn_like(x_0)
19
20 # Noise the data
21 alpha_bar_t = noise_schedule.alpha_bar[t].view(-1, 1, 1, 1)
22 x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * epsilon
23
24 # Predict noise
25 epsilon_pred = model(x_t, t)
26
27 # Simple MSE loss
28 loss = F.mse_loss(epsilon_pred, epsilon)
29
30 return loss
31
32
33def epsilon_to_score(epsilon: torch.Tensor, alpha_bar_t: torch.Tensor) -> torch.Tensor:
34 """Convert epsilon prediction to score."""
35 return -epsilon / torch.sqrt(1 - alpha_bar_t)
36
37
38def score_to_epsilon(score: torch.Tensor, alpha_bar_t: torch.Tensor) -> torch.Tensor:
39 """Convert score to epsilon prediction."""
40 return -score * torch.sqrt(1 - alpha_bar_t)Score Prediction
Direct score prediction, as in NCSN, predicts the gradient of the log-density:
1def score_prediction_loss(
2 model: nn.Module,
3 x_0: torch.Tensor,
4 sigmas: torch.Tensor
5) -> torch.Tensor:
6 """
7 Score-based objective (DSM).
8
9 Model predicts: s_theta(x_t, sigma) ≈ nabla_x log p(x_t | x_0)
10
11 The true conditional score is: -epsilon / sigma
12 """
13 # Sample noise levels
14 sigma_idx = torch.randint(0, len(sigmas), (x_0.shape[0],))
15 sigma = sigmas[sigma_idx].view(-1, 1, 1, 1)
16
17 # Add noise
18 epsilon = torch.randn_like(x_0)
19 x_t = x_0 + sigma * epsilon
20
21 # Predict score
22 score_pred = model(x_t, sigma_idx)
23
24 # True score
25 score_true = -epsilon / sigma
26
27 # DSM loss (weighted by sigma^2 for stability)
28 loss = 0.5 * ((score_pred - score_true) ** 2 * sigma ** 2).mean()
29
30 return lossV-Prediction
V-prediction, introduced in progressive distillation, predicts a velocity that balances signal and noise:
1def v_prediction_loss(
2 model: nn.Module,
3 x_0: torch.Tensor,
4 noise_schedule: NoiseSchedule
5) -> torch.Tensor:
6 """
7 V-prediction objective (Salimans & Ho, 2022).
8
9 Model predicts: v_theta(x_t, t) = sqrt(alpha_bar) * epsilon - sqrt(1 - alpha_bar) * x_0
10
11 Benefits:
12 - More stable gradients
13 - Better for distillation
14 - Symmetric treatment of signal and noise
15 """
16 t = torch.randint(0, noise_schedule.T, (x_0.shape[0],))
17 epsilon = torch.randn_like(x_0)
18
19 alpha_bar_t = noise_schedule.alpha_bar[t].view(-1, 1, 1, 1)
20 sqrt_alpha = torch.sqrt(alpha_bar_t)
21 sqrt_one_minus_alpha = torch.sqrt(1 - alpha_bar_t)
22
23 # Noisy sample
24 x_t = sqrt_alpha * x_0 + sqrt_one_minus_alpha * epsilon
25
26 # Target v
27 v_target = sqrt_alpha * epsilon - sqrt_one_minus_alpha * x_0
28
29 # Predict v
30 v_pred = model(x_t, t)
31
32 loss = F.mse_loss(v_pred, v_target)
33
34 return loss
35
36
37def v_to_epsilon(v: torch.Tensor, x_t: torch.Tensor, alpha_bar_t: torch.Tensor) -> torch.Tensor:
38 """Convert v prediction to epsilon."""
39 sqrt_alpha = torch.sqrt(alpha_bar_t)
40 sqrt_one_minus_alpha = torch.sqrt(1 - alpha_bar_t)
41 return sqrt_alpha * v + sqrt_one_minus_alpha * x_t / sqrt_alpha
42
43
44def v_to_x0(v: torch.Tensor, x_t: torch.Tensor, alpha_bar_t: torch.Tensor) -> torch.Tensor:
45 """Convert v prediction to x0."""
46 sqrt_alpha = torch.sqrt(alpha_bar_t)
47 sqrt_one_minus_alpha = torch.sqrt(1 - alpha_bar_t)
48 return sqrt_alpha * x_t - sqrt_one_minus_alpha * vX0 Prediction
Direct prediction of the clean data, useful when interpretability matters:
1def x0_prediction_loss(
2 model: nn.Module,
3 x_0: torch.Tensor,
4 noise_schedule: NoiseSchedule
5) -> torch.Tensor:
6 """
7 X0-prediction objective.
8
9 Model predicts: x0_theta(x_t, t) ≈ x_0
10
11 Benefits:
12 - Interpretable output (the denoised image)
13 - Natural for image-to-image tasks
14
15 Drawbacks:
16 - Can be numerically unstable at high noise levels
17 - Gradients can be noisy
18 """
19 t = torch.randint(0, noise_schedule.T, (x_0.shape[0],))
20 epsilon = torch.randn_like(x_0)
21
22 alpha_bar_t = noise_schedule.alpha_bar[t].view(-1, 1, 1, 1)
23 x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * epsilon
24
25 # Predict x0
26 x0_pred = model(x_t, t)
27
28 # Simple MSE (may need weighting for stability)
29 loss = F.mse_loss(x0_pred, x_0)
30
31 return loss
32
33
34def x0_to_epsilon(x0_pred: torch.Tensor, x_t: torch.Tensor, alpha_bar_t: torch.Tensor) -> torch.Tensor:
35 """Convert x0 prediction to epsilon."""
36 sqrt_alpha = torch.sqrt(alpha_bar_t)
37 sqrt_one_minus_alpha = torch.sqrt(1 - alpha_bar_t)
38 return (x_t - sqrt_alpha * x0_pred) / sqrt_one_minus_alpha| Prediction | Target | Score Relation | Best For |
|---|---|---|---|
| Epsilon | Added noise | s = -eps / sqrt(1-alpha_bar) | Standard training |
| Score | nabla log p | Direct | Theory, analysis |
| V | Velocity blend | Complex relation | Distillation |
| X0 | Clean data | s = (x_t - sqrt(a)*x0) / (1-a) | Interpretability |
Loss Weighting Strategies
Not all timesteps contribute equally to sample quality. Weighting the loss across timesteps can significantly improve results.
SNR-Based Weighting
The signal-to-noise ratio (SNR) provides a principled basis for weighting:
1def snr(alpha_bar_t: torch.Tensor) -> torch.Tensor:
2 """
3 Signal-to-noise ratio at timestep t.
4
5 SNR(t) = alpha_bar_t / (1 - alpha_bar_t)
6
7 High SNR = more signal, less noise (early timesteps)
8 Low SNR = less signal, more noise (late timesteps)
9 """
10 return alpha_bar_t / (1 - alpha_bar_t)
11
12
13def snr_weighted_loss(
14 model: nn.Module,
15 x_0: torch.Tensor,
16 noise_schedule: NoiseSchedule,
17 weighting: str = 'uniform'
18) -> torch.Tensor:
19 """
20 Loss with SNR-based weighting.
21
22 Weighting options:
23 - 'uniform': Standard unweighted loss
24 - 'snr': Weight by SNR (emphasize high-noise)
25 - 'truncated_snr': Clipped SNR weighting
26 """
27 t = torch.randint(0, noise_schedule.T, (x_0.shape[0],))
28 epsilon = torch.randn_like(x_0)
29
30 alpha_bar_t = noise_schedule.alpha_bar[t].view(-1, 1, 1, 1)
31 x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * epsilon
32
33 epsilon_pred = model(x_t, t)
34
35 # Per-sample MSE
36 mse = ((epsilon_pred - epsilon) ** 2).mean(dim=[1, 2, 3])
37
38 # Compute weights
39 snr_t = snr(alpha_bar_t.squeeze())
40
41 if weighting == 'uniform':
42 weights = torch.ones_like(snr_t)
43 elif weighting == 'snr':
44 weights = snr_t
45 elif weighting == 'truncated_snr':
46 weights = snr_t.clamp(max=5.0) # Truncate to avoid explosion
47
48 # Weighted loss
49 loss = (weights * mse).mean()
50
51 return lossMin-SNR Weighting
Min-SNR weighting, introduced by Hang et al., provides a principled approach that improves sample quality:
1def min_snr_gamma_weighting(
2 model: nn.Module,
3 x_0: torch.Tensor,
4 noise_schedule: NoiseSchedule,
5 gamma: float = 5.0
6) -> torch.Tensor:
7 """
8 Min-SNR-gamma weighting (Hang et al., 2023).
9
10 Weight = min(SNR(t), gamma) / SNR(t)
11
12 This:
13 - Downweights early timesteps (high SNR)
14 - Keeps late timesteps (low SNR) at full weight
15 - Prevents loss from being dominated by easy cases
16
17 gamma=5 is recommended default.
18 """
19 t = torch.randint(0, noise_schedule.T, (x_0.shape[0],))
20 epsilon = torch.randn_like(x_0)
21
22 alpha_bar_t = noise_schedule.alpha_bar[t].view(-1, 1, 1, 1)
23 x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * epsilon
24
25 epsilon_pred = model(x_t, t)
26
27 # Per-sample MSE
28 mse = ((epsilon_pred - epsilon) ** 2).mean(dim=[1, 2, 3])
29
30 # Min-SNR weighting
31 snr_t = snr(alpha_bar_t.squeeze())
32 min_snr_weight = torch.minimum(snr_t, gamma * torch.ones_like(snr_t)) / snr_t
33
34 loss = (min_snr_weight * mse).mean()
35
36 return loss
37
38
39class AdaptiveWeighting:
40 """
41 Learnable loss weighting that adapts during training.
42 """
43 def __init__(self, num_timesteps: int):
44 self.log_weights = nn.Parameter(torch.zeros(num_timesteps))
45
46 def get_weights(self, t: torch.Tensor) -> torch.Tensor:
47 """Get normalized weights for timesteps."""
48 weights = F.softmax(self.log_weights, dim=0)
49 return weights[t] * len(self.log_weights) # Scale to sum to T
50
51 def weighted_loss(
52 self,
53 mse: torch.Tensor,
54 t: torch.Tensor
55 ) -> torch.Tensor:
56 """Apply learned weighting."""
57 weights = self.get_weights(t)
58 return (weights * mse).mean()Connecting All Frameworks
Now we can explicitly show how different frameworks are mathematically equivalent.
DDPM as Discretized VP-SDE
DDPM is precisely a discrete approximation of the VP-SDE:
1def show_ddpm_vpsde_equivalence():
2 """
3 DDPM forward process:
4 q(x_t | x_{t-1}) = N(x_t; sqrt(1-beta_t) * x_{t-1}, beta_t * I)
5
6 VP-SDE forward:
7 dx = -0.5 * beta(t) * x * dt + sqrt(beta(t)) * dW
8
9 Connection:
10 With dt = 1/T and beta_t = beta(t/T) * dt, the DDPM transition
11 is exactly the Euler-Maruyama discretization of the VP-SDE.
12 """
13 # DDPM parameters
14 T = 1000
15 beta_min = 0.0001
16 beta_max = 0.02
17
18 # Linear schedule
19 betas_ddpm = torch.linspace(beta_min, beta_max, T)
20
21 # Equivalent continuous parameters
22 # beta_continuous(t) such that beta_ddpm[k] ≈ beta_continuous(k/T) * (1/T)
23 beta_min_continuous = beta_min * T
24 beta_max_continuous = beta_max * T
25
26 # Verify equivalence
27 for t in [0, 100, 500, 999]:
28 t_continuous = t / T
29
30 # DDPM cumulative product
31 alpha_bar_ddpm = torch.prod(1 - betas_ddpm[:t+1])
32
33 # VP-SDE marginal (closed form)
34 integral_beta = 0.5 * t_continuous**2 * (beta_max_continuous - beta_min_continuous) + \
35 t_continuous * beta_min_continuous
36 alpha_bar_vpsde = torch.exp(-integral_beta)
37
38 print(f"t={t}: DDPM alpha_bar={alpha_bar_ddpm:.4f}, VP-SDE={alpha_bar_vpsde:.4f}")
39
40
41class UnifiedDiffusionModel:
42 """
43 Unified view: same model works with DDPM, Score, or SDE formulation.
44 """
45 def __init__(self, model: nn.Module, noise_schedule: NoiseSchedule):
46 self.model = model
47 self.schedule = noise_schedule
48
49 def predict_epsilon(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
50 """Direct epsilon prediction."""
51 return self.model(x_t, t)
52
53 def predict_score(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
54 """Convert epsilon to score."""
55 epsilon = self.predict_epsilon(x_t, t)
56 alpha_bar_t = self.schedule.alpha_bar[t].view(-1, 1, 1, 1)
57 return -epsilon / torch.sqrt(1 - alpha_bar_t)
58
59 def predict_x0(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
60 """Convert epsilon to x0 prediction."""
61 epsilon = self.predict_epsilon(x_t, t)
62 alpha_bar_t = self.schedule.alpha_bar[t].view(-1, 1, 1, 1)
63 sqrt_alpha = torch.sqrt(alpha_bar_t)
64 sqrt_one_minus_alpha = torch.sqrt(1 - alpha_bar_t)
65 return (x_t - sqrt_one_minus_alpha * epsilon) / sqrt_alphaNCSN as Discretized VE-SDE
Similarly, NCSN (Noise Conditional Score Networks) corresponds to VE-SDE:
1def show_ncsn_vesde_equivalence():
2 """
3 NCSN noise levels:
4 sigma_i = sigma_min * (sigma_max / sigma_min)^(i / (L-1))
5
6 VE-SDE forward:
7 dx = sqrt(d[sigma^2(t)]/dt) * dW
8
9 where sigma(t) = sigma_min * (sigma_max / sigma_min)^t
10
11 The geometric noise schedule in NCSN is the discrete version
12 of the exponential schedule in VE-SDE.
13 """
14 # NCSN parameters
15 L = 10 # Number of noise levels
16 sigma_min = 0.01
17 sigma_max = 50.0
18
19 # NCSN sigmas
20 sigmas_ncsn = sigma_min * (sigma_max / sigma_min) ** (torch.arange(L) / (L - 1))
21
22 # VE-SDE continuous
23 t_continuous = torch.linspace(0, 1, L)
24 sigmas_vesde = sigma_min * (sigma_max / sigma_min) ** t_continuous
25
26 print("NCSN sigmas:", sigmas_ncsn)
27 print("VE-SDE sigmas:", sigmas_vesde)
28 # They are identical!
29
30 # The key difference:
31 # - NCSN uses discrete annealed Langevin with fixed sigma levels
32 # - VE-SDE uses continuous reverse SDE
33 # Both learn the same score functionConnection to Flow Matching
Recent work on flow matching provides yet another perspective that generalizes diffusion. Instead of focusing on scores, flow matching directly learns a velocity field that transports noise to data.
Conditional Flow Matching
Flow matching learns an ODE that transports samples:
1class FlowMatchingModel:
2 """
3 Flow matching: learn velocity field v(x, t) that transports
4 noise distribution to data distribution.
5
6 dx/dt = v(x, t)
7
8 Starting from x_0 ~ N(0, I) at t=0, we want x_1 ~ p_data at t=1.
9 """
10 def __init__(self, velocity_model: nn.Module):
11 self.velocity_model = velocity_model
12
13 def conditional_flow_matching_loss(
14 self,
15 x_1: torch.Tensor # Data samples
16 ) -> torch.Tensor:
17 """
18 Conditional Flow Matching (CFM) objective.
19
20 Key insight: define a simple conditional path between
21 x_0 (noise) and x_1 (data), then learn to match its velocity.
22
23 Linear interpolation path:
24 x_t = (1 - t) * x_0 + t * x_1
25
26 Conditional velocity:
27 u_t(x | x_0, x_1) = x_1 - x_0
28 """
29 batch_size = x_1.shape[0]
30
31 # Sample initial noise
32 x_0 = torch.randn_like(x_1)
33
34 # Sample time uniformly
35 t = torch.rand(batch_size, 1, 1, 1, device=x_1.device)
36
37 # Linear interpolation
38 x_t = (1 - t) * x_0 + t * x_1
39
40 # Conditional velocity (constant for linear path)
41 velocity_target = x_1 - x_0
42
43 # Predict velocity
44 velocity_pred = self.velocity_model(x_t, t.squeeze())
45
46 # MSE loss
47 loss = F.mse_loss(velocity_pred, velocity_target)
48
49 return loss
50
51 @torch.no_grad()
52 def sample(
53 self,
54 num_samples: int,
55 data_shape: tuple,
56 num_steps: int = 100,
57 device: str = 'cuda'
58 ) -> torch.Tensor:
59 """Generate samples by integrating the learned ODE."""
60 # Start from noise
61 x = torch.randn(num_samples, *data_shape, device=device)
62
63 dt = 1.0 / num_steps
64 for step in range(num_steps):
65 t = step / num_steps
66 t_batch = torch.full((num_samples,), t, device=device)
67
68 velocity = self.velocity_model(x, t_batch)
69 x = x + velocity * dt
70
71 return xOptimal Transport Paths
Flow matching can use optimal transport to define more efficient paths:
1class OptimalTransportFlowMatching:
2 """
3 OT-CFM: Use optimal transport to pair noise and data samples,
4 potentially creating more efficient flow paths.
5 """
6 def __init__(self, velocity_model: nn.Module):
7 self.velocity_model = velocity_model
8
9 def ot_cfm_loss(
10 self,
11 x_1: torch.Tensor, # Data samples
12 use_minibatch_ot: bool = True
13 ) -> torch.Tensor:
14 """
15 Optimal Transport Conditional Flow Matching.
16
17 Instead of pairing each data point with independent noise,
18 use OT to find better pairings within the minibatch.
19 """
20 batch_size = x_1.shape[0]
21 device = x_1.device
22
23 # Sample noise
24 x_0 = torch.randn_like(x_1)
25
26 if use_minibatch_ot:
27 # Compute OT coupling within minibatch
28 x_0 = self._ot_reorder(x_0, x_1)
29
30 # Rest is same as standard CFM
31 t = torch.rand(batch_size, 1, 1, 1, device=device)
32 x_t = (1 - t) * x_0 + t * x_1
33 velocity_target = x_1 - x_0
34
35 velocity_pred = self.velocity_model(x_t, t.squeeze())
36 loss = F.mse_loss(velocity_pred, velocity_target)
37
38 return loss
39
40 def _ot_reorder(
41 self,
42 x_0: torch.Tensor,
43 x_1: torch.Tensor
44 ) -> torch.Tensor:
45 """
46 Reorder x_0 samples to better match x_1 using OT.
47 Uses POT (Python Optimal Transport) library.
48 """
49 import ot
50
51 # Flatten for distance computation
52 x_0_flat = x_0.view(x_0.shape[0], -1).cpu().numpy()
53 x_1_flat = x_1.view(x_1.shape[0], -1).cpu().numpy()
54
55 # Compute cost matrix (squared Euclidean)
56 M = ot.dist(x_0_flat, x_1_flat)
57
58 # Solve OT
59 coupling = ot.emd(
60 np.ones(len(x_0)) / len(x_0),
61 np.ones(len(x_1)) / len(x_1),
62 M
63 )
64
65 # Get permutation from coupling
66 perm = coupling.argmax(axis=0)
67
68 return x_0[perm]Diffusion vs Flow Matching:
- Diffusion: Learn score, sample via reverse SDE/ODE with noise
- Flow Matching: Learn velocity, sample via ODE (no stochasticity)
- Both can achieve similar quality; flow matching may be simpler to train
Unified Implementation
Here's a unified implementation that supports all major formulations:
1class UnifiedGenerativeModel:
2 """
3 Unified framework supporting all diffusion/flow formulations.
4
5 Supports:
6 - DDPM (discrete, epsilon prediction)
7 - Score-based (continuous, score prediction)
8 - Flow matching (ODE, velocity prediction)
9 - Multiple noise schedules (VP, VE, linear)
10 """
11 def __init__(
12 self,
13 model: nn.Module,
14 framework: str = 'ddpm',
15 schedule: str = 'linear',
16 prediction: str = 'epsilon'
17 ):
18 self.model = model
19 self.framework = framework
20 self.prediction = prediction
21
22 # Initialize schedule
23 if schedule == 'linear':
24 self.schedule = LinearSchedule()
25 elif schedule == 'cosine':
26 self.schedule = CosineSchedule()
27 elif schedule == 'vp':
28 self.schedule = VPSchedule()
29
30 def get_loss(self, x_0: torch.Tensor) -> torch.Tensor:
31 """Compute training loss based on framework."""
32 if self.framework == 'ddpm':
33 return self._ddpm_loss(x_0)
34 elif self.framework == 'score':
35 return self._score_loss(x_0)
36 elif self.framework == 'flow':
37 return self._flow_loss(x_0)
38
39 def _ddpm_loss(self, x_0: torch.Tensor) -> torch.Tensor:
40 """DDPM discrete loss."""
41 t = torch.randint(0, self.schedule.T, (x_0.shape[0],))
42 epsilon = torch.randn_like(x_0)
43
44 alpha_bar = self.schedule.alpha_bar[t].view(-1, 1, 1, 1)
45 x_t = torch.sqrt(alpha_bar) * x_0 + torch.sqrt(1 - alpha_bar) * epsilon
46
47 if self.prediction == 'epsilon':
48 target = epsilon
49 elif self.prediction == 'x0':
50 target = x_0
51 elif self.prediction == 'v':
52 target = torch.sqrt(alpha_bar) * epsilon - torch.sqrt(1 - alpha_bar) * x_0
53
54 pred = self.model(x_t, t)
55 return F.mse_loss(pred, target)
56
57 def _score_loss(self, x_0: torch.Tensor) -> torch.Tensor:
58 """Score-based continuous loss."""
59 t = torch.rand(x_0.shape[0], device=x_0.device)
60 sigma = self.schedule.sigma(t).view(-1, 1, 1, 1)
61 epsilon = torch.randn_like(x_0)
62 x_t = x_0 + sigma * epsilon
63
64 score_target = -epsilon / sigma
65 score_pred = self.model(x_t, t)
66
67 # Weighted by sigma^2
68 loss = 0.5 * ((score_pred - score_target) ** 2 * sigma ** 2).mean()
69 return loss
70
71 def _flow_loss(self, x_0: torch.Tensor) -> torch.Tensor:
72 """Flow matching loss."""
73 noise = torch.randn_like(x_0)
74 t = torch.rand(x_0.shape[0], 1, 1, 1, device=x_0.device)
75 x_t = (1 - t) * noise + t * x_0
76
77 velocity_target = x_0 - noise
78 velocity_pred = self.model(x_t, t.squeeze())
79
80 return F.mse_loss(velocity_pred, velocity_target)
81
82 @torch.no_grad()
83 def sample(
84 self,
85 shape: tuple,
86 num_steps: int = 50,
87 device: str = 'cuda'
88 ) -> torch.Tensor:
89 """Generate samples using appropriate method."""
90 if self.framework in ['ddpm', 'score']:
91 return self._sample_diffusion(shape, num_steps, device)
92 elif self.framework == 'flow':
93 return self._sample_flow(shape, num_steps, device)
94
95 def _sample_diffusion(self, shape, num_steps, device):
96 """DDPM/Score-based sampling."""
97 x = torch.randn(*shape, device=device)
98 timesteps = torch.linspace(self.schedule.T - 1, 0, num_steps, dtype=torch.long)
99
100 for t in timesteps:
101 t_batch = t.expand(shape[0]).to(device)
102 x = self._diffusion_step(x, t_batch)
103
104 return x
105
106 def _sample_flow(self, shape, num_steps, device):
107 """Flow matching sampling (ODE)."""
108 x = torch.randn(*shape, device=device)
109 dt = 1.0 / num_steps
110
111 for step in range(num_steps):
112 t = step / num_steps
113 t_batch = torch.full((shape[0],), t, device=device)
114 velocity = self.model(x, t_batch)
115 x = x + velocity * dt
116
117 return xChoosing the Right Framework
Given the equivalence of these frameworks, how should you choose? Consider these factors:
| Factor | DDPM | Score/SDE | Flow Matching |
|---|---|---|---|
| Implementation | Simplest | Moderate | Simple |
| Theory | Variational | Rigorous | Elegant |
| Sampling speed | Slower | Flexible | Faster |
| Likelihood | ELBO only | Exact (ODE) | Exact |
| Conditioning | Well-studied | Well-studied | Emerging |
| Distillation | Good | Good | Excellent |
Practical Recommendations:
- For standard image generation: DDPM with epsilon prediction and cosine schedule
- For fast sampling or distillation: Flow matching or v-prediction
- For theoretical analysis or likelihood: SDE/ODE framework
- For maximum flexibility: Unified framework that supports all options
Summary
This section has unified the major perspectives on diffusion models into a coherent framework. Key takeaways:
- All frameworks learn the same thing: the score function, just with different parameterizations.
- DDPM is discretized VP-SDE: the familiar discrete Markov chain is an Euler-Maruyama approximation to a continuous SDE.
- Prediction targets are equivalent: epsilon, score, v, and x0 predictions can all be converted to each other.
- Loss weighting matters: Min-SNR and similar strategies improve sample quality by balancing contributions across timesteps.
- Flow matching generalizes diffusion: by learning velocity fields directly, it provides a simpler alternative that often trains faster.
This unified understanding empowers you to make informed choices about model design, combine ideas from different frameworks, and push the boundaries of generative modeling. The field continues to evolve, but these fundamental connections will remain relevant as new techniques emerge.