Chapter 4
15 min read
Section 22 of 76

Connection to Denoising Autoencoders

Understanding the Loss Function

Learning Objectives

By the end of this section, you will be able to:

  1. Understand the classical denoising autoencoder (DAE) framework and its training objective
  2. Derive the theoretical connection between DAE training and score function estimation
  3. Explain how diffusion models extend DAEs to a multi-scale denoising framework
  4. Appreciate the historical development from autoencoders to modern diffusion models

Classical Denoising Autoencoders

Denoising Autoencoders (DAEs), introduced by Vincent et al. (2008), learn representations by training to reconstruct clean data from corrupted inputs. This seemingly simple objective has deep connections to density estimation.

The DAE Objective

Given clean data x\mathbf{x}, a DAE:

  1. Corrupts the input: x~=x+σϵ\tilde{\mathbf{x}} = \mathbf{x} + \sigma \boldsymbol{\epsilon}, where ϵN(0,I)\boldsymbol{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})
  2. Learns to reconstruct: x^=fθ(x~)\hat{\mathbf{x}} = f_\theta(\tilde{\mathbf{x}})
  3. Minimizes reconstruction error: LDAE=Ex,ϵ[fθ(x+σϵ)x2]L_{\text{DAE}} = \mathbb{E}_{\mathbf{x}, \boldsymbol{\epsilon}}\left[ \|f_\theta(\mathbf{x} + \sigma \boldsymbol{\epsilon}) - \mathbf{x}\|^2 \right]

The noise level σ\sigma is a hyperparameter. Differentσ\sigma values lead to different learned representations:

Noise LevelWhat DAE LearnsRepresentation
Small sigmaLocal structure, fine detailsSurface features
Medium sigmaGlobal structure, shapesSemantic features
Large sigmaBasic statistics, meanCoarse features
The Key Question: Is there a deeper reason why learning to denoise produces useful representations? Yes - denoising is fundamentally connected to learning the score function of the data distribution.

DAE-Score Function Connection

A remarkable result from Vincent (2011) shows that optimal denoising directly estimates the score function - the gradient of the log probability:

xlogp(x)=f(x~)x~σ2\nabla_{\mathbf{x}} \log p(\mathbf{x}) = \frac{f^*(\tilde{\mathbf{x}}) - \tilde{\mathbf{x}}}{\sigma^2}

where f(x~)f^*(\tilde{\mathbf{x}}) is the optimal denoiser - the one that minimizes the DAE objective.

Derivation

The optimal denoiser computes the posterior mean:

f(x~)=E[xx~]=xp(xx~)dxf^*(\tilde{\mathbf{x}}) = \mathbb{E}[\mathbf{x} | \tilde{\mathbf{x}}] = \int \mathbf{x} \cdot p(\mathbf{x}|\tilde{\mathbf{x}}) \, d\mathbf{x}

Using Bayes' rule and the Gaussian noise model:

p(xx~)=p(x~x)p(x)p(x~)exp(x~x22σ2)p(x)p(\mathbf{x}|\tilde{\mathbf{x}}) = \frac{p(\tilde{\mathbf{x}}|\mathbf{x}) p(\mathbf{x})}{p(\tilde{\mathbf{x}})} \propto \exp\left(-\frac{\|\tilde{\mathbf{x}} - \mathbf{x}\|^2}{2\sigma^2}\right) p(\mathbf{x})

Taking the gradient of the log of the noisy distribution:

x~logp(x~)=E[xx~]x~σ2=f(x~)x~σ2\nabla_{\tilde{\mathbf{x}}} \log p(\tilde{\mathbf{x}}) = \frac{\mathbb{E}[\mathbf{x}|\tilde{\mathbf{x}}] - \tilde{\mathbf{x}}}{\sigma^2} = \frac{f^*(\tilde{\mathbf{x}}) - \tilde{\mathbf{x}}}{\sigma^2}

Tweedie's Formula Again

This is exactly Tweedie's formula that we saw in the score matching section! The optimal denoiser implements Tweedie's estimate of the clean data given the noisy observation.

From Denoising to Score

Rearranging, if our network learns to predict the noiseϵ^=ϵθ(x~)\hat{\boldsymbol{\epsilon}} = \boldsymbol{\epsilon}_\theta(\tilde{\mathbf{x}})(as in DDPM), then:

fθ(x~)=x~σϵθ(x~)f_\theta(\tilde{\mathbf{x}}) = \tilde{\mathbf{x}} - \sigma \boldsymbol{\epsilon}_\theta(\tilde{\mathbf{x}})

And the estimated score is:

sθ(x~)=fθ(x~)x~σ2=ϵθ(x~)σ\mathbf{s}_\theta(\tilde{\mathbf{x}}) = \frac{f_\theta(\tilde{\mathbf{x}}) - \tilde{\mathbf{x}}}{\sigma^2} = -\frac{\boldsymbol{\epsilon}_\theta(\tilde{\mathbf{x}})}{\sigma}

The Profound Connection: Training a network to predict noise is equivalent to training it to estimate the score function. Denoising and score matching are two views of the same underlying operation.

Diffusion as Multi-Scale DAE

Classical DAEs use a single noise level, which creates a trade-off: small σ\sigma captures fine details but has limited receptive field; large σ\sigma captures global structure but loses details.

Diffusion models solve this by training across all noise levels simultaneously - essentially a multi-scale denoising autoencoder.

Single DAE vs Diffusion Model

AspectClassical DAEDiffusion Model
Noise levelsSingle fixed sigmaContinuous range [0, T]
TrainingOne noise level at a timeAll levels simultaneously
ArchitectureSame network for one sigmaTime-conditioned network
GenerationNot directly possibleIterative denoising
Score estimationFor one sigma onlyFull multi-scale score

The Multi-Scale Objective

The diffusion training objective can be written as a weighted integral over noise levels:

L=0Tw(t)Ex0,ϵ[ϵθ(xt,t)ϵ2]dtL = \int_0^T w(t) \cdot \mathbb{E}_{\mathbf{x}_0, \boldsymbol{\epsilon}}\left[ \|\boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t) - \boldsymbol{\epsilon}\|^2 \right] dt

where xt=αˉtx0+1αˉtϵ\mathbf{x}_t = \sqrt{\bar{\alpha}_t}\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_t}\boldsymbol{\epsilon}. This is a continuous family of DAE objectives, one for each noise level.

Why Multi-Scale Matters

Training across all scales provides several advantages:

  1. Coarse-to-fine generation: High-noise timesteps establish global structure; low-noise timesteps refine details
  2. Consistent representations: The network learns a unified representation that captures features at all scales
  3. Smooth score field: Interpolating between noise levels creates smooth denoising trajectories
  4. Mode coverage: Different noise levels help the model discover and connect different modes of the distribution
🐍python
1# Illustrating the connection: DAE vs Diffusion
2import torch
3import torch.nn as nn
4import torch.nn.functional as F
5
6class ClassicalDAE(nn.Module):
7    """Single noise level denoising autoencoder."""
8
9    def __init__(self, sigma: float = 0.1):
10        super().__init__()
11        self.sigma = sigma
12        # Simple MLP denoiser
13        self.net = nn.Sequential(
14            nn.Linear(784, 512),
15            nn.ReLU(),
16            nn.Linear(512, 512),
17            nn.ReLU(),
18            nn.Linear(512, 784),
19        )
20
21    def forward(self, x: torch.Tensor) -> torch.Tensor:
22        """Predict clean data from noisy input."""
23        return self.net(x)
24
25    def loss(self, x_clean: torch.Tensor) -> torch.Tensor:
26        """Compute DAE loss."""
27        # Add noise
28        noise = torch.randn_like(x_clean)
29        x_noisy = x_clean + self.sigma * noise
30
31        # Predict clean
32        x_pred = self(x_noisy)
33
34        # MSE reconstruction loss
35        return F.mse_loss(x_pred, x_clean)
36
37    def get_score(self, x_noisy: torch.Tensor) -> torch.Tensor:
38        """Estimate score from denoiser output."""
39        x_pred = self(x_noisy)
40        # score = (predicted_clean - noisy) / sigma^2
41        return (x_pred - x_noisy) / (self.sigma ** 2)
42
43
44class DiffusionAsMultiScaleDAE(nn.Module):
45    """Diffusion model viewed as multi-scale DAE."""
46
47    def __init__(self, T: int = 1000):
48        super().__init__()
49        self.T = T
50
51        # Time-conditioned denoiser
52        self.net = nn.Sequential(
53            nn.Linear(784 + 128, 512),  # +128 for time embedding
54            nn.ReLU(),
55            nn.Linear(512, 512),
56            nn.ReLU(),
57            nn.Linear(512, 784),
58        )
59
60        # Noise schedule
61        betas = torch.linspace(1e-4, 0.02, T)
62        self.register_buffer("alpha_bar", torch.cumprod(1 - betas, dim=0))
63
64    def time_embedding(self, t: torch.Tensor) -> torch.Tensor:
65        """Simple sinusoidal time embedding."""
66        half_dim = 64
67        emb = torch.exp(
68            torch.arange(half_dim, device=t.device) * -torch.log(torch.tensor(10000.0)) / half_dim
69        )
70        emb = t.float().unsqueeze(1) * emb.unsqueeze(0)
71        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
72
73    def forward(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
74        """Predict noise (equivalent to predicting x_0 via Tweedie)."""
75        t_emb = self.time_embedding(t)
76        x_input = torch.cat([x_t, t_emb], dim=1)
77        return self.net(x_input)
78
79    def loss(self, x_0: torch.Tensor) -> torch.Tensor:
80        """Multi-scale DAE loss (diffusion training)."""
81        batch_size = x_0.shape[0]
82
83        # Sample random timesteps (noise levels)
84        t = torch.randint(0, self.T, (batch_size,), device=x_0.device)
85
86        # Get noise level for each timestep
87        alpha_bar_t = self.alpha_bar[t].unsqueeze(1)
88
89        # Add noise (different sigma for each sample)
90        noise = torch.randn_like(x_0)
91        x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * noise
92
93        # Predict noise
94        noise_pred = self(x_t, t)
95
96        # MSE loss on noise prediction
97        return F.mse_loss(noise_pred, noise)
98
99    def get_score(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
100        """Get score estimate at noise level t."""
101        # sigma_t = sqrt(1 - alpha_bar_t)
102        alpha_bar_t = self.alpha_bar[t].unsqueeze(1)
103        sigma_t = torch.sqrt(1 - alpha_bar_t)
104
105        # Predict noise
106        noise_pred = self(x_t, t)
107
108        # score = -noise / sigma
109        return -noise_pred / sigma_t
110
111
112# Demonstrate equivalence
113def show_dae_diffusion_connection():
114    """Show that diffusion training is multi-scale DAE training."""
115    x = torch.randn(32, 784)  # Batch of flattened 28x28 images
116
117    # Classical DAE at one noise level
118    dae = ClassicalDAE(sigma=0.5)
119    dae_loss = dae.loss(x)
120    print(f"Classical DAE loss (sigma=0.5): {dae_loss.item():.4f}")
121
122    # Diffusion (multi-scale DAE)
123    diffusion = DiffusionAsMultiScaleDAE(T=1000)
124    diff_loss = diffusion.loss(x)
125    print(f"Diffusion loss (multi-scale): {diff_loss.item():.4f}")
126
127    # Show score estimation equivalence
128    print("\nScore estimation comparison:")
129    print("Both estimate: nabla_x log p(x) = -epsilon / sigma")
130
131
132if __name__ == "__main__":
133    show_dae_diffusion_connection()

Historical Context

Understanding the historical development illuminates why diffusion models work and how they relate to earlier generative models.

Timeline of Key Developments

YearDevelopmentKey Insight
2006Deep belief networksLayer-wise pretraining
2008Denoising autoencodersCorruption improves representations
2010Contractive autoencodersPenalize sensitivity to input
2011DAE-score connectionDenoising estimates score
2014VAEs, GANs emergeDirect generative modeling
2015Deep diffusion modelsThermodynamic connection
2019NCSN (score matching)Langevin dynamics sampling
2020DDPMSimple loss, multi-scale DAE
2021Score SDEUnified continuous framework

The Intellectual Lineage

Diffusion models sit at the intersection of several research threads:

  1. Autoencoders and representation learning: DAEs showed that denoising creates useful representations. Diffusion extends this to generative modeling.
  2. Score matching: The work on estimating score functions without knowing the normalizing constant directly enables training on unnormalized densities.
  3. Statistical physics: The diffusion process has direct analogies to non-equilibrium thermodynamics and Langevin dynamics.
  4. Variational inference: The ELBO derivation connects diffusion to the VAE framework, providing principled likelihood bounds.
Synthesis: DDPM's genius was recognizing that these threads could be unified: train a multi-scale DAE with a simple noise prediction loss, and you get a generative model with strong theoretical foundations and excellent sample quality.

Implementation

Let's implement a complete example showing how to use the DAE-score connection for generation:

🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from typing import Optional
5
6class ScoreBasedDenoiser(nn.Module):
7    """
8    Denoiser that explicitly computes score function.
9
10    This implementation makes the DAE-score connection explicit,
11    supporting both denoising and score-based sampling.
12    """
13
14    def __init__(
15        self,
16        input_dim: int,
17        hidden_dim: int = 256,
18        num_layers: int = 3,
19        T: int = 1000,
20    ):
21        super().__init__()
22        self.T = T
23        self.input_dim = input_dim
24
25        # Build network
26        layers = [nn.Linear(input_dim + 128, hidden_dim), nn.SiLU()]
27        for _ in range(num_layers - 1):
28            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.SiLU()])
29        layers.append(nn.Linear(hidden_dim, input_dim))
30        self.net = nn.Sequential(*layers)
31
32        # Noise schedule
33        betas = torch.linspace(1e-4, 0.02, T)
34        alphas = 1 - betas
35        alpha_bar = torch.cumprod(alphas, dim=0)
36
37        self.register_buffer("betas", betas)
38        self.register_buffer("alphas", alphas)
39        self.register_buffer("alpha_bar", alpha_bar)
40        self.register_buffer("sigma", torch.sqrt(1 - alpha_bar))
41
42    def time_embedding(self, t: torch.Tensor) -> torch.Tensor:
43        """Sinusoidal time embedding."""
44        half_dim = 64
45        emb = torch.exp(
46            torch.arange(half_dim, device=t.device, dtype=torch.float32)
47            * (-torch.log(torch.tensor(10000.0)) / half_dim)
48        )
49        emb = t.float().unsqueeze(1) * emb.unsqueeze(0)
50        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
51
52    def predict_noise(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
53        """Predict the noise component."""
54        t_emb = self.time_embedding(t)
55        return self.net(torch.cat([x_t, t_emb], dim=1))
56
57    def predict_x0(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
58        """
59        Predict clean data via Tweedie's formula.
60
61        x_0 = (x_t - sigma_t * epsilon) / sqrt(alpha_bar_t)
62        """
63        alpha_bar_t = self.alpha_bar[t].unsqueeze(1)
64        sigma_t = self.sigma[t].unsqueeze(1)
65
66        epsilon_pred = self.predict_noise(x_t, t)
67        x0_pred = (x_t - sigma_t * epsilon_pred) / torch.sqrt(alpha_bar_t)
68
69        return x0_pred
70
71    def get_score(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
72        """
73        Get score function estimate.
74
75        score = -epsilon / sigma = nabla_x log p(x_t)
76        """
77        sigma_t = self.sigma[t].unsqueeze(1)
78        epsilon_pred = self.predict_noise(x_t, t)
79        return -epsilon_pred / sigma_t
80
81    def denoise(self, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
82        """
83        Classical denoising: predict clean data.
84
85        This is the DAE operation.
86        """
87        return self.predict_x0(x_t, t)
88
89    def training_loss(self, x_0: torch.Tensor) -> dict[str, torch.Tensor]:
90        """
91        Compute training loss with diagnostics.
92
93        Returns both the loss and useful diagnostic values.
94        """
95        batch_size = x_0.shape[0]
96        device = x_0.device
97
98        # Sample timesteps
99        t = torch.randint(0, self.T, (batch_size,), device=device)
100
101        # Add noise
102        noise = torch.randn_like(x_0)
103        alpha_bar_t = self.alpha_bar[t].unsqueeze(1)
104        x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * noise
105
106        # Predict noise
107        noise_pred = self.predict_noise(x_t, t)
108
109        # Compute loss
110        loss = F.mse_loss(noise_pred, noise)
111
112        # Compute x0 prediction for diagnostics
113        with torch.no_grad():
114            x0_pred = self.predict_x0(x_t, t)
115            x0_mse = F.mse_loss(x0_pred, x_0)
116
117        return {
118            "loss": loss,
119            "noise_mse": loss,
120            "x0_mse": x0_mse,
121        }
122
123
124def langevin_sampling(
125    score_fn,
126    shape: tuple,
127    n_steps: int = 1000,
128    step_size: float = 0.01,
129    noise_scale: float = 1.0,
130    device: str = "cpu",
131) -> torch.Tensor:
132    """
133    Sample using Langevin dynamics with learned score.
134
135    This demonstrates the score-sampling connection from DAE training.
136
137    Args:
138        score_fn: Function that takes x and returns score estimate
139        shape: Shape of samples to generate
140        n_steps: Number of Langevin steps
141        step_size: Step size epsilon
142        noise_scale: Scale of injected noise
143
144    Returns:
145        Generated samples
146    """
147    # Initialize from noise
148    x = torch.randn(shape, device=device)
149
150    for _ in range(n_steps):
151        # Get score estimate
152        score = score_fn(x)
153
154        # Langevin update: x = x + epsilon * score + sqrt(2*epsilon) * noise
155        noise = torch.randn_like(x)
156        x = x + step_size * score + noise_scale * torch.sqrt(2 * torch.tensor(step_size)) * noise
157
158    return x
159
160
161# Example: Training and sampling with the DAE-score connection
162def dae_score_example():
163    """Demonstrate the complete DAE-score-generation pipeline."""
164    # Create model
165    model = ScoreBasedDenoiser(input_dim=784, hidden_dim=256, T=1000)
166
167    # Simulate some training data (2D Gaussian mixture flattened)
168    def sample_data(n: int) -> torch.Tensor:
169        # Simple mixture of 4 Gaussians
170        centers = torch.tensor([[-2, -2], [-2, 2], [2, -2], [2, 2]]).float()
171        idx = torch.randint(0, 4, (n,))
172        samples = centers[idx] + 0.3 * torch.randn(n, 2)
173        # Pad to 784 dimensions (like MNIST)
174        return F.pad(samples, (0, 782))
175
176    # Training loop (simplified)
177    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
178
179    print("Training denoising autoencoder (score estimator)...")
180    for step in range(100):
181        x_0 = sample_data(64)
182        loss_dict = model.training_loss(x_0)
183
184        optimizer.zero_grad()
185        loss_dict["loss"].backward()
186        optimizer.step()
187
188        if step % 20 == 0:
189            print(f"Step {step}: noise_mse={loss_dict['noise_mse']:.4f}, "
190                  f"x0_mse={loss_dict['x0_mse']:.4f}")
191
192    print("\nDAE training complete!")
193    print("The same model can now be used for:")
194    print("1. Denoising (classical DAE operation)")
195    print("2. Score estimation (nabla log p)")
196    print("3. Generation (via Langevin or diffusion sampling)")
197
198
199if __name__ == "__main__":
200    dae_score_example()

Key Takeaways

  1. DAEs learn the score: The optimal denoiser implicitly estimates the score function via Tweedie's formula
  2. Diffusion = multi-scale DAE: Training across all noise levels creates a complete multi-scale score estimator
  3. Three equivalent views: Noise prediction, denoising, and score estimation are mathematically equivalent operations
  4. Historical synthesis: Diffusion models unify ideas from autoencoders, score matching, and statistical physics
  5. Practical benefit: Understanding the DAE connection helps interpret what the network learns at different noise levels
Looking Ahead: In the final section of this chapter, we'll conduct a numerical analysis of the loss function, examining gradient behavior, convergence properties, and practical debugging strategies.