Chapter 7
18 min read
Section 34 of 76

Problems with Ancestral Sampling

Improved Sampling Methods

Learning Objectives

By the end of this section, you will:

  1. Understand the computational bottleneck of standard DDPM sampling
  2. Identify the stochasticity problem and why it prevents reproducibility
  3. Recognize the interpolation limitation in latent space navigation
  4. Quantify the quality-speed trade-off with empirical analysis
  5. Motivate the need for improved sampling methods like DDIM

Setting Up the Problem

Before we introduce solutions, we need to deeply understand what's wrong with the current approach. This section diagnoses the fundamental limitations of ancestral sampling, which will motivate the elegant solutions we cover in subsequent sections.

Ancestral Sampling Recap

In Chapter 6, we implemented the standard DDPM sampling algorithm, also known as ancestral sampling. Let's briefly recap how it works:

The Forward Process

The forward process gradually adds noise to data according to a predefined schedule:

q(xtx0)=N(xt;αˉtx0,(1αˉt)I)q(\mathbf{x}_t | \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t} \mathbf{x}_0, (1 - \bar{\alpha}_t)\mathbf{I})

This can be reparameterized as:

xt=αˉtx0+1αˉtϵ,ϵN(0,I)\mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \boldsymbol{\epsilon}, \quad \boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})

The Reverse Process (Ancestral Sampling)

Ancestral sampling reverses this process one step at a time:

pθ(xt1xt)=N(xt1;μθ(xt,t),σt2I)p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t) = \mathcal{N}(\mathbf{x}_{t-1}; \boldsymbol{\mu}_\theta(\mathbf{x}_t, t), \sigma_t^2 \mathbf{I})

Where the mean is computed using the noise prediction:

μθ(xt,t)=1αt(xt1αt1αˉtϵθ(xt,t))\boldsymbol{\mu}_\theta(\mathbf{x}_t, t) = \frac{1}{\sqrt{\alpha_t}} \left( \mathbf{x}_t - \frac{1 - \alpha_t}{\sqrt{1 - \bar{\alpha}_t}} \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) \right)
🐍python
1# Standard DDPM ancestral sampling
2def ancestral_sample(model, noise_schedule, shape, device="cuda"):
3    """
4    Generate samples using ancestral (DDPM) sampling.
5
6    Args:
7        model: Trained noise prediction network
8        noise_schedule: Object containing alpha, alpha_bar values
9        shape: Output shape (batch, channels, height, width)
10        device: Computation device
11
12    Returns:
13        Generated samples in [-1, 1]
14    """
15    # Start from pure noise
16    x_t = torch.randn(shape, device=device)
17
18    # Iterate through all timesteps (T-1, T-2, ..., 0)
19    for t in reversed(range(noise_schedule.T)):
20        t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
21
22        # Predict noise
23        eps_pred = model(x_t, t_batch)
24
25        # Get schedule values
26        alpha_t = noise_schedule.alphas[t]
27        alpha_bar_t = noise_schedule.alphas_cumprod[t]
28        beta_t = noise_schedule.betas[t]
29
30        # Compute mean
31        coef1 = 1 / math.sqrt(alpha_t)
32        coef2 = beta_t / math.sqrt(1 - alpha_bar_t)
33        mean = coef1 * (x_t - coef2 * eps_pred)
34
35        # Add noise (except at t=0)
36        if t > 0:
37            noise = torch.randn_like(x_t)
38            sigma_t = math.sqrt(beta_t)
39            x_t = mean + sigma_t * noise
40        else:
41            x_t = mean
42
43    return x_t

Key Observation

The algorithm requires TT neural network forward passes (typically 1000). Each step adds fresh Gaussian noise, making the process inherently stochastic.

The Computational Cost Problem

The most glaring issue with ancestral sampling is its computational cost. With T=1000T = 1000 timesteps, generating a single image requires 1000 forward passes through the neural network.

Quantifying the Cost

MetricDDPM (T=1000)GANVAERatio
Neural network passes1000111000x slower
Time per image (V100)~20 seconds~0.02s~0.01s1000-2000x
Time for 50K images~12 days~17 min~8 min~1000x
FLOPs per image~10^15~10^12~10^111000-10000x

Practical Impact

At inference time, this makes diffusion models impractical for real-time applications and expensive for large-scale generation. A single high-resolution image can take minutes on consumer hardware.

Why Can't We Just Use Fewer Steps?

A natural question: why not simply use fewer timesteps? The problem is that ancestral sampling degrades catastrophically when we skip steps:

🐍python
1def ancestral_sample_with_stride(model, noise_schedule, shape, stride=1):
2    """
3    Attempt to speed up by skipping timesteps.
4
5    WARNING: This produces poor results!
6    """
7    x_t = torch.randn(shape, device=device)
8
9    # Use strided timesteps
10    timesteps = list(range(noise_schedule.T - 1, -1, -stride))
11
12    for i, t in enumerate(timesteps):
13        t_batch = torch.full((shape[0],), t, device=device, dtype=torch.long)
14        eps_pred = model(x_t, t_batch)
15
16        # Problem: alpha values don't account for skipped steps!
17        alpha_t = noise_schedule.alphas[t]
18        alpha_bar_t = noise_schedule.alphas_cumprod[t]
19
20        # The reverse process derivation assumes single steps
21        # Skipping breaks the mathematical assumptions
22        # ...
23
24    return x_t  # Poor quality!

The issue is fundamental: the reverse process transition probabilitypθ(xt1xt)p_\theta(\mathbf{x}_{t-1} | \mathbf{x}_t) was derived assuming single-step transitions. When we skip from tt totkt - k, the mathematical derivation no longer holds.

Sampling Trajectory Visualizer

Compare DDPM vs DDIM sampling paths in a 2D space. Watch how different samplers navigate from noise to data.

Step: 0/50
Target (clean data)Noise distributiont = 1000Position: (2.32, 1.91)

Trajectory Comparison (Same Seed)

DDPM (stochastic)DDIM (deterministic)
DDPM
  • - Stochastic (adds noise each step)
  • - Requires ~1000 steps
  • - Different samples from same noise
  • - Markov chain property
DDIM (eta=0)
  • - Deterministic sampling
  • - Can skip steps (10-50 steps)
  • - Same noise = same output
  • - Non-Markovian process
DDIM (0 < eta <= 1)
  • - Interpolates between both
  • - eta=1 recovers DDPM
  • - Tunable diversity-quality
  • - Flexible step skipping

Stochasticity Issues

The second major limitation is that ancestral sampling is inherently stochastic. At each timestep, we add fresh Gaussian noise:

xt1=μθ(xt,t)+σtz,zN(0,I)\mathbf{x}_{t-1} = \boldsymbol{\mu}_\theta(\mathbf{x}_t, t) + \sigma_t \mathbf{z}, \quad \mathbf{z} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})

This stochasticity causes several practical problems:

Problem 1: Non-Reproducibility

Even with the same initial noise xT\mathbf{x}_T, different runs produce different outputs:

🐍python
1def demonstrate_non_reproducibility():
2    """Show that ancestral sampling isn't reproducible."""
3
4    # Set seed for initial noise
5    torch.manual_seed(42)
6    initial_noise = torch.randn(1, 3, 64, 64, device=device)
7
8    # Run sampling twice with same initial noise
9    results = []
10    for run in range(2):
11        x_t = initial_noise.clone()
12
13        for t in reversed(range(T)):
14            t_batch = torch.tensor([t], device=device)
15            eps_pred = model(x_t, t_batch)
16
17            # Compute mean
18            mean = compute_mean(x_t, eps_pred, t)
19
20            # Add different noise each time!
21            if t > 0:
22                noise = torch.randn_like(x_t)  # Different every run
23                x_t = mean + sigma[t] * noise
24            else:
25                x_t = mean
26
27        results.append(x_t)
28
29    # These will be DIFFERENT images
30    mse = F.mse_loss(results[0], results[1])
31    print(f"MSE between runs: {mse.item():.4f}")  # >> 0!
32
33    return results

Problem 2: Inability to Encode Images

Because the mapping from noise to image is stochastic, there's no deterministic inverse. We cannot find which noise xT\mathbf{x}_T corresponds to a given image x0\mathbf{x}_0:

🐍python
1def attempt_inversion(model, image, noise_schedule):
2    """
3    Try to find the noise that generates this image.
4
5    This is IMPOSSIBLE with ancestral sampling because:
6    1. The process is stochastic
7    2. Multiple noise vectors can map to similar images
8    3. There's no deterministic encoder
9    """
10    # We can compute x_t from x_0 (forward process)
11    t = T - 1
12    x_t = forward_process(image, t, noise_schedule)
13
14    # But running reverse process won't recover the same image!
15    # Fresh noise is added at each step
16    reconstructed = ancestral_sample_from(model, x_t, t, noise_schedule)
17
18    # reconstructed != image (with high probability)
19    return reconstructed  # Different from original!

Problem 3: No Semantic Interpolation

The stochasticity also prevents meaningful interpolation in latent space:

🐍python
1def attempt_interpolation(z1, z2, steps=10):
2    """
3    Try to interpolate between two latent codes.
4
5    With stochastic sampling, this doesn't produce
6    meaningful semantic transitions.
7    """
8    interpolated_samples = []
9
10    for alpha in np.linspace(0, 1, steps):
11        # Interpolate in latent space
12        z_interp = (1 - alpha) * z1 + alpha * z2
13
14        # Sample - but noise added during sampling
15        # overwhelms the interpolation signal!
16        sample = ancestral_sample_from(model, z_interp, T-1, noise_schedule)
17        interpolated_samples.append(sample)
18
19    # Result: Discontinuous, inconsistent samples
20    # NOT a smooth semantic transition
21    return interpolated_samples

Root Cause

The root cause of these issues is the accumulation of stochastic noiseover 1000 timesteps. Even small per-step variance compounds into large total variance, destroying any structure we try to preserve.

The Interpolation Problem

One of the most useful capabilities in generative models is semantic interpolation: smoothly transitioning between two samples while maintaining coherent structure. GANs excel at this because they have deterministic generators.

GAN Interpolation (What We Want)

🐍python
1# GAN interpolation is clean and deterministic
2def gan_interpolation(generator, z1, z2, steps=10):
3    """
4    Smooth interpolation in GAN latent space.
5    Each latent vector deterministically maps to an image.
6    """
7    interpolated = []
8    for alpha in np.linspace(0, 1, steps):
9        z = (1 - alpha) * z1 + alpha * z2
10        image = generator(z)  # Deterministic!
11        interpolated.append(image)
12
13    return interpolated  # Smooth semantic transition

DDPM Interpolation (What We Get)

🐍python
1# DDPM interpolation is noisy and inconsistent
2def ddpm_interpolation_attempt(model, x1_T, x2_T, steps=10):
3    """
4    Attempt interpolation with ancestral sampling.
5    Results are poor due to stochasticity.
6    """
7    interpolated = []
8
9    for alpha in np.linspace(0, 1, steps):
10        # Interpolate initial noise
11        x_T = (1 - alpha) * x1_T + alpha * x2_T
12
13        # Run ancestral sampling
14        x_t = x_T.clone()
15        for t in reversed(range(T)):
16            # ... standard ancestral step ...
17            if t > 0:
18                x_t = mean + sigma[t] * torch.randn_like(x_t)  # Random!
19            else:
20                x_t = mean
21
22        interpolated.append(x_t)
23
24    # Problem: The accumulated random noise at each step
25    # creates discontinuous jumps between frames
26    return interpolated  # Jerky, inconsistent

The fundamental issue is that interpolating in the initial noise spacedoesn't translate to interpolation in the output space when the mapping is stochastic.


Quality vs Speed Trade-off

Let's quantify how sample quality degrades when we try to speed up ancestral sampling:

Naive Step Skipping

StepsTime (V100)FIDQuality
100020s3.17Excellent
50010s4.85Good
2505s8.42Acceptable
1002s18.73Poor
501s45.62Very Poor
250.5s89.31Unusable

Exponential Degradation

Quality doesn't degrade linearly with fewer steps - it degrades exponentially. Below ~200 steps, ancestral sampling produces noticeably blurry or corrupted images.

Why Does Quality Degrade So Rapidly?

The mathematical reason lies in the accumulated error from skipping intermediate states:

Skipping from ttk:E[xtktruextkapprox2]O(k2)\text{Skipping from } t \to t-k: \quad \mathbb{E}\left[\|\mathbf{x}_{t-k}^{\text{true}} - \mathbf{x}_{t-k}^{\text{approx}}\|^2\right] \approx O(k^2)

The error grows quadratically with the skip size, because:

  1. The neural network was trained to predict noise at specific noise levels
  2. Skipping creates a mismatch between expected and actual noise levels
  3. Each skipped step compounds the error from previous steps
  4. The Gaussian noise assumptions in the reverse process derivation break down

Analyzing These Problems

Here's code to empirically measure these limitations:

🐍python
1import torch
2import torch.nn.functional as F
3from tqdm import tqdm
4import numpy as np
5from typing import List, Tuple
6
7class AncestralSamplingAnalyzer:
8    """
9    Analyze the limitations of ancestral sampling.
10    """
11
12    def __init__(self, model, noise_schedule, device="cuda"):
13        self.model = model
14        self.ns = noise_schedule
15        self.device = device
16        self.model.eval()
17
18    def measure_reproducibility(
19        self,
20        initial_noise: torch.Tensor,
21        num_runs: int = 10
22    ) -> dict:
23        """
24        Measure variance across runs with same initial noise.
25
26        Returns statistics on output variance.
27        """
28        samples = []
29
30        for _ in range(num_runs):
31            sample = self._ancestral_sample(initial_noise.clone())
32            samples.append(sample)
33
34        samples = torch.stack(samples)  # (num_runs, C, H, W)
35
36        # Compute statistics
37        mean_sample = samples.mean(dim=0)
38        variance = samples.var(dim=0).mean()
39
40        # Pairwise MSE
41        pairwise_mse = []
42        for i in range(num_runs):
43            for j in range(i + 1, num_runs):
44                mse = F.mse_loss(samples[i], samples[j])
45                pairwise_mse.append(mse.item())
46
47        return {
48            "mean_sample": mean_sample,
49            "variance": variance.item(),
50            "mean_pairwise_mse": np.mean(pairwise_mse),
51            "std_pairwise_mse": np.std(pairwise_mse),
52        }
53
54    def measure_step_degradation(
55        self,
56        num_samples: int = 100,
57        step_counts: List[int] = [1000, 500, 250, 100, 50, 25]
58    ) -> dict:
59        """
60        Measure how quality degrades with fewer steps.
61
62        Uses MSE to full-step samples as quality proxy.
63        """
64        results = {}
65
66        # Generate reference samples with full steps
67        torch.manual_seed(42)
68        initial_noises = [
69            torch.randn(1, 3, 64, 64, device=self.device)
70            for _ in range(num_samples)
71        ]
72
73        reference_samples = []
74        for noise in tqdm(initial_noises, desc="Full-step reference"):
75            sample = self._ancestral_sample_strided(noise, stride=1)
76            reference_samples.append(sample)
77        reference_samples = torch.cat(reference_samples)
78
79        # Compare with different step counts
80        for steps in step_counts:
81            stride = 1000 // steps
82
83            reduced_samples = []
84            for noise in tqdm(initial_noises, desc=f"{steps} steps"):
85                sample = self._ancestral_sample_strided(noise.clone(), stride=stride)
86                reduced_samples.append(sample)
87            reduced_samples = torch.cat(reduced_samples)
88
89            mse = F.mse_loss(reduced_samples, reference_samples)
90            results[steps] = {
91                "mse_vs_reference": mse.item(),
92                "samples": reduced_samples[:5]  # Keep a few for visualization
93            }
94
95        return results
96
97    def measure_interpolation_smoothness(
98        self,
99        z1: torch.Tensor,
100        z2: torch.Tensor,
101        num_points: int = 10,
102        num_runs: int = 5
103    ) -> dict:
104        """
105        Measure how smooth interpolation is.
106
107        Smooth interpolation should have small differences
108        between adjacent frames.
109        """
110        all_runs = []
111
112        for run in range(num_runs):
113            interpolated = []
114
115            for alpha in np.linspace(0, 1, num_points):
116                z_interp = (1 - alpha) * z1 + alpha * z2
117                sample = self._ancestral_sample(z_interp)
118                interpolated.append(sample)
119
120            interpolated = torch.stack(interpolated)  # (num_points, C, H, W)
121            all_runs.append(interpolated)
122
123        all_runs = torch.stack(all_runs)  # (num_runs, num_points, C, H, W)
124
125        # Measure smoothness: MSE between adjacent frames
126        frame_diffs = []
127        for i in range(num_points - 1):
128            diff = F.mse_loss(
129                all_runs[:, i],
130                all_runs[:, i + 1]
131            )
132            frame_diffs.append(diff.item())
133
134        # Measure consistency: variance across runs for same alpha
135        consistency_scores = []
136        for i in range(num_points):
137            frames_at_alpha = all_runs[:, i]
138            variance = frames_at_alpha.var(dim=0).mean()
139            consistency_scores.append(variance.item())
140
141        return {
142            "frame_differences": frame_diffs,
143            "mean_frame_diff": np.mean(frame_diffs),
144            "consistency_per_point": consistency_scores,
145            "mean_consistency": np.mean(consistency_scores),
146            "example_run": all_runs[0]
147        }
148
149    def _ancestral_sample(self, x_t: torch.Tensor) -> torch.Tensor:
150        """Standard ancestral sampling."""
151        with torch.no_grad():
152            for t in reversed(range(self.ns.T)):
153                t_batch = torch.full(
154                    (x_t.shape[0],), t,
155                    device=self.device, dtype=torch.long
156                )
157                eps_pred = self.model(x_t, t_batch)
158
159                alpha_t = self.ns.alphas[t]
160                alpha_bar_t = self.ns.alphas_cumprod[t]
161                beta_t = self.ns.betas[t]
162
163                coef1 = 1 / (alpha_t ** 0.5)
164                coef2 = beta_t / ((1 - alpha_bar_t) ** 0.5)
165                mean = coef1 * (x_t - coef2 * eps_pred)
166
167                if t > 0:
168                    sigma_t = beta_t ** 0.5
169                    x_t = mean + sigma_t * torch.randn_like(x_t)
170                else:
171                    x_t = mean
172
173        return x_t
174
175    def _ancestral_sample_strided(
176        self,
177        x_t: torch.Tensor,
178        stride: int
179    ) -> torch.Tensor:
180        """Ancestral sampling with step skipping."""
181        timesteps = list(range(self.ns.T - 1, -1, -stride))
182
183        with torch.no_grad():
184            for t in timesteps:
185                t_batch = torch.full(
186                    (x_t.shape[0],), t,
187                    device=self.device, dtype=torch.long
188                )
189                eps_pred = self.model(x_t, t_batch)
190
191                alpha_t = self.ns.alphas[t]
192                alpha_bar_t = self.ns.alphas_cumprod[t]
193                beta_t = self.ns.betas[t]
194
195                coef1 = 1 / (alpha_t ** 0.5)
196                coef2 = beta_t / ((1 - alpha_bar_t) ** 0.5)
197                mean = coef1 * (x_t - coef2 * eps_pred)
198
199                if t > 0:
200                    sigma_t = beta_t ** 0.5
201                    x_t = mean + sigma_t * torch.randn_like(x_t)
202                else:
203                    x_t = mean
204
205        return x_t
206
207
208# Run the analysis
209def analyze_ancestral_limitations(model, noise_schedule, device="cuda"):
210    """
211    Comprehensive analysis of ancestral sampling limitations.
212    """
213    analyzer = AncestralSamplingAnalyzer(model, noise_schedule, device)
214
215    # 1. Reproducibility analysis
216    print("=" * 50)
217    print("1. REPRODUCIBILITY ANALYSIS")
218    print("=" * 50)
219
220    initial_noise = torch.randn(1, 3, 64, 64, device=device)
221    repro_results = analyzer.measure_reproducibility(initial_noise, num_runs=10)
222
223    print(f"Output variance: {repro_results['variance']:.4f}")
224    print(f"Mean pairwise MSE: {repro_results['mean_pairwise_mse']:.4f}")
225    print(f"Std pairwise MSE: {repro_results['std_pairwise_mse']:.4f}")
226    print("(Non-zero values indicate stochastic outputs)")
227
228    # 2. Step degradation analysis
229    print("\n" + "=" * 50)
230    print("2. STEP DEGRADATION ANALYSIS")
231    print("=" * 50)
232
233    degrad_results = analyzer.measure_step_degradation(
234        num_samples=20,
235        step_counts=[1000, 500, 250, 100, 50]
236    )
237
238    for steps, data in degrad_results.items():
239        print(f"Steps: {steps:4d} | MSE vs reference: {data['mse_vs_reference']:.4f}")
240
241    # 3. Interpolation analysis
242    print("\n" + "=" * 50)
243    print("3. INTERPOLATION SMOOTHNESS ANALYSIS")
244    print("=" * 50)
245
246    z1 = torch.randn(1, 3, 64, 64, device=device)
247    z2 = torch.randn(1, 3, 64, 64, device=device)
248    interp_results = analyzer.measure_interpolation_smoothness(z1, z2)
249
250    print(f"Mean frame difference: {interp_results['mean_frame_diff']:.4f}")
251    print(f"Mean consistency (variance): {interp_results['mean_consistency']:.4f}")
252    print("(High values indicate poor interpolation)")
253
254    return {
255        "reproducibility": repro_results,
256        "degradation": degrad_results,
257        "interpolation": interp_results,
258    }

Running the Analysis

This analysis code helps quantify exactly how severe each problem is for your specific model. Run it after training to understand the baseline before comparing with DDIM and other improved samplers.

Summary

We've identified three fundamental limitations of ancestral sampling:

ProblemCauseImpactCan We Fix It?
Slow (1000 steps)Sequential denoisingImpractical for real-timeYes - DDIM, DPM-Solver
StochasticNoise at each stepNo reproducibility/encodingYes - DDIM (eta=0)
Poor interpolationAccumulated noiseNo semantic controlYes - Deterministic sampling

The key takeaways:

  1. Computational cost is the most obvious limitation - 1000 network evaluations per sample is far too slow for practical applications
  2. Stochasticity prevents reproducibility, image encoding, and meaningful latent space operations
  3. Quality degrades quadratically when we naively skip steps, making simple acceleration approaches ineffective
  4. These problems are interconnected - solving the stochasticity problem also helps with the others

The DDIM Solution

In the next section, we'll introduce DDIM (Denoising Diffusion Implicit Models), which elegantly solves all three problems by reformulating the reverse process as a non-Markovian, deterministic transformation. DDIM uses the same trained model but achieves 10-50x speedup with equal or better quality.

Understanding these limitations deeply is crucial because DDIM's design choices directly address each one. The deterministic formulation enables reproducibility and encoding; the non-Markovian structure allows arbitrary step counts without quality degradation.