Chapter 2
20 min read
Section 12 of 76

Closed-Form Sampling at Any Timestep

The Forward Diffusion Process

Learning Objectives

By the end of this section, you will be able to:

  1. Derive the closed-form expression q(xtx0)=N(αˉtx0,(1αˉt)I)q(\mathbf{x}_t | \mathbf{x}_0) = \mathcal{N}(\sqrt{\bar{\alpha}_t}\mathbf{x}_0, (1-\bar{\alpha}_t)\mathbf{I})
  2. Explain why this enables efficient training by allowing direct sampling at any timestep
  3. Apply the reparameterization xt=αˉtx0+1αˉtϵ\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}
  4. Implement efficient forward sampling in PyTorch

The Key Insight

In Section 2.1, we defined the forward process as a sequence of TT steps, each adding a small amount of noise. Naively, to get xt\mathbf{x}_t, we would need to iterate through all tt steps. For T=1000T = 1000, this would be prohibitively slow during training.

The Key Question: Can we compute xt\mathbf{x}_t directly from x0\mathbf{x}_0 in a single step?

Yes! Because each step adds Gaussian noise and scales by a linear factor, the entire chain collapses to a single Gaussian. This is one of the most important properties of diffusion models.


The Derivation

Let's derive the closed-form expression step by step. Recall the single-step transition:

xt=1βtxt1+βtϵt1\mathbf{x}_t = \sqrt{1-\beta_t} \cdot \mathbf{x}_{t-1} + \sqrt{\beta_t} \cdot \boldsymbol{\epsilon}_{t-1}

where ϵt1N(0,I)\boldsymbol{\epsilon}_{t-1} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}). Let's define αt=1βt\alpha_t = 1 - \beta_t for convenience:

xt=αtxt1+1αtϵt1\mathbf{x}_t = \sqrt{\alpha_t} \cdot \mathbf{x}_{t-1} + \sqrt{1-\alpha_t} \cdot \boldsymbol{\epsilon}_{t-1}

Expanding the Recursion

Now substitute xt1\mathbf{x}_{t-1} in terms of xt2\mathbf{x}_{t-2}:

xt=αt(αt1xt2+1αt1ϵt2)+1αtϵt1\mathbf{x}_t = \sqrt{\alpha_t}\left(\sqrt{\alpha_{t-1}} \mathbf{x}_{t-2} + \sqrt{1-\alpha_{t-1}} \boldsymbol{\epsilon}_{t-2}\right) + \sqrt{1-\alpha_t} \boldsymbol{\epsilon}_{t-1}

=αtαt1xt2+αt(1αt1)ϵt2+1αtϵt1= \sqrt{\alpha_t \alpha_{t-1}} \mathbf{x}_{t-2} + \sqrt{\alpha_t(1-\alpha_{t-1})} \boldsymbol{\epsilon}_{t-2} + \sqrt{1-\alpha_t} \boldsymbol{\epsilon}_{t-1}

Combining Gaussian Noise Terms

The sum of two independent Gaussians aϵ1+bϵ2a\boldsymbol{\epsilon}_1 + b\boldsymbol{\epsilon}_2 where ϵ1,ϵ2N(0,I)\boldsymbol{\epsilon}_1, \boldsymbol{\epsilon}_2 \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) is equivalent to a single Gaussian a2+b2ϵ\sqrt{a^2 + b^2}\boldsymbol{\epsilon}:

αt(1αt1)ϵt2+1αtϵt1N(0,[αt(1αt1)+(1αt)]I)\sqrt{\alpha_t(1-\alpha_{t-1})}\boldsymbol{\epsilon}_{t-2} + \sqrt{1-\alpha_t}\boldsymbol{\epsilon}_{t-1} \sim \mathcal{N}\left(\mathbf{0}, \left[\alpha_t(1-\alpha_{t-1}) + (1-\alpha_t)\right]\mathbf{I}\right)

The variance simplifies beautifully:

αt(1αt1)+(1αt)=αtαtαt1+1αt=1αtαt1\alpha_t(1-\alpha_{t-1}) + (1-\alpha_t) = \alpha_t - \alpha_t\alpha_{t-1} + 1 - \alpha_t = 1 - \alpha_t\alpha_{t-1}

The General Pattern

Continuing this expansion to x0\mathbf{x}_0, we find:

xt=αtαt1α1x0+1αtαt1α1ϵ\mathbf{x}_t = \sqrt{\alpha_t \alpha_{t-1} \cdots \alpha_1} \cdot \mathbf{x}_0 + \sqrt{1 - \alpha_t \alpha_{t-1} \cdots \alpha_1} \cdot \boldsymbol{\epsilon}

Defining αˉt=s=1tαs\bar{\alpha}_t = \prod_{s=1}^t \alpha_s, we get the closed-form result:

xt=αˉtx0+1αˉtϵ\mathbf{x}_t = \sqrt{\bar{\alpha}_t} \cdot \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \cdot \boldsymbol{\epsilon}

The Distribution Form

This reparameterization tells us the distribution:

q(xtx0)=N(xt;αˉtx0,(1αˉt)I)q(\mathbf{x}_t | \mathbf{x}_0) = \mathcal{N}\left(\mathbf{x}_t; \sqrt{\bar{\alpha}_t}\mathbf{x}_0, (1-\bar{\alpha}_t)\mathbf{I}\right)

The mean is αˉtx0\sqrt{\bar{\alpha}_t}\mathbf{x}_0 and the variance is (1αˉt)I(1-\bar{\alpha}_t)\mathbf{I}.

The Reparameterization Trick

The reparameterization xt=αˉtx0+1αˉtϵ\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon} has a beautiful interpretation:

TermExpressionMeaning
Signal term√α̅_t · x_0Scaled-down original signal
Noise term√(1-α̅_t) · εScaled Gaussian noise
Sum of squares(√α̅_t)² + (√(1-α̅_t))² = 1Variance preserved

Think of it as a smooth interpolation between the clean data and pure noise:

  • At t=0t = 0: αˉ0=1\bar{\alpha}_0 = 1, so x0=1x0+0ϵ\mathbf{x}_0 = 1 \cdot \mathbf{x}_0 + 0 \cdot \boldsymbol{\epsilon} (pure signal)
  • At t=Tt = T: αˉT0\bar{\alpha}_T \approx 0, so xT0x0+1ϵ\mathbf{x}_T \approx 0 \cdot \mathbf{x}_0 + 1 \cdot \boldsymbol{\epsilon} (pure noise)
  • At intermediate tt: A blend of both

Visualizing the Decay

Use the interactive visualization below to explore how αˉt\bar{\alpha}_t and its square root coefficients evolve over timesteps. Adjust the noise schedule parameters to see how they affect the signal/noise balance:

📉Understanding α̅t Decay

Cumulative Signal/Noise Coefficients over Timesteps0.000.250.500.751.0002505007501000Timestep tCoefficient Valueα̅t (cumulative)√α̅t (signal coef)√(1-α̅t) (noise coef)
Explore timestep:t = 500
α̅t
0.077992
√α̅t
0.279271
√(1-α̅t)
0.960212
βt
0.010050

The Reparameterization Formula

xt = √α̅t · x0 + √(1-α̅t) · ε

At t = 500:

x500 = 0.2793 · x0 + 0.9602 · ε

Interpretation:

27.9% original signal + 96.0% noise

Key Insight: α̅t = ∏s=1t(1-βs) is the cumulative product of all alpha values. As t increases, this product shrinks exponentially, causing the signal coefficient √α̅tto decrease while the noise coefficient √(1-α̅t) increases. At t = T, xT is nearly pure Gaussian noise.


Why This Matters for Training

The closed-form formula is essential for training diffusion models efficiently:

Computational Efficiency

Without the closed form, generating one training sample would require:

  • TT forward passes through the noise addition
  • TT random number generations
  • O(Td)O(T \cdot d) operations where dd is data dimension

With the closed form, we need only:

  • 1 lookup for αˉt\sqrt{\bar{\alpha}_t}
  • 1 random number generation
  • O(d)O(d) operations

This is a T×T\times speedup! For T=1000T = 1000, this makes training practical.

Uniform Timestep Sampling

During training, we sample timesteps uniformly from {1,2,,T}\{1, 2, \ldots, T\}. The closed form lets us immediately compute xt\mathbf{x}_t for any sampled tt without computing all intermediate states.

Training Algorithm Preview:
  1. Sample a batch of clean data x0\mathbf{x}_0
  2. Sample random timesteps tt uniformly
  3. Sample noise ϵN(0,I)\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})
  4. Compute xt\mathbf{x}_t using closed form
  5. Train network to predict ϵ\boldsymbol{\epsilon} from xt\mathbf{x}_t

Implementation

Here is a complete, production-ready implementation of efficient forward sampling:

Efficient Closed-Form Forward Sampling
🐍forward_diffusion.py
1Import Dependencies

We need torch for tensor operations and Tuple for type hints on the return values.

3ForwardDiffusion Class

This class encapsulates the forward diffusion process with pre-computed coefficients for efficient sampling.

10The Key Formula

This is the closed-form equation. Instead of T iterations, we can compute x_t directly from x_0 in O(1) time.

24Pre-compute Alphas

alpha_t = 1 - beta_t. We convert from beta schedule to alpha values.

25Cumulative Product

alpha_bar_t = prod_{s=1}^t alpha_s. This is the key quantity - the cumulative product tells us how much signal remains.

EXAMPLE
cumprod([0.9, 0.8, 0.7]) = [0.9, 0.72, 0.504]
28Pre-compute Sqrt Terms

We store sqrt(alpha_bar) and sqrt(1-alpha_bar) since these are used directly in the sampling formula.

47Sample Timesteps

During training, we sample random timesteps uniformly. Each image in the batch gets a different t.

50Add Noise Function

This is the core function - it implements the closed-form forward sampling without any loops.

65Optional Noise Input

We allow pre-sampled noise for cases where you need to track the exact noise added (essential for training).

70Broadcasting Shape

We reshape [B] to [B, 1, 1, 1] so the coefficients broadcast correctly across channels and spatial dimensions.

74Closed-Form Sampling

The magic line! This single operation replaces T sequential noise additions, enabling efficient training.

82 lines without explanation
1import torch
2from typing import Tuple
3
4class ForwardDiffusion:
5    """
6    Forward diffusion process with efficient closed-form sampling.
7
8    Instead of iterating through all timesteps, we can sample x_t
9    directly from x_0 using the closed-form formula:
10
11    x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
12    """
13
14    def __init__(self, T: int = 1000, schedule: str = "cosine"):
15        self.T = T
16
17        # Compute beta schedule
18        if schedule == "linear":
19            self.betas = torch.linspace(0.0001, 0.02, T)
20        elif schedule == "cosine":
21            self.betas = self._cosine_schedule(T)
22        else:
23            raise ValueError(f"Unknown schedule: {schedule}")
24
25        # Pre-compute all derived quantities
26        self.alphas = 1.0 - self.betas
27        self.alpha_bars = torch.cumprod(self.alphas, dim=0)
28
29        # Store sqrt versions for efficient computation
30        self.sqrt_alpha_bars = torch.sqrt(self.alpha_bars)
31        self.sqrt_one_minus_alpha_bars = torch.sqrt(1.0 - self.alpha_bars)
32
33    def _cosine_schedule(self, T: int, s: float = 0.008) -> torch.Tensor:
34        """Compute cosine noise schedule."""
35        import math
36        def f(t):
37            return math.cos(((t / T) + s) / (1 + s) * math.pi / 2) ** 2
38        alpha_bars = torch.tensor([f(t) / f(0) for t in range(T + 1)])
39        betas = 1 - alpha_bars[1:] / alpha_bars[:-1]
40        return torch.clamp(betas, min=0.0001, max=0.999)
41
42    def sample_timesteps(self, batch_size: int) -> torch.Tensor:
43        """Sample random timesteps uniformly from [0, T-1]."""
44        return torch.randint(0, self.T, (batch_size,))
45
46    def add_noise(
47        self,
48        x_0: torch.Tensor,
49        t: torch.Tensor,
50        noise: torch.Tensor = None
51    ) -> Tuple[torch.Tensor, torch.Tensor]:
52        """
53        Sample x_t directly from x_0 using closed-form formula.
54
55        Args:
56            x_0: Clean data, shape [B, C, H, W]
57            t: Timesteps, shape [B]
58            noise: Optional pre-sampled noise, shape [B, C, H, W]
59
60        Returns:
61            x_t: Noised data at timestep t
62            noise: The noise that was added (for training target)
63        """
64        if noise is None:
65            noise = torch.randn_like(x_0)
66
67        # Get coefficients for each sample in batch
68        # Shape: [B, 1, 1, 1] for broadcasting
69        sqrt_alpha_bar = self.sqrt_alpha_bars[t].view(-1, 1, 1, 1)
70        sqrt_one_minus_alpha_bar = self.sqrt_one_minus_alpha_bars[t].view(-1, 1, 1, 1)
71
72        # x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * epsilon
73        x_t = sqrt_alpha_bar * x_0 + sqrt_one_minus_alpha_bar * noise
74
75        return x_t, noise
76
77
78# Example usage
79diffusion = ForwardDiffusion(T=1000, schedule="cosine")
80
81# Sample a batch of clean images (e.g., from your dataset)
82x_0 = torch.randn(8, 3, 64, 64)  # [batch, channels, height, width]
83
84# Sample random timesteps for each image in batch
85t = diffusion.sample_timesteps(batch_size=8)
86
87# Get noised versions and the noise that was added
88x_t, noise = diffusion.add_noise(x_0, t)
89
90print(f"Original shape: {x_0.shape}")
91print(f"Timesteps: {t}")
92print(f"Noised shape: {x_t.shape}")
93print(f"At t={t[0].item()}: alpha_bar = {diffusion.alpha_bars[t[0]]:.4f}")

Key Takeaways

  1. Closed-form exists: q(xtx0)=N(αˉtx0,(1αˉt)I)q(\mathbf{x}_t | \mathbf{x}_0) = \mathcal{N}(\sqrt{\bar{\alpha}_t}\mathbf{x}_0, (1-\bar{\alpha}_t)\mathbf{I})
  2. Reparameterization: xt=αˉtx0+1αˉtϵ\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}
  3. Derivation uses: Recursion expansion + Gaussian sum property
  4. Efficiency gain: O(d) instead of O(Td) per sample
  5. Interpretation: Smooth interpolation from data to noise
Looking Ahead: In the next section, we'll examine the mathematical properties of the forward process, including variance preservation, convergence to the prior, and the information-theoretic perspective.