Learning Objectives
By the end of this section, you will:
- Understand the DDIM reformulation and how it differs from DDPM
- Derive the DDIM update equation from first principles
- Master the eta parameter for controlling stochasticity
- Learn why DDIM can skip steps without quality degradation
- Compare theoretical properties of DDPM vs DDIM sampling
Same Model, Different Sampler
The Key Insight
DDIM (Denoising Diffusion Implicit Models) is built on a crucial insight: the training objective for diffusion models doesn't uniquely specify the reverse process. There are infinitely many reverse processes that lead to the same training loss.
The Training Objective (Revisited)
Recall that DDPM training minimizes:
This objective only requires the model to predict the noise added at timestep . It says nothing about how we should use this prediction to sample.
DDPM's Choice (Markovian)
DDPM chose to make the reverse process Markovian: each step only depends on the previous step. This led to the stochastic update:
Where is fresh noise added at each step.
DDIM's Choice (Non-Markovian)
DDIM makes a different choice: the reverse process can be non-Markovian, meaning each step can implicitly depend on . This leads to a family of reverse processes parameterized by :
The Magic
Mathematical Derivation
Let's derive the DDIM update rule step by step. The key is to think in terms of prediction rather than noise prediction.
Step 1: Predict from
From the forward process, we know:
If we knew the noise , we could solve for:
Since our model predicts , we can estimate:
Step 2: Choose the Direction
Now we want to go from to . We know that should be closer to (less noisy). The forward process tells us:
Substituting our prediction :
Step 3: The Noise Direction
Here's where DDIM differs from DDPM. Instead of using fresh noise, DDIM uses the predicted noise direction :
Expanding :
Determinism Achieved
The DDIM Update Rule
Let's simplify and implement the DDIM update. First, define some convenience variables:
Then the deterministic () DDIM update is:
1import torch
2import torch.nn as nn
3import math
4from typing import Optional, List
5
6class DDIMSampler:
7 """
8 DDIM (Denoising Diffusion Implicit Models) sampler.
9
10 Unlike DDPM, DDIM provides:
11 - Deterministic sampling (when eta=0)
12 - Arbitrary number of sampling steps
13 - Faster generation with minimal quality loss
14 """
15
16 def __init__(
17 self,
18 model: nn.Module,
19 alphas_cumprod: torch.Tensor,
20 timesteps: int = 1000,
21 device: str = "cuda"
22 ):
23 """
24 Initialize DDIM sampler.
25
26 Args:
27 model: Trained noise prediction network
28 alphas_cumprod: Cumulative product of alphas from noise schedule
29 timesteps: Total training timesteps
30 device: Computation device
31 """
32 self.model = model
33 self.alphas_cumprod = alphas_cumprod.to(device)
34 self.T = timesteps
35 self.device = device
36
37 def sample(
38 self,
39 shape: tuple,
40 num_steps: int = 50,
41 eta: float = 0.0,
42 return_trajectory: bool = False
43 ) -> torch.Tensor:
44 """
45 Generate samples using DDIM.
46
47 Args:
48 shape: Output shape (batch, channels, height, width)
49 num_steps: Number of sampling steps (can be << T)
50 eta: Stochasticity parameter (0 = deterministic, 1 = DDPM-like)
51 return_trajectory: If True, return all intermediate states
52
53 Returns:
54 Generated samples in [-1, 1]
55 """
56 # Create subsequence of timesteps
57 # e.g., for num_steps=50, T=1000: [0, 20, 40, ..., 980]
58 timesteps = self._get_timestep_sequence(num_steps)
59
60 # Start from pure noise
61 x_t = torch.randn(shape, device=self.device)
62
63 trajectory = [x_t] if return_trajectory else None
64
65 self.model.eval()
66 with torch.no_grad():
67 # Iterate through timesteps in reverse
68 for i in range(len(timesteps) - 1, -1, -1):
69 t = timesteps[i]
70 t_prev = timesteps[i - 1] if i > 0 else 0
71
72 x_t = self._ddim_step(x_t, t, t_prev, eta)
73
74 if return_trajectory:
75 trajectory.append(x_t)
76
77 if return_trajectory:
78 return torch.stack(trajectory)
79 return x_t
80
81 def _ddim_step(
82 self,
83 x_t: torch.Tensor,
84 t: int,
85 t_prev: int,
86 eta: float
87 ) -> torch.Tensor:
88 """
89 Single DDIM update step.
90
91 Args:
92 x_t: Current noisy sample
93 t: Current timestep
94 t_prev: Previous timestep (target)
95 eta: Stochasticity parameter
96
97 Returns:
98 Updated sample x_{t_prev}
99 """
100 batch_size = x_t.shape[0]
101
102 # Get alpha values
103 alpha_bar_t = self.alphas_cumprod[t]
104 alpha_bar_t_prev = self.alphas_cumprod[t_prev] if t_prev >= 0 else torch.tensor(1.0)
105
106 # Predict noise
107 t_batch = torch.full((batch_size,), t, device=self.device, dtype=torch.long)
108 eps_pred = self.model(x_t, t_batch)
109
110 # Predict x_0
111 # x_0 = (x_t - sqrt(1 - alpha_bar_t) * eps) / sqrt(alpha_bar_t)
112 x0_pred = (x_t - torch.sqrt(1 - alpha_bar_t) * eps_pred) / torch.sqrt(alpha_bar_t)
113
114 # Optional: Clip x0 prediction to [-1, 1]
115 x0_pred = torch.clamp(x0_pred, -1, 1)
116
117 # Compute sigma for optional stochasticity
118 # sigma_t = eta * sqrt((1 - alpha_bar_t_prev) / (1 - alpha_bar_t)) * sqrt(1 - alpha_bar_t / alpha_bar_t_prev)
119 sigma_t = self._compute_sigma(alpha_bar_t, alpha_bar_t_prev, eta)
120
121 # Compute "direction pointing to x_t"
122 # This is the deterministic direction
123 direction = torch.sqrt(1 - alpha_bar_t_prev - sigma_t**2) * eps_pred
124
125 # Compute x_{t-1}
126 x_prev = torch.sqrt(alpha_bar_t_prev) * x0_pred + direction
127
128 # Add noise if eta > 0
129 if eta > 0 and t_prev > 0:
130 noise = torch.randn_like(x_t)
131 x_prev = x_prev + sigma_t * noise
132
133 return x_prev
134
135 def _compute_sigma(
136 self,
137 alpha_bar_t: torch.Tensor,
138 alpha_bar_t_prev: torch.Tensor,
139 eta: float
140 ) -> torch.Tensor:
141 """
142 Compute sigma for stochastic DDIM.
143
144 When eta = 0, sigma = 0 (deterministic)
145 When eta = 1, matches DDPM variance
146 """
147 # Variance formula from DDIM paper
148 sigma = eta * torch.sqrt(
149 (1 - alpha_bar_t_prev) / (1 - alpha_bar_t) *
150 (1 - alpha_bar_t / alpha_bar_t_prev)
151 )
152 return sigma
153
154 def _get_timestep_sequence(self, num_steps: int) -> List[int]:
155 """
156 Create evenly spaced timestep sequence.
157
158 For num_steps=50, T=1000:
159 Returns approximately [999, 979, 959, ..., 19, 0]
160 """
161 # Method 1: Linear spacing (simple)
162 step_size = self.T // num_steps
163 timesteps = list(range(0, self.T, step_size))[:num_steps]
164 timesteps = timesteps[::-1] # Reverse for sampling
165
166 return timesteps
167
168 def _get_timestep_sequence_quadratic(self, num_steps: int) -> List[int]:
169 """
170 Create quadratically spaced timesteps.
171
172 Spends more steps at low noise levels (where details matter).
173 """
174 timesteps = (
175 (np.linspace(0, np.sqrt(self.T), num_steps) ** 2)
176 .astype(int)
177 .tolist()
178 )
179 timesteps = sorted(set(timesteps), reverse=True)
180 return timesteps
181
182
183# Usage example
184def generate_with_ddim(model, noise_schedule, num_images=4, num_steps=50):
185 """
186 Generate images using DDIM sampling.
187
188 Args:
189 model: Trained U-Net
190 noise_schedule: NoiseSchedule object with alphas_cumprod
191 num_images: Number of images to generate
192 num_steps: Sampling steps (e.g., 50 instead of 1000)
193
194 Returns:
195 Generated images
196 """
197 sampler = DDIMSampler(
198 model=model,
199 alphas_cumprod=noise_schedule.alphas_cumprod,
200 timesteps=noise_schedule.T,
201 device="cuda"
202 )
203
204 # Generate samples
205 samples = sampler.sample(
206 shape=(num_images, 3, 64, 64),
207 num_steps=num_steps,
208 eta=0.0 # Fully deterministic
209 )
210
211 # Convert to [0, 1] for visualization
212 samples = (samples + 1) / 2
213 samples = samples.clamp(0, 1)
214
215 return samplesKey Implementation Details
The Eta Parameter
The parameter controls the amount of stochasticity in DDIM:
| Eta Value | Behavior | Properties | Use Case |
|---|---|---|---|
| eta = 0 | Fully deterministic | Same x_T -> same x_0 | Reproducibility, interpolation |
| eta = 0.5 | Partially stochastic | Some randomness | Balance diversity/quality |
| eta = 1.0 | DDPM-equivalent variance | Maximum diversity | Maximum sample diversity |
1def demonstrate_eta_effect(model, noise_schedule, initial_noise, num_runs=5):
2 """
3 Show how eta affects sample diversity.
4 """
5 sampler = DDIMSampler(model, noise_schedule.alphas_cumprod)
6
7 eta_values = [0.0, 0.25, 0.5, 0.75, 1.0]
8
9 for eta in eta_values:
10 samples = []
11
12 # Run multiple times with same initial noise
13 for _ in range(num_runs):
14 sample = sampler.sample(
15 shape=(1, 3, 64, 64),
16 num_steps=50,
17 eta=eta
18 )
19 samples.append(sample)
20
21 samples = torch.cat(samples)
22
23 # Compute variance across runs
24 variance = samples.var(dim=0).mean().item()
25
26 print(f"eta = {eta:.2f} | Cross-run variance: {variance:.6f}")
27
28 # Expected output:
29 # eta = 0.00 | Cross-run variance: 0.000000 (all identical!)
30 # eta = 0.25 | Cross-run variance: 0.002341
31 # eta = 0.50 | Cross-run variance: 0.008756
32 # eta = 0.75 | Cross-run variance: 0.018234
33 # eta = 1.00 | Cross-run variance: 0.031567eta = 0 is Special
Skipping Steps Correctly
Unlike DDPM, DDIM can skip timesteps without quality degradation. The key is using the cumulative alpha values directly:
Why DDPM Can't Skip
DDPM uses step-wise alphas: . When skipping, the math breaks because .
Why DDIM Can Skip
DDIM uses cumulative alphas which directly relate any to :
This relationship holds for any , regardless of which intermediate steps we skip!
1class DDIMSamplerAdvanced:
2 """
3 Advanced DDIM sampler with multiple timestep selection strategies.
4 """
5
6 def __init__(self, model, alphas_cumprod, T=1000, device="cuda"):
7 self.model = model
8 self.alphas_cumprod = alphas_cumprod.to(device)
9 self.T = T
10 self.device = device
11
12 def get_timesteps(
13 self,
14 num_steps: int,
15 schedule: str = "uniform"
16 ) -> List[int]:
17 """
18 Get timestep sequence for sampling.
19
20 Args:
21 num_steps: Desired number of steps
22 schedule: "uniform", "quadratic", or "log"
23
24 Returns:
25 List of timesteps in descending order
26 """
27 if schedule == "uniform":
28 # Evenly spaced: [999, 979, 959, ..., 19, 0]
29 step_size = self.T // num_steps
30 timesteps = list(range(self.T - 1, -1, -step_size))[:num_steps]
31
32 elif schedule == "quadratic":
33 # More steps at low noise levels
34 timesteps = (
35 (np.linspace(0, np.sqrt(self.T), num_steps) ** 2)
36 .astype(int)
37 .tolist()
38 )
39 timesteps = sorted(set(timesteps), reverse=True)
40
41 elif schedule == "log":
42 # Logarithmic spacing
43 timesteps = (
44 np.exp(np.linspace(0, np.log(self.T), num_steps))
45 .astype(int)
46 .tolist()
47 )
48 timesteps = sorted(set(timesteps), reverse=True)
49
50 else:
51 raise ValueError(f"Unknown schedule: {schedule}")
52
53 # Ensure we start at T-1 and end at 0
54 if timesteps[0] != self.T - 1:
55 timesteps[0] = self.T - 1
56 if timesteps[-1] != 0:
57 timesteps.append(0)
58
59 return timesteps
60
61 def sample_with_custom_steps(
62 self,
63 shape: tuple,
64 timesteps: List[int],
65 eta: float = 0.0,
66 x_T: Optional[torch.Tensor] = None
67 ) -> torch.Tensor:
68 """
69 Sample with custom timestep sequence.
70
71 Allows fine-grained control over which steps to use.
72 """
73 # Start from provided noise or generate new
74 if x_T is None:
75 x_t = torch.randn(shape, device=self.device)
76 else:
77 x_t = x_T.to(self.device)
78
79 self.model.eval()
80 with torch.no_grad():
81 for i in range(len(timesteps)):
82 t = timesteps[i]
83 t_next = timesteps[i + 1] if i + 1 < len(timesteps) else 0
84
85 x_t = self._ddim_step(x_t, t, t_next, eta)
86
87 return x_t
88
89
90# Compare different step counts
91def compare_step_counts(model, noise_schedule, reference_samples=None):
92 """
93 Compare DDIM quality at different step counts.
94 """
95 sampler = DDIMSamplerAdvanced(
96 model=model,
97 alphas_cumprod=noise_schedule.alphas_cumprod,
98 T=noise_schedule.T
99 )
100
101 step_counts = [1000, 250, 100, 50, 25, 10]
102 results = {}
103
104 # Fix initial noise for fair comparison
105 torch.manual_seed(42)
106 x_T = torch.randn(16, 3, 64, 64, device="cuda")
107
108 for steps in step_counts:
109 timesteps = sampler.get_timesteps(steps, schedule="uniform")
110
111 samples = sampler.sample_with_custom_steps(
112 shape=(16, 3, 64, 64),
113 timesteps=timesteps,
114 eta=0.0,
115 x_T=x_T.clone()
116 )
117
118 # Compute MSE vs full-step reference
119 if reference_samples is not None:
120 mse = F.mse_loss(samples, reference_samples).item()
121 results[steps] = {"mse": mse, "samples": samples}
122 else:
123 results[steps] = {"samples": samples}
124
125 print(f"Steps: {steps:4d} | Timesteps used: {len(timesteps)}")
126
127 return resultsStep Count Recommendations
DDPM vs DDIM Comparison
Let's summarize the key differences between DDPM and DDIM sampling:
| Property | DDPM (Ancestral) | DDIM (eta=0) |
|---|---|---|
| Determinism | Stochastic | Deterministic |
| Minimum steps | ~200+ for quality | 25-50 sufficient |
| Speed (typical) | ~20s per image | ~1-2s per image |
| Reproducibility | Not reproducible | Fully reproducible |
| Image encoding | Not possible | Possible (inversion) |
| Interpolation | Poor (noisy) | Smooth semantic |
| Same trained model | Yes | Yes |
1def comprehensive_comparison(model, noise_schedule):
2 """
3 Side-by-side comparison of DDPM vs DDIM.
4 """
5 print("=" * 60)
6 print("DDPM vs DDIM Comparison")
7 print("=" * 60)
8
9 # 1. Speed comparison
10 import time
11
12 # DDPM (1000 steps)
13 start = time.time()
14 ddpm_samples = ddpm_sample(model, noise_schedule, shape=(4, 3, 64, 64))
15 ddpm_time = time.time() - start
16
17 # DDIM (50 steps)
18 start = time.time()
19 ddim_sampler = DDIMSampler(model, noise_schedule.alphas_cumprod)
20 ddim_samples = ddim_sampler.sample(shape=(4, 3, 64, 64), num_steps=50)
21 ddim_time = time.time() - start
22
23 print(f"\n1. SPEED COMPARISON")
24 print(f" DDPM (1000 steps): {ddpm_time:.2f}s")
25 print(f" DDIM (50 steps): {ddim_time:.2f}s")
26 print(f" Speedup: {ddpm_time / ddim_time:.1f}x")
27
28 # 2. Reproducibility test
29 print(f"\n2. REPRODUCIBILITY TEST")
30
31 x_T = torch.randn(1, 3, 64, 64, device="cuda")
32
33 # DDPM with same initial noise
34 ddpm_run1 = ddpm_sample_from(model, x_T.clone(), noise_schedule)
35 ddpm_run2 = ddpm_sample_from(model, x_T.clone(), noise_schedule)
36 ddpm_diff = F.mse_loss(ddpm_run1, ddpm_run2).item()
37
38 # DDIM with same initial noise
39 ddim_run1 = ddim_sampler.sample(shape=(1, 3, 64, 64), num_steps=50, eta=0.0)
40 # Need to reset to same x_T
41 ddim_run2 = ddim_sampler.sample(shape=(1, 3, 64, 64), num_steps=50, eta=0.0)
42 # Actually for DDIM we need to pass x_T explicitly to compare
43 ddim_run1 = ddim_sampler.sample_from_xT(x_T.clone(), num_steps=50, eta=0.0)
44 ddim_run2 = ddim_sampler.sample_from_xT(x_T.clone(), num_steps=50, eta=0.0)
45 ddim_diff = F.mse_loss(ddim_run1, ddim_run2).item()
46
47 print(f" DDPM: MSE between runs = {ddpm_diff:.6f} (stochastic)")
48 print(f" DDIM: MSE between runs = {ddim_diff:.6f} (deterministic)")
49
50 # 3. Quality at different step counts
51 print(f"\n3. QUALITY vs STEPS")
52
53 # Reference: DDPM 1000 steps (assumed ground truth)
54 reference = ddpm_samples
55
56 for steps in [100, 50, 25, 10]:
57 samples = ddim_sampler.sample(shape=(4, 3, 64, 64), num_steps=steps)
58 # Note: This isn't a perfect comparison since they start from different noise
59 # In practice, use FID for proper evaluation
60 print(f" DDIM {steps:4d} steps: samples generated")
61
62 return {
63 "ddpm_time": ddpm_time,
64 "ddim_time": ddim_time,
65 "speedup": ddpm_time / ddim_time,
66 }Summary
DDIM represents a fundamental shift in how we think about diffusion sampling:
- Non-Markovian formulation: By allowing each step to depend on prediction, DDIM breaks free from the step-by-step constraints of DDPM
- Determinism through eta=0: Setting eliminates all stochasticity, enabling reproducible generation
- Arbitrary step counts: Using cumulative alphas allows skipping from any to any t' without mathematical inconsistency
- Same trained model: No retraining required - DDIM is purely a sampling-time modification
Coming Up Next
The DDIM framework opens the door to many applications that were impossible with stochastic sampling. Image editing via inversion, controllable generation through latent manipulation, and real-time generation all become feasible once we have a deterministic mapping between noise and images.