Chapter 4
18 min read
Section 23 of 76

Numerical Analysis of the Loss

Understanding the Loss Function

Learning Objectives

By the end of this section, you will be able to:

  1. Analyze the loss landscape of diffusion models and understand how it varies across timesteps
  2. Diagnose gradient behavior issues and implement solutions for stable training
  3. Apply numerical stability techniques specific to diffusion model training
  4. Debug common training problems using systematic analysis tools

Understanding the Loss Landscape

The diffusion loss landscape has unique characteristics that differ from standard supervised learning. Understanding these properties helps explain training dynamics and informs hyperparameter choices.

Loss Decomposition by Timestep

The total loss is an average over timesteps, but the per-timestep loss varies dramatically:

L=1Tt=1TLt=1Tt=1TE[ϵϵθ(xt,t)2]L = \frac{1}{T} \sum_{t=1}^{T} L_t = \frac{1}{T} \sum_{t=1}^{T} \mathbb{E}\left[ \|\boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta(\mathbf{x}_t, t)\|^2 \right]

At the optimal solution, what should LtL_tlook like? This depends on the irreducible noise in each denoising task.

TimestepNoise LevelOptimal Loss Behavior
t = 1 (low)sigma ~ 0.01Very low - predicting ~zero noise
t ~ T/2 (mid)sigma ~ 0.5Moderate - balanced signal/noise
t = T (high)sigma ~ 1.0Higher - predicting from mostly noise

Expected Loss at Optimum

For an optimal network, the expected MSE loss on noise prediction equals the conditional variance of the noise given the noisy input:

Lt=E[Var[ϵxt]]L_t^* = \mathbb{E}\left[ \text{Var}[\boldsymbol{\epsilon} | \mathbf{x}_t] \right]

This is related to the Bayes optimal error - the inherent uncertainty that cannot be reduced even with a perfect predictor. For diffusion:

Ltd(1αˉt1αˉt+αˉt/σdata2)L_t^* \approx d \cdot \left(1 - \frac{\bar{\alpha}_t}{1 - \bar{\alpha}_t + \bar{\alpha}_t/\sigma_{\text{data}}^2}\right)

where dd is the data dimensionality and σdata2\sigma_{\text{data}}^2 is the data variance.

Practical Insight: If your per-timestep loss for early timesteps (small tt) is much larger than expected, the network may be struggling with the conditioning mechanism or time embedding.

Gradient Behavior Analysis

Understanding gradient flow in diffusion training is crucial for diagnosing and fixing training issues.

Gradient Magnitude Across Timesteps

The gradient magnitude varies significantly with timestep:

θLtϵϵθϵθθ\|\nabla_\theta L_t\| \propto \|\boldsymbol{\epsilon} - \boldsymbol{\epsilon}_\theta\| \cdot \left\|\frac{\partial \boldsymbol{\epsilon}_\theta}{\partial \theta}\right\|

Several factors affect this:

  1. Prediction difficulty: Harder timesteps (high noise) typically have larger prediction errors
  2. Input scale: The scale of xt\mathbf{x}_taffects gradient magnitudes through the network
  3. Time embedding influence: How strongly the time embedding modulates the network

Gradient Imbalance Problem

Without careful design, gradients from different timesteps can be severely imbalanced:

IssueSymptomSolution
Early timestep dominationFast loss decrease, then plateauSNR weighting or v-prediction
Late timestep gradient explosionTraining instabilityGradient clipping, smaller lr
Poor time conditioningLoss similar across timestepsBetter time embedding
Dead units at certain timestepsPer-timestep loss stuckResidual connections

Monitoring Gradient Health

Key metrics to track during training:

🐍python
1import torch
2from collections import defaultdict
3from typing import Optional
4
5class GradientMonitor:
6    """
7    Monitor gradient statistics during diffusion training.
8
9    Tracks per-timestep gradient magnitudes to diagnose training issues.
10    """
11
12    def __init__(self, T: int, num_timestep_bins: int = 10):
13        self.T = T
14        self.num_bins = num_timestep_bins
15        self.bin_size = T // num_timestep_bins
16
17        # Statistics storage
18        self.grad_norms: dict[int, list[float]] = defaultdict(list)
19        self.loss_values: dict[int, list[float]] = defaultdict(list)
20
21    def get_bin(self, t: int) -> int:
22        """Get timestep bin index."""
23        return min(t // self.bin_size, self.num_bins - 1)
24
25    def record_gradients(
26        self,
27        model: torch.nn.Module,
28        timesteps: torch.Tensor,
29        per_sample_loss: torch.Tensor,
30    ):
31        """
32        Record gradient statistics after backward pass.
33
34        Call this after loss.backward() but before optimizer.step().
35        """
36        # Compute total gradient norm
37        total_norm = 0.0
38        for p in model.parameters():
39            if p.grad is not None:
40                total_norm += p.grad.data.norm(2).item() ** 2
41        total_norm = total_norm ** 0.5
42
43        # Record per-timestep
44        for t, loss in zip(timesteps.tolist(), per_sample_loss.tolist()):
45            bin_idx = self.get_bin(t)
46            self.grad_norms[bin_idx].append(total_norm / len(timesteps))
47            self.loss_values[bin_idx].append(loss)
48
49    def get_statistics(self) -> dict:
50        """Get summary statistics."""
51        stats = {}
52        for bin_idx in range(self.num_bins):
53            if self.grad_norms[bin_idx]:
54                t_start = bin_idx * self.bin_size
55                t_end = (bin_idx + 1) * self.bin_size
56                stats[f"t_{t_start}-{t_end}"] = {
57                    "mean_grad_norm": sum(self.grad_norms[bin_idx]) / len(self.grad_norms[bin_idx]),
58                    "mean_loss": sum(self.loss_values[bin_idx]) / len(self.loss_values[bin_idx]),
59                    "num_samples": len(self.grad_norms[bin_idx]),
60                }
61        return stats
62
63    def print_report(self):
64        """Print formatted gradient report."""
65        stats = self.get_statistics()
66        print("\nGradient Analysis Report")
67        print("=" * 60)
68        print(f"{'Timestep Range':<20} {'Grad Norm':<15} {'Loss':<15}")
69        print("-" * 60)
70        for key, values in sorted(stats.items()):
71            print(f"{key:<20} {values['mean_grad_norm']:<15.4f} {values['mean_loss']:<15.4f}")
72
73    def clear(self):
74        """Reset statistics."""
75        self.grad_norms.clear()
76        self.loss_values.clear()

Numerical Stability Considerations

Diffusion training involves several operations prone to numerical issues. Let's examine each and discuss solutions.

Schedule Value Computation

Computing αˉt\bar{\alpha}_t as a cumulative product can lead to underflow for large TT:

αˉT=s=1T(1βs)104 to 106\bar{\alpha}_T = \prod_{s=1}^{T} (1 - \beta_s) \approx 10^{-4} \text{ to } 10^{-6}

🐍python
1import torch
2
3def compute_schedule_stable(betas: torch.Tensor) -> dict[str, torch.Tensor]:
4    """
5    Compute noise schedule with numerical stability.
6
7    Uses log-space computation to avoid underflow.
8    """
9    # Standard computation (can underflow)
10    alphas = 1.0 - betas
11    alpha_bar = torch.cumprod(alphas, dim=0)
12
13    # Log-space computation (more stable)
14    log_alphas = torch.log(alphas)
15    log_alpha_bar = torch.cumsum(log_alphas, dim=0)
16    alpha_bar_stable = torch.exp(log_alpha_bar)
17
18    # Clamp to avoid exact zeros
19    alpha_bar_stable = alpha_bar_stable.clamp(min=1e-10)
20
21    # Derived quantities
22    sqrt_alpha_bar = torch.sqrt(alpha_bar_stable)
23    sqrt_one_minus_alpha_bar = torch.sqrt(1.0 - alpha_bar_stable)
24
25    # SNR (also in log space for stability)
26    log_snr = log_alpha_bar - torch.log(1.0 - alpha_bar_stable)
27    snr = torch.exp(log_snr.clamp(max=20))  # Clamp to avoid overflow
28
29    return {
30        "betas": betas,
31        "alphas": alphas,
32        "alpha_bar": alpha_bar_stable,
33        "sqrt_alpha_bar": sqrt_alpha_bar,
34        "sqrt_one_minus_alpha_bar": sqrt_one_minus_alpha_bar,
35        "log_snr": log_snr,
36        "snr": snr,
37    }

Loss Computation Stability

The MSE loss itself is generally stable, but per-sample losses can have high variance. Use these techniques:

🐍python
1import torch
2import torch.nn.functional as F
3
4def stable_mse_loss(
5    pred: torch.Tensor,
6    target: torch.Tensor,
7    reduction: str = "mean",
8    clip_max: float = 100.0,
9) -> torch.Tensor:
10    """
11    Numerically stable MSE loss with optional clipping.
12
13    Large prediction errors can destabilize training. This function
14    optionally clips the per-element squared error.
15    """
16    # Compute per-element squared error
17    squared_error = (pred - target) ** 2
18
19    # Optional: clip extreme values
20    if clip_max is not None:
21        squared_error = squared_error.clamp(max=clip_max)
22
23    # Reduce
24    if reduction == "mean":
25        return squared_error.mean()
26    elif reduction == "sum":
27        return squared_error.sum()
28    elif reduction == "none":
29        return squared_error
30    else:
31        raise ValueError(f"Unknown reduction: {reduction}")
32
33
34def huber_diffusion_loss(
35    pred: torch.Tensor,
36    target: torch.Tensor,
37    delta: float = 1.0,
38) -> torch.Tensor:
39    """
40    Huber loss variant for diffusion - robust to outliers.
41
42    Uses L2 for small errors, L1 for large errors.
43    Helpful when model occasionally makes very large errors.
44    """
45    return F.huber_loss(pred, target, delta=delta, reduction="mean")

Gradient Clipping Strategies

Gradient clipping is essential for stable diffusion training:

MethodWhen to UseTypical Value
Global norm clipDefault choice1.0 - 5.0
Per-parameter clipMixed precision1.0
Adaptive (AdaGrad-style)Varying gradient scalesN/A
🐍python
1import torch
2
3def clip_gradients(
4    model: torch.nn.Module,
5    max_norm: float = 1.0,
6    norm_type: float = 2.0,
7    log_overflow: bool = True,
8) -> dict:
9    """
10    Clip gradients and return diagnostics.
11
12    Args:
13        model: Model with computed gradients
14        max_norm: Maximum gradient norm
15        norm_type: Type of norm (2 = L2, inf = max)
16        log_overflow: Whether to track clipping events
17
18    Returns:
19        Dict with gradient statistics
20    """
21    # Compute gradient norm before clipping
22    total_norm = torch.nn.utils.clip_grad_norm_(
23        model.parameters(),
24        max_norm=max_norm,
25        norm_type=norm_type,
26    )
27
28    # Track statistics
29    stats = {
30        "grad_norm": total_norm.item() if torch.is_tensor(total_norm) else total_norm,
31        "clipped": total_norm > max_norm if torch.is_tensor(total_norm) else False,
32    }
33
34    if log_overflow and stats["clipped"]:
35        stats["clip_ratio"] = max_norm / stats["grad_norm"]
36
37    return stats

Debugging Training

When diffusion training fails or produces poor results, systematic debugging is essential. Here are common issues and diagnostic approaches.

Common Training Failures

SymptomLikely CauseDiagnostic
Loss explodesLearning rate too high, gradient issuesMonitor grad norms per timestep
Loss stuck highTime embedding broken, architecture issueCheck loss per timestep
Good loss, bad samplesSampling bug, schedule mismatchVerify sampling matches training
Blurry samplesUndertrained, too few stepsTrain longer, check per-t loss
Artifacts at t~THigh-noise prediction failingIncrease capacity for large t

Diagnostic Training Loop

🐍python
1import torch
2import torch.nn as nn
3from dataclasses import dataclass, field
4from typing import Optional
5
6@dataclass
7class TrainingDiagnostics:
8    """Container for training diagnostics."""
9    step: int = 0
10    loss_history: list[float] = field(default_factory=list)
11    per_timestep_loss: dict[int, list[float]] = field(default_factory=dict)
12    grad_norms: list[float] = field(default_factory=list)
13    clip_events: int = 0
14
15    def add_timestep_loss(self, t: int, loss: float):
16        if t not in self.per_timestep_loss:
17            self.per_timestep_loss[t] = []
18        self.per_timestep_loss[t].append(loss)
19
20
21def diagnostic_training_step(
22    model: nn.Module,
23    optimizer: torch.optim.Optimizer,
24    x_0: torch.Tensor,
25    noise_schedule: dict,
26    diagnostics: TrainingDiagnostics,
27    grad_clip: float = 1.0,
28) -> dict:
29    """
30    Training step with comprehensive diagnostics.
31
32    Returns detailed information about the training dynamics.
33    """
34    batch_size = x_0.shape[0]
35    device = x_0.device
36    T = len(noise_schedule["alpha_bar"])
37
38    # Sample timesteps - stratified for better coverage
39    t = torch.randint(0, T, (batch_size,), device=device)
40
41    # Sample noise
42    epsilon = torch.randn_like(x_0)
43
44    # Create noisy input
45    alpha_bar_t = noise_schedule["alpha_bar"][t]
46    while alpha_bar_t.dim() < x_0.dim():
47        alpha_bar_t = alpha_bar_t.unsqueeze(-1)
48
49    x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * epsilon
50
51    # Forward pass
52    epsilon_pred = model(x_t, t)
53
54    # Per-sample loss (before reduction)
55    per_sample_loss = ((epsilon_pred - epsilon) ** 2).mean(dim=tuple(range(1, epsilon.dim())))
56
57    # Record per-timestep losses
58    for ti, loss_i in zip(t.tolist(), per_sample_loss.tolist()):
59        diagnostics.add_timestep_loss(ti, loss_i)
60
61    # Total loss
62    loss = per_sample_loss.mean()
63
64    # Backward pass
65    optimizer.zero_grad()
66    loss.backward()
67
68    # Gradient analysis (before clipping)
69    grad_norm_before = 0.0
70    for p in model.parameters():
71        if p.grad is not None:
72            grad_norm_before += p.grad.norm(2).item() ** 2
73    grad_norm_before = grad_norm_before ** 0.5
74
75    # Gradient clipping
76    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
77
78    # Check if clipped
79    if grad_norm_before > grad_clip:
80        diagnostics.clip_events += 1
81
82    # Optimizer step
83    optimizer.step()
84
85    # Update diagnostics
86    diagnostics.step += 1
87    diagnostics.loss_history.append(loss.item())
88    diagnostics.grad_norms.append(grad_norm_before)
89
90    return {
91        "loss": loss.item(),
92        "grad_norm": grad_norm_before,
93        "clipped": grad_norm_before > grad_clip,
94        "timesteps_sampled": t.tolist(),
95        "per_sample_loss_mean": per_sample_loss.mean().item(),
96        "per_sample_loss_std": per_sample_loss.std().item(),
97    }
98
99
100def analyze_diagnostics(diagnostics: TrainingDiagnostics, T: int) -> dict:
101    """
102    Analyze collected diagnostics to identify issues.
103
104    Returns a report with potential problems and recommendations.
105    """
106    report = {"issues": [], "recommendations": []}
107
108    # Check loss trend
109    if len(diagnostics.loss_history) > 100:
110        recent = diagnostics.loss_history[-100:]
111        early = diagnostics.loss_history[:100]
112        if sum(recent) / len(recent) > sum(early) / len(early):
113            report["issues"].append("Loss increasing over time")
114            report["recommendations"].append("Reduce learning rate")
115
116    # Check gradient clipping frequency
117    if diagnostics.step > 0:
118        clip_rate = diagnostics.clip_events / diagnostics.step
119        if clip_rate > 0.5:
120            report["issues"].append(f"High gradient clipping rate: {clip_rate:.2%}")
121            report["recommendations"].append("Reduce learning rate or increase clip threshold")
122
123    # Check per-timestep loss variance
124    if diagnostics.per_timestep_loss:
125        timestep_means = {}
126        for t, losses in diagnostics.per_timestep_loss.items():
127            if len(losses) >= 10:
128                timestep_means[t] = sum(losses) / len(losses)
129
130        if timestep_means:
131            mean_loss = sum(timestep_means.values()) / len(timestep_means)
132            max_loss = max(timestep_means.values())
133            min_loss = min(timestep_means.values())
134
135            if max_loss > 5 * min_loss:
136                report["issues"].append("Large variance in per-timestep loss")
137                report["recommendations"].append("Consider SNR-based weighting")
138
139    # Check for stuck loss
140    if len(diagnostics.loss_history) > 500:
141        recent = diagnostics.loss_history[-100:]
142        variance = sum((x - sum(recent)/len(recent))**2 for x in recent) / len(recent)
143        if variance < 1e-6:
144            report["issues"].append("Loss appears stuck")
145            report["recommendations"].append("Check time conditioning, increase learning rate")
146
147    return report

Implementation

Let's put together a complete training script with all stability and diagnostic features:

🐍python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4from torch.utils.data import DataLoader
5from dataclasses import dataclass
6from typing import Optional, Callable
7import logging
8
9logging.basicConfig(level=logging.INFO)
10logger = logging.getLogger(__name__)
11
12@dataclass
13class StableTrainingConfig:
14    """Configuration for numerically stable diffusion training."""
15    # Optimization
16    learning_rate: float = 1e-4
17    weight_decay: float = 0.0
18    grad_clip: float = 1.0
19    warmup_steps: int = 1000
20
21    # Stability
22    use_ema: bool = True
23    ema_decay: float = 0.9999
24    mixed_precision: bool = False
25    loss_clip: Optional[float] = 100.0
26
27    # Monitoring
28    log_interval: int = 100
29    eval_interval: int = 1000
30    checkpoint_interval: int = 10000
31
32
33class EMAModel:
34    """Exponential Moving Average of model parameters."""
35
36    def __init__(self, model: nn.Module, decay: float = 0.9999):
37        self.decay = decay
38        self.shadow = {}
39        self.backup = {}
40
41        for name, param in model.named_parameters():
42            if param.requires_grad:
43                self.shadow[name] = param.data.clone()
44
45    def update(self, model: nn.Module):
46        for name, param in model.named_parameters():
47            if param.requires_grad:
48                self.shadow[name] = (
49                    self.decay * self.shadow[name] + (1 - self.decay) * param.data
50                )
51
52    def apply_shadow(self, model: nn.Module):
53        for name, param in model.named_parameters():
54            if param.requires_grad:
55                self.backup[name] = param.data.clone()
56                param.data = self.shadow[name]
57
58    def restore(self, model: nn.Module):
59        for name, param in model.named_parameters():
60            if param.requires_grad:
61                param.data = self.backup[name]
62
63
64class StableDiffusionTrainer:
65    """
66    Numerically stable diffusion model trainer.
67
68    Implements all best practices for stable training:
69    - Gradient clipping
70    - EMA
71    - Learning rate warmup
72    - Per-timestep loss monitoring
73    - Automatic issue detection
74    """
75
76    def __init__(
77        self,
78        model: nn.Module,
79        noise_schedule: dict,
80        config: StableTrainingConfig,
81    ):
82        self.model = model
83        self.config = config
84        self.device = next(model.parameters()).device
85
86        # Store schedule
87        self.alpha_bar = noise_schedule["alpha_bar"].to(self.device)
88        self.T = len(self.alpha_bar)
89
90        # Optimizer with warmup
91        self.optimizer = torch.optim.AdamW(
92            model.parameters(),
93            lr=config.learning_rate,
94            weight_decay=config.weight_decay,
95        )
96
97        # EMA
98        self.ema = EMAModel(model, config.ema_decay) if config.use_ema else None
99
100        # Diagnostics
101        self.diagnostics = TrainingDiagnostics()
102        self.step = 0
103
104    def get_lr_scale(self) -> float:
105        """Compute learning rate scale for warmup."""
106        if self.step < self.config.warmup_steps:
107            return self.step / self.config.warmup_steps
108        return 1.0
109
110    def train_step(self, x_0: torch.Tensor) -> dict:
111        """Execute single training step with full stability measures."""
112        self.model.train()
113        batch_size = x_0.shape[0]
114
115        # Learning rate warmup
116        lr_scale = self.get_lr_scale()
117        for pg in self.optimizer.param_groups:
118            pg["lr"] = self.config.learning_rate * lr_scale
119
120        # Sample timesteps
121        t = torch.randint(0, self.T, (batch_size,), device=self.device)
122
123        # Sample noise
124        epsilon = torch.randn_like(x_0)
125
126        # Create noisy input
127        alpha_bar_t = self.alpha_bar[t].view(-1, *([1] * (x_0.dim() - 1)))
128        x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * epsilon
129
130        # Forward pass
131        epsilon_pred = self.model(x_t, t)
132
133        # Per-sample loss
134        per_sample_loss = ((epsilon_pred - epsilon) ** 2).flatten(1).mean(1)
135
136        # Optional loss clipping for stability
137        if self.config.loss_clip is not None:
138            per_sample_loss = per_sample_loss.clamp(max=self.config.loss_clip)
139
140        loss = per_sample_loss.mean()
141
142        # Backward
143        self.optimizer.zero_grad()
144        loss.backward()
145
146        # Gradient clipping
147        grad_norm = torch.nn.utils.clip_grad_norm_(
148            self.model.parameters(),
149            self.config.grad_clip,
150        )
151
152        # Optimizer step
153        self.optimizer.step()
154
155        # EMA update
156        if self.ema is not None:
157            self.ema.update(self.model)
158
159        # Update step counter
160        self.step += 1
161
162        # Record diagnostics
163        self.diagnostics.step = self.step
164        self.diagnostics.loss_history.append(loss.item())
165        self.diagnostics.grad_norms.append(
166            grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm
167        )
168
169        return {
170            "loss": loss.item(),
171            "grad_norm": grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm,
172            "lr": self.config.learning_rate * lr_scale,
173            "step": self.step,
174        }
175
176    def evaluate(self, dataloader: DataLoader) -> dict:
177        """Evaluate model with EMA weights."""
178        self.model.eval()
179
180        # Apply EMA weights
181        if self.ema is not None:
182            self.ema.apply_shadow(self.model)
183
184        total_loss = 0.0
185        num_batches = 0
186
187        per_timestep_loss = torch.zeros(self.T, device=self.device)
188        per_timestep_count = torch.zeros(self.T, device=self.device)
189
190        with torch.no_grad():
191            for x_0 in dataloader:
192                x_0 = x_0.to(self.device)
193                batch_size = x_0.shape[0]
194
195                # Sample all timesteps for comprehensive evaluation
196                t = torch.randint(0, self.T, (batch_size,), device=self.device)
197                epsilon = torch.randn_like(x_0)
198
199                alpha_bar_t = self.alpha_bar[t].view(-1, *([1] * (x_0.dim() - 1)))
200                x_t = torch.sqrt(alpha_bar_t) * x_0 + torch.sqrt(1 - alpha_bar_t) * epsilon
201
202                epsilon_pred = self.model(x_t, t)
203                per_sample_loss = ((epsilon_pred - epsilon) ** 2).flatten(1).mean(1)
204
205                # Accumulate per-timestep statistics
206                for ti, loss_i in zip(t, per_sample_loss):
207                    per_timestep_loss[ti] += loss_i
208                    per_timestep_count[ti] += 1
209
210                total_loss += per_sample_loss.sum().item()
211                num_batches += batch_size
212
213        # Restore original weights
214        if self.ema is not None:
215            self.ema.restore(self.model)
216
217        # Compute per-timestep means
218        mask = per_timestep_count > 0
219        per_timestep_mean = torch.zeros_like(per_timestep_loss)
220        per_timestep_mean[mask] = per_timestep_loss[mask] / per_timestep_count[mask]
221
222        return {
223            "eval_loss": total_loss / num_batches,
224            "per_timestep_loss": per_timestep_mean.cpu().tolist(),
225        }
226
227
228# Example usage
229def train_diffusion_model():
230    """Complete training example with stability features."""
231    # Setup
232    T = 1000
233    betas = torch.linspace(1e-4, 0.02, T)
234    alpha_bar = torch.cumprod(1 - betas, dim=0)
235    noise_schedule = {"alpha_bar": alpha_bar}
236
237    config = StableTrainingConfig(
238        learning_rate=1e-4,
239        grad_clip=1.0,
240        warmup_steps=1000,
241        use_ema=True,
242    )
243
244    # Create model (placeholder)
245    model = nn.Sequential(
246        nn.Linear(784 + 128, 512),
247        nn.SiLU(),
248        nn.Linear(512, 784),
249    )
250
251    trainer = StableDiffusionTrainer(model, noise_schedule, config)
252
253    logger.info("Starting stable diffusion training")
254    logger.info(f"Config: {config}")
255
256    # Training loop would go here
257    # for epoch in range(num_epochs):
258    #     for batch in dataloader:
259    #         metrics = trainer.train_step(batch)
260    #         if trainer.step % config.log_interval == 0:
261    #             logger.info(f"Step {trainer.step}: {metrics}")
262
263    return trainer
264
265
266if __name__ == "__main__":
267    train_diffusion_model()

Key Takeaways

  1. Loss varies by timestep: Early timesteps (low noise) should have lower loss than late timesteps (high noise) at convergence
  2. Monitor gradients per-timestep: Gradient imbalance across timesteps is a common source of training issues
  3. Use log-space for schedule: Compute αˉt\bar{\alpha}_t in log space to avoid underflow
  4. Gradient clipping is essential: Use global norm clipping with a threshold around 1.0
  5. Diagnostic tools accelerate debugging: Systematic monitoring of per-timestep loss and gradients identifies issues quickly
  6. EMA improves sample quality: Exponential moving average of weights typically produces better samples
Chapter Summary: We've now covered the complete theory of diffusion training objectives - from the simplified loss and weighting strategies to the deep connections with denoising autoencoders and practical numerical considerations. With this foundation, you're ready to implement and debug diffusion models effectively.