1import torch
2import torch.nn as nn
3import math
4from typing import Optional, Tuple, Dict
5from dataclasses import dataclass
6
7@dataclass
8class DiffusionConfig:
9 """Configuration for the diffusion process."""
10 num_timesteps: int = 1000
11 schedule: str = "cosine" # "linear", "cosine", "quadratic"
12 beta_start: float = 0.0001
13 beta_end: float = 0.02
14
15
16class GaussianDiffusion(nn.Module):
17 """
18 Complete forward diffusion process implementation.
19
20 This class handles:
21 - Noise schedule computation (linear, cosine, quadratic)
22 - Efficient closed-form sampling at any timestep
23 - Pre-computation of all derived quantities
24 - GPU-compatible tensor operations
25 """
26
27 def __init__(self, config: DiffusionConfig):
28 super().__init__()
29 self.config = config
30 self.T = config.num_timesteps
31
32 # Compute the noise schedule
33 betas = self._compute_schedule(config)
34
35 # Register as buffers (not parameters, but should be on same device)
36 self.register_buffer("betas", betas)
37 self.register_buffer("alphas", 1.0 - betas)
38 self.register_buffer("alpha_bars", torch.cumprod(self.alphas, dim=0))
39
40 # Pre-compute useful quantities
41 self.register_buffer("sqrt_alpha_bars", torch.sqrt(self.alpha_bars))
42 self.register_buffer("sqrt_one_minus_alpha_bars", torch.sqrt(1.0 - self.alpha_bars))
43 self.register_buffer("sqrt_recip_alpha_bars", torch.sqrt(1.0 / self.alpha_bars))
44 self.register_buffer("sqrt_recip_alpha_bars_minus_one",
45 torch.sqrt(1.0 / self.alpha_bars - 1.0))
46
47 # For posterior computation (used in reverse process)
48 alpha_bars_prev = torch.cat([torch.tensor([1.0]), self.alpha_bars[:-1]])
49 self.register_buffer("alpha_bars_prev", alpha_bars_prev)
50
51 # Posterior variance: (1 - alpha_bar_{t-1}) / (1 - alpha_bar_t) * beta_t
52 posterior_variance = betas * (1.0 - alpha_bars_prev) / (1.0 - self.alpha_bars)
53 self.register_buffer("posterior_variance", posterior_variance)
54 self.register_buffer("posterior_log_variance_clipped",
55 torch.log(torch.clamp(posterior_variance, min=1e-20)))
56
57 # Posterior mean coefficients
58 self.register_buffer("posterior_mean_coef1",
59 betas * torch.sqrt(alpha_bars_prev) / (1.0 - self.alpha_bars))
60 self.register_buffer("posterior_mean_coef2",
61 (1.0 - alpha_bars_prev) * torch.sqrt(self.alphas) / (1.0 - self.alpha_bars))
62
63 def _compute_schedule(self, config: DiffusionConfig) -> torch.Tensor:
64 """Compute the beta schedule based on configuration."""
65 T = config.num_timesteps
66
67 if config.schedule == "linear":
68 return torch.linspace(config.beta_start, config.beta_end, T)
69
70 elif config.schedule == "cosine":
71 s = 0.008
72 steps = torch.arange(T + 1, dtype=torch.float64) / T
73 alpha_bars = torch.cos((steps + s) / (1 + s) * math.pi / 2) ** 2
74 alpha_bars = alpha_bars / alpha_bars[0]
75 betas = 1 - alpha_bars[1:] / alpha_bars[:-1]
76 return torch.clamp(betas, min=0.0001, max=0.999).float()
77
78 elif config.schedule == "quadratic":
79 return torch.linspace(config.beta_start**0.5, config.beta_end**0.5, T) ** 2
80
81 else:
82 raise ValueError(f"Unknown schedule: {config.schedule}")
83
84 def q_sample(
85 self,
86 x_0: torch.Tensor,
87 t: torch.Tensor,
88 noise: Optional[torch.Tensor] = None
89 ) -> Tuple[torch.Tensor, torch.Tensor]:
90 """
91 Sample from q(x_t | x_0) using the closed-form formula.
92
93 x_t = sqrt(alpha_bar_t) * x_0 + sqrt(1 - alpha_bar_t) * noise
94
95 Args:
96 x_0: Clean data [B, C, H, W]
97 t: Timesteps [B]
98 noise: Optional pre-sampled noise [B, C, H, W]
99
100 Returns:
101 x_t: Noised data
102 noise: The noise that was added
103 """
104 if noise is None:
105 noise = torch.randn_like(x_0)
106
107 sqrt_alpha_bar = self._extract(self.sqrt_alpha_bars, t, x_0.shape)
108 sqrt_one_minus_alpha_bar = self._extract(self.sqrt_one_minus_alpha_bars, t, x_0.shape)
109
110 x_t = sqrt_alpha_bar * x_0 + sqrt_one_minus_alpha_bar * noise
111 return x_t, noise
112
113 def _extract(self, a: torch.Tensor, t: torch.Tensor, x_shape: tuple) -> torch.Tensor:
114 """Extract values from 'a' at indices 't' and reshape for broadcasting."""
115 batch_size = t.shape[0]
116 out = a.gather(-1, t)
117 return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
118
119 def get_snr(self, t: torch.Tensor) -> torch.Tensor:
120 """Get signal-to-noise ratio at timestep t."""
121 alpha_bar = self._extract(self.alpha_bars, t, t.shape)
122 return alpha_bar / (1.0 - alpha_bar)
123
124 def predict_x0_from_noise(self, x_t: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
125 """Recover x_0 from x_t and predicted noise."""
126 sqrt_recip_alpha_bar = self._extract(self.sqrt_recip_alpha_bars, t, x_t.shape)
127 sqrt_recip_alpha_bar_minus_one = self._extract(self.sqrt_recip_alpha_bars_minus_one, t, x_t.shape)
128 return sqrt_recip_alpha_bar * x_t - sqrt_recip_alpha_bar_minus_one * noise
129
130 def q_posterior_mean_variance(
131 self,
132 x_0: torch.Tensor,
133 x_t: torch.Tensor,
134 t: torch.Tensor
135 ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
136 """
137 Compute the mean and variance of q(x_{t-1} | x_t, x_0).
138
139 This is the true posterior, used as the target for training.
140 """
141 posterior_mean = (
142 self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_0 +
143 self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
144 )
145 posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
146 posterior_log_variance = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
147 return posterior_mean, posterior_variance, posterior_log_variance
148
149
150# Example instantiation
151config = DiffusionConfig(num_timesteps=1000, schedule="cosine")
152diffusion = GaussianDiffusion(config)
153
154# Test on a batch
155x_0 = torch.randn(4, 3, 64, 64)
156t = torch.randint(0, 1000, (4,))
157x_t, noise = diffusion.q_sample(x_0, t)
158
159print(f"Config: T={config.num_timesteps}, schedule={config.schedule}")
160print(f"x_0 shape: {x_0.shape}, x_t shape: {x_t.shape}")
161print(f"At t={t[0].item()}: alpha_bar={diffusion.alpha_bars[t[0]]:.4f}")