Learning Objectives
By the end of this section, you will be able to:
- Organize visual inspection of generated samples to identify quality issues systematically
- Explore the latent space to understand model behavior and interpolation quality
- Identify common failure modes and their root causes
- Create comprehensive evaluation reports combining quantitative and qualitative analysis
The Big Picture
While quantitative metrics like FID provide numerical comparisons, they can miss important aspects of generation quality that humans notice immediately. Qualitative analysis complements metrics by revealing what the model has learned, where it struggles, and how it behaves in edge cases.
The Qualitative Perspective: A model with FID of 5.0 might produce beautiful samples 90% of the time but have severe artifacts in 10% of cases. Another model with FID of 7.0 might be more consistent but less impressive. Only visual inspection reveals these differences.
| Analysis Type | What It Reveals | When to Use |
|---|---|---|
| Random samples | Overall quality distribution | First check after training |
| Curated samples | Best-case capabilities | Paper figures, demos |
| Failure cases | Systematic problems | Debugging, improvement |
| Interpolations | Latent space smoothness | Understanding representations |
| Class-conditional | Per-class quality variations | Identifying weak classes |
Visual Inspection Techniques
Sample Grid Generation
Create organized grids of samples for systematic evaluation:
🐍python
1import torch
2import torchvision.utils as vutils
3import matplotlib.pyplot as plt
4from pathlib import Path
5from typing import Optional, List, Callable
6import numpy as np
7
8class SampleInspector:
9 """Tools for visual inspection of generated samples."""
10
11 def __init__(
12 self,
13 model,
14 sampler,
15 output_dir: str,
16 device: str = "cuda",
17 ):
18 self.model = model.to(device).eval()
19 self.sampler = sampler
20 self.output_dir = Path(output_dir)
21 self.output_dir.mkdir(parents=True, exist_ok=True)
22 self.device = device
23
24 def generate_random_grid(
25 self,
26 num_samples: int = 64,
27 num_steps: int = 50,
28 image_size: int = 64,
29 nrow: int = 8,
30 title: str = "Random Samples",
31 save_name: str = "random_grid.png",
32 ) -> torch.Tensor:
33 """Generate and display a grid of random samples."""
34 shape = (num_samples, 3, image_size, image_size)
35 samples = self.sampler.sample(shape, num_steps=num_steps)
36
37 # Create figure
38 fig, ax = plt.subplots(figsize=(12, 12))
39 grid = vutils.make_grid(
40 samples,
41 nrow=nrow,
42 normalize=True,
43 value_range=(-1, 1),
44 padding=2,
45 )
46 ax.imshow(grid.permute(1, 2, 0).cpu().numpy())
47 ax.set_title(title, fontsize=14)
48 ax.axis("off")
49
50 # Save
51 plt.savefig(self.output_dir / save_name, dpi=150, bbox_inches="tight")
52 plt.close()
53
54 return samples
55
56 def generate_seeded_grid(
57 self,
58 seeds: List[int],
59 num_steps: int = 50,
60 image_size: int = 64,
61 nrow: int = 8,
62 title: str = "Seeded Samples",
63 save_name: str = "seeded_grid.png",
64 ) -> torch.Tensor:
65 """Generate samples from specific seeds for reproducibility."""
66 samples = []
67 for seed in seeds:
68 torch.manual_seed(seed)
69 noise = torch.randn(1, 3, image_size, image_size, device=self.device)
70 sample = self.sampler.sample(noise.shape, num_steps=num_steps)
71 samples.append(sample)
72
73 samples = torch.cat(samples, dim=0)
74
75 # Create figure with seed labels
76 fig, axes = plt.subplots(
77 (len(seeds) + nrow - 1) // nrow,
78 nrow,
79 figsize=(12, 12 * ((len(seeds) + nrow - 1) // nrow) / nrow)
80 )
81 axes = axes.flatten() if len(seeds) > nrow else [axes] if len(seeds) == 1 else axes
82
83 for i, (seed, sample, ax) in enumerate(zip(seeds, samples, axes)):
84 img = (sample.permute(1, 2, 0).cpu().numpy() + 1) / 2
85 ax.imshow(img.clip(0, 1))
86 ax.set_title(f"Seed: {seed}", fontsize=10)
87 ax.axis("off")
88
89 # Hide unused axes
90 for j in range(len(seeds), len(axes)):
91 axes[j].axis("off")
92
93 plt.suptitle(title, fontsize=14)
94 plt.tight_layout()
95 plt.savefig(self.output_dir / save_name, dpi=150, bbox_inches="tight")
96 plt.close()
97
98 return samples
99
100 def compare_timesteps(
101 self,
102 num_steps_list: List[int] = [10, 25, 50, 100, 200, 500, 1000],
103 seed: int = 42,
104 image_size: int = 64,
105 save_name: str = "timestep_comparison.png",
106 ) -> None:
107 """Compare samples generated with different numbers of steps."""
108 torch.manual_seed(seed)
109 initial_noise = torch.randn(1, 3, image_size, image_size, device=self.device)
110
111 fig, axes = plt.subplots(1, len(num_steps_list), figsize=(3 * len(num_steps_list), 3))
112
113 for i, num_steps in enumerate(num_steps_list):
114 sample = self.sampler.sample(
115 initial_noise.shape,
116 num_steps=num_steps,
117 )
118 img = (sample[0].permute(1, 2, 0).cpu().numpy() + 1) / 2
119 axes[i].imshow(img.clip(0, 1))
120 axes[i].set_title(f"{num_steps} steps", fontsize=10)
121 axes[i].axis("off")
122
123 plt.suptitle("Effect of Sampling Steps", fontsize=14)
124 plt.tight_layout()
125 plt.savefig(self.output_dir / save_name, dpi=150, bbox_inches="tight")
126 plt.close()
127
128 def analyze_sample_statistics(
129 self,
130 num_samples: int = 1000,
131 num_steps: int = 50,
132 ) -> dict:
133 """Analyze statistical properties of generated samples."""
134 samples = []
135 batch_size = 64
136 shape = (batch_size, 3, 64, 64)
137
138 for _ in range(num_samples // batch_size):
139 batch = self.sampler.sample(shape, num_steps=num_steps)
140 samples.append(batch.cpu())
141
142 samples = torch.cat(samples, dim=0)
143
144 # Compute statistics
145 stats = {
146 "mean": samples.mean().item(),
147 "std": samples.std().item(),
148 "min": samples.min().item(),
149 "max": samples.max().item(),
150 "per_channel_mean": samples.mean(dim=(0, 2, 3)).tolist(),
151 "per_channel_std": samples.std(dim=(0, 2, 3)).tolist(),
152 }
153
154 # Check for outliers
155 outlier_count = (
156 (samples < -1.5).sum().item() +
157 (samples > 1.5).sum().item()
158 )
159 stats["outlier_fraction"] = outlier_count / samples.numel()
160
161 # Create histogram
162 fig, axes = plt.subplots(1, 3, figsize=(15, 4))
163 colors = ["red", "green", "blue"]
164 channel_names = ["Red", "Green", "Blue"]
165
166 for i in range(3):
167 channel_data = samples[:, i].flatten().numpy()
168 axes[i].hist(channel_data, bins=100, alpha=0.7, color=colors[i])
169 axes[i].set_title(f"{channel_names[i]} Channel Distribution")
170 axes[i].set_xlabel("Pixel Value")
171 axes[i].set_ylabel("Frequency")
172 axes[i].axvline(-1, color="black", linestyle="--", alpha=0.5)
173 axes[i].axvline(1, color="black", linestyle="--", alpha=0.5)
174
175 plt.tight_layout()
176 plt.savefig(self.output_dir / "sample_statistics.png", dpi=150)
177 plt.close()
178
179 return statsLatent Space Exploration
Interpolation Analysis
Smooth interpolations between samples indicate a well-structured latent space. Sudden changes or artifacts during interpolation suggest problems:
🐍python
1import torch
2import numpy as np
3import matplotlib.pyplot as plt
4from typing import List, Tuple
5
6class LatentExplorer:
7 """Explore the latent space of a diffusion model."""
8
9 def __init__(self, model, sampler, device: str = "cuda"):
10 self.model = model.to(device).eval()
11 self.sampler = sampler
12 self.device = device
13
14 def linear_interpolation(
15 self,
16 z1: torch.Tensor,
17 z2: torch.Tensor,
18 num_steps: int = 10,
19 ) -> List[torch.Tensor]:
20 """Linear interpolation between two latent codes."""
21 interpolations = []
22 for alpha in np.linspace(0, 1, num_steps):
23 z = (1 - alpha) * z1 + alpha * z2
24 interpolations.append(z)
25 return interpolations
26
27 def spherical_interpolation(
28 self,
29 z1: torch.Tensor,
30 z2: torch.Tensor,
31 num_steps: int = 10,
32 ) -> List[torch.Tensor]:
33 """Spherical interpolation (slerp) between two latent codes."""
34 # Flatten for computation
35 z1_flat = z1.view(z1.shape[0], -1)
36 z2_flat = z2.view(z2.shape[0], -1)
37
38 # Normalize
39 z1_norm = z1_flat / z1_flat.norm(dim=-1, keepdim=True)
40 z2_norm = z2_flat / z2_flat.norm(dim=-1, keepdim=True)
41
42 # Compute angle
43 dot = (z1_norm * z2_norm).sum(dim=-1, keepdim=True)
44 omega = torch.acos(dot.clamp(-1, 1))
45
46 interpolations = []
47 for alpha in np.linspace(0, 1, num_steps):
48 sin_omega = torch.sin(omega)
49 if sin_omega.abs().min() < 1e-6:
50 # Fall back to linear
51 z = (1 - alpha) * z1_flat + alpha * z2_flat
52 else:
53 z = (
54 torch.sin((1 - alpha) * omega) / sin_omega * z1_flat +
55 torch.sin(alpha * omega) / sin_omega * z2_flat
56 )
57 z = z.view(z1.shape)
58 interpolations.append(z)
59
60 return interpolations
61
62 def generate_interpolation_grid(
63 self,
64 num_pairs: int = 4,
65 num_interpolations: int = 8,
66 num_steps: int = 50,
67 image_size: int = 64,
68 use_slerp: bool = True,
69 save_path: str = "interpolation_grid.png",
70 ) -> None:
71 """Generate a grid showing interpolations between random pairs."""
72 fig, axes = plt.subplots(num_pairs, num_interpolations, figsize=(2 * num_interpolations, 2 * num_pairs))
73
74 for row in range(num_pairs):
75 # Generate two random starting points
76 z1 = torch.randn(1, 3, image_size, image_size, device=self.device)
77 z2 = torch.randn(1, 3, image_size, image_size, device=self.device)
78
79 # Interpolate
80 if use_slerp:
81 latents = self.spherical_interpolation(z1, z2, num_interpolations)
82 else:
83 latents = self.linear_interpolation(z1, z2, num_interpolations)
84
85 # Generate samples
86 for col, z in enumerate(latents):
87 sample = self.sampler.sample(z.shape, num_steps=num_steps)
88 img = (sample[0].permute(1, 2, 0).cpu().numpy() + 1) / 2
89 axes[row, col].imshow(img.clip(0, 1))
90 axes[row, col].axis("off")
91
92 plt.suptitle("Latent Space Interpolations", fontsize=14)
93 plt.tight_layout()
94 plt.savefig(save_path, dpi=150, bbox_inches="tight")
95 plt.close()
96
97 def latent_walk(
98 self,
99 start_z: torch.Tensor,
100 directions: int = 4,
101 steps_per_direction: int = 5,
102 step_size: float = 0.5,
103 num_sampling_steps: int = 50,
104 save_path: str = "latent_walk.png",
105 ) -> None:
106 """Random walk in latent space from a starting point."""
107 fig, axes = plt.subplots(directions, steps_per_direction * 2 + 1, figsize=(20, 8))
108
109 for d in range(directions):
110 # Random direction
111 direction = torch.randn_like(start_z)
112 direction = direction / direction.norm() * step_size
113
114 # Walk in both directions
115 samples = []
116 for i in range(-steps_per_direction, steps_per_direction + 1):
117 z = start_z + i * direction
118 sample = self.sampler.sample(z.shape, num_steps=num_sampling_steps)
119 samples.append(sample)
120
121 # Display
122 for col, sample in enumerate(samples):
123 img = (sample[0].permute(1, 2, 0).cpu().numpy() + 1) / 2
124 axes[d, col].imshow(img.clip(0, 1))
125 axes[d, col].axis("off")
126
127 # Mark center
128 if col == steps_per_direction:
129 axes[d, col].patch.set_edgecolor("red")
130 axes[d, col].patch.set_linewidth(3)
131
132 plt.suptitle("Latent Space Walk", fontsize=14)
133 plt.tight_layout()
134 plt.savefig(save_path, dpi=150, bbox_inches="tight")
135 plt.close()
136
137 def analyze_interpolation_smoothness(
138 self,
139 num_interpolations: int = 50,
140 num_pairs: int = 100,
141 num_steps: int = 50,
142 ) -> dict:
143 """Quantify interpolation smoothness using perceptual distance."""
144 from torchvision.models import vgg16, VGG16_Weights
145 import torch.nn as nn
146
147 # Load VGG for perceptual distance
148 vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features[:16]
149 vgg = vgg.to(self.device).eval()
150
151 def perceptual_distance(x1, x2):
152 f1 = vgg(x1)
153 f2 = vgg(x2)
154 return nn.functional.mse_loss(f1, f2).item()
155
156 # Measure distances along interpolations
157 all_distances = []
158 all_acceleration = [] # Second derivative
159
160 for _ in range(num_pairs):
161 z1 = torch.randn(1, 3, 64, 64, device=self.device)
162 z2 = torch.randn(1, 3, 64, 64, device=self.device)
163
164 latents = self.spherical_interpolation(z1, z2, num_interpolations)
165
166 # Generate samples
167 samples = [
168 self.sampler.sample(z.shape, num_steps=num_steps)
169 for z in latents
170 ]
171
172 # Compute consecutive distances
173 distances = []
174 for i in range(len(samples) - 1):
175 d = perceptual_distance(samples[i], samples[i + 1])
176 distances.append(d)
177 all_distances.append(d)
178
179 # Compute acceleration (smoothness measure)
180 for i in range(len(distances) - 1):
181 accel = abs(distances[i + 1] - distances[i])
182 all_acceleration.append(accel)
183
184 return {
185 "mean_distance": np.mean(all_distances),
186 "std_distance": np.std(all_distances),
187 "max_distance": np.max(all_distances),
188 "mean_acceleration": np.mean(all_acceleration),
189 "std_acceleration": np.std(all_acceleration),
190 "smoothness_score": 1.0 / (1.0 + np.mean(all_acceleration)),
191 }Failure Mode Analysis
Identifying Common Issues
Systematically categorize and analyze failure cases:
🐍python
1import torch
2import torch.nn as nn
3from typing import Dict, List, Tuple
4import numpy as np
5from collections import defaultdict
6import matplotlib.pyplot as plt
7from pathlib import Path
8
9class FailureModeAnalyzer:
10 """Analyze and categorize failure modes in generated samples."""
11
12 def __init__(self, output_dir: str):
13 self.output_dir = Path(output_dir)
14 self.output_dir.mkdir(parents=True, exist_ok=True)
15 self.failure_categories = defaultdict(list)
16
17 def detect_artifacts(
18 self,
19 samples: torch.Tensor,
20 threshold: float = 0.1,
21 ) -> Dict[str, List[int]]:
22 """Detect various types of artifacts in samples."""
23 artifacts = {
24 "high_frequency_noise": [],
25 "color_banding": [],
26 "value_clipping": [],
27 "blur": [],
28 "repeated_patterns": [],
29 }
30
31 for i, sample in enumerate(samples):
32 # High frequency noise detection (Laplacian variance)
33 gray = sample.mean(dim=0)
34 laplacian = torch.abs(
35 gray[2:, 1:-1] + gray[:-2, 1:-1] +
36 gray[1:-1, 2:] + gray[1:-1, :-2] -
37 4 * gray[1:-1, 1:-1]
38 )
39 if laplacian.var() > 0.5:
40 artifacts["high_frequency_noise"].append(i)
41
42 # Color banding (few unique values)
43 unique_values = len(torch.unique(sample))
44 if unique_values < sample.numel() * 0.1:
45 artifacts["color_banding"].append(i)
46
47 # Value clipping (too many values at boundaries)
48 clipped = (
49 (sample < -0.99).sum() + (sample > 0.99).sum()
50 ).item()
51 if clipped > sample.numel() * threshold:
52 artifacts["value_clipping"].append(i)
53
54 # Blur detection (low Laplacian variance)
55 if laplacian.var() < 0.01:
56 artifacts["blur"].append(i)
57
58 # Repeated patterns (autocorrelation)
59 for c in range(3):
60 channel = sample[c]
61 autocorr = self._compute_autocorrelation(channel)
62 if autocorr > 0.5:
63 artifacts["repeated_patterns"].append(i)
64 break
65
66 return artifacts
67
68 def _compute_autocorrelation(self, image: torch.Tensor) -> float:
69 """Compute spatial autocorrelation for pattern detection."""
70 h, w = image.shape
71 shifts = [(4, 0), (0, 4), (8, 0), (0, 8)]
72
73 max_corr = 0
74 for dy, dx in shifts:
75 if dy > 0:
76 shifted = image[dy:, :]
77 original = image[:-dy, :]
78 else:
79 shifted = image[:, dx:]
80 original = image[:, :-dx]
81
82 corr = torch.corrcoef(
83 torch.stack([original.flatten(), shifted.flatten()])
84 )[0, 1]
85
86 if not torch.isnan(corr):
87 max_corr = max(max_corr, corr.abs().item())
88
89 return max_corr
90
91 def analyze_diversity_failure(
92 self,
93 samples: torch.Tensor,
94 threshold: float = 0.9,
95 ) -> Dict[str, any]:
96 """Detect mode collapse or low diversity."""
97 # Compute pairwise similarities
98 samples_flat = samples.view(len(samples), -1)
99 samples_norm = samples_flat / samples_flat.norm(dim=1, keepdim=True)
100
101 similarities = torch.mm(samples_norm, samples_norm.t())
102
103 # Exclude diagonal
104 mask = ~torch.eye(len(samples), dtype=torch.bool, device=samples.device)
105 pairwise_sims = similarities[mask]
106
107 # Analysis
108 very_similar_pairs = (pairwise_sims > threshold).sum().item()
109 total_pairs = len(pairwise_sims)
110
111 # Find near-duplicate clusters
112 clusters = self._find_similar_clusters(similarities.cpu().numpy(), threshold)
113
114 return {
115 "mean_similarity": pairwise_sims.mean().item(),
116 "max_similarity": pairwise_sims.max().item(),
117 "duplicate_fraction": very_similar_pairs / total_pairs,
118 "num_clusters": len(clusters),
119 "largest_cluster_size": max(len(c) for c in clusters) if clusters else 0,
120 "unique_samples_fraction": len([c for c in clusters if len(c) == 1]) / len(samples),
121 }
122
123 def _find_similar_clusters(
124 self,
125 similarity_matrix: np.ndarray,
126 threshold: float,
127 ) -> List[List[int]]:
128 """Find clusters of similar samples."""
129 n = len(similarity_matrix)
130 visited = set()
131 clusters = []
132
133 for i in range(n):
134 if i in visited:
135 continue
136
137 cluster = [i]
138 visited.add(i)
139
140 for j in range(i + 1, n):
141 if j not in visited and similarity_matrix[i, j] > threshold:
142 cluster.append(j)
143 visited.add(j)
144
145 clusters.append(cluster)
146
147 return clusters
148
149 def generate_failure_report(
150 self,
151 samples: torch.Tensor,
152 save_path: str = "failure_report.txt",
153 ) -> Dict[str, any]:
154 """Generate comprehensive failure analysis report."""
155 # Detect artifacts
156 artifacts = self.detect_artifacts(samples)
157
158 # Analyze diversity
159 diversity = self.analyze_diversity_failure(samples)
160
161 # Compile report
162 report = {
163 "total_samples": len(samples),
164 "artifacts": {
165 k: len(v) for k, v in artifacts.items()
166 },
167 "diversity": diversity,
168 }
169
170 # Calculate overall quality score
171 artifact_count = sum(len(v) for v in artifacts.values())
172 artifact_rate = artifact_count / len(samples)
173 diversity_score = 1 - diversity["duplicate_fraction"]
174
175 report["quality_score"] = (1 - artifact_rate) * 0.5 + diversity_score * 0.5
176
177 # Write report
178 with open(self.output_dir / save_path, "w") as f:
179 f.write("=" * 60 + "\n")
180 f.write("FAILURE ANALYSIS REPORT\n")
181 f.write("=" * 60 + "\n\n")
182
183 f.write(f"Total samples analyzed: {report['total_samples']}\n\n")
184
185 f.write("ARTIFACTS:\n")
186 f.write("-" * 40 + "\n")
187 for artifact_type, count in report["artifacts"].items():
188 pct = count / len(samples) * 100
189 f.write(f" {artifact_type}: {count} ({pct:.1f}%)\n")
190
191 f.write("\nDIVERSITY:\n")
192 f.write("-" * 40 + "\n")
193 f.write(f" Mean pairwise similarity: {diversity['mean_similarity']:.3f}\n")
194 f.write(f" Max pairwise similarity: {diversity['max_similarity']:.3f}\n")
195 f.write(f" Near-duplicate fraction: {diversity['duplicate_fraction']:.3f}\n")
196 f.write(f" Unique samples: {diversity['unique_samples_fraction']:.1%}\n")
197
198 f.write("\n" + "=" * 60 + "\n")
199 f.write(f"OVERALL QUALITY SCORE: {report['quality_score']:.3f}\n")
200 f.write("=" * 60 + "\n")
201
202 # Save examples of each failure type
203 self._save_failure_examples(samples, artifacts)
204
205 return report
206
207 def _save_failure_examples(
208 self,
209 samples: torch.Tensor,
210 artifacts: Dict[str, List[int]],
211 ) -> None:
212 """Save example images for each failure type."""
213 for artifact_type, indices in artifacts.items():
214 if len(indices) == 0:
215 continue
216
217 # Take up to 16 examples
218 example_indices = indices[:16]
219 examples = samples[example_indices]
220
221 fig, axes = plt.subplots(
222 (len(examples) + 3) // 4, 4,
223 figsize=(12, 3 * ((len(examples) + 3) // 4))
224 )
225 axes = axes.flatten() if len(examples) > 4 else [axes] if len(examples) == 1 else axes
226
227 for i, (idx, sample) in enumerate(zip(example_indices, examples)):
228 if i < len(axes):
229 img = (sample.permute(1, 2, 0).cpu().numpy() + 1) / 2
230 axes[i].imshow(img.clip(0, 1))
231 axes[i].set_title(f"Sample {idx}", fontsize=8)
232 axes[i].axis("off")
233
234 for j in range(len(examples), len(axes)):
235 axes[j].axis("off")
236
237 plt.suptitle(f"Failure Mode: {artifact_type}", fontsize=14)
238 plt.tight_layout()
239 plt.savefig(
240 self.output_dir / f"failure_{artifact_type}.png",
241 dpi=150, bbox_inches="tight"
242 )
243 plt.close()| Failure Mode | Cause | Solution |
|---|---|---|
| High-frequency noise | Undertrained, wrong LR | Train longer, adjust LR |
| Blurry outputs | Low model capacity | Increase channels/depth |
| Repeated patterns | Mode collapse | Check diversity, EMA decay |
| Color artifacts | Normalization issues | Check data preprocessing |
| Clipped values | Wrong value range | Verify [-1,1] normalization |
Comprehensive Evaluation Report
🐍python
1import torch
2import json
3from datetime import datetime
4from pathlib import Path
5from typing import Dict, Any, Optional
6import matplotlib.pyplot as plt
7from dataclasses import dataclass, asdict
8
9@dataclass
10class EvaluationReport:
11 """Complete evaluation report combining all analyses."""
12 # Model info
13 model_name: str
14 checkpoint_path: str
15 timestamp: str
16
17 # Quantitative metrics
18 fid: float
19 inception_score_mean: float
20 inception_score_std: float
21 precision: float
22 recall: float
23
24 # Qualitative analysis
25 sample_statistics: Dict[str, float]
26 interpolation_smoothness: float
27 failure_analysis: Dict[str, Any]
28
29 # Configuration
30 num_samples: int
31 num_steps: int
32 image_size: int
33
34 def to_dict(self) -> dict:
35 return asdict(self)
36
37 def save(self, path: str):
38 with open(path, "w") as f:
39 json.dump(self.to_dict(), f, indent=2)
40
41
42class ComprehensiveEvaluator:
43 """Complete evaluation pipeline combining all analyses."""
44
45 def __init__(
46 self,
47 model,
48 sampler,
49 fid_calculator,
50 is_calculator,
51 output_dir: str,
52 device: str = "cuda",
53 ):
54 self.model = model.to(device).eval()
55 self.sampler = sampler
56 self.fid_calc = fid_calculator
57 self.is_calc = is_calculator
58 self.output_dir = Path(output_dir)
59 self.output_dir.mkdir(parents=True, exist_ok=True)
60 self.device = device
61
62 # Sub-analyzers
63 self.inspector = SampleInspector(model, sampler, str(output_dir / "samples"), device)
64 self.explorer = LatentExplorer(model, sampler, device)
65 self.failure_analyzer = FailureModeAnalyzer(str(output_dir / "failures"))
66
67 def evaluate(
68 self,
69 model_name: str,
70 checkpoint_path: str,
71 num_samples: int = 10000,
72 num_steps: int = 50,
73 image_size: int = 64,
74 compute_all: bool = True,
75 ) -> EvaluationReport:
76 """Run complete evaluation pipeline."""
77 print("=" * 60)
78 print("COMPREHENSIVE MODEL EVALUATION")
79 print("=" * 60)
80 print(f"Model: {model_name}")
81 print(f"Checkpoint: {checkpoint_path}")
82 print(f"Samples: {num_samples}")
83 print()
84
85 # Generate samples
86 print("Generating samples...")
87 samples = []
88 batch_size = 64
89 shape = (batch_size, 3, image_size, image_size)
90
91 for i in range(0, num_samples, batch_size):
92 batch = self.sampler.sample(shape, num_steps=num_steps)
93 samples.append(batch.cpu())
94
95 samples = torch.cat(samples, dim=0)[:num_samples]
96 print(f"Generated {len(samples)} samples")
97
98 # 1. Quantitative metrics
99 print("\nComputing quantitative metrics...")
100 fid = self.fid_calc.compute_fid(samples)
101 is_mean, is_std = self.is_calc.compute_inception_score(samples)
102 print(f" FID: {fid:.2f}")
103 print(f" IS: {is_mean:.2f} +/- {is_std:.2f}")
104
105 # Precision/Recall (simplified - need real features)
106 precision, recall = 0.0, 0.0 # Would compute with real data
107
108 # 2. Visual inspection
109 print("\nGenerating visual inspection grids...")
110 self.inspector.generate_random_grid(64, num_steps, image_size)
111 self.inspector.generate_seeded_grid(list(range(16)), num_steps, image_size)
112 self.inspector.compare_timesteps()
113
114 # 3. Sample statistics
115 print("\nAnalyzing sample statistics...")
116 sample_stats = self.inspector.analyze_sample_statistics(
117 num_samples=min(1000, num_samples),
118 num_steps=num_steps,
119 )
120 print(f" Mean: {sample_stats['mean']:.4f}")
121 print(f" Std: {sample_stats['std']:.4f}")
122 print(f" Outlier fraction: {sample_stats['outlier_fraction']:.4f}")
123
124 # 4. Latent space exploration
125 if compute_all:
126 print("\nExploring latent space...")
127 self.explorer.generate_interpolation_grid(
128 num_pairs=4,
129 num_interpolations=8,
130 num_steps=num_steps,
131 save_path=str(self.output_dir / "interpolations.png"),
132 )
133
134 smoothness_analysis = self.explorer.analyze_interpolation_smoothness(
135 num_interpolations=20,
136 num_pairs=10,
137 num_steps=num_steps,
138 )
139 print(f" Smoothness score: {smoothness_analysis['smoothness_score']:.4f}")
140 else:
141 smoothness_analysis = {"smoothness_score": 0.0}
142
143 # 5. Failure analysis
144 print("\nAnalyzing failure modes...")
145 failure_report = self.failure_analyzer.generate_failure_report(
146 samples[:1000], # Analyze subset
147 )
148 print(f" Quality score: {failure_report['quality_score']:.4f}")
149 for artifact, count in failure_report["artifacts"].items():
150 if count > 0:
151 print(f" {artifact}: {count}")
152
153 # Compile report
154 report = EvaluationReport(
155 model_name=model_name,
156 checkpoint_path=checkpoint_path,
157 timestamp=datetime.now().isoformat(),
158 fid=fid,
159 inception_score_mean=is_mean,
160 inception_score_std=is_std,
161 precision=precision,
162 recall=recall,
163 sample_statistics=sample_stats,
164 interpolation_smoothness=smoothness_analysis["smoothness_score"],
165 failure_analysis=failure_report,
166 num_samples=num_samples,
167 num_steps=num_steps,
168 image_size=image_size,
169 )
170
171 # Save report
172 report_path = self.output_dir / f"evaluation_report_{model_name}.json"
173 report.save(str(report_path))
174 print(f"\nReport saved to {report_path}")
175
176 # Generate summary figure
177 self._generate_summary_figure(report)
178
179 return report
180
181 def _generate_summary_figure(self, report: EvaluationReport) -> None:
182 """Generate a visual summary of the evaluation."""
183 fig = plt.figure(figsize=(16, 10))
184
185 # Create grid layout
186 gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.3)
187
188 # Title
189 fig.suptitle(
190 f"Evaluation Report: {report.model_name}\n"
191 f"FID: {report.fid:.2f} | IS: {report.inception_score_mean:.2f}",
192 fontsize=14
193 )
194
195 # Load and display sample grid
196 ax1 = fig.add_subplot(gs[0:2, 0:2])
197 grid_path = self.output_dir / "samples" / "random_grid.png"
198 if grid_path.exists():
199 img = plt.imread(str(grid_path))
200 ax1.imshow(img)
201 ax1.set_title("Random Samples")
202 ax1.axis("off")
203
204 # Metrics bar chart
205 ax2 = fig.add_subplot(gs[0, 2])
206 metrics = ["FID (inv)", "IS", "Quality"]
207 values = [
208 100 / (report.fid + 1), # Invert FID for visualization
209 report.inception_score_mean,
210 report.failure_analysis.get("quality_score", 0) * 10,
211 ]
212 ax2.bar(metrics, values, color=["blue", "green", "orange"])
213 ax2.set_title("Quality Metrics")
214
215 # Failure distribution pie chart
216 ax3 = fig.add_subplot(gs[0, 3])
217 artifacts = report.failure_analysis.get("artifacts", {})
218 if sum(artifacts.values()) > 0:
219 labels = [k for k, v in artifacts.items() if v > 0]
220 sizes = [v for v in artifacts.values() if v > 0]
221 ax3.pie(sizes, labels=labels, autopct="%1.1f%%")
222 ax3.set_title("Artifact Distribution")
223
224 # Sample statistics
225 ax4 = fig.add_subplot(gs[1, 2:])
226 stats = report.sample_statistics
227 text = (
228 f"Mean: {stats.get('mean', 0):.4f}\n"
229 f"Std: {stats.get('std', 0):.4f}\n"
230 f"Outliers: {stats.get('outlier_fraction', 0):.2%}\n"
231 f"Smoothness: {report.interpolation_smoothness:.4f}"
232 )
233 ax4.text(0.1, 0.5, text, fontsize=12, verticalalignment="center")
234 ax4.set_title("Statistics")
235 ax4.axis("off")
236
237 # Interpolation grid
238 ax5 = fig.add_subplot(gs[2, :])
239 interp_path = self.output_dir / "interpolations.png"
240 if interp_path.exists():
241 img = plt.imread(str(interp_path))
242 ax5.imshow(img)
243 ax5.set_title("Latent Interpolations")
244 ax5.axis("off")
245
246 plt.savefig(
247 self.output_dir / f"summary_{report.model_name}.png",
248 dpi=150, bbox_inches="tight"
249 )
250 plt.close()Key Takeaways
- Visual inspection is essential: FID cannot capture all quality aspects. Always examine samples visually.
- Use consistent seeds: Generate samples from fixed seeds to track progress and compare models fairly.
- Explore the latent space: Smooth interpolations indicate good representations. Artifacts during interpolation reveal problems.
- Systematically categorize failures: Different failure modes have different root causes and solutions.
- Create comprehensive reports: Combine quantitative metrics with qualitative analysis for complete understanding.
Looking Ahead: The final section provides an interactive demonstration bringing together everything we have learned: model loading, sampling, and real-time generation.