Learning Objectives
By the end of this section, you will be able to:
- Diagnose gradient issues including exploding/vanishing gradients and NaN losses
- Prevent mode collapse and maintain generation diversity
- Optimize memory usage for training large models on limited hardware
- Stabilize training with proper initialization and regularization techniques
- Debug common training failures using systematic approaches
The Big Picture
Training diffusion models is challenging due to the complex interplay of timestep conditioning, noise prediction, and the iterative sampling process. Unlike simpler neural networks, diffusion models must learn to predict noise at thousands of different noise levels while maintaining stable gradients throughout.
The Training Challenge: A diffusion model sees vastly different inputs depending on the timestep. At , inputs are nearly pure noise; at , inputs are nearly clean images. The model must excel at all noise levels simultaneously.
| Issue Category | Symptoms | Typical Causes |
|---|---|---|
| Gradient Problems | NaN loss, exploding updates | Bad initialization, high LR |
| Mode Collapse | Low diversity, repeated outputs | Insufficient capacity, data imbalance |
| Memory Issues | OOM errors, slow training | Large batch size, inefficient architecture |
| Training Instability | Oscillating loss, poor convergence | LR schedule, batch normalization |
Gradient Problems
Detecting Gradient Issues
Gradient problems are the most common cause of training failure. They manifest as NaN losses, exploding gradients, or vanishing updates:
1import torch
2import torch.nn as nn
3from typing import Dict, Optional, Tuple
4import warnings
5
6class GradientMonitor:
7 """Monitor gradient health during training."""
8
9 def __init__(
10 self,
11 model: nn.Module,
12 warn_threshold: float = 10.0,
13 nan_threshold: int = 3,
14 ):
15 self.model = model
16 self.warn_threshold = warn_threshold
17 self.nan_threshold = nan_threshold
18 self.nan_count = 0
19 self.gradient_history = []
20
21 def check_gradients(self) -> Dict[str, float]:
22 """Check gradient statistics after backward pass."""
23 stats = {
24 "max_grad": 0.0,
25 "min_grad": float("inf"),
26 "mean_grad": 0.0,
27 "nan_params": 0,
28 "zero_params": 0,
29 }
30
31 total_params = 0
32 total_grad_sum = 0.0
33
34 for name, param in self.model.named_parameters():
35 if param.grad is not None:
36 grad = param.grad.data
37
38 # Check for NaN
39 if torch.isnan(grad).any():
40 stats["nan_params"] += 1
41 warnings.warn(f"NaN gradient in {name}")
42 continue
43
44 # Check for zeros
45 if (grad == 0).all():
46 stats["zero_params"] += 1
47 continue
48
49 grad_abs = grad.abs()
50 max_grad = grad_abs.max().item()
51 min_grad = grad_abs.min().item()
52 mean_grad = grad_abs.mean().item()
53
54 stats["max_grad"] = max(stats["max_grad"], max_grad)
55 stats["min_grad"] = min(stats["min_grad"], min_grad)
56 total_grad_sum += mean_grad * param.numel()
57 total_params += param.numel()
58
59 if total_params > 0:
60 stats["mean_grad"] = total_grad_sum / total_params
61
62 self.gradient_history.append(stats)
63 return stats
64
65 def diagnose(self) -> str:
66 """Provide diagnosis based on gradient history."""
67 if not self.gradient_history:
68 return "No gradient history available"
69
70 recent = self.gradient_history[-10:]
71 avg_max = sum(s["max_grad"] for s in recent) / len(recent)
72 avg_nan = sum(s["nan_params"] for s in recent) / len(recent)
73 avg_zero = sum(s["zero_params"] for s in recent) / len(recent)
74
75 issues = []
76
77 if avg_nan > 0:
78 issues.append(
79 "NaN gradients detected. Try: lower LR, gradient clipping, "
80 "check for log(0) or div by zero"
81 )
82
83 if avg_max > self.warn_threshold:
84 issues.append(
85 f"Large gradients ({avg_max:.2f}). Try: gradient clipping, "
86 "lower LR, check normalization layers"
87 )
88
89 if avg_zero > len(list(self.model.parameters())) * 0.1:
90 issues.append(
91 "Many zero gradients. Try: check for dead ReLUs, "
92 "verify data pipeline, check skip connections"
93 )
94
95 if avg_max < 1e-7:
96 issues.append(
97 "Vanishing gradients. Try: different initialization, "
98 "residual connections, layer normalization"
99 )
100
101 return "\n".join(issues) if issues else "Gradients look healthy"
102
103
104def apply_gradient_fixes(model: nn.Module, config: Dict) -> nn.Module:
105 """Apply common gradient stabilization techniques."""
106
107 # 1. Better initialization
108 for module in model.modules():
109 if isinstance(module, (nn.Conv2d, nn.Linear)):
110 # Xavier/Glorot for most layers
111 nn.init.xavier_uniform_(module.weight)
112 if module.bias is not None:
113 nn.init.zeros_(module.bias)
114 elif isinstance(module, nn.GroupNorm):
115 nn.init.ones_(module.weight)
116 nn.init.zeros_(module.bias)
117
118 # 2. Initialize output layer to predict zero
119 # This is crucial for diffusion models!
120 if hasattr(model, "out"):
121 if isinstance(model.out, nn.Conv2d):
122 nn.init.zeros_(model.out.weight)
123 if model.out.bias is not None:
124 nn.init.zeros_(model.out.bias)
125
126 return modelGradient Clipping Strategies
Gradient clipping prevents exploding gradients but must be applied carefully:
1import torch
2from torch import nn
3from typing import Optional
4
5class AdaptiveGradientClipper:
6 """Adaptive gradient clipping based on historical norms."""
7
8 def __init__(
9 self,
10 initial_clip: float = 1.0,
11 adaptation_rate: float = 0.01,
12 min_clip: float = 0.1,
13 max_clip: float = 10.0,
14 ):
15 self.clip_value = initial_clip
16 self.adaptation_rate = adaptation_rate
17 self.min_clip = min_clip
18 self.max_clip = max_clip
19 self.grad_norm_ema = initial_clip
20
21 def __call__(self, model: nn.Module) -> float:
22 """Clip gradients and return the gradient norm."""
23 # Compute gradient norm
24 total_norm = 0.0
25 for param in model.parameters():
26 if param.grad is not None:
27 total_norm += param.grad.data.norm(2).item() ** 2
28 total_norm = total_norm ** 0.5
29
30 # Update EMA of gradient norm
31 self.grad_norm_ema = (
32 (1 - self.adaptation_rate) * self.grad_norm_ema +
33 self.adaptation_rate * total_norm
34 )
35
36 # Adapt clip value
37 self.clip_value = min(
38 self.max_clip,
39 max(self.min_clip, 2.0 * self.grad_norm_ema)
40 )
41
42 # Apply clipping
43 torch.nn.utils.clip_grad_norm_(
44 model.parameters(),
45 max_norm=self.clip_value
46 )
47
48 return total_norm
49
50
51# Training loop with gradient monitoring
52def train_step_with_monitoring(
53 model: nn.Module,
54 batch: torch.Tensor,
55 optimizer: torch.optim.Optimizer,
56 diffusion,
57 grad_monitor: GradientMonitor,
58 grad_clipper: AdaptiveGradientClipper,
59 scaler: Optional[torch.cuda.amp.GradScaler] = None,
60) -> Dict[str, float]:
61 """Training step with full gradient monitoring."""
62
63 model.train()
64 optimizer.zero_grad()
65
66 # Forward pass with mixed precision
67 with torch.cuda.amp.autocast(enabled=scaler is not None):
68 t = torch.randint(
69 0, diffusion.timesteps,
70 (batch.shape[0],),
71 device=batch.device
72 )
73 noise = torch.randn_like(batch)
74 x_noisy = diffusion.q_sample(batch, t, noise)
75 pred_noise = model(x_noisy, t)
76 loss = nn.functional.mse_loss(pred_noise, noise)
77
78 # Check for NaN loss
79 if torch.isnan(loss):
80 warnings.warn("NaN loss detected! Skipping update.")
81 return {"loss": float("nan"), "skipped": True}
82
83 # Backward pass
84 if scaler is not None:
85 scaler.scale(loss).backward()
86 scaler.unscale_(optimizer)
87 else:
88 loss.backward()
89
90 # Monitor gradients
91 grad_stats = grad_monitor.check_gradients()
92
93 # Skip update if gradients are bad
94 if grad_stats["nan_params"] > 0:
95 optimizer.zero_grad()
96 return {"loss": loss.item(), "skipped": True, **grad_stats}
97
98 # Clip gradients
99 grad_norm = grad_clipper(model)
100
101 # Optimizer step
102 if scaler is not None:
103 scaler.step(optimizer)
104 scaler.update()
105 else:
106 optimizer.step()
107
108 return {
109 "loss": loss.item(),
110 "grad_norm": grad_norm,
111 "clip_value": grad_clipper.clip_value,
112 "skipped": False,
113 **grad_stats,
114 }Zero-Initialize Output Layer
Mode Collapse and Diversity
Understanding Mode Collapse
Mode collapse in diffusion models appears as reduced diversity in generated samples. Unlike GANs, diffusion models rarely suffer from severe mode collapse, but subtle forms can still occur:
1import torch
2import torch.nn as nn
3import numpy as np
4from scipy import linalg
5from collections import Counter
6from typing import List, Tuple
7
8class DiversityAnalyzer:
9 """Analyze diversity of generated samples."""
10
11 def __init__(self, feature_extractor: nn.Module, device: str = "cuda"):
12 self.feature_extractor = feature_extractor.to(device).eval()
13 self.device = device
14
15 @torch.no_grad()
16 def extract_features(self, images: torch.Tensor) -> np.ndarray:
17 """Extract features from images."""
18 features = self.feature_extractor(images.to(self.device))
19 return features.cpu().numpy()
20
21 def compute_diversity_metrics(
22 self,
23 generated_features: np.ndarray,
24 real_features: np.ndarray,
25 ) -> dict:
26 """Compute various diversity metrics."""
27
28 metrics = {}
29
30 # 1. Feature spread (variance in feature space)
31 gen_var = np.var(generated_features, axis=0).mean()
32 real_var = np.var(real_features, axis=0).mean()
33 metrics["variance_ratio"] = gen_var / (real_var + 1e-8)
34
35 # 2. Pairwise distances
36 gen_dists = self._pairwise_distances(generated_features)
37 real_dists = self._pairwise_distances(real_features)
38 metrics["mean_pairwise_distance"] = gen_dists.mean()
39 metrics["distance_ratio"] = gen_dists.mean() / (real_dists.mean() + 1e-8)
40
41 # 3. Number of distinct modes (via clustering)
42 from sklearn.cluster import KMeans
43
44 n_clusters = min(50, len(generated_features) // 10)
45 kmeans = KMeans(n_clusters=n_clusters, random_state=42)
46 gen_clusters = kmeans.fit_predict(generated_features)
47
48 # Check cluster occupancy
49 cluster_counts = Counter(gen_clusters)
50 occupied_clusters = sum(1 for c in cluster_counts.values() if c > 0)
51 metrics["mode_coverage"] = occupied_clusters / n_clusters
52
53 # Check for dominant modes
54 max_cluster_frac = max(cluster_counts.values()) / len(generated_features)
55 metrics["max_mode_fraction"] = max_cluster_frac
56
57 # 4. Coverage: how many real modes are covered
58 real_clusters = kmeans.predict(real_features)
59 gen_modes = set(gen_clusters)
60 real_modes = set(real_clusters)
61 metrics["coverage"] = len(gen_modes & real_modes) / len(real_modes)
62
63 return metrics
64
65 def _pairwise_distances(self, features: np.ndarray) -> np.ndarray:
66 """Compute pairwise L2 distances."""
67 # Sample for efficiency
68 n = min(1000, len(features))
69 idx = np.random.choice(len(features), n, replace=False)
70 sample = features[idx]
71
72 diffs = sample[:, None] - sample[None, :]
73 dists = np.sqrt((diffs ** 2).sum(axis=-1))
74
75 # Return upper triangle (excluding diagonal)
76 return dists[np.triu_indices(n, k=1)]
77
78 def diagnose_diversity(self, metrics: dict) -> str:
79 """Provide diagnosis based on diversity metrics."""
80 issues = []
81
82 if metrics["variance_ratio"] < 0.5:
83 issues.append(
84 "Low feature variance - samples may lack diversity. "
85 "Try: longer training, more data augmentation"
86 )
87
88 if metrics["max_mode_fraction"] > 0.3:
89 issues.append(
90 f"Mode concentration detected ({metrics['max_mode_fraction']:.1%} "
91 "in single mode). Try: check data balance, increase model capacity"
92 )
93
94 if metrics["coverage"] < 0.7:
95 issues.append(
96 f"Low mode coverage ({metrics['coverage']:.1%}). "
97 "Some data modes not represented. Try: longer training, "
98 "verify data preprocessing"
99 )
100
101 if metrics["distance_ratio"] < 0.8:
102 issues.append(
103 "Samples too similar. Try: increase sampling temperature, "
104 "check for EMA decay issues"
105 )
106
107 return "\n".join(issues) if issues else "Diversity looks good"
108
109
110def prevent_mode_collapse(config: dict) -> dict:
111 """Apply techniques to prevent mode collapse."""
112
113 recommendations = {
114 # Use dropout during training
115 "dropout": 0.1,
116
117 # Don't use too aggressive EMA
118 "ema_decay": 0.9999, # Not 0.99999
119
120 # Sufficient model capacity
121 "min_channels": 128,
122
123 # Data augmentation
124 "augmentation": {
125 "horizontal_flip": True,
126 "random_crop": True,
127 "color_jitter": 0.1,
128 },
129
130 # Balanced sampling
131 "balanced_sampling": True,
132
133 # Sufficient training
134 "min_steps": 100000,
135 }
136
137 return recommendations| Symptom | Likely Cause | Solution |
|---|---|---|
| All samples look similar | EMA decay too high | Reduce EMA decay to 0.9999 |
| Missing classes/modes | Imbalanced dataset | Use balanced sampling |
| Blurry but uniform | Undertrained | Train longer, increase LR |
| Some modes overrepresented | Model capacity | Increase channels/layers |
Memory Optimization
Gradient Checkpointing
Gradient checkpointing trades computation for memory by not storing all intermediate activations:
1import torch
2import torch.nn as nn
3from torch.utils.checkpoint import checkpoint, checkpoint_sequential
4
5class MemoryEfficientUNet(nn.Module):
6 """U-Net with gradient checkpointing for memory efficiency."""
7
8 def __init__(
9 self,
10 in_channels: int = 3,
11 model_channels: int = 128,
12 num_res_blocks: int = 2,
13 attention_resolutions: tuple = (16, 8),
14 channel_mult: tuple = (1, 2, 4, 8),
15 use_checkpoint: bool = True,
16 ):
17 super().__init__()
18 self.use_checkpoint = use_checkpoint
19
20 # Build encoder blocks
21 self.encoder_blocks = nn.ModuleList()
22 ch = model_channels
23 for level, mult in enumerate(channel_mult):
24 out_ch = model_channels * mult
25 for _ in range(num_res_blocks):
26 self.encoder_blocks.append(
27 ResBlock(ch, out_ch, use_checkpoint=use_checkpoint)
28 )
29 ch = out_ch
30 if level != len(channel_mult) - 1:
31 self.encoder_blocks.append(Downsample(ch))
32
33 # Middle blocks
34 self.middle_blocks = nn.ModuleList([
35 ResBlock(ch, ch, use_checkpoint=use_checkpoint),
36 AttentionBlock(ch),
37 ResBlock(ch, ch, use_checkpoint=use_checkpoint),
38 ])
39
40 # Build decoder blocks
41 self.decoder_blocks = nn.ModuleList()
42 for level, mult in reversed(list(enumerate(channel_mult))):
43 out_ch = model_channels * mult
44 for i in range(num_res_blocks + 1):
45 self.decoder_blocks.append(
46 ResBlock(ch + out_ch, out_ch, use_checkpoint=use_checkpoint)
47 )
48 ch = out_ch
49 if level != 0:
50 self.decoder_blocks.append(Upsample(ch))
51
52 def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
53 # Use checkpointing for memory efficiency
54 if self.use_checkpoint and self.training:
55 return self._forward_with_checkpointing(x, t)
56 return self._forward(x, t)
57
58 def _forward_with_checkpointing(
59 self,
60 x: torch.Tensor,
61 t: torch.Tensor
62 ) -> torch.Tensor:
63 """Forward pass with gradient checkpointing."""
64 emb = self.time_embed(t)
65
66 # Checkpoint encoder blocks in groups
67 h = x
68 hs = []
69 for block in self.encoder_blocks:
70 # Checkpoint each block
71 h = checkpoint(block, h, emb, use_reentrant=False)
72 hs.append(h)
73
74 # Middle blocks
75 for block in self.middle_blocks:
76 h = checkpoint(block, h, emb, use_reentrant=False)
77
78 # Decoder with skip connections
79 for block in self.decoder_blocks:
80 if isinstance(block, ResBlock):
81 h = torch.cat([h, hs.pop()], dim=1)
82 h = checkpoint(block, h, emb, use_reentrant=False)
83
84 return self.out(h)
85
86
87class ResBlock(nn.Module):
88 """Residual block with optional checkpointing."""
89
90 def __init__(
91 self,
92 in_channels: int,
93 out_channels: int,
94 time_channels: int = 512,
95 use_checkpoint: bool = False,
96 ):
97 super().__init__()
98 self.use_checkpoint = use_checkpoint
99
100 self.norm1 = nn.GroupNorm(32, in_channels)
101 self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
102 self.time_proj = nn.Linear(time_channels, out_channels)
103 self.norm2 = nn.GroupNorm(32, out_channels)
104 self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
105
106 if in_channels != out_channels:
107 self.skip = nn.Conv2d(in_channels, out_channels, 1)
108 else:
109 self.skip = nn.Identity()
110
111 def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
112 if self.use_checkpoint and self.training:
113 return checkpoint(
114 self._forward, x, emb,
115 use_reentrant=False
116 )
117 return self._forward(x, emb)
118
119 def _forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
120 h = self.norm1(x)
121 h = nn.functional.silu(h)
122 h = self.conv1(h)
123
124 # Add time embedding
125 h = h + self.time_proj(emb)[:, :, None, None]
126
127 h = self.norm2(h)
128 h = nn.functional.silu(h)
129 h = self.conv2(h)
130
131 return h + self.skip(x)Memory-Efficient Training Configuration
1import torch
2from torch.cuda.amp import autocast, GradScaler
3import gc
4
5class MemoryOptimizedTrainer:
6 """Training configuration optimized for memory."""
7
8 def __init__(
9 self,
10 model: nn.Module,
11 target_batch_size: int = 64,
12 device_batch_size: int = 8, # What fits in GPU memory
13 use_amp: bool = True,
14 use_compile: bool = True, # PyTorch 2.0+
15 ):
16 self.model = model
17 self.target_batch_size = target_batch_size
18 self.device_batch_size = device_batch_size
19 self.accumulation_steps = target_batch_size // device_batch_size
20 self.use_amp = use_amp
21
22 # Mixed precision
23 self.scaler = GradScaler() if use_amp else None
24
25 # Compile model for better performance (PyTorch 2.0+)
26 if use_compile and hasattr(torch, "compile"):
27 self.model = torch.compile(
28 model,
29 mode="reduce-overhead", # Or "max-autotune" for best perf
30 )
31
32 @staticmethod
33 def estimate_memory(
34 model: nn.Module,
35 batch_size: int,
36 image_size: int,
37 use_amp: bool = True,
38 ) -> dict:
39 """Estimate memory requirements."""
40
41 # Count parameters
42 params = sum(p.numel() for p in model.parameters())
43 param_bytes = params * (2 if use_amp else 4) # FP16 or FP32
44
45 # Estimate activation memory (rough)
46 # For U-Net, roughly 10x parameter count per sample
47 activation_bytes = params * 10 * batch_size * (2 if use_amp else 4)
48
49 # Gradient memory
50 grad_bytes = param_bytes
51
52 # Optimizer states (Adam has 2 states per parameter)
53 optim_bytes = params * 2 * 4 # Always FP32
54
55 total_bytes = param_bytes + activation_bytes + grad_bytes + optim_bytes
56
57 return {
58 "parameters_mb": param_bytes / 1e6,
59 "activations_mb": activation_bytes / 1e6,
60 "gradients_mb": grad_bytes / 1e6,
61 "optimizer_mb": optim_bytes / 1e6,
62 "total_mb": total_bytes / 1e6,
63 "total_gb": total_bytes / 1e9,
64 }
65
66 @staticmethod
67 def optimize_for_memory() -> dict:
68 """Return memory optimization settings."""
69 return {
70 # Enable memory efficient attention
71 "attention": "flash", # or "xformers"
72
73 # Use gradient checkpointing
74 "gradient_checkpointing": True,
75
76 # Use mixed precision
77 "mixed_precision": "fp16", # or "bf16" on Ampere+
78
79 # Efficient data loading
80 "num_workers": 4,
81 "pin_memory": True,
82 "prefetch_factor": 2,
83
84 # Clear cache periodically
85 "empty_cache_freq": 100, # steps
86 }
87
88 def train_step(
89 self,
90 batch: torch.Tensor,
91 optimizer: torch.optim.Optimizer,
92 diffusion,
93 step: int,
94 ) -> float:
95 """Memory-optimized training step with gradient accumulation."""
96
97 # Split batch for accumulation
98 micro_batches = batch.split(self.device_batch_size)
99 total_loss = 0.0
100
101 for i, micro_batch in enumerate(micro_batches):
102 # Use AMP for forward pass
103 with autocast(enabled=self.use_amp):
104 t = torch.randint(
105 0, diffusion.timesteps,
106 (micro_batch.shape[0],),
107 device=micro_batch.device
108 )
109 noise = torch.randn_like(micro_batch)
110 x_noisy = diffusion.q_sample(micro_batch, t, noise)
111 pred = self.model(x_noisy, t)
112 loss = nn.functional.mse_loss(pred, noise)
113 loss = loss / self.accumulation_steps
114
115 # Backward pass
116 if self.scaler is not None:
117 self.scaler.scale(loss).backward()
118 else:
119 loss.backward()
120
121 total_loss += loss.item()
122
123 # Optimizer step after accumulation
124 if self.scaler is not None:
125 self.scaler.unscale_(optimizer)
126 torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
127 self.scaler.step(optimizer)
128 self.scaler.update()
129 else:
130 torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
131 optimizer.step()
132
133 optimizer.zero_grad(set_to_none=True) # More memory efficient
134
135 # Periodic cache clearing
136 if step % 100 == 0:
137 gc.collect()
138 torch.cuda.empty_cache()
139
140 return total_loss * self.accumulation_stepsFlash Attention
torch.nn.functional.scaled_dot_product_attention.Training Instability
Learning Rate Issues
Learning rate problems are often the root cause of training instability:
1import torch
2from torch.optim import AdamW
3from torch.optim.lr_scheduler import (
4 CosineAnnealingLR,
5 LinearLR,
6 SequentialLR,
7)
8import math
9
10class StableTrainingConfig:
11 """Configuration for stable diffusion model training."""
12
13 @staticmethod
14 def get_optimizer(
15 model: nn.Module,
16 learning_rate: float = 1e-4,
17 weight_decay: float = 0.01,
18 betas: tuple = (0.9, 0.999),
19 ) -> torch.optim.Optimizer:
20 """Get optimizer with appropriate settings."""
21
22 # Separate parameters for weight decay
23 decay_params = []
24 no_decay_params = []
25
26 for name, param in model.named_parameters():
27 if not param.requires_grad:
28 continue
29 # Don't apply weight decay to biases and normalization
30 if "bias" in name or "norm" in name:
31 no_decay_params.append(param)
32 else:
33 decay_params.append(param)
34
35 param_groups = [
36 {"params": decay_params, "weight_decay": weight_decay},
37 {"params": no_decay_params, "weight_decay": 0.0},
38 ]
39
40 return AdamW(
41 param_groups,
42 lr=learning_rate,
43 betas=betas,
44 eps=1e-8,
45 )
46
47 @staticmethod
48 def get_scheduler(
49 optimizer: torch.optim.Optimizer,
50 warmup_steps: int = 5000,
51 total_steps: int = 500000,
52 min_lr_ratio: float = 0.1,
53 ):
54 """Get learning rate scheduler with warmup."""
55
56 # Linear warmup
57 warmup_scheduler = LinearLR(
58 optimizer,
59 start_factor=0.01,
60 end_factor=1.0,
61 total_iters=warmup_steps,
62 )
63
64 # Cosine decay after warmup
65 decay_scheduler = CosineAnnealingLR(
66 optimizer,
67 T_max=total_steps - warmup_steps,
68 eta_min=optimizer.param_groups[0]["lr"] * min_lr_ratio,
69 )
70
71 # Combine schedulers
72 scheduler = SequentialLR(
73 optimizer,
74 schedulers=[warmup_scheduler, decay_scheduler],
75 milestones=[warmup_steps],
76 )
77
78 return scheduler
79
80 @staticmethod
81 def diagnose_lr_issues(loss_history: list) -> str:
82 """Diagnose learning rate issues from loss curve."""
83
84 if len(loss_history) < 100:
85 return "Need more data to diagnose"
86
87 issues = []
88
89 # Check for oscillation
90 recent = loss_history[-100:]
91 diffs = [recent[i+1] - recent[i] for i in range(len(recent)-1)]
92 sign_changes = sum(
93 1 for i in range(len(diffs)-1)
94 if diffs[i] * diffs[i+1] < 0
95 )
96
97 if sign_changes > 60: # Too many sign changes
98 issues.append(
99 "High oscillation detected. Learning rate may be too high. "
100 "Try reducing by 2-5x."
101 )
102
103 # Check for plateau
104 first_half = sum(recent[:50]) / 50
105 second_half = sum(recent[50:]) / 50
106
107 if abs(first_half - second_half) / first_half < 0.01:
108 issues.append(
109 "Loss plateau detected. Try: increase LR, check gradients, "
110 "verify data pipeline."
111 )
112
113 # Check for divergence
114 if recent[-1] > recent[0] * 1.5:
115 issues.append(
116 "Loss increasing! Possible divergence. "
117 "Immediately reduce LR or restart from checkpoint."
118 )
119
120 return "\n".join(issues) if issues else "Learning rate seems appropriate"
121
122
123class TrainingStabilizer:
124 """Apply various stabilization techniques."""
125
126 @staticmethod
127 def apply_spectral_norm(model: nn.Module) -> nn.Module:
128 """Apply spectral normalization to conv layers."""
129 for name, module in model.named_modules():
130 if isinstance(module, (nn.Conv2d, nn.Linear)):
131 # Don't apply to output layer
132 if "out" not in name:
133 nn.utils.spectral_norm(module)
134 return model
135
136 @staticmethod
137 def apply_ema(
138 model: nn.Module,
139 ema_model: nn.Module,
140 decay: float = 0.9999,
141 ) -> None:
142 """Update EMA model parameters."""
143 with torch.no_grad():
144 for ema_param, param in zip(
145 ema_model.parameters(),
146 model.parameters()
147 ):
148 ema_param.data.mul_(decay).add_(param.data, alpha=1 - decay)
149
150 @staticmethod
151 def detect_nan(tensor: torch.Tensor, name: str = "tensor") -> bool:
152 """Check for NaN values and warn."""
153 if torch.isnan(tensor).any():
154 print(f"WARNING: NaN detected in {name}")
155 return True
156 if torch.isinf(tensor).any():
157 print(f"WARNING: Inf detected in {name}")
158 return True
159 return False| Symptom | Diagnosis | Fix |
|---|---|---|
| Loss spikes randomly | LR too high during warmup | Slower warmup (10k steps) |
| Loss stuck after warmup | LR too low | Increase base LR 2-5x |
| Gradual divergence | No LR decay | Add cosine/linear decay |
| Training unstable after long time | Numerical issues | Use BF16 instead of FP16 |
Debugging Toolkit
1import torch
2import torch.nn as nn
3from typing import Dict, List, Optional
4import numpy as np
5
6class DiffusionDebugger:
7 """Comprehensive debugging toolkit for diffusion models."""
8
9 def __init__(self, model: nn.Module, diffusion):
10 self.model = model
11 self.diffusion = diffusion
12 self.debug_info = {}
13
14 def run_diagnostics(
15 self,
16 sample_batch: torch.Tensor,
17 verbose: bool = True,
18 ) -> Dict:
19 """Run comprehensive diagnostics."""
20
21 results = {}
22
23 # 1. Model output sanity check
24 results["model_output"] = self._check_model_output(sample_batch)
25
26 # 2. Gradient flow check
27 results["gradients"] = self._check_gradient_flow(sample_batch)
28
29 # 3. Timestep response check
30 results["timestep_response"] = self._check_timestep_response(sample_batch)
31
32 # 4. Noise schedule check
33 results["noise_schedule"] = self._check_noise_schedule()
34
35 # 5. Generate diagnostic samples
36 results["sample_quality"] = self._check_sample_quality()
37
38 if verbose:
39 self._print_report(results)
40
41 return results
42
43 def _check_model_output(self, batch: torch.Tensor) -> Dict:
44 """Check model output statistics."""
45 self.model.eval()
46
47 with torch.no_grad():
48 t = torch.randint(
49 0, self.diffusion.timesteps,
50 (batch.shape[0],),
51 device=batch.device
52 )
53 noise = torch.randn_like(batch)
54 x_noisy = self.diffusion.q_sample(batch, t, noise)
55 pred = self.model(x_noisy, t)
56
57 stats = {
58 "mean": pred.mean().item(),
59 "std": pred.std().item(),
60 "min": pred.min().item(),
61 "max": pred.max().item(),
62 "has_nan": torch.isnan(pred).any().item(),
63 "has_inf": torch.isinf(pred).any().item(),
64 }
65
66 # Expected: mean near 0, std near 1 (predicting noise)
67 stats["healthy"] = (
68 abs(stats["mean"]) < 0.5 and
69 0.5 < stats["std"] < 2.0 and
70 not stats["has_nan"] and
71 not stats["has_inf"]
72 )
73
74 return stats
75
76 def _check_gradient_flow(self, batch: torch.Tensor) -> Dict:
77 """Check gradient flow through model."""
78 self.model.train()
79
80 t = torch.randint(
81 0, self.diffusion.timesteps,
82 (batch.shape[0],),
83 device=batch.device
84 )
85 noise = torch.randn_like(batch)
86 x_noisy = self.diffusion.q_sample(batch, t, noise)
87 pred = self.model(x_noisy, t)
88 loss = nn.functional.mse_loss(pred, noise)
89 loss.backward()
90
91 grad_stats = {}
92 zero_grad_layers = []
93 large_grad_layers = []
94
95 for name, param in self.model.named_parameters():
96 if param.grad is not None:
97 grad_norm = param.grad.norm().item()
98 if grad_norm == 0:
99 zero_grad_layers.append(name)
100 elif grad_norm > 100:
101 large_grad_layers.append((name, grad_norm))
102
103 grad_stats["zero_grad_layers"] = zero_grad_layers
104 grad_stats["large_grad_layers"] = large_grad_layers
105 grad_stats["healthy"] = (
106 len(zero_grad_layers) == 0 and
107 len(large_grad_layers) == 0
108 )
109
110 self.model.zero_grad()
111 return grad_stats
112
113 def _check_timestep_response(self, batch: torch.Tensor) -> Dict:
114 """Check model response across different timesteps."""
115 self.model.eval()
116
117 timesteps_to_check = [0, 100, 500, 900, 999]
118 responses = {}
119
120 with torch.no_grad():
121 x = batch[:4] # Small batch
122 noise = torch.randn_like(x)
123
124 for t_val in timesteps_to_check:
125 t = torch.full((x.shape[0],), t_val, device=x.device)
126 x_noisy = self.diffusion.q_sample(x, t, noise)
127 pred = self.model(x_noisy, t)
128
129 responses[t_val] = {
130 "pred_std": pred.std().item(),
131 "mse_to_noise": nn.functional.mse_loss(pred, noise).item(),
132 }
133
134 # Check that model responds differently to different timesteps
135 stds = [r["pred_std"] for r in responses.values()]
136 stats = {
137 "responses": responses,
138 "std_variation": np.std(stds),
139 "healthy": np.std(stds) > 0.01, # Should vary with timestep
140 }
141
142 return stats
143
144 def _check_noise_schedule(self) -> Dict:
145 """Verify noise schedule properties."""
146 betas = self.diffusion.betas.cpu().numpy()
147 alphas = self.diffusion.alphas.cpu().numpy()
148 alphas_cumprod = self.diffusion.alphas_cumprod.cpu().numpy()
149
150 stats = {
151 "beta_min": betas.min(),
152 "beta_max": betas.max(),
153 "alpha_bar_start": alphas_cumprod[0],
154 "alpha_bar_end": alphas_cumprod[-1],
155 "monotonic": all(
156 alphas_cumprod[i] >= alphas_cumprod[i+1]
157 for i in range(len(alphas_cumprod)-1)
158 ),
159 }
160
161 # Check for healthy schedule
162 stats["healthy"] = (
163 stats["beta_min"] > 0 and
164 stats["beta_max"] < 0.1 and
165 stats["alpha_bar_start"] > 0.99 and
166 stats["alpha_bar_end"] < 0.01 and
167 stats["monotonic"]
168 )
169
170 return stats
171
172 def _check_sample_quality(self, num_samples: int = 4) -> Dict:
173 """Generate samples and check basic quality."""
174 self.model.eval()
175
176 with torch.no_grad():
177 # Quick sampling with fewer steps
178 samples = self._quick_sample(num_samples, steps=50)
179
180 stats = {
181 "mean": samples.mean().item(),
182 "std": samples.std().item(),
183 "min": samples.min().item(),
184 "max": samples.max().item(),
185 "in_range": (
186 samples.min().item() >= -1.5 and
187 samples.max().item() <= 1.5
188 ),
189 }
190
191 stats["healthy"] = stats["in_range"]
192 return stats
193
194 def _quick_sample(
195 self,
196 num_samples: int,
197 steps: int = 50,
198 ) -> torch.Tensor:
199 """Quick sampling with DDIM for debugging."""
200 device = next(self.model.parameters()).device
201 shape = (num_samples, 3, 64, 64)
202
203 x = torch.randn(shape, device=device)
204
205 # DDIM with few steps
206 step_size = self.diffusion.timesteps // steps
207 timesteps = list(range(0, self.diffusion.timesteps, step_size))[::-1]
208
209 for i, t in enumerate(timesteps):
210 t_tensor = torch.full((num_samples,), t, device=device, dtype=torch.long)
211 pred_noise = self.model(x, t_tensor)
212
213 alpha_bar = self.diffusion.alphas_cumprod[t]
214 alpha_bar_prev = (
215 self.diffusion.alphas_cumprod[timesteps[i+1]]
216 if i < len(timesteps) - 1
217 else torch.tensor(1.0)
218 )
219
220 # DDIM update
221 pred_x0 = (x - torch.sqrt(1 - alpha_bar) * pred_noise) / torch.sqrt(alpha_bar)
222 pred_x0 = pred_x0.clamp(-1, 1)
223
224 x = (
225 torch.sqrt(alpha_bar_prev) * pred_x0 +
226 torch.sqrt(1 - alpha_bar_prev) * pred_noise
227 )
228
229 return x
230
231 def _print_report(self, results: Dict) -> None:
232 """Print human-readable diagnostic report."""
233 print("=" * 60)
234 print("DIFFUSION MODEL DIAGNOSTIC REPORT")
235 print("=" * 60)
236
237 for category, stats in results.items():
238 healthy = stats.get("healthy", "Unknown")
239 status = "OK" if healthy else "ISSUE"
240 print(f"\n{category.upper()}: [{status}]")
241
242 for key, value in stats.items():
243 if key != "healthy":
244 print(f" {key}: {value}")
245
246 print("\n" + "=" * 60)
247
248 # Summary
249 all_healthy = all(
250 stats.get("healthy", True)
251 for stats in results.values()
252 )
253
254 if all_healthy:
255 print("All diagnostics passed!")
256 else:
257 print("Issues detected. Review the report above.")
258
259
260# Usage example
261def debug_training_issue(model, diffusion, sample_batch):
262 """Debug a training issue."""
263 debugger = DiffusionDebugger(model, diffusion)
264
265 # Run full diagnostics
266 results = debugger.run_diagnostics(sample_batch)
267
268 # Get specific recommendations
269 if not results["model_output"]["healthy"]:
270 print("\nModel output issue detected!")
271 print("Try: Check initialization, reduce model complexity")
272
273 if not results["gradients"]["healthy"]:
274 print("\nGradient issue detected!")
275 print("Try: Gradient clipping, lower learning rate")
276
277 if not results["timestep_response"]["healthy"]:
278 print("\nTimestep conditioning issue!")
279 print("Try: Check time embedding, verify conditioning mechanism")
280
281 return resultsKey Takeaways
- Zero-initialize output layers: This simple technique prevents many gradient issues at the start of training.
- Use gradient checkpointing: Essential for training large models on limited GPU memory.
- Monitor diversity: Track variance ratios and mode coverage to catch subtle collapse early.
- Warmup is critical: Use 5-10k step linear warmup to stabilize early training.
- Debug systematically: Use the diagnostic toolkit to identify issues rather than guessing.
Looking Ahead: With a fully trained model, we are ready to generate and evaluate samples. The next chapter covers the complete generation pipeline and evaluation metrics like FID, IS, and perceptual quality assessment.