Chapter 11
18 min read
Section 54 of 76

Common Issues and Solutions

Training the Model

Learning Objectives

By the end of this section, you will be able to:

  1. Diagnose gradient issues including exploding/vanishing gradients and NaN losses
  2. Prevent mode collapse and maintain generation diversity
  3. Optimize memory usage for training large models on limited hardware
  4. Stabilize training with proper initialization and regularization techniques
  5. Debug common training failures using systematic approaches

The Big Picture

Training diffusion models is challenging due to the complex interplay of timestep conditioning, noise prediction, and the iterative sampling process. Unlike simpler neural networks, diffusion models must learn to predict noise at thousands of different noise levels while maintaining stable gradients throughout.

The Training Challenge: A diffusion model sees vastly different inputs depending on the timestep. At tTt \approx T, inputs are nearly pure noise; at t0t \approx 0, inputs are nearly clean images. The model must excel at all noise levels simultaneously.
Issue CategorySymptomsTypical Causes
Gradient ProblemsNaN loss, exploding updatesBad initialization, high LR
Mode CollapseLow diversity, repeated outputsInsufficient capacity, data imbalance
Memory IssuesOOM errors, slow trainingLarge batch size, inefficient architecture
Training InstabilityOscillating loss, poor convergenceLR schedule, batch normalization

Gradient Problems

Detecting Gradient Issues

Gradient problems are the most common cause of training failure. They manifest as NaN losses, exploding gradients, or vanishing updates:

🐍python
1import torch
2import torch.nn as nn
3from typing import Dict, Optional, Tuple
4import warnings
5
6class GradientMonitor:
7    """Monitor gradient health during training."""
8
9    def __init__(
10        self,
11        model: nn.Module,
12        warn_threshold: float = 10.0,
13        nan_threshold: int = 3,
14    ):
15        self.model = model
16        self.warn_threshold = warn_threshold
17        self.nan_threshold = nan_threshold
18        self.nan_count = 0
19        self.gradient_history = []
20
21    def check_gradients(self) -> Dict[str, float]:
22        """Check gradient statistics after backward pass."""
23        stats = {
24            "max_grad": 0.0,
25            "min_grad": float("inf"),
26            "mean_grad": 0.0,
27            "nan_params": 0,
28            "zero_params": 0,
29        }
30
31        total_params = 0
32        total_grad_sum = 0.0
33
34        for name, param in self.model.named_parameters():
35            if param.grad is not None:
36                grad = param.grad.data
37
38                # Check for NaN
39                if torch.isnan(grad).any():
40                    stats["nan_params"] += 1
41                    warnings.warn(f"NaN gradient in {name}")
42                    continue
43
44                # Check for zeros
45                if (grad == 0).all():
46                    stats["zero_params"] += 1
47                    continue
48
49                grad_abs = grad.abs()
50                max_grad = grad_abs.max().item()
51                min_grad = grad_abs.min().item()
52                mean_grad = grad_abs.mean().item()
53
54                stats["max_grad"] = max(stats["max_grad"], max_grad)
55                stats["min_grad"] = min(stats["min_grad"], min_grad)
56                total_grad_sum += mean_grad * param.numel()
57                total_params += param.numel()
58
59        if total_params > 0:
60            stats["mean_grad"] = total_grad_sum / total_params
61
62        self.gradient_history.append(stats)
63        return stats
64
65    def diagnose(self) -> str:
66        """Provide diagnosis based on gradient history."""
67        if not self.gradient_history:
68            return "No gradient history available"
69
70        recent = self.gradient_history[-10:]
71        avg_max = sum(s["max_grad"] for s in recent) / len(recent)
72        avg_nan = sum(s["nan_params"] for s in recent) / len(recent)
73        avg_zero = sum(s["zero_params"] for s in recent) / len(recent)
74
75        issues = []
76
77        if avg_nan > 0:
78            issues.append(
79                "NaN gradients detected. Try: lower LR, gradient clipping, "
80                "check for log(0) or div by zero"
81            )
82
83        if avg_max > self.warn_threshold:
84            issues.append(
85                f"Large gradients ({avg_max:.2f}). Try: gradient clipping, "
86                "lower LR, check normalization layers"
87            )
88
89        if avg_zero > len(list(self.model.parameters())) * 0.1:
90            issues.append(
91                "Many zero gradients. Try: check for dead ReLUs, "
92                "verify data pipeline, check skip connections"
93            )
94
95        if avg_max < 1e-7:
96            issues.append(
97                "Vanishing gradients. Try: different initialization, "
98                "residual connections, layer normalization"
99            )
100
101        return "\n".join(issues) if issues else "Gradients look healthy"
102
103
104def apply_gradient_fixes(model: nn.Module, config: Dict) -> nn.Module:
105    """Apply common gradient stabilization techniques."""
106
107    # 1. Better initialization
108    for module in model.modules():
109        if isinstance(module, (nn.Conv2d, nn.Linear)):
110            # Xavier/Glorot for most layers
111            nn.init.xavier_uniform_(module.weight)
112            if module.bias is not None:
113                nn.init.zeros_(module.bias)
114        elif isinstance(module, nn.GroupNorm):
115            nn.init.ones_(module.weight)
116            nn.init.zeros_(module.bias)
117
118    # 2. Initialize output layer to predict zero
119    # This is crucial for diffusion models!
120    if hasattr(model, "out"):
121        if isinstance(model.out, nn.Conv2d):
122            nn.init.zeros_(model.out.weight)
123            if model.out.bias is not None:
124                nn.init.zeros_(model.out.bias)
125
126    return model

Gradient Clipping Strategies

Gradient clipping prevents exploding gradients but must be applied carefully:

🐍python
1import torch
2from torch import nn
3from typing import Optional
4
5class AdaptiveGradientClipper:
6    """Adaptive gradient clipping based on historical norms."""
7
8    def __init__(
9        self,
10        initial_clip: float = 1.0,
11        adaptation_rate: float = 0.01,
12        min_clip: float = 0.1,
13        max_clip: float = 10.0,
14    ):
15        self.clip_value = initial_clip
16        self.adaptation_rate = adaptation_rate
17        self.min_clip = min_clip
18        self.max_clip = max_clip
19        self.grad_norm_ema = initial_clip
20
21    def __call__(self, model: nn.Module) -> float:
22        """Clip gradients and return the gradient norm."""
23        # Compute gradient norm
24        total_norm = 0.0
25        for param in model.parameters():
26            if param.grad is not None:
27                total_norm += param.grad.data.norm(2).item() ** 2
28        total_norm = total_norm ** 0.5
29
30        # Update EMA of gradient norm
31        self.grad_norm_ema = (
32            (1 - self.adaptation_rate) * self.grad_norm_ema +
33            self.adaptation_rate * total_norm
34        )
35
36        # Adapt clip value
37        self.clip_value = min(
38            self.max_clip,
39            max(self.min_clip, 2.0 * self.grad_norm_ema)
40        )
41
42        # Apply clipping
43        torch.nn.utils.clip_grad_norm_(
44            model.parameters(),
45            max_norm=self.clip_value
46        )
47
48        return total_norm
49
50
51# Training loop with gradient monitoring
52def train_step_with_monitoring(
53    model: nn.Module,
54    batch: torch.Tensor,
55    optimizer: torch.optim.Optimizer,
56    diffusion,
57    grad_monitor: GradientMonitor,
58    grad_clipper: AdaptiveGradientClipper,
59    scaler: Optional[torch.cuda.amp.GradScaler] = None,
60) -> Dict[str, float]:
61    """Training step with full gradient monitoring."""
62
63    model.train()
64    optimizer.zero_grad()
65
66    # Forward pass with mixed precision
67    with torch.cuda.amp.autocast(enabled=scaler is not None):
68        t = torch.randint(
69            0, diffusion.timesteps,
70            (batch.shape[0],),
71            device=batch.device
72        )
73        noise = torch.randn_like(batch)
74        x_noisy = diffusion.q_sample(batch, t, noise)
75        pred_noise = model(x_noisy, t)
76        loss = nn.functional.mse_loss(pred_noise, noise)
77
78    # Check for NaN loss
79    if torch.isnan(loss):
80        warnings.warn("NaN loss detected! Skipping update.")
81        return {"loss": float("nan"), "skipped": True}
82
83    # Backward pass
84    if scaler is not None:
85        scaler.scale(loss).backward()
86        scaler.unscale_(optimizer)
87    else:
88        loss.backward()
89
90    # Monitor gradients
91    grad_stats = grad_monitor.check_gradients()
92
93    # Skip update if gradients are bad
94    if grad_stats["nan_params"] > 0:
95        optimizer.zero_grad()
96        return {"loss": loss.item(), "skipped": True, **grad_stats}
97
98    # Clip gradients
99    grad_norm = grad_clipper(model)
100
101    # Optimizer step
102    if scaler is not None:
103        scaler.step(optimizer)
104        scaler.update()
105    else:
106        optimizer.step()
107
108    return {
109        "loss": loss.item(),
110        "grad_norm": grad_norm,
111        "clip_value": grad_clipper.clip_value,
112        "skipped": False,
113        **grad_stats,
114    }

Zero-Initialize Output Layer

A crucial but often overlooked technique: initialize the final output layer to predict zero. At the start of training, the model should output small values, allowing gradients to flow back smoothly. This single change can prevent many training failures.

Mode Collapse and Diversity

Understanding Mode Collapse

Mode collapse in diffusion models appears as reduced diversity in generated samples. Unlike GANs, diffusion models rarely suffer from severe mode collapse, but subtle forms can still occur:

🐍python
1import torch
2import torch.nn as nn
3import numpy as np
4from scipy import linalg
5from collections import Counter
6from typing import List, Tuple
7
8class DiversityAnalyzer:
9    """Analyze diversity of generated samples."""
10
11    def __init__(self, feature_extractor: nn.Module, device: str = "cuda"):
12        self.feature_extractor = feature_extractor.to(device).eval()
13        self.device = device
14
15    @torch.no_grad()
16    def extract_features(self, images: torch.Tensor) -> np.ndarray:
17        """Extract features from images."""
18        features = self.feature_extractor(images.to(self.device))
19        return features.cpu().numpy()
20
21    def compute_diversity_metrics(
22        self,
23        generated_features: np.ndarray,
24        real_features: np.ndarray,
25    ) -> dict:
26        """Compute various diversity metrics."""
27
28        metrics = {}
29
30        # 1. Feature spread (variance in feature space)
31        gen_var = np.var(generated_features, axis=0).mean()
32        real_var = np.var(real_features, axis=0).mean()
33        metrics["variance_ratio"] = gen_var / (real_var + 1e-8)
34
35        # 2. Pairwise distances
36        gen_dists = self._pairwise_distances(generated_features)
37        real_dists = self._pairwise_distances(real_features)
38        metrics["mean_pairwise_distance"] = gen_dists.mean()
39        metrics["distance_ratio"] = gen_dists.mean() / (real_dists.mean() + 1e-8)
40
41        # 3. Number of distinct modes (via clustering)
42        from sklearn.cluster import KMeans
43
44        n_clusters = min(50, len(generated_features) // 10)
45        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
46        gen_clusters = kmeans.fit_predict(generated_features)
47
48        # Check cluster occupancy
49        cluster_counts = Counter(gen_clusters)
50        occupied_clusters = sum(1 for c in cluster_counts.values() if c > 0)
51        metrics["mode_coverage"] = occupied_clusters / n_clusters
52
53        # Check for dominant modes
54        max_cluster_frac = max(cluster_counts.values()) / len(generated_features)
55        metrics["max_mode_fraction"] = max_cluster_frac
56
57        # 4. Coverage: how many real modes are covered
58        real_clusters = kmeans.predict(real_features)
59        gen_modes = set(gen_clusters)
60        real_modes = set(real_clusters)
61        metrics["coverage"] = len(gen_modes & real_modes) / len(real_modes)
62
63        return metrics
64
65    def _pairwise_distances(self, features: np.ndarray) -> np.ndarray:
66        """Compute pairwise L2 distances."""
67        # Sample for efficiency
68        n = min(1000, len(features))
69        idx = np.random.choice(len(features), n, replace=False)
70        sample = features[idx]
71
72        diffs = sample[:, None] - sample[None, :]
73        dists = np.sqrt((diffs ** 2).sum(axis=-1))
74
75        # Return upper triangle (excluding diagonal)
76        return dists[np.triu_indices(n, k=1)]
77
78    def diagnose_diversity(self, metrics: dict) -> str:
79        """Provide diagnosis based on diversity metrics."""
80        issues = []
81
82        if metrics["variance_ratio"] < 0.5:
83            issues.append(
84                "Low feature variance - samples may lack diversity. "
85                "Try: longer training, more data augmentation"
86            )
87
88        if metrics["max_mode_fraction"] > 0.3:
89            issues.append(
90                f"Mode concentration detected ({metrics['max_mode_fraction']:.1%} "
91                "in single mode). Try: check data balance, increase model capacity"
92            )
93
94        if metrics["coverage"] < 0.7:
95            issues.append(
96                f"Low mode coverage ({metrics['coverage']:.1%}). "
97                "Some data modes not represented. Try: longer training, "
98                "verify data preprocessing"
99            )
100
101        if metrics["distance_ratio"] < 0.8:
102            issues.append(
103                "Samples too similar. Try: increase sampling temperature, "
104                "check for EMA decay issues"
105            )
106
107        return "\n".join(issues) if issues else "Diversity looks good"
108
109
110def prevent_mode_collapse(config: dict) -> dict:
111    """Apply techniques to prevent mode collapse."""
112
113    recommendations = {
114        # Use dropout during training
115        "dropout": 0.1,
116
117        # Don&apos;t use too aggressive EMA
118        "ema_decay": 0.9999,  # Not 0.99999
119
120        # Sufficient model capacity
121        "min_channels": 128,
122
123        # Data augmentation
124        "augmentation": {
125            "horizontal_flip": True,
126            "random_crop": True,
127            "color_jitter": 0.1,
128        },
129
130        # Balanced sampling
131        "balanced_sampling": True,
132
133        # Sufficient training
134        "min_steps": 100000,
135    }
136
137    return recommendations
SymptomLikely CauseSolution
All samples look similarEMA decay too highReduce EMA decay to 0.9999
Missing classes/modesImbalanced datasetUse balanced sampling
Blurry but uniformUndertrainedTrain longer, increase LR
Some modes overrepresentedModel capacityIncrease channels/layers

Memory Optimization

Gradient Checkpointing

Gradient checkpointing trades computation for memory by not storing all intermediate activations:

🐍python
1import torch
2import torch.nn as nn
3from torch.utils.checkpoint import checkpoint, checkpoint_sequential
4
5class MemoryEfficientUNet(nn.Module):
6    """U-Net with gradient checkpointing for memory efficiency."""
7
8    def __init__(
9        self,
10        in_channels: int = 3,
11        model_channels: int = 128,
12        num_res_blocks: int = 2,
13        attention_resolutions: tuple = (16, 8),
14        channel_mult: tuple = (1, 2, 4, 8),
15        use_checkpoint: bool = True,
16    ):
17        super().__init__()
18        self.use_checkpoint = use_checkpoint
19
20        # Build encoder blocks
21        self.encoder_blocks = nn.ModuleList()
22        ch = model_channels
23        for level, mult in enumerate(channel_mult):
24            out_ch = model_channels * mult
25            for _ in range(num_res_blocks):
26                self.encoder_blocks.append(
27                    ResBlock(ch, out_ch, use_checkpoint=use_checkpoint)
28                )
29                ch = out_ch
30            if level != len(channel_mult) - 1:
31                self.encoder_blocks.append(Downsample(ch))
32
33        # Middle blocks
34        self.middle_blocks = nn.ModuleList([
35            ResBlock(ch, ch, use_checkpoint=use_checkpoint),
36            AttentionBlock(ch),
37            ResBlock(ch, ch, use_checkpoint=use_checkpoint),
38        ])
39
40        # Build decoder blocks
41        self.decoder_blocks = nn.ModuleList()
42        for level, mult in reversed(list(enumerate(channel_mult))):
43            out_ch = model_channels * mult
44            for i in range(num_res_blocks + 1):
45                self.decoder_blocks.append(
46                    ResBlock(ch + out_ch, out_ch, use_checkpoint=use_checkpoint)
47                )
48                ch = out_ch
49            if level != 0:
50                self.decoder_blocks.append(Upsample(ch))
51
52    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
53        # Use checkpointing for memory efficiency
54        if self.use_checkpoint and self.training:
55            return self._forward_with_checkpointing(x, t)
56        return self._forward(x, t)
57
58    def _forward_with_checkpointing(
59        self,
60        x: torch.Tensor,
61        t: torch.Tensor
62    ) -> torch.Tensor:
63        """Forward pass with gradient checkpointing."""
64        emb = self.time_embed(t)
65
66        # Checkpoint encoder blocks in groups
67        h = x
68        hs = []
69        for block in self.encoder_blocks:
70            # Checkpoint each block
71            h = checkpoint(block, h, emb, use_reentrant=False)
72            hs.append(h)
73
74        # Middle blocks
75        for block in self.middle_blocks:
76            h = checkpoint(block, h, emb, use_reentrant=False)
77
78        # Decoder with skip connections
79        for block in self.decoder_blocks:
80            if isinstance(block, ResBlock):
81                h = torch.cat([h, hs.pop()], dim=1)
82            h = checkpoint(block, h, emb, use_reentrant=False)
83
84        return self.out(h)
85
86
87class ResBlock(nn.Module):
88    """Residual block with optional checkpointing."""
89
90    def __init__(
91        self,
92        in_channels: int,
93        out_channels: int,
94        time_channels: int = 512,
95        use_checkpoint: bool = False,
96    ):
97        super().__init__()
98        self.use_checkpoint = use_checkpoint
99
100        self.norm1 = nn.GroupNorm(32, in_channels)
101        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
102        self.time_proj = nn.Linear(time_channels, out_channels)
103        self.norm2 = nn.GroupNorm(32, out_channels)
104        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
105
106        if in_channels != out_channels:
107            self.skip = nn.Conv2d(in_channels, out_channels, 1)
108        else:
109            self.skip = nn.Identity()
110
111    def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
112        if self.use_checkpoint and self.training:
113            return checkpoint(
114                self._forward, x, emb,
115                use_reentrant=False
116            )
117        return self._forward(x, emb)
118
119    def _forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
120        h = self.norm1(x)
121        h = nn.functional.silu(h)
122        h = self.conv1(h)
123
124        # Add time embedding
125        h = h + self.time_proj(emb)[:, :, None, None]
126
127        h = self.norm2(h)
128        h = nn.functional.silu(h)
129        h = self.conv2(h)
130
131        return h + self.skip(x)

Memory-Efficient Training Configuration

🐍python
1import torch
2from torch.cuda.amp import autocast, GradScaler
3import gc
4
5class MemoryOptimizedTrainer:
6    """Training configuration optimized for memory."""
7
8    def __init__(
9        self,
10        model: nn.Module,
11        target_batch_size: int = 64,
12        device_batch_size: int = 8,  # What fits in GPU memory
13        use_amp: bool = True,
14        use_compile: bool = True,  # PyTorch 2.0+
15    ):
16        self.model = model
17        self.target_batch_size = target_batch_size
18        self.device_batch_size = device_batch_size
19        self.accumulation_steps = target_batch_size // device_batch_size
20        self.use_amp = use_amp
21
22        # Mixed precision
23        self.scaler = GradScaler() if use_amp else None
24
25        # Compile model for better performance (PyTorch 2.0+)
26        if use_compile and hasattr(torch, "compile"):
27            self.model = torch.compile(
28                model,
29                mode="reduce-overhead",  # Or "max-autotune" for best perf
30            )
31
32    @staticmethod
33    def estimate_memory(
34        model: nn.Module,
35        batch_size: int,
36        image_size: int,
37        use_amp: bool = True,
38    ) -> dict:
39        """Estimate memory requirements."""
40
41        # Count parameters
42        params = sum(p.numel() for p in model.parameters())
43        param_bytes = params * (2 if use_amp else 4)  # FP16 or FP32
44
45        # Estimate activation memory (rough)
46        # For U-Net, roughly 10x parameter count per sample
47        activation_bytes = params * 10 * batch_size * (2 if use_amp else 4)
48
49        # Gradient memory
50        grad_bytes = param_bytes
51
52        # Optimizer states (Adam has 2 states per parameter)
53        optim_bytes = params * 2 * 4  # Always FP32
54
55        total_bytes = param_bytes + activation_bytes + grad_bytes + optim_bytes
56
57        return {
58            "parameters_mb": param_bytes / 1e6,
59            "activations_mb": activation_bytes / 1e6,
60            "gradients_mb": grad_bytes / 1e6,
61            "optimizer_mb": optim_bytes / 1e6,
62            "total_mb": total_bytes / 1e6,
63            "total_gb": total_bytes / 1e9,
64        }
65
66    @staticmethod
67    def optimize_for_memory() -> dict:
68        """Return memory optimization settings."""
69        return {
70            # Enable memory efficient attention
71            "attention": "flash",  # or "xformers"
72
73            # Use gradient checkpointing
74            "gradient_checkpointing": True,
75
76            # Use mixed precision
77            "mixed_precision": "fp16",  # or "bf16" on Ampere+
78
79            # Efficient data loading
80            "num_workers": 4,
81            "pin_memory": True,
82            "prefetch_factor": 2,
83
84            # Clear cache periodically
85            "empty_cache_freq": 100,  # steps
86        }
87
88    def train_step(
89        self,
90        batch: torch.Tensor,
91        optimizer: torch.optim.Optimizer,
92        diffusion,
93        step: int,
94    ) -> float:
95        """Memory-optimized training step with gradient accumulation."""
96
97        # Split batch for accumulation
98        micro_batches = batch.split(self.device_batch_size)
99        total_loss = 0.0
100
101        for i, micro_batch in enumerate(micro_batches):
102            # Use AMP for forward pass
103            with autocast(enabled=self.use_amp):
104                t = torch.randint(
105                    0, diffusion.timesteps,
106                    (micro_batch.shape[0],),
107                    device=micro_batch.device
108                )
109                noise = torch.randn_like(micro_batch)
110                x_noisy = diffusion.q_sample(micro_batch, t, noise)
111                pred = self.model(x_noisy, t)
112                loss = nn.functional.mse_loss(pred, noise)
113                loss = loss / self.accumulation_steps
114
115            # Backward pass
116            if self.scaler is not None:
117                self.scaler.scale(loss).backward()
118            else:
119                loss.backward()
120
121            total_loss += loss.item()
122
123        # Optimizer step after accumulation
124        if self.scaler is not None:
125            self.scaler.unscale_(optimizer)
126            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
127            self.scaler.step(optimizer)
128            self.scaler.update()
129        else:
130            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
131            optimizer.step()
132
133        optimizer.zero_grad(set_to_none=True)  # More memory efficient
134
135        # Periodic cache clearing
136        if step % 100 == 0:
137            gc.collect()
138            torch.cuda.empty_cache()
139
140        return total_loss * self.accumulation_steps

Flash Attention

On modern GPUs (Ampere and later), use Flash Attention for 2-4x memory savings in attention layers. This is now built into PyTorch 2.0+ viatorch.nn.functional.scaled_dot_product_attention.

Training Instability

Learning Rate Issues

Learning rate problems are often the root cause of training instability:

🐍python
1import torch
2from torch.optim import AdamW
3from torch.optim.lr_scheduler import (
4    CosineAnnealingLR,
5    LinearLR,
6    SequentialLR,
7)
8import math
9
10class StableTrainingConfig:
11    """Configuration for stable diffusion model training."""
12
13    @staticmethod
14    def get_optimizer(
15        model: nn.Module,
16        learning_rate: float = 1e-4,
17        weight_decay: float = 0.01,
18        betas: tuple = (0.9, 0.999),
19    ) -> torch.optim.Optimizer:
20        """Get optimizer with appropriate settings."""
21
22        # Separate parameters for weight decay
23        decay_params = []
24        no_decay_params = []
25
26        for name, param in model.named_parameters():
27            if not param.requires_grad:
28                continue
29            # Don&apos;t apply weight decay to biases and normalization
30            if "bias" in name or "norm" in name:
31                no_decay_params.append(param)
32            else:
33                decay_params.append(param)
34
35        param_groups = [
36            {"params": decay_params, "weight_decay": weight_decay},
37            {"params": no_decay_params, "weight_decay": 0.0},
38        ]
39
40        return AdamW(
41            param_groups,
42            lr=learning_rate,
43            betas=betas,
44            eps=1e-8,
45        )
46
47    @staticmethod
48    def get_scheduler(
49        optimizer: torch.optim.Optimizer,
50        warmup_steps: int = 5000,
51        total_steps: int = 500000,
52        min_lr_ratio: float = 0.1,
53    ):
54        """Get learning rate scheduler with warmup."""
55
56        # Linear warmup
57        warmup_scheduler = LinearLR(
58            optimizer,
59            start_factor=0.01,
60            end_factor=1.0,
61            total_iters=warmup_steps,
62        )
63
64        # Cosine decay after warmup
65        decay_scheduler = CosineAnnealingLR(
66            optimizer,
67            T_max=total_steps - warmup_steps,
68            eta_min=optimizer.param_groups[0]["lr"] * min_lr_ratio,
69        )
70
71        # Combine schedulers
72        scheduler = SequentialLR(
73            optimizer,
74            schedulers=[warmup_scheduler, decay_scheduler],
75            milestones=[warmup_steps],
76        )
77
78        return scheduler
79
80    @staticmethod
81    def diagnose_lr_issues(loss_history: list) -> str:
82        """Diagnose learning rate issues from loss curve."""
83
84        if len(loss_history) < 100:
85            return "Need more data to diagnose"
86
87        issues = []
88
89        # Check for oscillation
90        recent = loss_history[-100:]
91        diffs = [recent[i+1] - recent[i] for i in range(len(recent)-1)]
92        sign_changes = sum(
93            1 for i in range(len(diffs)-1)
94            if diffs[i] * diffs[i+1] < 0
95        )
96
97        if sign_changes > 60:  # Too many sign changes
98            issues.append(
99                "High oscillation detected. Learning rate may be too high. "
100                "Try reducing by 2-5x."
101            )
102
103        # Check for plateau
104        first_half = sum(recent[:50]) / 50
105        second_half = sum(recent[50:]) / 50
106
107        if abs(first_half - second_half) / first_half < 0.01:
108            issues.append(
109                "Loss plateau detected. Try: increase LR, check gradients, "
110                "verify data pipeline."
111            )
112
113        # Check for divergence
114        if recent[-1] > recent[0] * 1.5:
115            issues.append(
116                "Loss increasing! Possible divergence. "
117                "Immediately reduce LR or restart from checkpoint."
118            )
119
120        return "\n".join(issues) if issues else "Learning rate seems appropriate"
121
122
123class TrainingStabilizer:
124    """Apply various stabilization techniques."""
125
126    @staticmethod
127    def apply_spectral_norm(model: nn.Module) -> nn.Module:
128        """Apply spectral normalization to conv layers."""
129        for name, module in model.named_modules():
130            if isinstance(module, (nn.Conv2d, nn.Linear)):
131                # Don&apos;t apply to output layer
132                if "out" not in name:
133                    nn.utils.spectral_norm(module)
134        return model
135
136    @staticmethod
137    def apply_ema(
138        model: nn.Module,
139        ema_model: nn.Module,
140        decay: float = 0.9999,
141    ) -> None:
142        """Update EMA model parameters."""
143        with torch.no_grad():
144            for ema_param, param in zip(
145                ema_model.parameters(),
146                model.parameters()
147            ):
148                ema_param.data.mul_(decay).add_(param.data, alpha=1 - decay)
149
150    @staticmethod
151    def detect_nan(tensor: torch.Tensor, name: str = "tensor") -> bool:
152        """Check for NaN values and warn."""
153        if torch.isnan(tensor).any():
154            print(f"WARNING: NaN detected in {name}")
155            return True
156        if torch.isinf(tensor).any():
157            print(f"WARNING: Inf detected in {name}")
158            return True
159        return False
SymptomDiagnosisFix
Loss spikes randomlyLR too high during warmupSlower warmup (10k steps)
Loss stuck after warmupLR too lowIncrease base LR 2-5x
Gradual divergenceNo LR decayAdd cosine/linear decay
Training unstable after long timeNumerical issuesUse BF16 instead of FP16

Debugging Toolkit

🐍python
1import torch
2import torch.nn as nn
3from typing import Dict, List, Optional
4import numpy as np
5
6class DiffusionDebugger:
7    """Comprehensive debugging toolkit for diffusion models."""
8
9    def __init__(self, model: nn.Module, diffusion):
10        self.model = model
11        self.diffusion = diffusion
12        self.debug_info = {}
13
14    def run_diagnostics(
15        self,
16        sample_batch: torch.Tensor,
17        verbose: bool = True,
18    ) -> Dict:
19        """Run comprehensive diagnostics."""
20
21        results = {}
22
23        # 1. Model output sanity check
24        results["model_output"] = self._check_model_output(sample_batch)
25
26        # 2. Gradient flow check
27        results["gradients"] = self._check_gradient_flow(sample_batch)
28
29        # 3. Timestep response check
30        results["timestep_response"] = self._check_timestep_response(sample_batch)
31
32        # 4. Noise schedule check
33        results["noise_schedule"] = self._check_noise_schedule()
34
35        # 5. Generate diagnostic samples
36        results["sample_quality"] = self._check_sample_quality()
37
38        if verbose:
39            self._print_report(results)
40
41        return results
42
43    def _check_model_output(self, batch: torch.Tensor) -> Dict:
44        """Check model output statistics."""
45        self.model.eval()
46
47        with torch.no_grad():
48            t = torch.randint(
49                0, self.diffusion.timesteps,
50                (batch.shape[0],),
51                device=batch.device
52            )
53            noise = torch.randn_like(batch)
54            x_noisy = self.diffusion.q_sample(batch, t, noise)
55            pred = self.model(x_noisy, t)
56
57        stats = {
58            "mean": pred.mean().item(),
59            "std": pred.std().item(),
60            "min": pred.min().item(),
61            "max": pred.max().item(),
62            "has_nan": torch.isnan(pred).any().item(),
63            "has_inf": torch.isinf(pred).any().item(),
64        }
65
66        # Expected: mean near 0, std near 1 (predicting noise)
67        stats["healthy"] = (
68            abs(stats["mean"]) < 0.5 and
69            0.5 < stats["std"] < 2.0 and
70            not stats["has_nan"] and
71            not stats["has_inf"]
72        )
73
74        return stats
75
76    def _check_gradient_flow(self, batch: torch.Tensor) -> Dict:
77        """Check gradient flow through model."""
78        self.model.train()
79
80        t = torch.randint(
81            0, self.diffusion.timesteps,
82            (batch.shape[0],),
83            device=batch.device
84        )
85        noise = torch.randn_like(batch)
86        x_noisy = self.diffusion.q_sample(batch, t, noise)
87        pred = self.model(x_noisy, t)
88        loss = nn.functional.mse_loss(pred, noise)
89        loss.backward()
90
91        grad_stats = {}
92        zero_grad_layers = []
93        large_grad_layers = []
94
95        for name, param in self.model.named_parameters():
96            if param.grad is not None:
97                grad_norm = param.grad.norm().item()
98                if grad_norm == 0:
99                    zero_grad_layers.append(name)
100                elif grad_norm > 100:
101                    large_grad_layers.append((name, grad_norm))
102
103        grad_stats["zero_grad_layers"] = zero_grad_layers
104        grad_stats["large_grad_layers"] = large_grad_layers
105        grad_stats["healthy"] = (
106            len(zero_grad_layers) == 0 and
107            len(large_grad_layers) == 0
108        )
109
110        self.model.zero_grad()
111        return grad_stats
112
113    def _check_timestep_response(self, batch: torch.Tensor) -> Dict:
114        """Check model response across different timesteps."""
115        self.model.eval()
116
117        timesteps_to_check = [0, 100, 500, 900, 999]
118        responses = {}
119
120        with torch.no_grad():
121            x = batch[:4]  # Small batch
122            noise = torch.randn_like(x)
123
124            for t_val in timesteps_to_check:
125                t = torch.full((x.shape[0],), t_val, device=x.device)
126                x_noisy = self.diffusion.q_sample(x, t, noise)
127                pred = self.model(x_noisy, t)
128
129                responses[t_val] = {
130                    "pred_std": pred.std().item(),
131                    "mse_to_noise": nn.functional.mse_loss(pred, noise).item(),
132                }
133
134        # Check that model responds differently to different timesteps
135        stds = [r["pred_std"] for r in responses.values()]
136        stats = {
137            "responses": responses,
138            "std_variation": np.std(stds),
139            "healthy": np.std(stds) > 0.01,  # Should vary with timestep
140        }
141
142        return stats
143
144    def _check_noise_schedule(self) -> Dict:
145        """Verify noise schedule properties."""
146        betas = self.diffusion.betas.cpu().numpy()
147        alphas = self.diffusion.alphas.cpu().numpy()
148        alphas_cumprod = self.diffusion.alphas_cumprod.cpu().numpy()
149
150        stats = {
151            "beta_min": betas.min(),
152            "beta_max": betas.max(),
153            "alpha_bar_start": alphas_cumprod[0],
154            "alpha_bar_end": alphas_cumprod[-1],
155            "monotonic": all(
156                alphas_cumprod[i] >= alphas_cumprod[i+1]
157                for i in range(len(alphas_cumprod)-1)
158            ),
159        }
160
161        # Check for healthy schedule
162        stats["healthy"] = (
163            stats["beta_min"] > 0 and
164            stats["beta_max"] < 0.1 and
165            stats["alpha_bar_start"] > 0.99 and
166            stats["alpha_bar_end"] < 0.01 and
167            stats["monotonic"]
168        )
169
170        return stats
171
172    def _check_sample_quality(self, num_samples: int = 4) -> Dict:
173        """Generate samples and check basic quality."""
174        self.model.eval()
175
176        with torch.no_grad():
177            # Quick sampling with fewer steps
178            samples = self._quick_sample(num_samples, steps=50)
179
180        stats = {
181            "mean": samples.mean().item(),
182            "std": samples.std().item(),
183            "min": samples.min().item(),
184            "max": samples.max().item(),
185            "in_range": (
186                samples.min().item() >= -1.5 and
187                samples.max().item() <= 1.5
188            ),
189        }
190
191        stats["healthy"] = stats["in_range"]
192        return stats
193
194    def _quick_sample(
195        self,
196        num_samples: int,
197        steps: int = 50,
198    ) -> torch.Tensor:
199        """Quick sampling with DDIM for debugging."""
200        device = next(self.model.parameters()).device
201        shape = (num_samples, 3, 64, 64)
202
203        x = torch.randn(shape, device=device)
204
205        # DDIM with few steps
206        step_size = self.diffusion.timesteps // steps
207        timesteps = list(range(0, self.diffusion.timesteps, step_size))[::-1]
208
209        for i, t in enumerate(timesteps):
210            t_tensor = torch.full((num_samples,), t, device=device, dtype=torch.long)
211            pred_noise = self.model(x, t_tensor)
212
213            alpha_bar = self.diffusion.alphas_cumprod[t]
214            alpha_bar_prev = (
215                self.diffusion.alphas_cumprod[timesteps[i+1]]
216                if i < len(timesteps) - 1
217                else torch.tensor(1.0)
218            )
219
220            # DDIM update
221            pred_x0 = (x - torch.sqrt(1 - alpha_bar) * pred_noise) / torch.sqrt(alpha_bar)
222            pred_x0 = pred_x0.clamp(-1, 1)
223
224            x = (
225                torch.sqrt(alpha_bar_prev) * pred_x0 +
226                torch.sqrt(1 - alpha_bar_prev) * pred_noise
227            )
228
229        return x
230
231    def _print_report(self, results: Dict) -> None:
232        """Print human-readable diagnostic report."""
233        print("=" * 60)
234        print("DIFFUSION MODEL DIAGNOSTIC REPORT")
235        print("=" * 60)
236
237        for category, stats in results.items():
238            healthy = stats.get("healthy", "Unknown")
239            status = "OK" if healthy else "ISSUE"
240            print(f"\n{category.upper()}: [{status}]")
241
242            for key, value in stats.items():
243                if key != "healthy":
244                    print(f"  {key}: {value}")
245
246        print("\n" + "=" * 60)
247
248        # Summary
249        all_healthy = all(
250            stats.get("healthy", True)
251            for stats in results.values()
252        )
253
254        if all_healthy:
255            print("All diagnostics passed!")
256        else:
257            print("Issues detected. Review the report above.")
258
259
260# Usage example
261def debug_training_issue(model, diffusion, sample_batch):
262    """Debug a training issue."""
263    debugger = DiffusionDebugger(model, diffusion)
264
265    # Run full diagnostics
266    results = debugger.run_diagnostics(sample_batch)
267
268    # Get specific recommendations
269    if not results["model_output"]["healthy"]:
270        print("\nModel output issue detected!")
271        print("Try: Check initialization, reduce model complexity")
272
273    if not results["gradients"]["healthy"]:
274        print("\nGradient issue detected!")
275        print("Try: Gradient clipping, lower learning rate")
276
277    if not results["timestep_response"]["healthy"]:
278        print("\nTimestep conditioning issue!")
279        print("Try: Check time embedding, verify conditioning mechanism")
280
281    return results

Key Takeaways

  1. Zero-initialize output layers: This simple technique prevents many gradient issues at the start of training.
  2. Use gradient checkpointing: Essential for training large models on limited GPU memory.
  3. Monitor diversity: Track variance ratios and mode coverage to catch subtle collapse early.
  4. Warmup is critical: Use 5-10k step linear warmup to stabilize early training.
  5. Debug systematically: Use the diagnostic toolkit to identify issues rather than guessing.
Looking Ahead: With a fully trained model, we are ready to generate and evaluate samples. The next chapter covers the complete generation pipeline and evaluation metrics like FID, IS, and perceptual quality assessment.