Learning Objectives
By the end of this section, you will be able to:
- Implement the complete DDPM sampling algorithm from pure noise to clean images
- Apply DDIM for accelerated deterministic sampling with fewer steps
- Configure sampling parameters for optimal quality-speed tradeoffs
- Generate large batches of images efficiently using GPU parallelism
The Big Picture
Image generation in diffusion models is the reverse of the forward process: we start from pure Gaussian noise and iteratively denoise to produce clean images. The trained model predicts the noise at each step, allowing us to reconstruct the original signal.
The Sampling Process: Starting from , we iteratively compute from using the model's noise prediction. After steps, we obtain a clean sample .
Watch the diffusion model progressively denoise a 1D signal from pure noise to clean data
What's Happening at Each Step?
1. Predict Noise (or x0)
The neural network takes the current noisy signal x_t and timestep t, then predicts either the noise epsilon or the clean signal x_0.
2. Compute Mean
Using the prediction, we compute the mean of p(x_{t-1}|x_t), which tells us where to move in the next step.
3. Add Noise (DDPM) or Not (DDIM)
DDPM adds scaled Gaussian noise to maintain the Markov property. DDIM skips this for deterministic, faster sampling.
4. Update x
We update x_{t-1} = mean + sigma * z (DDPM) or x_{t-1} = mean (DDIM), moving one step closer to the clean data.
Key Equations
| Method | Steps | Stochastic | Quality | Speed |
|---|---|---|---|---|
| DDPM | 1000 | Yes | Best | Slow |
| DDIM | 50-100 | No | Near DDPM | Fast |
| DDIM (eta=1) | 50-100 | Yes | Good | Fast |
| DPM-Solver | 10-20 | No | Good | Very Fast |
DDPM Sampling
The DDPM Reverse Process
The DDPM sampling follows the reverse process derived from the variational lower bound. At each step, we predict the noise and compute the mean and variance of the reverse distribution:
where and or .
1import torch
2import torch.nn as nn
3from typing import Optional, Callable
4from tqdm import tqdm
5
6class DDPMSampler:
7 """DDPM sampling with various noise schedule options."""
8
9 def __init__(
10 self,
11 model: nn.Module,
12 betas: torch.Tensor,
13 device: str = "cuda",
14 ):
15 self.model = model.to(device).eval()
16 self.device = device
17
18 # Precompute coefficients
19 self.betas = betas.to(device)
20 self.alphas = 1.0 - self.betas
21 self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
22 self.alphas_cumprod_prev = torch.cat([
23 torch.tensor([1.0], device=device),
24 self.alphas_cumprod[:-1]
25 ])
26
27 # Coefficients for sampling
28 self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
29 self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
30
31 # Posterior variance options
32 self.posterior_variance = (
33 self.betas * (1.0 - self.alphas_cumprod_prev) /
34 (1.0 - self.alphas_cumprod)
35 )
36
37 self.timesteps = len(betas)
38
39 @torch.no_grad()
40 def sample(
41 self,
42 shape: tuple,
43 num_steps: Optional[int] = None,
44 return_intermediates: bool = False,
45 progress_callback: Optional[Callable] = None,
46 ) -> torch.Tensor:
47 """Generate samples using DDPM reverse process.
48
49 Args:
50 shape: Output shape (batch_size, channels, height, width)
51 num_steps: Number of steps (defaults to full timesteps)
52 return_intermediates: Whether to return intermediate samples
53 progress_callback: Optional callback for progress updates
54
55 Returns:
56 Generated samples (and intermediates if requested)
57 """
58 if num_steps is None:
59 num_steps = self.timesteps
60
61 # Start from pure noise
62 x = torch.randn(shape, device=self.device)
63
64 intermediates = [x.clone()] if return_intermediates else None
65
66 # Reverse process
67 timesteps = torch.linspace(
68 self.timesteps - 1, 0, num_steps, dtype=torch.long, device=self.device
69 )
70
71 for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
72 t_batch = torch.full((shape[0],), t, device=self.device, dtype=torch.long)
73
74 # Predict noise
75 pred_noise = self.model(x, t_batch)
76
77 # Compute mean
78 alpha = self.alphas[t]
79 alpha_bar = self.alphas_cumprod[t]
80 beta = self.betas[t]
81
82 mean = self.sqrt_recip_alphas[t] * (
83 x - beta / self.sqrt_one_minus_alphas_cumprod[t] * pred_noise
84 )
85
86 # Add noise (except for t=0)
87 if t > 0:
88 noise = torch.randn_like(x)
89 sigma = torch.sqrt(self.posterior_variance[t])
90 x = mean + sigma * noise
91 else:
92 x = mean
93
94 if return_intermediates and i % (num_steps // 10) == 0:
95 intermediates.append(x.clone())
96
97 if progress_callback:
98 progress_callback(i / num_steps)
99
100 # Final clipping
101 x = x.clamp(-1, 1)
102
103 if return_intermediates:
104 return x, intermediates
105 return x
106
107 @torch.no_grad()
108 def sample_with_guidance(
109 self,
110 shape: tuple,
111 condition: torch.Tensor,
112 guidance_scale: float = 7.5,
113 num_steps: Optional[int] = None,
114 ) -> torch.Tensor:
115 """Classifier-free guidance sampling."""
116 if num_steps is None:
117 num_steps = self.timesteps
118
119 x = torch.randn(shape, device=self.device)
120 batch_size = shape[0]
121
122 timesteps = torch.linspace(
123 self.timesteps - 1, 0, num_steps, dtype=torch.long, device=self.device
124 )
125
126 for t in tqdm(timesteps, desc="Sampling with guidance"):
127 t_batch = torch.full((batch_size,), t, device=self.device, dtype=torch.long)
128
129 # Conditional prediction
130 pred_cond = self.model(x, t_batch, condition)
131
132 # Unconditional prediction (null condition)
133 null_condition = torch.zeros_like(condition)
134 pred_uncond = self.model(x, t_batch, null_condition)
135
136 # Classifier-free guidance
137 pred_noise = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
138
139 # DDPM update
140 alpha = self.alphas[t]
141 alpha_bar = self.alphas_cumprod[t]
142 beta = self.betas[t]
143
144 mean = self.sqrt_recip_alphas[t] * (
145 x - beta / self.sqrt_one_minus_alphas_cumprod[t] * pred_noise
146 )
147
148 if t > 0:
149 noise = torch.randn_like(x)
150 sigma = torch.sqrt(self.posterior_variance[t])
151 x = mean + sigma * noise
152 else:
153 x = mean
154
155 return x.clamp(-1, 1)Posterior Variance Choice
DDIM Sampling
Deterministic Sampling
DDIM (Denoising Diffusion Implicit Models) enables deterministic sampling and allows using fewer steps without retraining. The key insight is that the same trained model can be used with a non-Markovian reverse process:
1import torch
2import torch.nn as nn
3from typing import Optional, List
4import numpy as np
5from tqdm import tqdm
6
7class DDIMSampler:
8 """DDIM sampler for fast, deterministic generation."""
9
10 def __init__(
11 self,
12 model: nn.Module,
13 alphas_cumprod: torch.Tensor,
14 device: str = "cuda",
15 ):
16 self.model = model.to(device).eval()
17 self.device = device
18 self.alphas_cumprod = alphas_cumprod.to(device)
19 self.total_timesteps = len(alphas_cumprod)
20
21 def _get_timesteps(self, num_steps: int) -> torch.Tensor:
22 """Get evenly spaced timesteps for sampling."""
23 # Use linspace for even spacing
24 timesteps = torch.linspace(
25 0, self.total_timesteps - 1, num_steps, dtype=torch.long
26 )
27 return timesteps.flip(0).to(self.device)
28
29 def _get_alpha_bars(
30 self,
31 timesteps: torch.Tensor
32 ) -> tuple:
33 """Get alpha_bar values for given timesteps."""
34 alpha_bar = self.alphas_cumprod[timesteps]
35 alpha_bar_prev = torch.cat([
36 torch.tensor([1.0], device=self.device),
37 self.alphas_cumprod[timesteps[:-1]]
38 ])
39 return alpha_bar, alpha_bar_prev
40
41 @torch.no_grad()
42 def sample(
43 self,
44 shape: tuple,
45 num_steps: int = 50,
46 eta: float = 0.0, # 0 = deterministic, 1 = DDPM-like stochasticity
47 return_intermediates: bool = False,
48 ) -> torch.Tensor:
49 """Generate samples using DDIM.
50
51 Args:
52 shape: Output shape (batch_size, channels, height, width)
53 num_steps: Number of sampling steps (can be much less than training steps)
54 eta: Controls stochasticity (0 = deterministic, 1 = full noise)
55 return_intermediates: Return intermediate samples
56
57 Returns:
58 Generated samples
59 """
60 batch_size = shape[0]
61
62 # Get timestep schedule
63 timesteps = self._get_timesteps(num_steps)
64
65 # Start from noise
66 x = torch.randn(shape, device=self.device)
67
68 intermediates = [x.clone()] if return_intermediates else None
69
70 for i, t in enumerate(tqdm(timesteps, desc=f"DDIM ({num_steps} steps)")):
71 t_batch = torch.full((batch_size,), t, device=self.device, dtype=torch.long)
72
73 # Get alpha values
74 alpha_bar = self.alphas_cumprod[t]
75
76 if i < len(timesteps) - 1:
77 alpha_bar_prev = self.alphas_cumprod[timesteps[i + 1]]
78 else:
79 alpha_bar_prev = torch.tensor(1.0, device=self.device)
80
81 # Predict noise
82 pred_noise = self.model(x, t_batch)
83
84 # Predict x0
85 pred_x0 = (x - torch.sqrt(1 - alpha_bar) * pred_noise) / torch.sqrt(alpha_bar)
86 pred_x0 = pred_x0.clamp(-1, 1) # Clamp for stability
87
88 # Compute sigma for stochasticity
89 sigma = eta * torch.sqrt(
90 (1 - alpha_bar_prev) / (1 - alpha_bar) *
91 (1 - alpha_bar / alpha_bar_prev)
92 )
93
94 # Direction pointing to x_t
95 dir_xt = torch.sqrt(1 - alpha_bar_prev - sigma**2) * pred_noise
96
97 # Add noise if eta > 0
98 if eta > 0 and i < len(timesteps) - 1:
99 noise = torch.randn_like(x)
100 x = torch.sqrt(alpha_bar_prev) * pred_x0 + dir_xt + sigma * noise
101 else:
102 x = torch.sqrt(alpha_bar_prev) * pred_x0 + dir_xt
103
104 if return_intermediates and i % (num_steps // 10) == 0:
105 intermediates.append(x.clone())
106
107 x = x.clamp(-1, 1)
108
109 if return_intermediates:
110 return x, intermediates
111 return x
112
113 @torch.no_grad()
114 def sample_interpolation(
115 self,
116 x0_start: torch.Tensor,
117 x0_end: torch.Tensor,
118 num_interpolations: int = 10,
119 num_steps: int = 50,
120 ) -> List[torch.Tensor]:
121 """Generate interpolations between two images via latent space.
122
123 Args:
124 x0_start: Starting image
125 x0_end: Ending image
126 num_interpolations: Number of interpolation steps
127 num_steps: Sampling steps for each interpolation
128
129 Returns:
130 List of interpolated images
131 """
132 # Encode both images to latent (add noise up to step T)
133 timesteps = self._get_timesteps(num_steps)
134 T = timesteps[0]
135
136 # Add noise to both images
137 noise_start = torch.randn_like(x0_start)
138 noise_end = torch.randn_like(x0_end)
139
140 alpha_bar_T = self.alphas_cumprod[T]
141 xT_start = torch.sqrt(alpha_bar_T) * x0_start + torch.sqrt(1 - alpha_bar_T) * noise_start
142 xT_end = torch.sqrt(alpha_bar_T) * x0_end + torch.sqrt(1 - alpha_bar_T) * noise_end
143
144 # Generate interpolations
145 interpolations = []
146 for i in range(num_interpolations):
147 alpha = i / (num_interpolations - 1)
148
149 # Spherical interpolation in latent space
150 xT_interp = self._slerp(xT_start, xT_end, alpha)
151
152 # Denoise
153 sample = self._denoise_from(xT_interp, timesteps)
154 interpolations.append(sample)
155
156 return interpolations
157
158 def _slerp(
159 self,
160 x0: torch.Tensor,
161 x1: torch.Tensor,
162 alpha: float,
163 ) -> torch.Tensor:
164 """Spherical linear interpolation."""
165 # Flatten for computation
166 x0_flat = x0.view(x0.shape[0], -1)
167 x1_flat = x1.view(x1.shape[0], -1)
168
169 # Normalize
170 x0_norm = x0_flat / x0_flat.norm(dim=-1, keepdim=True)
171 x1_norm = x1_flat / x1_flat.norm(dim=-1, keepdim=True)
172
173 # Compute angle
174 dot = (x0_norm * x1_norm).sum(dim=-1, keepdim=True)
175 omega = torch.acos(dot.clamp(-1, 1))
176
177 # Interpolate
178 sin_omega = torch.sin(omega)
179 if sin_omega.abs().min() < 1e-6:
180 # Fall back to linear interpolation
181 result = (1 - alpha) * x0_flat + alpha * x1_flat
182 else:
183 result = (
184 torch.sin((1 - alpha) * omega) / sin_omega * x0_flat +
185 torch.sin(alpha * omega) / sin_omega * x1_flat
186 )
187
188 return result.view(x0.shape)
189
190 def _denoise_from(
191 self,
192 xT: torch.Tensor,
193 timesteps: torch.Tensor,
194 ) -> torch.Tensor:
195 """Denoise from a given latent."""
196 x = xT.clone()
197 batch_size = x.shape[0]
198
199 for i, t in enumerate(timesteps):
200 t_batch = torch.full((batch_size,), t, device=self.device, dtype=torch.long)
201
202 alpha_bar = self.alphas_cumprod[t]
203 alpha_bar_prev = (
204 self.alphas_cumprod[timesteps[i + 1]]
205 if i < len(timesteps) - 1
206 else torch.tensor(1.0, device=self.device)
207 )
208
209 pred_noise = self.model(x, t_batch)
210 pred_x0 = (x - torch.sqrt(1 - alpha_bar) * pred_noise) / torch.sqrt(alpha_bar)
211 pred_x0 = pred_x0.clamp(-1, 1)
212
213 dir_xt = torch.sqrt(1 - alpha_bar_prev) * pred_noise
214 x = torch.sqrt(alpha_bar_prev) * pred_x0 + dir_xt
215
216 return x.clamp(-1, 1)| Eta Value | Behavior | Use Case |
|---|---|---|
| 0.0 | Fully deterministic | Reproducible results, interpolation |
| 0.5 | Mild stochasticity | Balance of diversity and quality |
| 1.0 | DDPM-equivalent noise | Maximum diversity |
Sampling Strategies
Timestep Scheduling
The choice of timestep schedule significantly affects sample quality, especially when using fewer steps:
1import torch
2import numpy as np
3from typing import Literal
4
5class TimestepScheduler:
6 """Various timestep scheduling strategies for sampling."""
7
8 @staticmethod
9 def uniform(
10 total_timesteps: int,
11 num_steps: int,
12 ) -> torch.Tensor:
13 """Uniform spacing - simple but not optimal."""
14 return torch.linspace(total_timesteps - 1, 0, num_steps).long()
15
16 @staticmethod
17 def quadratic(
18 total_timesteps: int,
19 num_steps: int,
20 ) -> torch.Tensor:
21 """Quadratic spacing - more steps at high noise levels."""
22 t = torch.linspace(0, 1, num_steps)
23 timesteps = (t ** 2) * (total_timesteps - 1)
24 return timesteps.flip(0).long()
25
26 @staticmethod
27 def trailing(
28 total_timesteps: int,
29 num_steps: int,
30 ) -> torch.Tensor:
31 """Trailing - align with original timesteps."""
32 step_ratio = total_timesteps // num_steps
33 timesteps = torch.arange(0, total_timesteps, step_ratio)[:num_steps]
34 return timesteps.flip(0).long()
35
36 @staticmethod
37 def leading(
38 total_timesteps: int,
39 num_steps: int,
40 ) -> torch.Tensor:
41 """Leading - start from near max timestep."""
42 step_ratio = total_timesteps // num_steps
43 timesteps = torch.arange(step_ratio - 1, total_timesteps, step_ratio)[:num_steps]
44 return timesteps.flip(0).long()
45
46 @staticmethod
47 def karras(
48 total_timesteps: int,
49 num_steps: int,
50 sigma_min: float = 0.002,
51 sigma_max: float = 80.0,
52 rho: float = 7.0,
53 ) -> torch.Tensor:
54 """Karras et al. schedule - optimal for diffusion models."""
55 ramp = torch.linspace(0, 1, num_steps)
56 min_inv_rho = sigma_min ** (1 / rho)
57 max_inv_rho = sigma_max ** (1 / rho)
58 sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
59
60 # Convert sigmas to timesteps
61 timesteps = ((sigmas / sigma_max) * (total_timesteps - 1)).long()
62 return timesteps
63
64
65def compare_schedules():
66 """Compare different schedules visually."""
67 import matplotlib.pyplot as plt
68
69 total_t = 1000
70 num_steps = 50
71
72 schedules = {
73 "Uniform": TimestepScheduler.uniform(total_t, num_steps),
74 "Quadratic": TimestepScheduler.quadratic(total_t, num_steps),
75 "Trailing": TimestepScheduler.trailing(total_t, num_steps),
76 "Karras": TimestepScheduler.karras(total_t, num_steps),
77 }
78
79 fig, ax = plt.subplots(figsize=(10, 6))
80 for name, timesteps in schedules.items():
81 ax.plot(range(num_steps), timesteps.numpy(), label=name, marker=".")
82
83 ax.set_xlabel("Step")
84 ax.set_ylabel("Timestep")
85 ax.set_title("Timestep Schedule Comparison")
86 ax.legend()
87 ax.grid(True, alpha=0.3)
88 plt.show()Noise Level Strategies
1import torch
2from typing import Optional
3
4class NoiseStrategies:
5 """Different strategies for initial noise in sampling."""
6
7 @staticmethod
8 def standard_normal(
9 shape: tuple,
10 device: str = "cuda",
11 seed: Optional[int] = None,
12 ) -> torch.Tensor:
13 """Standard Gaussian noise."""
14 if seed is not None:
15 torch.manual_seed(seed)
16 return torch.randn(shape, device=device)
17
18 @staticmethod
19 def truncated_normal(
20 shape: tuple,
21 truncation: float = 2.0,
22 device: str = "cuda",
23 seed: Optional[int] = None,
24 ) -> torch.Tensor:
25 """Truncated Gaussian for more consistent results."""
26 if seed is not None:
27 torch.manual_seed(seed)
28
29 noise = torch.randn(shape, device=device)
30 # Clamp to truncation range
31 noise = noise.clamp(-truncation, truncation)
32 return noise
33
34 @staticmethod
35 def pyramid_noise(
36 shape: tuple,
37 discount: float = 0.8,
38 device: str = "cuda",
39 ) -> torch.Tensor:
40 """Multi-scale pyramid noise for better structure."""
41 batch, channels, height, width = shape
42
43 # Generate noise at multiple scales
44 noise = torch.zeros(shape, device=device)
45 scale = 1.0
46
47 h, w = height, width
48 while h >= 4 and w >= 4:
49 # Generate noise at this scale
50 scale_noise = torch.randn(batch, channels, h, w, device=device)
51
52 # Upscale to full resolution
53 if h != height or w != width:
54 scale_noise = torch.nn.functional.interpolate(
55 scale_noise,
56 size=(height, width),
57 mode="bilinear",
58 align_corners=False,
59 )
60
61 noise = noise + scale * scale_noise
62 scale *= discount
63 h, w = h // 2, w // 2
64
65 # Normalize
66 noise = noise / noise.std()
67 return noise
68
69 @staticmethod
70 def latent_blend(
71 shape: tuple,
72 reference: torch.Tensor,
73 blend_strength: float = 0.3,
74 device: str = "cuda",
75 ) -> torch.Tensor:
76 """Blend random noise with reference latent for style transfer."""
77 random_noise = torch.randn(shape, device=device)
78 # reference should be encoded image latent
79 blended = (1 - blend_strength) * random_noise + blend_strength * reference
80 return blendedBatch Generation
1import torch
2import torch.nn as nn
3from pathlib import Path
4from typing import Optional, Dict, Any
5from tqdm import tqdm
6import torchvision.utils as vutils
7import json
8
9class ImageGenerator:
10 """Production-ready image generation with batching and saving."""
11
12 def __init__(
13 self,
14 model: nn.Module,
15 sampler, # DDPMSampler or DDIMSampler
16 output_dir: str,
17 device: str = "cuda",
18 ):
19 self.model = model.to(device).eval()
20 self.sampler = sampler
21 self.output_dir = Path(output_dir)
22 self.output_dir.mkdir(parents=True, exist_ok=True)
23 self.device = device
24
25 def generate_batch(
26 self,
27 num_images: int,
28 batch_size: int = 32,
29 image_size: int = 64,
30 channels: int = 3,
31 num_steps: int = 50,
32 save_individual: bool = True,
33 save_grid: bool = True,
34 grid_nrow: int = 8,
35 metadata: Optional[Dict[str, Any]] = None,
36 ) -> torch.Tensor:
37 """Generate a batch of images with progress tracking.
38
39 Args:
40 num_images: Total number of images to generate
41 batch_size: Batch size for GPU
42 image_size: Output image size
43 channels: Number of image channels
44 num_steps: Sampling steps
45 save_individual: Save each image separately
46 save_grid: Save images as a grid
47 grid_nrow: Number of images per row in grid
48 metadata: Optional metadata to save
49
50 Returns:
51 All generated images as a tensor
52 """
53 all_samples = []
54 num_batches = (num_images + batch_size - 1) // batch_size
55
56 print(f"Generating {num_images} images in {num_batches} batches...")
57
58 for batch_idx in tqdm(range(num_batches), desc="Generating"):
59 # Calculate batch size for this iteration
60 current_batch_size = min(
61 batch_size,
62 num_images - batch_idx * batch_size
63 )
64
65 shape = (current_batch_size, channels, image_size, image_size)
66
67 # Generate samples
68 samples = self.sampler.sample(shape, num_steps=num_steps)
69 all_samples.append(samples.cpu())
70
71 # Save individual images
72 if save_individual:
73 for i, sample in enumerate(samples):
74 img_idx = batch_idx * batch_size + i
75 self._save_image(
76 sample,
77 self.output_dir / f"sample_{img_idx:05d}.png"
78 )
79
80 # Concatenate all samples
81 all_samples = torch.cat(all_samples, dim=0)
82
83 # Save grid
84 if save_grid:
85 grid = vutils.make_grid(
86 all_samples,
87 nrow=grid_nrow,
88 normalize=True,
89 value_range=(-1, 1),
90 )
91 vutils.save_image(
92 grid,
93 self.output_dir / "samples_grid.png"
94 )
95
96 # Save metadata
97 if metadata is not None:
98 metadata.update({
99 "num_images": num_images,
100 "image_size": image_size,
101 "num_steps": num_steps,
102 })
103 with open(self.output_dir / "metadata.json", "w") as f:
104 json.dump(metadata, f, indent=2)
105
106 print(f"Saved {num_images} images to {self.output_dir}")
107 return all_samples
108
109 def generate_with_seeds(
110 self,
111 seeds: list,
112 image_size: int = 64,
113 channels: int = 3,
114 num_steps: int = 50,
115 ) -> torch.Tensor:
116 """Generate images with specific seeds for reproducibility."""
117 samples = []
118
119 for seed in tqdm(seeds, desc="Generating from seeds"):
120 torch.manual_seed(seed)
121 noise = torch.randn(1, channels, image_size, image_size, device=self.device)
122
123 # Sample from this specific noise
124 sample = self.sampler.sample(
125 noise.shape,
126 num_steps=num_steps,
127 )
128 samples.append(sample)
129
130 # Save with seed in filename
131 self._save_image(
132 sample[0],
133 self.output_dir / f"seed_{seed:08d}.png"
134 )
135
136 return torch.cat(samples, dim=0)
137
138 def generate_variations(
139 self,
140 reference: torch.Tensor,
141 num_variations: int = 8,
142 noise_strength: float = 0.3,
143 num_steps: int = 50,
144 ) -> torch.Tensor:
145 """Generate variations of a reference image."""
146 samples = [reference]
147
148 # Add varying amounts of noise and denoise
149 for i in range(num_variations):
150 # Add noise to reference
151 noise = torch.randn_like(reference)
152 strength = noise_strength * (i + 1) / num_variations
153 noisy = reference * (1 - strength) + noise * strength
154
155 # Use as starting point for sampling
156 # (simplified - full implementation would use proper DDIM inversion)
157 sample = self.sampler.sample(reference.shape, num_steps=num_steps)
158 samples.append(sample)
159
160 return torch.cat(samples, dim=0)
161
162 def _save_image(self, tensor: torch.Tensor, path: Path):
163 """Save a single image tensor."""
164 # Convert from [-1, 1] to [0, 1]
165 tensor = (tensor + 1) / 2
166 tensor = tensor.clamp(0, 1)
167 vutils.save_image(tensor, str(path))
168
169
170# Usage example
171def generate_samples():
172 """Example generation script."""
173 # Load model
174 model = load_trained_model("checkpoints/best_model.pt")
175
176 # Create sampler (DDIM for speed)
177 sampler = DDIMSampler(model, alphas_cumprod, device="cuda")
178
179 # Create generator
180 generator = ImageGenerator(
181 model=model,
182 sampler=sampler,
183 output_dir="./generated_samples",
184 )
185
186 # Generate 1000 images
187 samples = generator.generate_batch(
188 num_images=1000,
189 batch_size=64,
190 image_size=64,
191 num_steps=50,
192 save_individual=True,
193 save_grid=True,
194 metadata={"model": "ddpm_cifar10", "steps": 50, "sampler": "ddim"},
195 )
196
197 # Generate specific seeds
198 seed_samples = generator.generate_with_seeds(
199 seeds=[42, 123, 456, 789, 1000],
200 num_steps=50,
201 )
202
203 return samplesMemory Management
Key Takeaways
- DDPM provides best quality: Use full 1000 steps for final production samples when quality matters most.
- DDIM enables fast iteration: Use 50-100 steps during development and for applications requiring speed.
- Eta controls diversity: Set eta=0 for deterministic results (interpolation, reproducibility) or eta=0.5-1.0 for variety.
- Timestep schedule matters: Karras schedule often produces better results than uniform spacing with fewer steps.
- Batch for efficiency: Process multiple images in parallel to maximize GPU utilization.
Looking Ahead: Now that we can generate images, we need to evaluate their quality quantitatively. The next section covers FID, IS, and other metrics for assessing generative model performance.