Chapter 11
15 min read
Section 53 of 76

Training Monitoring

Training the Model

Learning Objectives

By the end of this section, you will be able to:

  1. Track training loss curves and identify healthy vs problematic training dynamics
  2. Compute FID during training to monitor generation quality
  3. Visualize generated samples at different training stages
  4. Build comprehensive monitoring dashboards using TensorBoard and Weights & Biases

The Big Picture

Unlike classification models where validation accuracy directly measures performance, diffusion models require more nuanced monitoring. The training loss measures how well the model predicts noise, but this doesn't directly tell you about generation quality.

The Monitoring Challenge: A model can have low training loss but still produce poor-quality images. Conversely, a slightly higher loss might produce visually superior results. You need multiple metrics to truly understand training progress.

Interactive Training Dashboard

Step: 50,000 / 100,000
Training Loss
0.3092
FID Score
55.5
Progress
50%

Training Loss

0.00.51.01.50k25k50k75k100kTraining Steps

FID Score (lower is better)

01002003000k25k50k75k100kTarget FID

Generated Samples at Step 50,000

Sample quality improves as training progresses (quality: 50%)

Note: This visualization simulates typical diffusion model training dynamics. The loss curve shows the MSE/L1 loss between predicted and actual noise, while FID measures generation quality against real samples.


Loss Monitoring

Understanding the Loss Curve

The diffusion training loss is the MSE between predicted and actual noise:

🐍python
1import torch
2import numpy as np
3from collections import deque
4
5class LossMonitor:
6    """Monitor training loss with smoothing and anomaly detection."""
7
8    def __init__(self, window_size: int = 100):
9        self.losses = []
10        self.window = deque(maxlen=window_size)
11        self.window_size = window_size
12
13    def update(self, loss: float) -> dict:
14        """Record a loss value and return statistics."""
15        self.losses.append(loss)
16        self.window.append(loss)
17
18        stats = {
19            "loss": loss,
20            "loss_smooth": np.mean(self.window),
21            "loss_std": np.std(self.window) if len(self.window) > 1 else 0,
22        }
23
24        # Detect anomalies
25        if len(self.losses) > self.window_size:
26            recent_mean = np.mean(self.window)
27            recent_std = np.std(self.window)
28            if loss > recent_mean + 3 * recent_std:
29                stats["anomaly"] = "spike"
30            elif loss < recent_mean - 3 * recent_std:
31                stats["anomaly"] = "drop"
32
33        return stats
34
35    def is_converging(self, patience: int = 1000) -> bool:
36        """Check if loss is still decreasing."""
37        if len(self.losses) < patience * 2:
38            return True  # Not enough data
39
40        recent = np.mean(self.losses[-patience:])
41        earlier = np.mean(self.losses[-2*patience:-patience])
42        return recent < earlier * 0.99  # 1% improvement threshold
43
44
45# Usage in training loop
46monitor = LossMonitor()
47
48for epoch in range(num_epochs):
49    for batch in dataloader:
50        loss = train_step(batch)
51        stats = monitor.update(loss.item())
52
53        if "anomaly" in stats:
54            print(f"Warning: Loss {stats['anomaly']} detected!")
55
56    if not monitor.is_converging():
57        print("Training may have converged - consider stopping")

What to Look For

PatternWhat It MeansAction
Steady decreaseHealthy trainingContinue training
Plateauing earlyModel too small or LR too lowIncrease capacity or LR
Oscillating wildlyLR too highReduce learning rate
Sudden spikeBad batch or NaNCheck gradients, reduce LR
Gradual increaseOverfitting or schedule issueAdd regularization

FID During Training

FID (Frechet Inception Distance) measures how close generated images are to real images. Computing it during training provides direct feedback on generation quality:

🐍python
1import torch
2import numpy as np
3from scipy import linalg
4from torchvision.models import inception_v3
5from torch.nn.functional import adaptive_avg_pool2d
6
7class FIDCalculator:
8    """Calculate FID between generated and real images."""
9
10    def __init__(self, device: str = "cuda"):
11        self.device = device
12        # Load Inception v3 for feature extraction
13        self.inception = inception_v3(pretrained=True, transform_input=False)
14        self.inception.fc = torch.nn.Identity()  # Remove final layer
15        self.inception = self.inception.to(device).eval()
16
17    @torch.no_grad()
18    def get_features(self, images: torch.Tensor) -> np.ndarray:
19        """Extract Inception features from images."""
20        # Images should be [B, 3, 299, 299] in range [0, 1]
21        # Resize if needed
22        if images.shape[-1] != 299:
23            images = torch.nn.functional.interpolate(
24                images, size=(299, 299), mode='bilinear'
25            )
26
27        # Normalize for Inception
28        images = (images - 0.5) / 0.5
29
30        features = self.inception(images.to(self.device))
31        return features.cpu().numpy()
32
33    def calculate_statistics(self, features: np.ndarray) -> tuple:
34        """Calculate mean and covariance of features."""
35        mu = np.mean(features, axis=0)
36        sigma = np.cov(features, rowvar=False)
37        return mu, sigma
38
39    def calculate_fid(
40        self,
41        real_features: np.ndarray,
42        fake_features: np.ndarray,
43    ) -> float:
44        """Calculate FID from features."""
45        mu1, sigma1 = self.calculate_statistics(real_features)
46        mu2, sigma2 = self.calculate_statistics(fake_features)
47
48        diff = mu1 - mu2
49        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
50
51        if np.iscomplexobj(covmean):
52            covmean = covmean.real
53
54        fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
55        return float(fid)
56
57
58def compute_fid_during_training(
59    model,
60    diffusion,
61    real_loader,
62    fid_calc: FIDCalculator,
63    num_samples: int = 1000,
64) -> float:
65    """Compute FID during training."""
66    model.eval()
67
68    # Collect real features (can cache these)
69    real_features = []
70    for batch in real_loader:
71        if len(real_features) * batch.shape[0] >= num_samples:
72            break
73        images = (batch + 1) / 2  # [-1, 1] -> [0, 1]
74        real_features.append(fid_calc.get_features(images))
75    real_features = np.concatenate(real_features)[:num_samples]
76
77    # Generate samples and collect features
78    fake_features = []
79    batch_size = 32
80    while len(fake_features) * batch_size < num_samples:
81        samples = generate_samples(model, diffusion, batch_size)
82        samples = (samples + 1) / 2  # [-1, 1] -> [0, 1]
83        fake_features.append(fid_calc.get_features(samples))
84    fake_features = np.concatenate(fake_features)[:num_samples]
85
86    model.train()
87    return fid_calc.calculate_fid(real_features, fake_features)

FID Computation Cost

Computing FID requires generating many samples (1000-50000), which is slow. Compute FID every few epochs, not every step. A common pattern is to compute FID every 5-10 epochs.

Sample Visualization

🐍python
1import torch
2import torchvision.utils as vutils
3from pathlib import Path
4import matplotlib.pyplot as plt
5
6class SampleVisualizer:
7    """Visualize and save generated samples during training."""
8
9    def __init__(self, save_dir: str, num_samples: int = 16):
10        self.save_dir = Path(save_dir)
11        self.save_dir.mkdir(parents=True, exist_ok=True)
12        self.num_samples = num_samples
13        self.fixed_noise = None  # For consistent comparison
14
15    @torch.no_grad()
16    def generate_and_save(
17        self,
18        model,
19        diffusion,
20        epoch: int,
21        device: str = "cuda",
22    ):
23        """Generate samples and save as image grid."""
24        model.eval()
25
26        # Use fixed noise for consistent comparison across epochs
27        if self.fixed_noise is None:
28            self.fixed_noise = torch.randn(
29                self.num_samples, 3, 64, 64, device=device
30            )
31
32        # Generate samples
33        samples = self._sample(model, diffusion, self.fixed_noise.clone())
34
35        # Create grid
36        grid = vutils.make_grid(
37            samples,
38            nrow=int(self.num_samples ** 0.5),
39            normalize=True,
40            value_range=(-1, 1),
41        )
42
43        # Save
44        save_path = self.save_dir / f"samples_epoch_{epoch:04d}.png"
45        vutils.save_image(grid, str(save_path))
46
47        model.train()
48        return grid
49
50    def _sample(self, model, diffusion, x):
51        """DDPM sampling loop."""
52        for t in reversed(range(diffusion.timesteps)):
53            t_batch = torch.full((x.shape[0],), t, device=x.device, dtype=torch.long)
54            pred_noise = model(x, t_batch)
55
56            alpha = diffusion.alphas[t]
57            alpha_bar = diffusion.alphas_cumprod[t]
58            beta = diffusion.betas[t]
59
60            noise = torch.randn_like(x) if t > 0 else 0
61            x = (1 / torch.sqrt(alpha)) * (
62                x - beta / torch.sqrt(1 - alpha_bar) * pred_noise
63            ) + torch.sqrt(beta) * noise
64
65        return x.clamp(-1, 1)
66
67    def create_progress_gif(self, epochs: list):
68        """Create animated GIF showing training progress."""
69        import imageio
70        images = []
71        for epoch in epochs:
72            path = self.save_dir / f"samples_epoch_{epoch:04d}.png"
73            if path.exists():
74                images.append(imageio.imread(str(path)))
75
76        if images:
77            output_path = self.save_dir / "training_progress.gif"
78            imageio.mimsave(str(output_path), images, fps=2)
79            print(f"Saved progress GIF to {output_path}")

Monitoring Dashboard

Putting it all together into a comprehensive monitoring system:

🐍python
1import wandb
2from torch.utils.tensorboard import SummaryWriter
3
4class TrainingMonitor:
5    """Comprehensive training monitor combining all metrics."""
6
7    def __init__(
8        self,
9        log_dir: str,
10        use_wandb: bool = True,
11        project_name: str = "diffusion",
12    ):
13        self.log_dir = Path(log_dir)
14        self.log_dir.mkdir(parents=True, exist_ok=True)
15
16        # Initialize loggers
17        self.writer = SummaryWriter(log_dir=str(self.log_dir / "tensorboard"))
18
19        if use_wandb:
20            wandb.init(project=project_name, dir=str(self.log_dir))
21            self.use_wandb = True
22        else:
23            self.use_wandb = False
24
25        # Initialize trackers
26        self.loss_monitor = LossMonitor()
27        self.fid_calculator = FIDCalculator()
28        self.sample_visualizer = SampleVisualizer(str(self.log_dir / "samples"))
29
30        self.step = 0
31        self.best_fid = float('inf')
32
33    def log_step(self, loss: float):
34        """Log a single training step."""
35        self.step += 1
36        stats = self.loss_monitor.update(loss)
37
38        # Log to TensorBoard
39        self.writer.add_scalar("loss/train", loss, self.step)
40        self.writer.add_scalar("loss/smooth", stats["loss_smooth"], self.step)
41
42        # Log to W&B
43        if self.use_wandb:
44            wandb.log({"loss": loss, "loss_smooth": stats["loss_smooth"]}, step=self.step)
45
46        return stats
47
48    def log_epoch(
49        self,
50        epoch: int,
51        model,
52        diffusion,
53        real_loader,
54        compute_fid: bool = True,
55    ):
56        """Log epoch-level metrics."""
57        # Generate and save samples
58        grid = self.sample_visualizer.generate_and_save(model, diffusion, epoch)
59        self.writer.add_image("samples", grid, epoch)
60
61        if self.use_wandb:
62            wandb.log({"samples": wandb.Image(grid)}, step=self.step)
63
64        # Compute FID periodically
65        if compute_fid:
66            fid = compute_fid_during_training(
67                model, diffusion, real_loader, self.fid_calculator
68            )
69            self.writer.add_scalar("fid", fid, epoch)
70
71            if self.use_wandb:
72                wandb.log({"fid": fid}, step=self.step)
73
74            if fid < self.best_fid:
75                self.best_fid = fid
76                return {"fid": fid, "is_best": True}
77
78            return {"fid": fid, "is_best": False}
79
80        return {}
81
82    def finish(self):
83        """Clean up logging."""
84        self.writer.close()
85        if self.use_wandb:
86            wandb.finish()
87
88
89# Usage in training
90monitor = TrainingMonitor("./logs/experiment1", use_wandb=True)
91
92for epoch in range(num_epochs):
93    for batch in dataloader:
94        loss = train_step(batch)
95        monitor.log_step(loss.item())
96
97    # Log epoch metrics (FID every 5 epochs)
98    result = monitor.log_epoch(
99        epoch, model, diffusion, val_loader,
100        compute_fid=(epoch % 5 == 0)
101    )
102
103    if result.get("is_best"):
104        save_checkpoint(model, "best_model.pt")
105
106monitor.finish()

Key Takeaways

  1. Loss alone is insufficient: Use FID and visual inspection to truly assess generation quality.
  2. Use fixed noise for comparison: Generate samples with the same initial noise across epochs to see consistent progress.
  3. Compute FID periodically: It's expensive, so compute every 5-10 epochs, not every step.
  4. Watch for anomalies: Sudden spikes or divergence in loss often indicate gradient issues or bad batches.
  5. Log everything: Save samples, checkpoints, and metrics so you can analyze training after the fact.
Looking Ahead: In the next section, we'll cover common training issues and their solutions, including mode collapse, gradient problems, and memory optimization.