Chapter 12
15 min read
Section 55 of 76

Generating Images

Generation and Evaluation

Learning Objectives

By the end of this section, you will be able to:

  1. Implement the complete DDPM sampling algorithm from pure noise to clean images
  2. Apply DDIM for accelerated deterministic sampling with fewer steps
  3. Configure sampling parameters for optimal quality-speed tradeoffs
  4. Generate large batches of images efficiently using GPU parallelism

The Big Picture

Image generation in diffusion models is the reverse of the forward process: we start from pure Gaussian noise and iteratively denoise to produce clean images. The trained model predicts the noise at each step, allowing us to reconstruct the original signal.

The Sampling Process: Starting from xTN(0,I)x_T \sim \mathcal{N}(0, I), we iteratively compute xt1x_{t-1} from xtx_t using the model's noise prediction. After TT steps, we obtain a clean sample x0x_0.
Step-by-Step Denoising Animation

Watch the diffusion model progressively denoise a 1D signal from pure noise to clean data

Noise level: 100%
+10-1Current x_tTarget x_0Predicted x_0
Current Timestep
t = 20
Pure noise
Noise Schedule
alpha_bar = 0.000
Signal: 0% | Noise: 100%
Sampling Mode
DDPM
Stochastic reverse process

What's Happening at Each Step?

1. Predict Noise (or x0)

The neural network takes the current noisy signal x_t and timestep t, then predicts either the noise epsilon or the clean signal x_0.

2. Compute Mean

Using the prediction, we compute the mean of p(x_{t-1}|x_t), which tells us where to move in the next step.

3. Add Noise (DDPM) or Not (DDIM)

DDPM adds scaled Gaussian noise to maintain the Markov property. DDIM skips this for deterministic, faster sampling.

4. Update x

We update x_{t-1} = mean + sigma * z (DDPM) or x_{t-1} = mean (DDIM), moving one step closer to the clean data.

Key Equations

DDPM Reverse Step:
x_{t-1} = (1/sqrt(alpha_t)) * (x_t - (beta_t/sqrt(1-alpha_bar_t)) * epsilon_theta) + sigma_t * z
DDIM Reverse Step:
x_{t-1} = sqrt(alpha_bar_{t-1}) * pred_x0 + sqrt(1-alpha_bar_{t-1}) * direction_pointing_to_x_t
MethodStepsStochasticQualitySpeed
DDPM1000YesBestSlow
DDIM50-100NoNear DDPMFast
DDIM (eta=1)50-100YesGoodFast
DPM-Solver10-20NoGoodVery Fast

DDPM Sampling

The DDPM Reverse Process

The DDPM sampling follows the reverse process derived from the variational lower bound. At each step, we predict the noise and compute the mean and variance of the reverse distribution:

xt1=1αt(xtβt1αˉtϵθ(xt,t))+σtzx_{t-1} = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{\beta_t}{\sqrt{1-\bar{\alpha}_t}}\epsilon_\theta(x_t, t)\right) + \sigma_t z

where zN(0,I)z \sim \mathcal{N}(0, I) and σt=βt\sigma_t = \sqrt{\beta_t} or σt=β~t\sigma_t = \sqrt{\tilde{\beta}_t}.

🐍python
1import torch
2import torch.nn as nn
3from typing import Optional, Callable
4from tqdm import tqdm
5
6class DDPMSampler:
7    """DDPM sampling with various noise schedule options."""
8
9    def __init__(
10        self,
11        model: nn.Module,
12        betas: torch.Tensor,
13        device: str = "cuda",
14    ):
15        self.model = model.to(device).eval()
16        self.device = device
17
18        # Precompute coefficients
19        self.betas = betas.to(device)
20        self.alphas = 1.0 - self.betas
21        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
22        self.alphas_cumprod_prev = torch.cat([
23            torch.tensor([1.0], device=device),
24            self.alphas_cumprod[:-1]
25        ])
26
27        # Coefficients for sampling
28        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
29        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
30
31        # Posterior variance options
32        self.posterior_variance = (
33            self.betas * (1.0 - self.alphas_cumprod_prev) /
34            (1.0 - self.alphas_cumprod)
35        )
36
37        self.timesteps = len(betas)
38
39    @torch.no_grad()
40    def sample(
41        self,
42        shape: tuple,
43        num_steps: Optional[int] = None,
44        return_intermediates: bool = False,
45        progress_callback: Optional[Callable] = None,
46    ) -> torch.Tensor:
47        """Generate samples using DDPM reverse process.
48
49        Args:
50            shape: Output shape (batch_size, channels, height, width)
51            num_steps: Number of steps (defaults to full timesteps)
52            return_intermediates: Whether to return intermediate samples
53            progress_callback: Optional callback for progress updates
54
55        Returns:
56            Generated samples (and intermediates if requested)
57        """
58        if num_steps is None:
59            num_steps = self.timesteps
60
61        # Start from pure noise
62        x = torch.randn(shape, device=self.device)
63
64        intermediates = [x.clone()] if return_intermediates else None
65
66        # Reverse process
67        timesteps = torch.linspace(
68            self.timesteps - 1, 0, num_steps, dtype=torch.long, device=self.device
69        )
70
71        for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
72            t_batch = torch.full((shape[0],), t, device=self.device, dtype=torch.long)
73
74            # Predict noise
75            pred_noise = self.model(x, t_batch)
76
77            # Compute mean
78            alpha = self.alphas[t]
79            alpha_bar = self.alphas_cumprod[t]
80            beta = self.betas[t]
81
82            mean = self.sqrt_recip_alphas[t] * (
83                x - beta / self.sqrt_one_minus_alphas_cumprod[t] * pred_noise
84            )
85
86            # Add noise (except for t=0)
87            if t > 0:
88                noise = torch.randn_like(x)
89                sigma = torch.sqrt(self.posterior_variance[t])
90                x = mean + sigma * noise
91            else:
92                x = mean
93
94            if return_intermediates and i % (num_steps // 10) == 0:
95                intermediates.append(x.clone())
96
97            if progress_callback:
98                progress_callback(i / num_steps)
99
100        # Final clipping
101        x = x.clamp(-1, 1)
102
103        if return_intermediates:
104            return x, intermediates
105        return x
106
107    @torch.no_grad()
108    def sample_with_guidance(
109        self,
110        shape: tuple,
111        condition: torch.Tensor,
112        guidance_scale: float = 7.5,
113        num_steps: Optional[int] = None,
114    ) -> torch.Tensor:
115        """Classifier-free guidance sampling."""
116        if num_steps is None:
117            num_steps = self.timesteps
118
119        x = torch.randn(shape, device=self.device)
120        batch_size = shape[0]
121
122        timesteps = torch.linspace(
123            self.timesteps - 1, 0, num_steps, dtype=torch.long, device=self.device
124        )
125
126        for t in tqdm(timesteps, desc="Sampling with guidance"):
127            t_batch = torch.full((batch_size,), t, device=self.device, dtype=torch.long)
128
129            # Conditional prediction
130            pred_cond = self.model(x, t_batch, condition)
131
132            # Unconditional prediction (null condition)
133            null_condition = torch.zeros_like(condition)
134            pred_uncond = self.model(x, t_batch, null_condition)
135
136            # Classifier-free guidance
137            pred_noise = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
138
139            # DDPM update
140            alpha = self.alphas[t]
141            alpha_bar = self.alphas_cumprod[t]
142            beta = self.betas[t]
143
144            mean = self.sqrt_recip_alphas[t] * (
145                x - beta / self.sqrt_one_minus_alphas_cumprod[t] * pred_noise
146            )
147
148            if t > 0:
149                noise = torch.randn_like(x)
150                sigma = torch.sqrt(self.posterior_variance[t])
151                x = mean + sigma * noise
152            else:
153                x = mean
154
155        return x.clamp(-1, 1)

Posterior Variance Choice

DDPM defines two variance options: σt2=βt\sigma_t^2 = \beta_t (forward process variance) or σt2=β~t\sigma_t^2 = \tilde{\beta}_t (posterior variance). The latter gives slightly better sample quality for most cases.

DDIM Sampling

Deterministic Sampling

DDIM (Denoising Diffusion Implicit Models) enables deterministic sampling and allows using fewer steps without retraining. The key insight is that the same trained model can be used with a non-Markovian reverse process:

xt1=αˉt1(xt1αˉtϵθ(xt,t)αˉt)predicted x0+1αˉt1σt2ϵθ(xt,t)+σtϵtx_{t-1} = \sqrt{\bar{\alpha}_{t-1}} \underbrace{\left(\frac{x_t - \sqrt{1-\bar{\alpha}_t}\epsilon_\theta(x_t, t)}{\sqrt{\bar{\alpha}_t}}\right)}_{\text{predicted } x_0} + \sqrt{1 - \bar{\alpha}_{t-1} - \sigma_t^2} \cdot \epsilon_\theta(x_t, t) + \sigma_t \epsilon_t
🐍python
1import torch
2import torch.nn as nn
3from typing import Optional, List
4import numpy as np
5from tqdm import tqdm
6
7class DDIMSampler:
8    """DDIM sampler for fast, deterministic generation."""
9
10    def __init__(
11        self,
12        model: nn.Module,
13        alphas_cumprod: torch.Tensor,
14        device: str = "cuda",
15    ):
16        self.model = model.to(device).eval()
17        self.device = device
18        self.alphas_cumprod = alphas_cumprod.to(device)
19        self.total_timesteps = len(alphas_cumprod)
20
21    def _get_timesteps(self, num_steps: int) -> torch.Tensor:
22        """Get evenly spaced timesteps for sampling."""
23        # Use linspace for even spacing
24        timesteps = torch.linspace(
25            0, self.total_timesteps - 1, num_steps, dtype=torch.long
26        )
27        return timesteps.flip(0).to(self.device)
28
29    def _get_alpha_bars(
30        self,
31        timesteps: torch.Tensor
32    ) -> tuple:
33        """Get alpha_bar values for given timesteps."""
34        alpha_bar = self.alphas_cumprod[timesteps]
35        alpha_bar_prev = torch.cat([
36            torch.tensor([1.0], device=self.device),
37            self.alphas_cumprod[timesteps[:-1]]
38        ])
39        return alpha_bar, alpha_bar_prev
40
41    @torch.no_grad()
42    def sample(
43        self,
44        shape: tuple,
45        num_steps: int = 50,
46        eta: float = 0.0,  # 0 = deterministic, 1 = DDPM-like stochasticity
47        return_intermediates: bool = False,
48    ) -> torch.Tensor:
49        """Generate samples using DDIM.
50
51        Args:
52            shape: Output shape (batch_size, channels, height, width)
53            num_steps: Number of sampling steps (can be much less than training steps)
54            eta: Controls stochasticity (0 = deterministic, 1 = full noise)
55            return_intermediates: Return intermediate samples
56
57        Returns:
58            Generated samples
59        """
60        batch_size = shape[0]
61
62        # Get timestep schedule
63        timesteps = self._get_timesteps(num_steps)
64
65        # Start from noise
66        x = torch.randn(shape, device=self.device)
67
68        intermediates = [x.clone()] if return_intermediates else None
69
70        for i, t in enumerate(tqdm(timesteps, desc=f"DDIM ({num_steps} steps)")):
71            t_batch = torch.full((batch_size,), t, device=self.device, dtype=torch.long)
72
73            # Get alpha values
74            alpha_bar = self.alphas_cumprod[t]
75
76            if i < len(timesteps) - 1:
77                alpha_bar_prev = self.alphas_cumprod[timesteps[i + 1]]
78            else:
79                alpha_bar_prev = torch.tensor(1.0, device=self.device)
80
81            # Predict noise
82            pred_noise = self.model(x, t_batch)
83
84            # Predict x0
85            pred_x0 = (x - torch.sqrt(1 - alpha_bar) * pred_noise) / torch.sqrt(alpha_bar)
86            pred_x0 = pred_x0.clamp(-1, 1)  # Clamp for stability
87
88            # Compute sigma for stochasticity
89            sigma = eta * torch.sqrt(
90                (1 - alpha_bar_prev) / (1 - alpha_bar) *
91                (1 - alpha_bar / alpha_bar_prev)
92            )
93
94            # Direction pointing to x_t
95            dir_xt = torch.sqrt(1 - alpha_bar_prev - sigma**2) * pred_noise
96
97            # Add noise if eta > 0
98            if eta > 0 and i < len(timesteps) - 1:
99                noise = torch.randn_like(x)
100                x = torch.sqrt(alpha_bar_prev) * pred_x0 + dir_xt + sigma * noise
101            else:
102                x = torch.sqrt(alpha_bar_prev) * pred_x0 + dir_xt
103
104            if return_intermediates and i % (num_steps // 10) == 0:
105                intermediates.append(x.clone())
106
107        x = x.clamp(-1, 1)
108
109        if return_intermediates:
110            return x, intermediates
111        return x
112
113    @torch.no_grad()
114    def sample_interpolation(
115        self,
116        x0_start: torch.Tensor,
117        x0_end: torch.Tensor,
118        num_interpolations: int = 10,
119        num_steps: int = 50,
120    ) -> List[torch.Tensor]:
121        """Generate interpolations between two images via latent space.
122
123        Args:
124            x0_start: Starting image
125            x0_end: Ending image
126            num_interpolations: Number of interpolation steps
127            num_steps: Sampling steps for each interpolation
128
129        Returns:
130            List of interpolated images
131        """
132        # Encode both images to latent (add noise up to step T)
133        timesteps = self._get_timesteps(num_steps)
134        T = timesteps[0]
135
136        # Add noise to both images
137        noise_start = torch.randn_like(x0_start)
138        noise_end = torch.randn_like(x0_end)
139
140        alpha_bar_T = self.alphas_cumprod[T]
141        xT_start = torch.sqrt(alpha_bar_T) * x0_start + torch.sqrt(1 - alpha_bar_T) * noise_start
142        xT_end = torch.sqrt(alpha_bar_T) * x0_end + torch.sqrt(1 - alpha_bar_T) * noise_end
143
144        # Generate interpolations
145        interpolations = []
146        for i in range(num_interpolations):
147            alpha = i / (num_interpolations - 1)
148
149            # Spherical interpolation in latent space
150            xT_interp = self._slerp(xT_start, xT_end, alpha)
151
152            # Denoise
153            sample = self._denoise_from(xT_interp, timesteps)
154            interpolations.append(sample)
155
156        return interpolations
157
158    def _slerp(
159        self,
160        x0: torch.Tensor,
161        x1: torch.Tensor,
162        alpha: float,
163    ) -> torch.Tensor:
164        """Spherical linear interpolation."""
165        # Flatten for computation
166        x0_flat = x0.view(x0.shape[0], -1)
167        x1_flat = x1.view(x1.shape[0], -1)
168
169        # Normalize
170        x0_norm = x0_flat / x0_flat.norm(dim=-1, keepdim=True)
171        x1_norm = x1_flat / x1_flat.norm(dim=-1, keepdim=True)
172
173        # Compute angle
174        dot = (x0_norm * x1_norm).sum(dim=-1, keepdim=True)
175        omega = torch.acos(dot.clamp(-1, 1))
176
177        # Interpolate
178        sin_omega = torch.sin(omega)
179        if sin_omega.abs().min() < 1e-6:
180            # Fall back to linear interpolation
181            result = (1 - alpha) * x0_flat + alpha * x1_flat
182        else:
183            result = (
184                torch.sin((1 - alpha) * omega) / sin_omega * x0_flat +
185                torch.sin(alpha * omega) / sin_omega * x1_flat
186            )
187
188        return result.view(x0.shape)
189
190    def _denoise_from(
191        self,
192        xT: torch.Tensor,
193        timesteps: torch.Tensor,
194    ) -> torch.Tensor:
195        """Denoise from a given latent."""
196        x = xT.clone()
197        batch_size = x.shape[0]
198
199        for i, t in enumerate(timesteps):
200            t_batch = torch.full((batch_size,), t, device=self.device, dtype=torch.long)
201
202            alpha_bar = self.alphas_cumprod[t]
203            alpha_bar_prev = (
204                self.alphas_cumprod[timesteps[i + 1]]
205                if i < len(timesteps) - 1
206                else torch.tensor(1.0, device=self.device)
207            )
208
209            pred_noise = self.model(x, t_batch)
210            pred_x0 = (x - torch.sqrt(1 - alpha_bar) * pred_noise) / torch.sqrt(alpha_bar)
211            pred_x0 = pred_x0.clamp(-1, 1)
212
213            dir_xt = torch.sqrt(1 - alpha_bar_prev) * pred_noise
214            x = torch.sqrt(alpha_bar_prev) * pred_x0 + dir_xt
215
216        return x.clamp(-1, 1)
Eta ValueBehaviorUse Case
0.0Fully deterministicReproducible results, interpolation
0.5Mild stochasticityBalance of diversity and quality
1.0DDPM-equivalent noiseMaximum diversity

Sampling Strategies

Timestep Scheduling

The choice of timestep schedule significantly affects sample quality, especially when using fewer steps:

🐍python
1import torch
2import numpy as np
3from typing import Literal
4
5class TimestepScheduler:
6    """Various timestep scheduling strategies for sampling."""
7
8    @staticmethod
9    def uniform(
10        total_timesteps: int,
11        num_steps: int,
12    ) -> torch.Tensor:
13        """Uniform spacing - simple but not optimal."""
14        return torch.linspace(total_timesteps - 1, 0, num_steps).long()
15
16    @staticmethod
17    def quadratic(
18        total_timesteps: int,
19        num_steps: int,
20    ) -> torch.Tensor:
21        """Quadratic spacing - more steps at high noise levels."""
22        t = torch.linspace(0, 1, num_steps)
23        timesteps = (t ** 2) * (total_timesteps - 1)
24        return timesteps.flip(0).long()
25
26    @staticmethod
27    def trailing(
28        total_timesteps: int,
29        num_steps: int,
30    ) -> torch.Tensor:
31        """Trailing - align with original timesteps."""
32        step_ratio = total_timesteps // num_steps
33        timesteps = torch.arange(0, total_timesteps, step_ratio)[:num_steps]
34        return timesteps.flip(0).long()
35
36    @staticmethod
37    def leading(
38        total_timesteps: int,
39        num_steps: int,
40    ) -> torch.Tensor:
41        """Leading - start from near max timestep."""
42        step_ratio = total_timesteps // num_steps
43        timesteps = torch.arange(step_ratio - 1, total_timesteps, step_ratio)[:num_steps]
44        return timesteps.flip(0).long()
45
46    @staticmethod
47    def karras(
48        total_timesteps: int,
49        num_steps: int,
50        sigma_min: float = 0.002,
51        sigma_max: float = 80.0,
52        rho: float = 7.0,
53    ) -> torch.Tensor:
54        """Karras et al. schedule - optimal for diffusion models."""
55        ramp = torch.linspace(0, 1, num_steps)
56        min_inv_rho = sigma_min ** (1 / rho)
57        max_inv_rho = sigma_max ** (1 / rho)
58        sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
59
60        # Convert sigmas to timesteps
61        timesteps = ((sigmas / sigma_max) * (total_timesteps - 1)).long()
62        return timesteps
63
64
65def compare_schedules():
66    """Compare different schedules visually."""
67    import matplotlib.pyplot as plt
68
69    total_t = 1000
70    num_steps = 50
71
72    schedules = {
73        "Uniform": TimestepScheduler.uniform(total_t, num_steps),
74        "Quadratic": TimestepScheduler.quadratic(total_t, num_steps),
75        "Trailing": TimestepScheduler.trailing(total_t, num_steps),
76        "Karras": TimestepScheduler.karras(total_t, num_steps),
77    }
78
79    fig, ax = plt.subplots(figsize=(10, 6))
80    for name, timesteps in schedules.items():
81        ax.plot(range(num_steps), timesteps.numpy(), label=name, marker=".")
82
83    ax.set_xlabel("Step")
84    ax.set_ylabel("Timestep")
85    ax.set_title("Timestep Schedule Comparison")
86    ax.legend()
87    ax.grid(True, alpha=0.3)
88    plt.show()

Noise Level Strategies

🐍python
1import torch
2from typing import Optional
3
4class NoiseStrategies:
5    """Different strategies for initial noise in sampling."""
6
7    @staticmethod
8    def standard_normal(
9        shape: tuple,
10        device: str = "cuda",
11        seed: Optional[int] = None,
12    ) -> torch.Tensor:
13        """Standard Gaussian noise."""
14        if seed is not None:
15            torch.manual_seed(seed)
16        return torch.randn(shape, device=device)
17
18    @staticmethod
19    def truncated_normal(
20        shape: tuple,
21        truncation: float = 2.0,
22        device: str = "cuda",
23        seed: Optional[int] = None,
24    ) -> torch.Tensor:
25        """Truncated Gaussian for more consistent results."""
26        if seed is not None:
27            torch.manual_seed(seed)
28
29        noise = torch.randn(shape, device=device)
30        # Clamp to truncation range
31        noise = noise.clamp(-truncation, truncation)
32        return noise
33
34    @staticmethod
35    def pyramid_noise(
36        shape: tuple,
37        discount: float = 0.8,
38        device: str = "cuda",
39    ) -> torch.Tensor:
40        """Multi-scale pyramid noise for better structure."""
41        batch, channels, height, width = shape
42
43        # Generate noise at multiple scales
44        noise = torch.zeros(shape, device=device)
45        scale = 1.0
46
47        h, w = height, width
48        while h >= 4 and w >= 4:
49            # Generate noise at this scale
50            scale_noise = torch.randn(batch, channels, h, w, device=device)
51
52            # Upscale to full resolution
53            if h != height or w != width:
54                scale_noise = torch.nn.functional.interpolate(
55                    scale_noise,
56                    size=(height, width),
57                    mode="bilinear",
58                    align_corners=False,
59                )
60
61            noise = noise + scale * scale_noise
62            scale *= discount
63            h, w = h // 2, w // 2
64
65        # Normalize
66        noise = noise / noise.std()
67        return noise
68
69    @staticmethod
70    def latent_blend(
71        shape: tuple,
72        reference: torch.Tensor,
73        blend_strength: float = 0.3,
74        device: str = "cuda",
75    ) -> torch.Tensor:
76        """Blend random noise with reference latent for style transfer."""
77        random_noise = torch.randn(shape, device=device)
78        # reference should be encoded image latent
79        blended = (1 - blend_strength) * random_noise + blend_strength * reference
80        return blended

Batch Generation

🐍python
1import torch
2import torch.nn as nn
3from pathlib import Path
4from typing import Optional, Dict, Any
5from tqdm import tqdm
6import torchvision.utils as vutils
7import json
8
9class ImageGenerator:
10    """Production-ready image generation with batching and saving."""
11
12    def __init__(
13        self,
14        model: nn.Module,
15        sampler,  # DDPMSampler or DDIMSampler
16        output_dir: str,
17        device: str = "cuda",
18    ):
19        self.model = model.to(device).eval()
20        self.sampler = sampler
21        self.output_dir = Path(output_dir)
22        self.output_dir.mkdir(parents=True, exist_ok=True)
23        self.device = device
24
25    def generate_batch(
26        self,
27        num_images: int,
28        batch_size: int = 32,
29        image_size: int = 64,
30        channels: int = 3,
31        num_steps: int = 50,
32        save_individual: bool = True,
33        save_grid: bool = True,
34        grid_nrow: int = 8,
35        metadata: Optional[Dict[str, Any]] = None,
36    ) -> torch.Tensor:
37        """Generate a batch of images with progress tracking.
38
39        Args:
40            num_images: Total number of images to generate
41            batch_size: Batch size for GPU
42            image_size: Output image size
43            channels: Number of image channels
44            num_steps: Sampling steps
45            save_individual: Save each image separately
46            save_grid: Save images as a grid
47            grid_nrow: Number of images per row in grid
48            metadata: Optional metadata to save
49
50        Returns:
51            All generated images as a tensor
52        """
53        all_samples = []
54        num_batches = (num_images + batch_size - 1) // batch_size
55
56        print(f"Generating {num_images} images in {num_batches} batches...")
57
58        for batch_idx in tqdm(range(num_batches), desc="Generating"):
59            # Calculate batch size for this iteration
60            current_batch_size = min(
61                batch_size,
62                num_images - batch_idx * batch_size
63            )
64
65            shape = (current_batch_size, channels, image_size, image_size)
66
67            # Generate samples
68            samples = self.sampler.sample(shape, num_steps=num_steps)
69            all_samples.append(samples.cpu())
70
71            # Save individual images
72            if save_individual:
73                for i, sample in enumerate(samples):
74                    img_idx = batch_idx * batch_size + i
75                    self._save_image(
76                        sample,
77                        self.output_dir / f"sample_{img_idx:05d}.png"
78                    )
79
80        # Concatenate all samples
81        all_samples = torch.cat(all_samples, dim=0)
82
83        # Save grid
84        if save_grid:
85            grid = vutils.make_grid(
86                all_samples,
87                nrow=grid_nrow,
88                normalize=True,
89                value_range=(-1, 1),
90            )
91            vutils.save_image(
92                grid,
93                self.output_dir / "samples_grid.png"
94            )
95
96        # Save metadata
97        if metadata is not None:
98            metadata.update({
99                "num_images": num_images,
100                "image_size": image_size,
101                "num_steps": num_steps,
102            })
103            with open(self.output_dir / "metadata.json", "w") as f:
104                json.dump(metadata, f, indent=2)
105
106        print(f"Saved {num_images} images to {self.output_dir}")
107        return all_samples
108
109    def generate_with_seeds(
110        self,
111        seeds: list,
112        image_size: int = 64,
113        channels: int = 3,
114        num_steps: int = 50,
115    ) -> torch.Tensor:
116        """Generate images with specific seeds for reproducibility."""
117        samples = []
118
119        for seed in tqdm(seeds, desc="Generating from seeds"):
120            torch.manual_seed(seed)
121            noise = torch.randn(1, channels, image_size, image_size, device=self.device)
122
123            # Sample from this specific noise
124            sample = self.sampler.sample(
125                noise.shape,
126                num_steps=num_steps,
127            )
128            samples.append(sample)
129
130            # Save with seed in filename
131            self._save_image(
132                sample[0],
133                self.output_dir / f"seed_{seed:08d}.png"
134            )
135
136        return torch.cat(samples, dim=0)
137
138    def generate_variations(
139        self,
140        reference: torch.Tensor,
141        num_variations: int = 8,
142        noise_strength: float = 0.3,
143        num_steps: int = 50,
144    ) -> torch.Tensor:
145        """Generate variations of a reference image."""
146        samples = [reference]
147
148        # Add varying amounts of noise and denoise
149        for i in range(num_variations):
150            # Add noise to reference
151            noise = torch.randn_like(reference)
152            strength = noise_strength * (i + 1) / num_variations
153            noisy = reference * (1 - strength) + noise * strength
154
155            # Use as starting point for sampling
156            # (simplified - full implementation would use proper DDIM inversion)
157            sample = self.sampler.sample(reference.shape, num_steps=num_steps)
158            samples.append(sample)
159
160        return torch.cat(samples, dim=0)
161
162    def _save_image(self, tensor: torch.Tensor, path: Path):
163        """Save a single image tensor."""
164        # Convert from [-1, 1] to [0, 1]
165        tensor = (tensor + 1) / 2
166        tensor = tensor.clamp(0, 1)
167        vutils.save_image(tensor, str(path))
168
169
170# Usage example
171def generate_samples():
172    """Example generation script."""
173    # Load model
174    model = load_trained_model("checkpoints/best_model.pt")
175
176    # Create sampler (DDIM for speed)
177    sampler = DDIMSampler(model, alphas_cumprod, device="cuda")
178
179    # Create generator
180    generator = ImageGenerator(
181        model=model,
182        sampler=sampler,
183        output_dir="./generated_samples",
184    )
185
186    # Generate 1000 images
187    samples = generator.generate_batch(
188        num_images=1000,
189        batch_size=64,
190        image_size=64,
191        num_steps=50,
192        save_individual=True,
193        save_grid=True,
194        metadata={"model": "ddpm_cifar10", "steps": 50, "sampler": "ddim"},
195    )
196
197    # Generate specific seeds
198    seed_samples = generator.generate_with_seeds(
199        seeds=[42, 123, 456, 789, 1000],
200        num_steps=50,
201    )
202
203    return samples

Memory Management

When generating many images, process in batches to avoid OOM errors. For 1000+ images, use batch_size=32-64 and save incrementally rather than keeping all images in memory.

Key Takeaways

  1. DDPM provides best quality: Use full 1000 steps for final production samples when quality matters most.
  2. DDIM enables fast iteration: Use 50-100 steps during development and for applications requiring speed.
  3. Eta controls diversity: Set eta=0 for deterministic results (interpolation, reproducibility) or eta=0.5-1.0 for variety.
  4. Timestep schedule matters: Karras schedule often produces better results than uniform spacing with fewer steps.
  5. Batch for efficiency: Process multiple images in parallel to maximize GPU utilization.
Looking Ahead: Now that we can generate images, we need to evaluate their quality quantitatively. The next section covers FID, IS, and other metrics for assessing generative model performance.