Learning Objectives
By the end of this section, you will be able to:
- Track training loss curves and identify healthy vs problematic training dynamics
- Compute FID during training to monitor generation quality
- Visualize generated samples at different training stages
- Build comprehensive monitoring dashboards using TensorBoard and Weights & Biases
The Big Picture
Unlike classification models where validation accuracy directly measures performance, diffusion models require more nuanced monitoring. The training loss measures how well the model predicts noise, but this doesn't directly tell you about generation quality.
The Monitoring Challenge: A model can have low training loss but still produce poor-quality images. Conversely, a slightly higher loss might produce visually superior results. You need multiple metrics to truly understand training progress.
Interactive Training Dashboard
Training Loss
FID Score (lower is better)
Generated Samples at Step 50,000
Sample quality improves as training progresses (quality: 50%)
Note: This visualization simulates typical diffusion model training dynamics. The loss curve shows the MSE/L1 loss between predicted and actual noise, while FID measures generation quality against real samples.
Loss Monitoring
Understanding the Loss Curve
The diffusion training loss is the MSE between predicted and actual noise:
1import torch
2import numpy as np
3from collections import deque
4
5class LossMonitor:
6 """Monitor training loss with smoothing and anomaly detection."""
7
8 def __init__(self, window_size: int = 100):
9 self.losses = []
10 self.window = deque(maxlen=window_size)
11 self.window_size = window_size
12
13 def update(self, loss: float) -> dict:
14 """Record a loss value and return statistics."""
15 self.losses.append(loss)
16 self.window.append(loss)
17
18 stats = {
19 "loss": loss,
20 "loss_smooth": np.mean(self.window),
21 "loss_std": np.std(self.window) if len(self.window) > 1 else 0,
22 }
23
24 # Detect anomalies
25 if len(self.losses) > self.window_size:
26 recent_mean = np.mean(self.window)
27 recent_std = np.std(self.window)
28 if loss > recent_mean + 3 * recent_std:
29 stats["anomaly"] = "spike"
30 elif loss < recent_mean - 3 * recent_std:
31 stats["anomaly"] = "drop"
32
33 return stats
34
35 def is_converging(self, patience: int = 1000) -> bool:
36 """Check if loss is still decreasing."""
37 if len(self.losses) < patience * 2:
38 return True # Not enough data
39
40 recent = np.mean(self.losses[-patience:])
41 earlier = np.mean(self.losses[-2*patience:-patience])
42 return recent < earlier * 0.99 # 1% improvement threshold
43
44
45# Usage in training loop
46monitor = LossMonitor()
47
48for epoch in range(num_epochs):
49 for batch in dataloader:
50 loss = train_step(batch)
51 stats = monitor.update(loss.item())
52
53 if "anomaly" in stats:
54 print(f"Warning: Loss {stats['anomaly']} detected!")
55
56 if not monitor.is_converging():
57 print("Training may have converged - consider stopping")What to Look For
| Pattern | What It Means | Action |
|---|---|---|
| Steady decrease | Healthy training | Continue training |
| Plateauing early | Model too small or LR too low | Increase capacity or LR |
| Oscillating wildly | LR too high | Reduce learning rate |
| Sudden spike | Bad batch or NaN | Check gradients, reduce LR |
| Gradual increase | Overfitting or schedule issue | Add regularization |
FID During Training
FID (Frechet Inception Distance) measures how close generated images are to real images. Computing it during training provides direct feedback on generation quality:
1import torch
2import numpy as np
3from scipy import linalg
4from torchvision.models import inception_v3
5from torch.nn.functional import adaptive_avg_pool2d
6
7class FIDCalculator:
8 """Calculate FID between generated and real images."""
9
10 def __init__(self, device: str = "cuda"):
11 self.device = device
12 # Load Inception v3 for feature extraction
13 self.inception = inception_v3(pretrained=True, transform_input=False)
14 self.inception.fc = torch.nn.Identity() # Remove final layer
15 self.inception = self.inception.to(device).eval()
16
17 @torch.no_grad()
18 def get_features(self, images: torch.Tensor) -> np.ndarray:
19 """Extract Inception features from images."""
20 # Images should be [B, 3, 299, 299] in range [0, 1]
21 # Resize if needed
22 if images.shape[-1] != 299:
23 images = torch.nn.functional.interpolate(
24 images, size=(299, 299), mode='bilinear'
25 )
26
27 # Normalize for Inception
28 images = (images - 0.5) / 0.5
29
30 features = self.inception(images.to(self.device))
31 return features.cpu().numpy()
32
33 def calculate_statistics(self, features: np.ndarray) -> tuple:
34 """Calculate mean and covariance of features."""
35 mu = np.mean(features, axis=0)
36 sigma = np.cov(features, rowvar=False)
37 return mu, sigma
38
39 def calculate_fid(
40 self,
41 real_features: np.ndarray,
42 fake_features: np.ndarray,
43 ) -> float:
44 """Calculate FID from features."""
45 mu1, sigma1 = self.calculate_statistics(real_features)
46 mu2, sigma2 = self.calculate_statistics(fake_features)
47
48 diff = mu1 - mu2
49 covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
50
51 if np.iscomplexobj(covmean):
52 covmean = covmean.real
53
54 fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
55 return float(fid)
56
57
58def compute_fid_during_training(
59 model,
60 diffusion,
61 real_loader,
62 fid_calc: FIDCalculator,
63 num_samples: int = 1000,
64) -> float:
65 """Compute FID during training."""
66 model.eval()
67
68 # Collect real features (can cache these)
69 real_features = []
70 for batch in real_loader:
71 if len(real_features) * batch.shape[0] >= num_samples:
72 break
73 images = (batch + 1) / 2 # [-1, 1] -> [0, 1]
74 real_features.append(fid_calc.get_features(images))
75 real_features = np.concatenate(real_features)[:num_samples]
76
77 # Generate samples and collect features
78 fake_features = []
79 batch_size = 32
80 while len(fake_features) * batch_size < num_samples:
81 samples = generate_samples(model, diffusion, batch_size)
82 samples = (samples + 1) / 2 # [-1, 1] -> [0, 1]
83 fake_features.append(fid_calc.get_features(samples))
84 fake_features = np.concatenate(fake_features)[:num_samples]
85
86 model.train()
87 return fid_calc.calculate_fid(real_features, fake_features)FID Computation Cost
Sample Visualization
1import torch
2import torchvision.utils as vutils
3from pathlib import Path
4import matplotlib.pyplot as plt
5
6class SampleVisualizer:
7 """Visualize and save generated samples during training."""
8
9 def __init__(self, save_dir: str, num_samples: int = 16):
10 self.save_dir = Path(save_dir)
11 self.save_dir.mkdir(parents=True, exist_ok=True)
12 self.num_samples = num_samples
13 self.fixed_noise = None # For consistent comparison
14
15 @torch.no_grad()
16 def generate_and_save(
17 self,
18 model,
19 diffusion,
20 epoch: int,
21 device: str = "cuda",
22 ):
23 """Generate samples and save as image grid."""
24 model.eval()
25
26 # Use fixed noise for consistent comparison across epochs
27 if self.fixed_noise is None:
28 self.fixed_noise = torch.randn(
29 self.num_samples, 3, 64, 64, device=device
30 )
31
32 # Generate samples
33 samples = self._sample(model, diffusion, self.fixed_noise.clone())
34
35 # Create grid
36 grid = vutils.make_grid(
37 samples,
38 nrow=int(self.num_samples ** 0.5),
39 normalize=True,
40 value_range=(-1, 1),
41 )
42
43 # Save
44 save_path = self.save_dir / f"samples_epoch_{epoch:04d}.png"
45 vutils.save_image(grid, str(save_path))
46
47 model.train()
48 return grid
49
50 def _sample(self, model, diffusion, x):
51 """DDPM sampling loop."""
52 for t in reversed(range(diffusion.timesteps)):
53 t_batch = torch.full((x.shape[0],), t, device=x.device, dtype=torch.long)
54 pred_noise = model(x, t_batch)
55
56 alpha = diffusion.alphas[t]
57 alpha_bar = diffusion.alphas_cumprod[t]
58 beta = diffusion.betas[t]
59
60 noise = torch.randn_like(x) if t > 0 else 0
61 x = (1 / torch.sqrt(alpha)) * (
62 x - beta / torch.sqrt(1 - alpha_bar) * pred_noise
63 ) + torch.sqrt(beta) * noise
64
65 return x.clamp(-1, 1)
66
67 def create_progress_gif(self, epochs: list):
68 """Create animated GIF showing training progress."""
69 import imageio
70 images = []
71 for epoch in epochs:
72 path = self.save_dir / f"samples_epoch_{epoch:04d}.png"
73 if path.exists():
74 images.append(imageio.imread(str(path)))
75
76 if images:
77 output_path = self.save_dir / "training_progress.gif"
78 imageio.mimsave(str(output_path), images, fps=2)
79 print(f"Saved progress GIF to {output_path}")Monitoring Dashboard
Putting it all together into a comprehensive monitoring system:
1import wandb
2from torch.utils.tensorboard import SummaryWriter
3
4class TrainingMonitor:
5 """Comprehensive training monitor combining all metrics."""
6
7 def __init__(
8 self,
9 log_dir: str,
10 use_wandb: bool = True,
11 project_name: str = "diffusion",
12 ):
13 self.log_dir = Path(log_dir)
14 self.log_dir.mkdir(parents=True, exist_ok=True)
15
16 # Initialize loggers
17 self.writer = SummaryWriter(log_dir=str(self.log_dir / "tensorboard"))
18
19 if use_wandb:
20 wandb.init(project=project_name, dir=str(self.log_dir))
21 self.use_wandb = True
22 else:
23 self.use_wandb = False
24
25 # Initialize trackers
26 self.loss_monitor = LossMonitor()
27 self.fid_calculator = FIDCalculator()
28 self.sample_visualizer = SampleVisualizer(str(self.log_dir / "samples"))
29
30 self.step = 0
31 self.best_fid = float('inf')
32
33 def log_step(self, loss: float):
34 """Log a single training step."""
35 self.step += 1
36 stats = self.loss_monitor.update(loss)
37
38 # Log to TensorBoard
39 self.writer.add_scalar("loss/train", loss, self.step)
40 self.writer.add_scalar("loss/smooth", stats["loss_smooth"], self.step)
41
42 # Log to W&B
43 if self.use_wandb:
44 wandb.log({"loss": loss, "loss_smooth": stats["loss_smooth"]}, step=self.step)
45
46 return stats
47
48 def log_epoch(
49 self,
50 epoch: int,
51 model,
52 diffusion,
53 real_loader,
54 compute_fid: bool = True,
55 ):
56 """Log epoch-level metrics."""
57 # Generate and save samples
58 grid = self.sample_visualizer.generate_and_save(model, diffusion, epoch)
59 self.writer.add_image("samples", grid, epoch)
60
61 if self.use_wandb:
62 wandb.log({"samples": wandb.Image(grid)}, step=self.step)
63
64 # Compute FID periodically
65 if compute_fid:
66 fid = compute_fid_during_training(
67 model, diffusion, real_loader, self.fid_calculator
68 )
69 self.writer.add_scalar("fid", fid, epoch)
70
71 if self.use_wandb:
72 wandb.log({"fid": fid}, step=self.step)
73
74 if fid < self.best_fid:
75 self.best_fid = fid
76 return {"fid": fid, "is_best": True}
77
78 return {"fid": fid, "is_best": False}
79
80 return {}
81
82 def finish(self):
83 """Clean up logging."""
84 self.writer.close()
85 if self.use_wandb:
86 wandb.finish()
87
88
89# Usage in training
90monitor = TrainingMonitor("./logs/experiment1", use_wandb=True)
91
92for epoch in range(num_epochs):
93 for batch in dataloader:
94 loss = train_step(batch)
95 monitor.log_step(loss.item())
96
97 # Log epoch metrics (FID every 5 epochs)
98 result = monitor.log_epoch(
99 epoch, model, diffusion, val_loader,
100 compute_fid=(epoch % 5 == 0)
101 )
102
103 if result.get("is_best"):
104 save_checkpoint(model, "best_model.pt")
105
106monitor.finish()Key Takeaways
- Loss alone is insufficient: Use FID and visual inspection to truly assess generation quality.
- Use fixed noise for comparison: Generate samples with the same initial noise across epochs to see consistent progress.
- Compute FID periodically: It's expensive, so compute every 5-10 epochs, not every step.
- Watch for anomalies: Sudden spikes or divergence in loss often indicate gradient issues or bad batches.
- Log everything: Save samples, checkpoints, and metrics so you can analyze training after the fact.
Looking Ahead: In the next section, we'll cover common training issues and their solutions, including mode collapse, gradient problems, and memory optimization.