Learning Objectives
By the end of this section, you will:
- Implement a production-ready DDIM sampler with all features
- Master DDIM inversion for encoding real images into latent space
- Create smooth interpolations between images in latent space
- Implement image editing through latent manipulation
- Optimize DDIM for maximum performance
Hands-On Implementation
Complete DDIM Implementation
Let's implement a full-featured DDIM sampler with all the bells and whistles:
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import numpy as np
5from typing import Optional, List, Tuple, Union, Callable
6from tqdm import tqdm
7from dataclasses import dataclass
8
9
10@dataclass
11class DDIMConfig:
12 """Configuration for DDIM sampler."""
13 num_timesteps: int = 1000
14 eta: float = 0.0 # 0 = deterministic, 1 = DDPM-like
15 clip_denoised: bool = True
16 clip_range: Tuple[float, float] = (-1.0, 1.0)
17
18
19class DDIMSampler:
20 """
21 Production-ready DDIM sampler with advanced features.
22
23 Features:
24 - Deterministic and stochastic sampling
25 - DDIM inversion for image encoding
26 - Arbitrary timestep schedules
27 - Progress tracking and callbacks
28 - Memory-efficient implementation
29 """
30
31 def __init__(
32 self,
33 model: nn.Module,
34 alphas_cumprod: torch.Tensor,
35 config: Optional[DDIMConfig] = None,
36 device: str = "cuda"
37 ):
38 """
39 Initialize DDIM sampler.
40
41 Args:
42 model: Trained noise prediction network
43 alphas_cumprod: Cumulative product of alphas
44 config: DDIM configuration
45 device: Computation device
46 """
47 self.model = model
48 self.device = device
49 self.config = config or DDIMConfig()
50
51 # Precompute schedule values
52 self.alphas_cumprod = alphas_cumprod.to(device)
53 self.alphas_cumprod_prev = F.pad(
54 self.alphas_cumprod[:-1], (1, 0), value=1.0
55 )
56
57 # Useful derived quantities
58 self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
59 self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)
60
61 def make_timesteps(
62 self,
63 num_steps: int,
64 schedule: str = "uniform"
65 ) -> torch.Tensor:
66 """
67 Create timestep sequence for sampling.
68
69 Args:
70 num_steps: Number of sampling steps
71 schedule: "uniform", "quadratic", or "trailing"
72
73 Returns:
74 Tensor of timesteps in descending order
75 """
76 T = self.config.num_timesteps
77
78 if schedule == "uniform":
79 # Uniform spacing
80 c = T // num_steps
81 timesteps = np.asarray(list(range(0, T, c)))
82
83 elif schedule == "quadratic":
84 # Quadratic: more steps at lower noise
85 timesteps = (
86 (np.linspace(0, np.sqrt(T * 0.8), num_steps)) ** 2
87 ).astype(int)
88
89 elif schedule == "trailing":
90 # Trailing: used by some samplers
91 timesteps = np.round(
92 np.linspace(0, T - 1, num_steps)
93 ).astype(int)
94
95 else:
96 raise ValueError(f"Unknown schedule: {schedule}")
97
98 # Ensure unique and sorted descending
99 timesteps = np.unique(timesteps)[::-1]
100
101 return torch.from_numpy(timesteps.copy()).long().to(self.device)
102
103 @torch.no_grad()
104 def sample(
105 self,
106 shape: Tuple[int, ...],
107 num_steps: int = 50,
108 eta: Optional[float] = None,
109 x_T: Optional[torch.Tensor] = None,
110 progress: bool = True,
111 callback: Optional[Callable] = None,
112 schedule: str = "uniform"
113 ) -> torch.Tensor:
114 """
115 Generate samples using DDIM.
116
117 Args:
118 shape: Output shape (B, C, H, W)
119 num_steps: Number of sampling steps
120 eta: Stochasticity (None uses config default)
121 x_T: Starting noise (None = random)
122 progress: Show progress bar
123 callback: Called each step with (step, x_t, x0_pred)
124 schedule: Timestep schedule type
125
126 Returns:
127 Generated samples in [-1, 1]
128 """
129 eta = eta if eta is not None else self.config.eta
130
131 # Initialize
132 if x_T is None:
133 x_t = torch.randn(shape, device=self.device)
134 else:
135 x_t = x_T.to(self.device)
136
137 # Get timesteps
138 timesteps = self.make_timesteps(num_steps, schedule)
139
140 self.model.eval()
141
142 # Sampling loop
143 iterator = tqdm(range(len(timesteps)), desc="DDIM") if progress else range(len(timesteps))
144
145 for i in iterator:
146 t = timesteps[i]
147 t_prev = timesteps[i + 1] if i + 1 < len(timesteps) else torch.tensor(0)
148
149 x_t, x0_pred = self._ddim_step(x_t, t, t_prev, eta)
150
151 if callback is not None:
152 callback(i, x_t, x0_pred)
153
154 return x_t
155
156 def _ddim_step(
157 self,
158 x_t: torch.Tensor,
159 t: torch.Tensor,
160 t_prev: torch.Tensor,
161 eta: float
162 ) -> Tuple[torch.Tensor, torch.Tensor]:
163 """
164 Single DDIM denoising step.
165
166 Returns:
167 (x_{t_prev}, predicted_x0)
168 """
169 batch_size = x_t.shape[0]
170 t_batch = t.expand(batch_size) if t.dim() == 0 else t
171
172 # Get alpha values
173 alpha_t = self.alphas_cumprod[t]
174 alpha_t_prev = self.alphas_cumprod[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=self.device)
175
176 # Predict noise
177 eps_pred = self.model(x_t, t_batch)
178
179 # Predict x_0
180 x0_pred = (x_t - self.sqrt_one_minus_alphas_cumprod[t] * eps_pred) / self.sqrt_alphas_cumprod[t]
181
182 # Clip if configured
183 if self.config.clip_denoised:
184 x0_pred = x0_pred.clamp(*self.config.clip_range)
185
186 # Compute sigma
187 sigma_t = self._compute_sigma(alpha_t, alpha_t_prev, eta)
188
189 # Direction pointing to x_t
190 direction = torch.sqrt(1 - alpha_t_prev - sigma_t ** 2) * eps_pred
191
192 # Compute x_{t-1}
193 x_prev = torch.sqrt(alpha_t_prev) * x0_pred + direction
194
195 # Add noise if stochastic
196 if eta > 0 and t_prev > 0:
197 noise = torch.randn_like(x_t)
198 x_prev = x_prev + sigma_t * noise
199
200 return x_prev, x0_pred
201
202 def _compute_sigma(
203 self,
204 alpha_t: torch.Tensor,
205 alpha_t_prev: torch.Tensor,
206 eta: float
207 ) -> torch.Tensor:
208 """Compute sigma for stochastic DDIM."""
209 sigma = eta * torch.sqrt(
210 (1 - alpha_t_prev) / (1 - alpha_t) *
211 (1 - alpha_t / alpha_t_prev)
212 )
213 return sigma
214
215
216# Example usage
217def example_ddim_generation():
218 """Demonstrate DDIM generation."""
219 # Assuming model and noise_schedule are defined
220 config = DDIMConfig(
221 num_timesteps=1000,
222 eta=0.0, # Deterministic
223 clip_denoised=True,
224 clip_range=(-1.0, 1.0)
225 )
226
227 sampler = DDIMSampler(
228 model=model,
229 alphas_cumprod=noise_schedule.alphas_cumprod,
230 config=config,
231 device="cuda"
232 )
233
234 # Generate 4 images with 50 steps
235 samples = sampler.sample(
236 shape=(4, 3, 64, 64),
237 num_steps=50,
238 eta=0.0,
239 progress=True
240 )
241
242 print(f"Generated {samples.shape[0]} images")
243 return samplesDDIM Inversion (Encoding)
One of DDIM's most powerful features is inversion: finding the latent code that generates a given image. This enables image editing and manipulation.
The Inversion Process
DDIM inversion runs the sampling process backwards in time:
Starting from a real image , we iteratively add noise to recover the latent .
1class DDIMInverter:
2 """
3 DDIM inversion for encoding images into latent space.
4
5 Given an image x_0, finds the latent x_T such that
6 DDIM(x_T) approximately equals x_0.
7 """
8
9 def __init__(
10 self,
11 model: nn.Module,
12 alphas_cumprod: torch.Tensor,
13 num_timesteps: int = 1000,
14 device: str = "cuda"
15 ):
16 self.model = model
17 self.alphas_cumprod = alphas_cumprod.to(device)
18 self.T = num_timesteps
19 self.device = device
20
21 # Precompute
22 self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
23 self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)
24
25 @torch.no_grad()
26 def invert(
27 self,
28 x_0: torch.Tensor,
29 num_steps: int = 50,
30 progress: bool = True
31 ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
32 """
33 Invert an image to find its latent code.
34
35 Args:
36 x_0: Input image in [-1, 1], shape (B, C, H, W)
37 num_steps: Number of inversion steps
38 progress: Show progress bar
39
40 Returns:
41 (x_T, trajectory) - latent code and intermediate states
42 """
43 x_0 = x_0.to(self.device)
44
45 # Create forward timesteps (0 -> T)
46 timesteps = self._make_inversion_timesteps(num_steps)
47
48 # Start from x_0
49 x_t = x_0.clone()
50 trajectory = [x_t.clone()]
51
52 self.model.eval()
53
54 iterator = tqdm(range(len(timesteps) - 1), desc="Inverting") if progress else range(len(timesteps) - 1)
55
56 for i in iterator:
57 t = timesteps[i]
58 t_next = timesteps[i + 1]
59
60 x_t = self._inversion_step(x_t, t, t_next, x_0)
61 trajectory.append(x_t.clone())
62
63 return x_t, trajectory
64
65 def _inversion_step(
66 self,
67 x_t: torch.Tensor,
68 t: int,
69 t_next: int,
70 x_0: torch.Tensor
71 ) -> torch.Tensor:
72 """
73 Single DDIM inversion step (forward in time).
74
75 Goes from x_t to x_{t_next} where t_next > t.
76 """
77 batch_size = x_t.shape[0]
78 t_batch = torch.full((batch_size,), t, device=self.device, dtype=torch.long)
79
80 # Predict noise at current timestep
81 eps_pred = self.model(x_t, t_batch)
82
83 # Get alpha values
84 alpha_t = self.alphas_cumprod[t]
85 alpha_t_next = self.alphas_cumprod[t_next]
86
87 # Predict x_0 from current x_t
88 x0_pred = (x_t - self.sqrt_one_minus_alphas_cumprod[t] * eps_pred) / self.sqrt_alphas_cumprod[t]
89
90 # Optional: use actual x_0 for more accurate inversion
91 # This is "guided inversion"
92 # x0_pred = x_0 # Uncomment for guided mode
93
94 # Compute x_{t_next}
95 direction = self.sqrt_one_minus_alphas_cumprod[t_next] * eps_pred
96 x_next = self.sqrt_alphas_cumprod[t_next] * x0_pred + direction
97
98 return x_next
99
100 def _make_inversion_timesteps(self, num_steps: int) -> List[int]:
101 """Create timesteps for inversion (ascending order)."""
102 step_size = self.T // num_steps
103 timesteps = list(range(0, self.T, step_size))
104 if timesteps[-1] != self.T - 1:
105 timesteps.append(self.T - 1)
106 return timesteps
107
108
109def demonstrate_inversion(model, noise_schedule, image):
110 """
111 Demonstrate DDIM inversion and reconstruction.
112 """
113 inverter = DDIMInverter(
114 model=model,
115 alphas_cumprod=noise_schedule.alphas_cumprod,
116 num_timesteps=noise_schedule.T
117 )
118
119 sampler = DDIMSampler(
120 model=model,
121 alphas_cumprod=noise_schedule.alphas_cumprod,
122 config=DDIMConfig(eta=0.0)
123 )
124
125 # Invert the image
126 x_T, inversion_trajectory = inverter.invert(
127 x_0=image,
128 num_steps=50
129 )
130
131 print(f"Inverted image shape: {image.shape}")
132 print(f"Latent code shape: {x_T.shape}")
133
134 # Reconstruct from latent
135 reconstructed = sampler.sample(
136 shape=image.shape,
137 num_steps=50,
138 x_T=x_T,
139 eta=0.0 # Must be deterministic for reconstruction
140 )
141
142 # Compute reconstruction error
143 mse = F.mse_loss(image, reconstructed).item()
144 print(f"Reconstruction MSE: {mse:.6f}")
145
146 return x_T, reconstructed, inversion_trajectoryInversion Accuracy
Semantic Interpolation
With deterministic DDIM, we can smoothly interpolate between images by interpolating their latent codes:
1class DDIMInterpolator:
2 """
3 Semantic interpolation between images using DDIM.
4 """
5
6 def __init__(
7 self,
8 model: nn.Module,
9 alphas_cumprod: torch.Tensor,
10 num_timesteps: int = 1000,
11 device: str = "cuda"
12 ):
13 self.model = model
14 self.device = device
15
16 self.inverter = DDIMInverter(
17 model=model,
18 alphas_cumprod=alphas_cumprod,
19 num_timesteps=num_timesteps,
20 device=device
21 )
22
23 self.sampler = DDIMSampler(
24 model=model,
25 alphas_cumprod=alphas_cumprod,
26 config=DDIMConfig(eta=0.0),
27 device=device
28 )
29
30 @torch.no_grad()
31 def interpolate(
32 self,
33 image1: torch.Tensor,
34 image2: torch.Tensor,
35 num_frames: int = 10,
36 inversion_steps: int = 50,
37 sampling_steps: int = 50,
38 interpolation_type: str = "linear"
39 ) -> List[torch.Tensor]:
40 """
41 Create smooth interpolation between two images.
42
43 Args:
44 image1: First image in [-1, 1]
45 image2: Second image in [-1, 1]
46 num_frames: Number of interpolation frames
47 inversion_steps: Steps for inversion
48 sampling_steps: Steps for generation
49 interpolation_type: "linear" or "spherical"
50
51 Returns:
52 List of interpolated images
53 """
54 # Invert both images to latent space
55 z1, _ = self.inverter.invert(image1, num_steps=inversion_steps)
56 z2, _ = self.inverter.invert(image2, num_steps=inversion_steps)
57
58 # Generate interpolation alphas
59 alphas = np.linspace(0, 1, num_frames)
60
61 interpolated = []
62
63 for alpha in alphas:
64 # Interpolate in latent space
65 if interpolation_type == "linear":
66 z_interp = self._lerp(z1, z2, alpha)
67 elif interpolation_type == "spherical":
68 z_interp = self._slerp(z1, z2, alpha)
69 else:
70 raise ValueError(f"Unknown interpolation: {interpolation_type}")
71
72 # Generate from interpolated latent
73 sample = self.sampler.sample(
74 shape=z_interp.shape,
75 num_steps=sampling_steps,
76 x_T=z_interp,
77 progress=False
78 )
79
80 interpolated.append(sample)
81
82 return interpolated
83
84 def _lerp(
85 self,
86 z1: torch.Tensor,
87 z2: torch.Tensor,
88 alpha: float
89 ) -> torch.Tensor:
90 """Linear interpolation."""
91 return (1 - alpha) * z1 + alpha * z2
92
93 def _slerp(
94 self,
95 z1: torch.Tensor,
96 z2: torch.Tensor,
97 alpha: float
98 ) -> torch.Tensor:
99 """
100 Spherical linear interpolation.
101
102 Better for high-dimensional spaces like latent codes.
103 """
104 # Normalize
105 z1_flat = z1.flatten()
106 z2_flat = z2.flatten()
107
108 # Compute angle
109 dot = torch.dot(z1_flat, z2_flat)
110 dot = dot / (torch.norm(z1_flat) * torch.norm(z2_flat))
111 dot = torch.clamp(dot, -1, 1)
112 omega = torch.acos(dot)
113
114 # Handle edge case
115 if torch.abs(omega) < 1e-10:
116 return self._lerp(z1, z2, alpha)
117
118 # Spherical interpolation
119 sin_omega = torch.sin(omega)
120 s1 = torch.sin((1 - alpha) * omega) / sin_omega
121 s2 = torch.sin(alpha * omega) / sin_omega
122
123 return s1 * z1 + s2 * z2
124
125
126def create_interpolation_video(model, noise_schedule, image1, image2):
127 """
128 Create smooth interpolation between two images.
129 """
130 interpolator = DDIMInterpolator(
131 model=model,
132 alphas_cumprod=noise_schedule.alphas_cumprod
133 )
134
135 frames = interpolator.interpolate(
136 image1=image1,
137 image2=image2,
138 num_frames=30,
139 interpolation_type="spherical" # Better for latent spaces
140 )
141
142 print(f"Generated {len(frames)} interpolation frames")
143
144 # Convert to video-ready format
145 frames_tensor = torch.stack([f.squeeze(0) for f in frames])
146 frames_tensor = (frames_tensor + 1) / 2 # [-1, 1] -> [0, 1]
147
148 return frames_tensorSpherical vs Linear Interpolation
Image Editing via Latent Space
DDIM inversion enables powerful image editing by manipulating the latent representation:
1class DDIMImageEditor:
2 """
3 Edit images by manipulating their latent representations.
4 """
5
6 def __init__(
7 self,
8 model: nn.Module,
9 alphas_cumprod: torch.Tensor,
10 num_timesteps: int = 1000,
11 device: str = "cuda"
12 ):
13 self.model = model
14 self.device = device
15
16 self.inverter = DDIMInverter(
17 model=model,
18 alphas_cumprod=alphas_cumprod,
19 num_timesteps=num_timesteps,
20 device=device
21 )
22
23 self.sampler = DDIMSampler(
24 model=model,
25 alphas_cumprod=alphas_cumprod,
26 config=DDIMConfig(eta=0.0),
27 device=device
28 )
29
30 @torch.no_grad()
31 def edit_by_noise_injection(
32 self,
33 image: torch.Tensor,
34 noise_scale: float = 0.3,
35 edit_timestep: int = 200,
36 inversion_steps: int = 50,
37 sampling_steps: int = 50
38 ) -> torch.Tensor:
39 """
40 Edit image by injecting noise at an intermediate timestep.
41
42 This creates variations that preserve overall structure.
43
44 Args:
45 image: Input image in [-1, 1]
46 noise_scale: Amount of noise to inject (0-1)
47 edit_timestep: Timestep at which to inject noise
48 inversion_steps: Steps for inversion
49 sampling_steps: Steps for generation
50
51 Returns:
52 Edited image
53 """
54 # Invert to full latent trajectory
55 _, trajectory = self.inverter.invert(
56 image,
57 num_steps=inversion_steps
58 )
59
60 # Find the state at edit_timestep
61 step_size = self.inverter.T // inversion_steps
62 edit_index = edit_timestep // step_size
63 edit_index = min(edit_index, len(trajectory) - 1)
64
65 x_edit = trajectory[edit_index].clone()
66
67 # Inject noise
68 noise = torch.randn_like(x_edit) * noise_scale
69 x_edit = x_edit + noise
70
71 # Sample from edited point
72 edited = self.sampler.sample(
73 shape=x_edit.shape,
74 num_steps=sampling_steps - edit_index,
75 x_T=x_edit,
76 progress=False
77 )
78
79 return edited
80
81 @torch.no_grad()
82 def edit_by_direction(
83 self,
84 image: torch.Tensor,
85 direction: torch.Tensor,
86 strength: float = 1.0,
87 inversion_steps: int = 100,
88 sampling_steps: int = 50
89 ) -> torch.Tensor:
90 """
91 Edit image by moving in a semantic direction.
92
93 The direction can be learned from pairs of images
94 (e.g., smiling vs not smiling).
95
96 Args:
97 image: Input image in [-1, 1]
98 direction: Edit direction in latent space
99 strength: How far to move along direction
100 inversion_steps: Steps for inversion
101 sampling_steps: Steps for generation
102
103 Returns:
104 Edited image
105 """
106 # Invert
107 x_T, _ = self.inverter.invert(image, num_steps=inversion_steps)
108
109 # Apply direction
110 x_T_edited = x_T + strength * direction.to(self.device)
111
112 # Regenerate
113 edited = self.sampler.sample(
114 shape=x_T_edited.shape,
115 num_steps=sampling_steps,
116 x_T=x_T_edited,
117 progress=False
118 )
119
120 return edited
121
122 @torch.no_grad()
123 def blend_images(
124 self,
125 images: List[torch.Tensor],
126 weights: List[float],
127 inversion_steps: int = 100,
128 sampling_steps: int = 50
129 ) -> torch.Tensor:
130 """
131 Blend multiple images by combining their latent codes.
132
133 Args:
134 images: List of images in [-1, 1]
135 weights: Blending weights (should sum to 1)
136 inversion_steps: Steps for inversion
137 sampling_steps: Steps for generation
138
139 Returns:
140 Blended image
141 """
142 assert len(images) == len(weights)
143 assert abs(sum(weights) - 1.0) < 1e-6
144
145 # Invert all images
146 latents = []
147 for img in images:
148 z, _ = self.inverter.invert(img, num_steps=inversion_steps)
149 latents.append(z)
150
151 # Weighted combination
152 x_T_blend = sum(w * z for w, z in zip(weights, latents))
153
154 # Generate
155 blended = self.sampler.sample(
156 shape=x_T_blend.shape,
157 num_steps=sampling_steps,
158 x_T=x_T_blend,
159 progress=False
160 )
161
162 return blended
163
164
165def learn_edit_direction(
166 model,
167 noise_schedule,
168 positive_images: List[torch.Tensor],
169 negative_images: List[torch.Tensor],
170 inversion_steps: int = 100
171) -> torch.Tensor:
172 """
173 Learn a semantic edit direction from image pairs.
174
175 Example: positive = smiling faces, negative = neutral faces
176 The resulting direction can be used to add/remove smiles.
177 """
178 inverter = DDIMInverter(
179 model=model,
180 alphas_cumprod=noise_schedule.alphas_cumprod
181 )
182
183 # Invert positive examples
184 positive_latents = []
185 for img in positive_images:
186 z, _ = inverter.invert(img, num_steps=inversion_steps)
187 positive_latents.append(z)
188
189 # Invert negative examples
190 negative_latents = []
191 for img in negative_images:
192 z, _ = inverter.invert(img, num_steps=inversion_steps)
193 negative_latents.append(z)
194
195 # Compute mean difference
196 pos_mean = torch.stack(positive_latents).mean(dim=0)
197 neg_mean = torch.stack(negative_latents).mean(dim=0)
198
199 direction = pos_mean - neg_mean
200
201 # Normalize
202 direction = direction / torch.norm(direction)
203
204 return directionEdit Direction Discovery
Implementation Optimizations
Here are key optimizations for production DDIM:
1. Batched Sampling
1def optimized_batch_sample(
2 sampler: DDIMSampler,
3 num_samples: int,
4 batch_size: int = 16,
5 num_steps: int = 50
6) -> torch.Tensor:
7 """
8 Generate many samples efficiently with batching.
9 """
10 all_samples = []
11 num_batches = (num_samples + batch_size - 1) // batch_size
12
13 for i in range(num_batches):
14 current_batch = min(batch_size, num_samples - i * batch_size)
15
16 samples = sampler.sample(
17 shape=(current_batch, 3, 64, 64),
18 num_steps=num_steps,
19 progress=False
20 )
21 all_samples.append(samples)
22
23 return torch.cat(all_samples, dim=0)[:num_samples]2. Mixed Precision Sampling
1class DDIMSamplerAMP(DDIMSampler):
2 """DDIM with automatic mixed precision for faster sampling."""
3
4 @torch.no_grad()
5 def sample(self, shape, num_steps=50, eta=None, x_T=None, progress=True, **kwargs):
6 """Sample with AMP for speed."""
7 eta = eta if eta is not None else self.config.eta
8
9 if x_T is None:
10 x_t = torch.randn(shape, device=self.device)
11 else:
12 x_t = x_T.to(self.device)
13
14 timesteps = self.make_timesteps(num_steps)
15
16 self.model.eval()
17
18 # Use autocast for faster inference
19 with torch.amp.autocast('cuda'):
20 for i in (tqdm(range(len(timesteps))) if progress else range(len(timesteps))):
21 t = timesteps[i]
22 t_prev = timesteps[i + 1] if i + 1 < len(timesteps) else torch.tensor(0)
23 x_t, _ = self._ddim_step(x_t, t, t_prev, eta)
24
25 return x_t3. Compiled Model
1# PyTorch 2.0+ compilation for faster inference
2def compile_for_sampling(model: nn.Module) -> nn.Module:
3 """Compile model for faster DDIM sampling."""
4 return torch.compile(
5 model,
6 mode="reduce-overhead", # Optimize for repeated calls
7 fullgraph=True
8 )
9
10# Usage
11compiled_model = compile_for_sampling(model)
12sampler = DDIMSampler(
13 model=compiled_model,
14 alphas_cumprod=noise_schedule.alphas_cumprod
15)
16
17# First sample is slow (compilation), subsequent are fast
18samples = sampler.sample(shape=(4, 3, 64, 64), num_steps=50)4. Caching for Repeated Sampling
1class CachedDDIMSampler(DDIMSampler):
2 """
3 DDIM sampler with caching for repeated operations.
4 """
5
6 def __init__(self, *args, **kwargs):
7 super().__init__(*args, **kwargs)
8 self._timestep_cache = {}
9
10 def make_timesteps(self, num_steps, schedule="uniform"):
11 """Cached timestep generation."""
12 cache_key = (num_steps, schedule)
13 if cache_key not in self._timestep_cache:
14 self._timestep_cache[cache_key] = super().make_timesteps(num_steps, schedule)
15 return self._timestep_cache[cache_key]
16
17 @torch.no_grad()
18 def sample_multiple_from_same_noise(
19 self,
20 x_T: torch.Tensor,
21 num_steps_list: List[int],
22 ) -> dict:
23 """
24 Generate at multiple step counts from same noise.
25
26 Useful for quality comparison experiments.
27 """
28 results = {}
29
30 for num_steps in sorted(num_steps_list, reverse=True):
31 samples = self.sample(
32 shape=x_T.shape,
33 num_steps=num_steps,
34 x_T=x_T.clone(),
35 progress=False
36 )
37 results[num_steps] = samples
38
39 return results| Optimization | Speedup | Memory | Notes |
|---|---|---|---|
| Mixed Precision | 1.5-2x | 50% less | Minimal quality loss |
| Torch Compile | 1.3-2x | Same | PyTorch 2.0+ only |
| Batching | Linear | Linear increase | GPU utilization |
| Caching | Variable | Slight increase | Repeated timesteps |
Summary
We've built a complete DDIM toolkit with:
- Full-featured sampler with configurable eta, progress tracking, and multiple timestep schedules
- DDIM inversion for encoding real images into latent space, enabling reconstruction and manipulation
- Semantic interpolation with both linear and spherical methods for smooth transitions between images
- Image editing capabilities including noise injection, directional editing, and multi-image blending
- Production optimizations for speed and efficiency
Coming Up Next
The DDIM framework we've built here forms the foundation for modern diffusion model applications. Whether you're building an image generation service, an editing tool, or a creative application, these components provide the building blocks you need.