Chapter 7
25 min read
Section 36 of 76

DDIM Implementation

Improved Sampling Methods

Learning Objectives

By the end of this section, you will:

  1. Implement a production-ready DDIM sampler with all features
  2. Master DDIM inversion for encoding real images into latent space
  3. Create smooth interpolations between images in latent space
  4. Implement image editing through latent manipulation
  5. Optimize DDIM for maximum performance

Hands-On Implementation

This section is entirely practical. We'll build a complete DDIM toolkit that you can use for generation, inversion, interpolation, and editing. Every line of code is production-ready and thoroughly tested.

Complete DDIM Implementation

Let's implement a full-featured DDIM sampler with all the bells and whistles:

🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import numpy as np
5from typing import Optional, List, Tuple, Union, Callable
6from tqdm import tqdm
7from dataclasses import dataclass
8
9
10@dataclass
11class DDIMConfig:
12    """Configuration for DDIM sampler."""
13    num_timesteps: int = 1000
14    eta: float = 0.0  # 0 = deterministic, 1 = DDPM-like
15    clip_denoised: bool = True
16    clip_range: Tuple[float, float] = (-1.0, 1.0)
17
18
19class DDIMSampler:
20    """
21    Production-ready DDIM sampler with advanced features.
22
23    Features:
24    - Deterministic and stochastic sampling
25    - DDIM inversion for image encoding
26    - Arbitrary timestep schedules
27    - Progress tracking and callbacks
28    - Memory-efficient implementation
29    """
30
31    def __init__(
32        self,
33        model: nn.Module,
34        alphas_cumprod: torch.Tensor,
35        config: Optional[DDIMConfig] = None,
36        device: str = "cuda"
37    ):
38        """
39        Initialize DDIM sampler.
40
41        Args:
42            model: Trained noise prediction network
43            alphas_cumprod: Cumulative product of alphas
44            config: DDIM configuration
45            device: Computation device
46        """
47        self.model = model
48        self.device = device
49        self.config = config or DDIMConfig()
50
51        # Precompute schedule values
52        self.alphas_cumprod = alphas_cumprod.to(device)
53        self.alphas_cumprod_prev = F.pad(
54            self.alphas_cumprod[:-1], (1, 0), value=1.0
55        )
56
57        # Useful derived quantities
58        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
59        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)
60
61    def make_timesteps(
62        self,
63        num_steps: int,
64        schedule: str = "uniform"
65    ) -> torch.Tensor:
66        """
67        Create timestep sequence for sampling.
68
69        Args:
70            num_steps: Number of sampling steps
71            schedule: "uniform", "quadratic", or "trailing"
72
73        Returns:
74            Tensor of timesteps in descending order
75        """
76        T = self.config.num_timesteps
77
78        if schedule == "uniform":
79            # Uniform spacing
80            c = T // num_steps
81            timesteps = np.asarray(list(range(0, T, c)))
82
83        elif schedule == "quadratic":
84            # Quadratic: more steps at lower noise
85            timesteps = (
86                (np.linspace(0, np.sqrt(T * 0.8), num_steps)) ** 2
87            ).astype(int)
88
89        elif schedule == "trailing":
90            # Trailing: used by some samplers
91            timesteps = np.round(
92                np.linspace(0, T - 1, num_steps)
93            ).astype(int)
94
95        else:
96            raise ValueError(f"Unknown schedule: {schedule}")
97
98        # Ensure unique and sorted descending
99        timesteps = np.unique(timesteps)[::-1]
100
101        return torch.from_numpy(timesteps.copy()).long().to(self.device)
102
103    @torch.no_grad()
104    def sample(
105        self,
106        shape: Tuple[int, ...],
107        num_steps: int = 50,
108        eta: Optional[float] = None,
109        x_T: Optional[torch.Tensor] = None,
110        progress: bool = True,
111        callback: Optional[Callable] = None,
112        schedule: str = "uniform"
113    ) -> torch.Tensor:
114        """
115        Generate samples using DDIM.
116
117        Args:
118            shape: Output shape (B, C, H, W)
119            num_steps: Number of sampling steps
120            eta: Stochasticity (None uses config default)
121            x_T: Starting noise (None = random)
122            progress: Show progress bar
123            callback: Called each step with (step, x_t, x0_pred)
124            schedule: Timestep schedule type
125
126        Returns:
127            Generated samples in [-1, 1]
128        """
129        eta = eta if eta is not None else self.config.eta
130
131        # Initialize
132        if x_T is None:
133            x_t = torch.randn(shape, device=self.device)
134        else:
135            x_t = x_T.to(self.device)
136
137        # Get timesteps
138        timesteps = self.make_timesteps(num_steps, schedule)
139
140        self.model.eval()
141
142        # Sampling loop
143        iterator = tqdm(range(len(timesteps)), desc="DDIM") if progress else range(len(timesteps))
144
145        for i in iterator:
146            t = timesteps[i]
147            t_prev = timesteps[i + 1] if i + 1 < len(timesteps) else torch.tensor(0)
148
149            x_t, x0_pred = self._ddim_step(x_t, t, t_prev, eta)
150
151            if callback is not None:
152                callback(i, x_t, x0_pred)
153
154        return x_t
155
156    def _ddim_step(
157        self,
158        x_t: torch.Tensor,
159        t: torch.Tensor,
160        t_prev: torch.Tensor,
161        eta: float
162    ) -> Tuple[torch.Tensor, torch.Tensor]:
163        """
164        Single DDIM denoising step.
165
166        Returns:
167            (x_{t_prev}, predicted_x0)
168        """
169        batch_size = x_t.shape[0]
170        t_batch = t.expand(batch_size) if t.dim() == 0 else t
171
172        # Get alpha values
173        alpha_t = self.alphas_cumprod[t]
174        alpha_t_prev = self.alphas_cumprod[t_prev] if t_prev >= 0 else torch.tensor(1.0, device=self.device)
175
176        # Predict noise
177        eps_pred = self.model(x_t, t_batch)
178
179        # Predict x_0
180        x0_pred = (x_t - self.sqrt_one_minus_alphas_cumprod[t] * eps_pred) / self.sqrt_alphas_cumprod[t]
181
182        # Clip if configured
183        if self.config.clip_denoised:
184            x0_pred = x0_pred.clamp(*self.config.clip_range)
185
186        # Compute sigma
187        sigma_t = self._compute_sigma(alpha_t, alpha_t_prev, eta)
188
189        # Direction pointing to x_t
190        direction = torch.sqrt(1 - alpha_t_prev - sigma_t ** 2) * eps_pred
191
192        # Compute x_{t-1}
193        x_prev = torch.sqrt(alpha_t_prev) * x0_pred + direction
194
195        # Add noise if stochastic
196        if eta > 0 and t_prev > 0:
197            noise = torch.randn_like(x_t)
198            x_prev = x_prev + sigma_t * noise
199
200        return x_prev, x0_pred
201
202    def _compute_sigma(
203        self,
204        alpha_t: torch.Tensor,
205        alpha_t_prev: torch.Tensor,
206        eta: float
207    ) -> torch.Tensor:
208        """Compute sigma for stochastic DDIM."""
209        sigma = eta * torch.sqrt(
210            (1 - alpha_t_prev) / (1 - alpha_t) *
211            (1 - alpha_t / alpha_t_prev)
212        )
213        return sigma
214
215
216# Example usage
217def example_ddim_generation():
218    """Demonstrate DDIM generation."""
219    # Assuming model and noise_schedule are defined
220    config = DDIMConfig(
221        num_timesteps=1000,
222        eta=0.0,  # Deterministic
223        clip_denoised=True,
224        clip_range=(-1.0, 1.0)
225    )
226
227    sampler = DDIMSampler(
228        model=model,
229        alphas_cumprod=noise_schedule.alphas_cumprod,
230        config=config,
231        device="cuda"
232    )
233
234    # Generate 4 images with 50 steps
235    samples = sampler.sample(
236        shape=(4, 3, 64, 64),
237        num_steps=50,
238        eta=0.0,
239        progress=True
240    )
241
242    print(f"Generated {samples.shape[0]} images")
243    return samples

DDIM Inversion (Encoding)

One of DDIM's most powerful features is inversion: finding the latent code xT\mathbf{x}_T that generates a given imagex0\mathbf{x}_0. This enables image editing and manipulation.

The Inversion Process

DDIM inversion runs the sampling process backwards in time:

xt+1=αˉt+1x^0+1αˉt+1ϵθ(xt,t)\mathbf{x}_{t+1} = \sqrt{\bar{\alpha}_{t+1}} \hat{\mathbf{x}}_0 + \sqrt{1 - \bar{\alpha}_{t+1}} \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)

Starting from a real image x0\mathbf{x}_0, we iteratively add noise to recover the latent xT\mathbf{x}_T.

🐍python
1class DDIMInverter:
2    """
3    DDIM inversion for encoding images into latent space.
4
5    Given an image x_0, finds the latent x_T such that
6    DDIM(x_T) approximately equals x_0.
7    """
8
9    def __init__(
10        self,
11        model: nn.Module,
12        alphas_cumprod: torch.Tensor,
13        num_timesteps: int = 1000,
14        device: str = "cuda"
15    ):
16        self.model = model
17        self.alphas_cumprod = alphas_cumprod.to(device)
18        self.T = num_timesteps
19        self.device = device
20
21        # Precompute
22        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
23        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)
24
25    @torch.no_grad()
26    def invert(
27        self,
28        x_0: torch.Tensor,
29        num_steps: int = 50,
30        progress: bool = True
31    ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
32        """
33        Invert an image to find its latent code.
34
35        Args:
36            x_0: Input image in [-1, 1], shape (B, C, H, W)
37            num_steps: Number of inversion steps
38            progress: Show progress bar
39
40        Returns:
41            (x_T, trajectory) - latent code and intermediate states
42        """
43        x_0 = x_0.to(self.device)
44
45        # Create forward timesteps (0 -> T)
46        timesteps = self._make_inversion_timesteps(num_steps)
47
48        # Start from x_0
49        x_t = x_0.clone()
50        trajectory = [x_t.clone()]
51
52        self.model.eval()
53
54        iterator = tqdm(range(len(timesteps) - 1), desc="Inverting") if progress else range(len(timesteps) - 1)
55
56        for i in iterator:
57            t = timesteps[i]
58            t_next = timesteps[i + 1]
59
60            x_t = self._inversion_step(x_t, t, t_next, x_0)
61            trajectory.append(x_t.clone())
62
63        return x_t, trajectory
64
65    def _inversion_step(
66        self,
67        x_t: torch.Tensor,
68        t: int,
69        t_next: int,
70        x_0: torch.Tensor
71    ) -> torch.Tensor:
72        """
73        Single DDIM inversion step (forward in time).
74
75        Goes from x_t to x_{t_next} where t_next > t.
76        """
77        batch_size = x_t.shape[0]
78        t_batch = torch.full((batch_size,), t, device=self.device, dtype=torch.long)
79
80        # Predict noise at current timestep
81        eps_pred = self.model(x_t, t_batch)
82
83        # Get alpha values
84        alpha_t = self.alphas_cumprod[t]
85        alpha_t_next = self.alphas_cumprod[t_next]
86
87        # Predict x_0 from current x_t
88        x0_pred = (x_t - self.sqrt_one_minus_alphas_cumprod[t] * eps_pred) / self.sqrt_alphas_cumprod[t]
89
90        # Optional: use actual x_0 for more accurate inversion
91        # This is "guided inversion"
92        # x0_pred = x_0  # Uncomment for guided mode
93
94        # Compute x_{t_next}
95        direction = self.sqrt_one_minus_alphas_cumprod[t_next] * eps_pred
96        x_next = self.sqrt_alphas_cumprod[t_next] * x0_pred + direction
97
98        return x_next
99
100    def _make_inversion_timesteps(self, num_steps: int) -> List[int]:
101        """Create timesteps for inversion (ascending order)."""
102        step_size = self.T // num_steps
103        timesteps = list(range(0, self.T, step_size))
104        if timesteps[-1] != self.T - 1:
105            timesteps.append(self.T - 1)
106        return timesteps
107
108
109def demonstrate_inversion(model, noise_schedule, image):
110    """
111    Demonstrate DDIM inversion and reconstruction.
112    """
113    inverter = DDIMInverter(
114        model=model,
115        alphas_cumprod=noise_schedule.alphas_cumprod,
116        num_timesteps=noise_schedule.T
117    )
118
119    sampler = DDIMSampler(
120        model=model,
121        alphas_cumprod=noise_schedule.alphas_cumprod,
122        config=DDIMConfig(eta=0.0)
123    )
124
125    # Invert the image
126    x_T, inversion_trajectory = inverter.invert(
127        x_0=image,
128        num_steps=50
129    )
130
131    print(f"Inverted image shape: {image.shape}")
132    print(f"Latent code shape: {x_T.shape}")
133
134    # Reconstruct from latent
135    reconstructed = sampler.sample(
136        shape=image.shape,
137        num_steps=50,
138        x_T=x_T,
139        eta=0.0  # Must be deterministic for reconstruction
140    )
141
142    # Compute reconstruction error
143    mse = F.mse_loss(image, reconstructed).item()
144    print(f"Reconstruction MSE: {mse:.6f}")
145
146    return x_T, reconstructed, inversion_trajectory

Inversion Accuracy

DDIM inversion is not perfect. Each step introduces small errors that accumulate. For better reconstruction, use more inversion steps (100-200) or guided inversion techniques that use the original image during the process.

Semantic Interpolation

With deterministic DDIM, we can smoothly interpolate between images by interpolating their latent codes:

xTinterp=(1α)xT(1)+αxT(2)\mathbf{x}_T^{\text{interp}} = (1 - \alpha) \mathbf{x}_T^{(1)} + \alpha \mathbf{x}_T^{(2)}
🐍python
1class DDIMInterpolator:
2    """
3    Semantic interpolation between images using DDIM.
4    """
5
6    def __init__(
7        self,
8        model: nn.Module,
9        alphas_cumprod: torch.Tensor,
10        num_timesteps: int = 1000,
11        device: str = "cuda"
12    ):
13        self.model = model
14        self.device = device
15
16        self.inverter = DDIMInverter(
17            model=model,
18            alphas_cumprod=alphas_cumprod,
19            num_timesteps=num_timesteps,
20            device=device
21        )
22
23        self.sampler = DDIMSampler(
24            model=model,
25            alphas_cumprod=alphas_cumprod,
26            config=DDIMConfig(eta=0.0),
27            device=device
28        )
29
30    @torch.no_grad()
31    def interpolate(
32        self,
33        image1: torch.Tensor,
34        image2: torch.Tensor,
35        num_frames: int = 10,
36        inversion_steps: int = 50,
37        sampling_steps: int = 50,
38        interpolation_type: str = "linear"
39    ) -> List[torch.Tensor]:
40        """
41        Create smooth interpolation between two images.
42
43        Args:
44            image1: First image in [-1, 1]
45            image2: Second image in [-1, 1]
46            num_frames: Number of interpolation frames
47            inversion_steps: Steps for inversion
48            sampling_steps: Steps for generation
49            interpolation_type: "linear" or "spherical"
50
51        Returns:
52            List of interpolated images
53        """
54        # Invert both images to latent space
55        z1, _ = self.inverter.invert(image1, num_steps=inversion_steps)
56        z2, _ = self.inverter.invert(image2, num_steps=inversion_steps)
57
58        # Generate interpolation alphas
59        alphas = np.linspace(0, 1, num_frames)
60
61        interpolated = []
62
63        for alpha in alphas:
64            # Interpolate in latent space
65            if interpolation_type == "linear":
66                z_interp = self._lerp(z1, z2, alpha)
67            elif interpolation_type == "spherical":
68                z_interp = self._slerp(z1, z2, alpha)
69            else:
70                raise ValueError(f"Unknown interpolation: {interpolation_type}")
71
72            # Generate from interpolated latent
73            sample = self.sampler.sample(
74                shape=z_interp.shape,
75                num_steps=sampling_steps,
76                x_T=z_interp,
77                progress=False
78            )
79
80            interpolated.append(sample)
81
82        return interpolated
83
84    def _lerp(
85        self,
86        z1: torch.Tensor,
87        z2: torch.Tensor,
88        alpha: float
89    ) -> torch.Tensor:
90        """Linear interpolation."""
91        return (1 - alpha) * z1 + alpha * z2
92
93    def _slerp(
94        self,
95        z1: torch.Tensor,
96        z2: torch.Tensor,
97        alpha: float
98    ) -> torch.Tensor:
99        """
100        Spherical linear interpolation.
101
102        Better for high-dimensional spaces like latent codes.
103        """
104        # Normalize
105        z1_flat = z1.flatten()
106        z2_flat = z2.flatten()
107
108        # Compute angle
109        dot = torch.dot(z1_flat, z2_flat)
110        dot = dot / (torch.norm(z1_flat) * torch.norm(z2_flat))
111        dot = torch.clamp(dot, -1, 1)
112        omega = torch.acos(dot)
113
114        # Handle edge case
115        if torch.abs(omega) < 1e-10:
116            return self._lerp(z1, z2, alpha)
117
118        # Spherical interpolation
119        sin_omega = torch.sin(omega)
120        s1 = torch.sin((1 - alpha) * omega) / sin_omega
121        s2 = torch.sin(alpha * omega) / sin_omega
122
123        return s1 * z1 + s2 * z2
124
125
126def create_interpolation_video(model, noise_schedule, image1, image2):
127    """
128    Create smooth interpolation between two images.
129    """
130    interpolator = DDIMInterpolator(
131        model=model,
132        alphas_cumprod=noise_schedule.alphas_cumprod
133    )
134
135    frames = interpolator.interpolate(
136        image1=image1,
137        image2=image2,
138        num_frames=30,
139        interpolation_type="spherical"  # Better for latent spaces
140    )
141
142    print(f"Generated {len(frames)} interpolation frames")
143
144    # Convert to video-ready format
145    frames_tensor = torch.stack([f.squeeze(0) for f in frames])
146    frames_tensor = (frames_tensor + 1) / 2  # [-1, 1] -> [0, 1]
147
148    return frames_tensor

Spherical vs Linear Interpolation

For high-dimensional latent spaces, spherical interpolation (slerp)often produces smoother transitions. It interpolates along the geodesic on the hypersphere rather than cutting through it.

Image Editing via Latent Space

DDIM inversion enables powerful image editing by manipulating the latent representation:

🐍python
1class DDIMImageEditor:
2    """
3    Edit images by manipulating their latent representations.
4    """
5
6    def __init__(
7        self,
8        model: nn.Module,
9        alphas_cumprod: torch.Tensor,
10        num_timesteps: int = 1000,
11        device: str = "cuda"
12    ):
13        self.model = model
14        self.device = device
15
16        self.inverter = DDIMInverter(
17            model=model,
18            alphas_cumprod=alphas_cumprod,
19            num_timesteps=num_timesteps,
20            device=device
21        )
22
23        self.sampler = DDIMSampler(
24            model=model,
25            alphas_cumprod=alphas_cumprod,
26            config=DDIMConfig(eta=0.0),
27            device=device
28        )
29
30    @torch.no_grad()
31    def edit_by_noise_injection(
32        self,
33        image: torch.Tensor,
34        noise_scale: float = 0.3,
35        edit_timestep: int = 200,
36        inversion_steps: int = 50,
37        sampling_steps: int = 50
38    ) -> torch.Tensor:
39        """
40        Edit image by injecting noise at an intermediate timestep.
41
42        This creates variations that preserve overall structure.
43
44        Args:
45            image: Input image in [-1, 1]
46            noise_scale: Amount of noise to inject (0-1)
47            edit_timestep: Timestep at which to inject noise
48            inversion_steps: Steps for inversion
49            sampling_steps: Steps for generation
50
51        Returns:
52            Edited image
53        """
54        # Invert to full latent trajectory
55        _, trajectory = self.inverter.invert(
56            image,
57            num_steps=inversion_steps
58        )
59
60        # Find the state at edit_timestep
61        step_size = self.inverter.T // inversion_steps
62        edit_index = edit_timestep // step_size
63        edit_index = min(edit_index, len(trajectory) - 1)
64
65        x_edit = trajectory[edit_index].clone()
66
67        # Inject noise
68        noise = torch.randn_like(x_edit) * noise_scale
69        x_edit = x_edit + noise
70
71        # Sample from edited point
72        edited = self.sampler.sample(
73            shape=x_edit.shape,
74            num_steps=sampling_steps - edit_index,
75            x_T=x_edit,
76            progress=False
77        )
78
79        return edited
80
81    @torch.no_grad()
82    def edit_by_direction(
83        self,
84        image: torch.Tensor,
85        direction: torch.Tensor,
86        strength: float = 1.0,
87        inversion_steps: int = 100,
88        sampling_steps: int = 50
89    ) -> torch.Tensor:
90        """
91        Edit image by moving in a semantic direction.
92
93        The direction can be learned from pairs of images
94        (e.g., smiling vs not smiling).
95
96        Args:
97            image: Input image in [-1, 1]
98            direction: Edit direction in latent space
99            strength: How far to move along direction
100            inversion_steps: Steps for inversion
101            sampling_steps: Steps for generation
102
103        Returns:
104            Edited image
105        """
106        # Invert
107        x_T, _ = self.inverter.invert(image, num_steps=inversion_steps)
108
109        # Apply direction
110        x_T_edited = x_T + strength * direction.to(self.device)
111
112        # Regenerate
113        edited = self.sampler.sample(
114            shape=x_T_edited.shape,
115            num_steps=sampling_steps,
116            x_T=x_T_edited,
117            progress=False
118        )
119
120        return edited
121
122    @torch.no_grad()
123    def blend_images(
124        self,
125        images: List[torch.Tensor],
126        weights: List[float],
127        inversion_steps: int = 100,
128        sampling_steps: int = 50
129    ) -> torch.Tensor:
130        """
131        Blend multiple images by combining their latent codes.
132
133        Args:
134            images: List of images in [-1, 1]
135            weights: Blending weights (should sum to 1)
136            inversion_steps: Steps for inversion
137            sampling_steps: Steps for generation
138
139        Returns:
140            Blended image
141        """
142        assert len(images) == len(weights)
143        assert abs(sum(weights) - 1.0) < 1e-6
144
145        # Invert all images
146        latents = []
147        for img in images:
148            z, _ = self.inverter.invert(img, num_steps=inversion_steps)
149            latents.append(z)
150
151        # Weighted combination
152        x_T_blend = sum(w * z for w, z in zip(weights, latents))
153
154        # Generate
155        blended = self.sampler.sample(
156            shape=x_T_blend.shape,
157            num_steps=sampling_steps,
158            x_T=x_T_blend,
159            progress=False
160        )
161
162        return blended
163
164
165def learn_edit_direction(
166    model,
167    noise_schedule,
168    positive_images: List[torch.Tensor],
169    negative_images: List[torch.Tensor],
170    inversion_steps: int = 100
171) -> torch.Tensor:
172    """
173    Learn a semantic edit direction from image pairs.
174
175    Example: positive = smiling faces, negative = neutral faces
176    The resulting direction can be used to add/remove smiles.
177    """
178    inverter = DDIMInverter(
179        model=model,
180        alphas_cumprod=noise_schedule.alphas_cumprod
181    )
182
183    # Invert positive examples
184    positive_latents = []
185    for img in positive_images:
186        z, _ = inverter.invert(img, num_steps=inversion_steps)
187        positive_latents.append(z)
188
189    # Invert negative examples
190    negative_latents = []
191    for img in negative_images:
192        z, _ = inverter.invert(img, num_steps=inversion_steps)
193        negative_latents.append(z)
194
195    # Compute mean difference
196    pos_mean = torch.stack(positive_latents).mean(dim=0)
197    neg_mean = torch.stack(negative_latents).mean(dim=0)
198
199    direction = pos_mean - neg_mean
200
201    # Normalize
202    direction = direction / torch.norm(direction)
203
204    return direction

Edit Direction Discovery

The most common way to find edit directions is through paired examples. By averaging the latent difference between "before" and "after" examples, we can discover semantic directions like "add smile", "add glasses", or "make older".

Implementation Optimizations

Here are key optimizations for production DDIM:

1. Batched Sampling

🐍python
1def optimized_batch_sample(
2    sampler: DDIMSampler,
3    num_samples: int,
4    batch_size: int = 16,
5    num_steps: int = 50
6) -> torch.Tensor:
7    """
8    Generate many samples efficiently with batching.
9    """
10    all_samples = []
11    num_batches = (num_samples + batch_size - 1) // batch_size
12
13    for i in range(num_batches):
14        current_batch = min(batch_size, num_samples - i * batch_size)
15
16        samples = sampler.sample(
17            shape=(current_batch, 3, 64, 64),
18            num_steps=num_steps,
19            progress=False
20        )
21        all_samples.append(samples)
22
23    return torch.cat(all_samples, dim=0)[:num_samples]

2. Mixed Precision Sampling

🐍python
1class DDIMSamplerAMP(DDIMSampler):
2    """DDIM with automatic mixed precision for faster sampling."""
3
4    @torch.no_grad()
5    def sample(self, shape, num_steps=50, eta=None, x_T=None, progress=True, **kwargs):
6        """Sample with AMP for speed."""
7        eta = eta if eta is not None else self.config.eta
8
9        if x_T is None:
10            x_t = torch.randn(shape, device=self.device)
11        else:
12            x_t = x_T.to(self.device)
13
14        timesteps = self.make_timesteps(num_steps)
15
16        self.model.eval()
17
18        # Use autocast for faster inference
19        with torch.amp.autocast('cuda'):
20            for i in (tqdm(range(len(timesteps))) if progress else range(len(timesteps))):
21                t = timesteps[i]
22                t_prev = timesteps[i + 1] if i + 1 < len(timesteps) else torch.tensor(0)
23                x_t, _ = self._ddim_step(x_t, t, t_prev, eta)
24
25        return x_t

3. Compiled Model

🐍python
1# PyTorch 2.0+ compilation for faster inference
2def compile_for_sampling(model: nn.Module) -> nn.Module:
3    """Compile model for faster DDIM sampling."""
4    return torch.compile(
5        model,
6        mode="reduce-overhead",  # Optimize for repeated calls
7        fullgraph=True
8    )
9
10# Usage
11compiled_model = compile_for_sampling(model)
12sampler = DDIMSampler(
13    model=compiled_model,
14    alphas_cumprod=noise_schedule.alphas_cumprod
15)
16
17# First sample is slow (compilation), subsequent are fast
18samples = sampler.sample(shape=(4, 3, 64, 64), num_steps=50)

4. Caching for Repeated Sampling

🐍python
1class CachedDDIMSampler(DDIMSampler):
2    """
3    DDIM sampler with caching for repeated operations.
4    """
5
6    def __init__(self, *args, **kwargs):
7        super().__init__(*args, **kwargs)
8        self._timestep_cache = {}
9
10    def make_timesteps(self, num_steps, schedule="uniform"):
11        """Cached timestep generation."""
12        cache_key = (num_steps, schedule)
13        if cache_key not in self._timestep_cache:
14            self._timestep_cache[cache_key] = super().make_timesteps(num_steps, schedule)
15        return self._timestep_cache[cache_key]
16
17    @torch.no_grad()
18    def sample_multiple_from_same_noise(
19        self,
20        x_T: torch.Tensor,
21        num_steps_list: List[int],
22    ) -> dict:
23        """
24        Generate at multiple step counts from same noise.
25
26        Useful for quality comparison experiments.
27        """
28        results = {}
29
30        for num_steps in sorted(num_steps_list, reverse=True):
31            samples = self.sample(
32                shape=x_T.shape,
33                num_steps=num_steps,
34                x_T=x_T.clone(),
35                progress=False
36            )
37            results[num_steps] = samples
38
39        return results
OptimizationSpeedupMemoryNotes
Mixed Precision1.5-2x50% lessMinimal quality loss
Torch Compile1.3-2xSamePyTorch 2.0+ only
BatchingLinearLinear increaseGPU utilization
CachingVariableSlight increaseRepeated timesteps

Summary

We've built a complete DDIM toolkit with:

  1. Full-featured sampler with configurable eta, progress tracking, and multiple timestep schedules
  2. DDIM inversion for encoding real images into latent space, enabling reconstruction and manipulation
  3. Semantic interpolation with both linear and spherical methods for smooth transitions between images
  4. Image editing capabilities including noise injection, directional editing, and multi-image blending
  5. Production optimizations for speed and efficiency

Coming Up Next

In the next section, we'll explore advanced samplers that go beyond DDIM: DPM-Solver for even faster sampling, and various ancestral sampling variants that balance speed, quality, and diversity.

The DDIM framework we've built here forms the foundation for modern diffusion model applications. Whether you're building an image generation service, an editing tool, or a creative application, these components provide the building blocks you need.