Chapter 12
15 min read
Section 57 of 76

Qualitative Analysis

Generation and Evaluation

Learning Objectives

By the end of this section, you will be able to:

  1. Organize visual inspection of generated samples to identify quality issues systematically
  2. Explore the latent space to understand model behavior and interpolation quality
  3. Identify common failure modes and their root causes
  4. Create comprehensive evaluation reports combining quantitative and qualitative analysis

The Big Picture

While quantitative metrics like FID provide numerical comparisons, they can miss important aspects of generation quality that humans notice immediately. Qualitative analysis complements metrics by revealing what the model has learned, where it struggles, and how it behaves in edge cases.

The Qualitative Perspective: A model with FID of 5.0 might produce beautiful samples 90% of the time but have severe artifacts in 10% of cases. Another model with FID of 7.0 might be more consistent but less impressive. Only visual inspection reveals these differences.
Analysis TypeWhat It RevealsWhen to Use
Random samplesOverall quality distributionFirst check after training
Curated samplesBest-case capabilitiesPaper figures, demos
Failure casesSystematic problemsDebugging, improvement
InterpolationsLatent space smoothnessUnderstanding representations
Class-conditionalPer-class quality variationsIdentifying weak classes

Visual Inspection Techniques

Sample Grid Generation

Create organized grids of samples for systematic evaluation:

🐍python
1import torch
2import torchvision.utils as vutils
3import matplotlib.pyplot as plt
4from pathlib import Path
5from typing import Optional, List, Callable
6import numpy as np
7
8class SampleInspector:
9    """Tools for visual inspection of generated samples."""
10
11    def __init__(
12        self,
13        model,
14        sampler,
15        output_dir: str,
16        device: str = "cuda",
17    ):
18        self.model = model.to(device).eval()
19        self.sampler = sampler
20        self.output_dir = Path(output_dir)
21        self.output_dir.mkdir(parents=True, exist_ok=True)
22        self.device = device
23
24    def generate_random_grid(
25        self,
26        num_samples: int = 64,
27        num_steps: int = 50,
28        image_size: int = 64,
29        nrow: int = 8,
30        title: str = "Random Samples",
31        save_name: str = "random_grid.png",
32    ) -> torch.Tensor:
33        """Generate and display a grid of random samples."""
34        shape = (num_samples, 3, image_size, image_size)
35        samples = self.sampler.sample(shape, num_steps=num_steps)
36
37        # Create figure
38        fig, ax = plt.subplots(figsize=(12, 12))
39        grid = vutils.make_grid(
40            samples,
41            nrow=nrow,
42            normalize=True,
43            value_range=(-1, 1),
44            padding=2,
45        )
46        ax.imshow(grid.permute(1, 2, 0).cpu().numpy())
47        ax.set_title(title, fontsize=14)
48        ax.axis("off")
49
50        # Save
51        plt.savefig(self.output_dir / save_name, dpi=150, bbox_inches="tight")
52        plt.close()
53
54        return samples
55
56    def generate_seeded_grid(
57        self,
58        seeds: List[int],
59        num_steps: int = 50,
60        image_size: int = 64,
61        nrow: int = 8,
62        title: str = "Seeded Samples",
63        save_name: str = "seeded_grid.png",
64    ) -> torch.Tensor:
65        """Generate samples from specific seeds for reproducibility."""
66        samples = []
67        for seed in seeds:
68            torch.manual_seed(seed)
69            noise = torch.randn(1, 3, image_size, image_size, device=self.device)
70            sample = self.sampler.sample(noise.shape, num_steps=num_steps)
71            samples.append(sample)
72
73        samples = torch.cat(samples, dim=0)
74
75        # Create figure with seed labels
76        fig, axes = plt.subplots(
77            (len(seeds) + nrow - 1) // nrow,
78            nrow,
79            figsize=(12, 12 * ((len(seeds) + nrow - 1) // nrow) / nrow)
80        )
81        axes = axes.flatten() if len(seeds) > nrow else [axes] if len(seeds) == 1 else axes
82
83        for i, (seed, sample, ax) in enumerate(zip(seeds, samples, axes)):
84            img = (sample.permute(1, 2, 0).cpu().numpy() + 1) / 2
85            ax.imshow(img.clip(0, 1))
86            ax.set_title(f"Seed: {seed}", fontsize=10)
87            ax.axis("off")
88
89        # Hide unused axes
90        for j in range(len(seeds), len(axes)):
91            axes[j].axis("off")
92
93        plt.suptitle(title, fontsize=14)
94        plt.tight_layout()
95        plt.savefig(self.output_dir / save_name, dpi=150, bbox_inches="tight")
96        plt.close()
97
98        return samples
99
100    def compare_timesteps(
101        self,
102        num_steps_list: List[int] = [10, 25, 50, 100, 200, 500, 1000],
103        seed: int = 42,
104        image_size: int = 64,
105        save_name: str = "timestep_comparison.png",
106    ) -> None:
107        """Compare samples generated with different numbers of steps."""
108        torch.manual_seed(seed)
109        initial_noise = torch.randn(1, 3, image_size, image_size, device=self.device)
110
111        fig, axes = plt.subplots(1, len(num_steps_list), figsize=(3 * len(num_steps_list), 3))
112
113        for i, num_steps in enumerate(num_steps_list):
114            sample = self.sampler.sample(
115                initial_noise.shape,
116                num_steps=num_steps,
117            )
118            img = (sample[0].permute(1, 2, 0).cpu().numpy() + 1) / 2
119            axes[i].imshow(img.clip(0, 1))
120            axes[i].set_title(f"{num_steps} steps", fontsize=10)
121            axes[i].axis("off")
122
123        plt.suptitle("Effect of Sampling Steps", fontsize=14)
124        plt.tight_layout()
125        plt.savefig(self.output_dir / save_name, dpi=150, bbox_inches="tight")
126        plt.close()
127
128    def analyze_sample_statistics(
129        self,
130        num_samples: int = 1000,
131        num_steps: int = 50,
132    ) -> dict:
133        """Analyze statistical properties of generated samples."""
134        samples = []
135        batch_size = 64
136        shape = (batch_size, 3, 64, 64)
137
138        for _ in range(num_samples // batch_size):
139            batch = self.sampler.sample(shape, num_steps=num_steps)
140            samples.append(batch.cpu())
141
142        samples = torch.cat(samples, dim=0)
143
144        # Compute statistics
145        stats = {
146            "mean": samples.mean().item(),
147            "std": samples.std().item(),
148            "min": samples.min().item(),
149            "max": samples.max().item(),
150            "per_channel_mean": samples.mean(dim=(0, 2, 3)).tolist(),
151            "per_channel_std": samples.std(dim=(0, 2, 3)).tolist(),
152        }
153
154        # Check for outliers
155        outlier_count = (
156            (samples < -1.5).sum().item() +
157            (samples > 1.5).sum().item()
158        )
159        stats["outlier_fraction"] = outlier_count / samples.numel()
160
161        # Create histogram
162        fig, axes = plt.subplots(1, 3, figsize=(15, 4))
163        colors = ["red", "green", "blue"]
164        channel_names = ["Red", "Green", "Blue"]
165
166        for i in range(3):
167            channel_data = samples[:, i].flatten().numpy()
168            axes[i].hist(channel_data, bins=100, alpha=0.7, color=colors[i])
169            axes[i].set_title(f"{channel_names[i]} Channel Distribution")
170            axes[i].set_xlabel("Pixel Value")
171            axes[i].set_ylabel("Frequency")
172            axes[i].axvline(-1, color="black", linestyle="--", alpha=0.5)
173            axes[i].axvline(1, color="black", linestyle="--", alpha=0.5)
174
175        plt.tight_layout()
176        plt.savefig(self.output_dir / "sample_statistics.png", dpi=150)
177        plt.close()
178
179        return stats

Latent Space Exploration

Interpolation Analysis

Smooth interpolations between samples indicate a well-structured latent space. Sudden changes or artifacts during interpolation suggest problems:

🐍python
1import torch
2import numpy as np
3import matplotlib.pyplot as plt
4from typing import List, Tuple
5
6class LatentExplorer:
7    """Explore the latent space of a diffusion model."""
8
9    def __init__(self, model, sampler, device: str = "cuda"):
10        self.model = model.to(device).eval()
11        self.sampler = sampler
12        self.device = device
13
14    def linear_interpolation(
15        self,
16        z1: torch.Tensor,
17        z2: torch.Tensor,
18        num_steps: int = 10,
19    ) -> List[torch.Tensor]:
20        """Linear interpolation between two latent codes."""
21        interpolations = []
22        for alpha in np.linspace(0, 1, num_steps):
23            z = (1 - alpha) * z1 + alpha * z2
24            interpolations.append(z)
25        return interpolations
26
27    def spherical_interpolation(
28        self,
29        z1: torch.Tensor,
30        z2: torch.Tensor,
31        num_steps: int = 10,
32    ) -> List[torch.Tensor]:
33        """Spherical interpolation (slerp) between two latent codes."""
34        # Flatten for computation
35        z1_flat = z1.view(z1.shape[0], -1)
36        z2_flat = z2.view(z2.shape[0], -1)
37
38        # Normalize
39        z1_norm = z1_flat / z1_flat.norm(dim=-1, keepdim=True)
40        z2_norm = z2_flat / z2_flat.norm(dim=-1, keepdim=True)
41
42        # Compute angle
43        dot = (z1_norm * z2_norm).sum(dim=-1, keepdim=True)
44        omega = torch.acos(dot.clamp(-1, 1))
45
46        interpolations = []
47        for alpha in np.linspace(0, 1, num_steps):
48            sin_omega = torch.sin(omega)
49            if sin_omega.abs().min() < 1e-6:
50                # Fall back to linear
51                z = (1 - alpha) * z1_flat + alpha * z2_flat
52            else:
53                z = (
54                    torch.sin((1 - alpha) * omega) / sin_omega * z1_flat +
55                    torch.sin(alpha * omega) / sin_omega * z2_flat
56                )
57            z = z.view(z1.shape)
58            interpolations.append(z)
59
60        return interpolations
61
62    def generate_interpolation_grid(
63        self,
64        num_pairs: int = 4,
65        num_interpolations: int = 8,
66        num_steps: int = 50,
67        image_size: int = 64,
68        use_slerp: bool = True,
69        save_path: str = "interpolation_grid.png",
70    ) -> None:
71        """Generate a grid showing interpolations between random pairs."""
72        fig, axes = plt.subplots(num_pairs, num_interpolations, figsize=(2 * num_interpolations, 2 * num_pairs))
73
74        for row in range(num_pairs):
75            # Generate two random starting points
76            z1 = torch.randn(1, 3, image_size, image_size, device=self.device)
77            z2 = torch.randn(1, 3, image_size, image_size, device=self.device)
78
79            # Interpolate
80            if use_slerp:
81                latents = self.spherical_interpolation(z1, z2, num_interpolations)
82            else:
83                latents = self.linear_interpolation(z1, z2, num_interpolations)
84
85            # Generate samples
86            for col, z in enumerate(latents):
87                sample = self.sampler.sample(z.shape, num_steps=num_steps)
88                img = (sample[0].permute(1, 2, 0).cpu().numpy() + 1) / 2
89                axes[row, col].imshow(img.clip(0, 1))
90                axes[row, col].axis("off")
91
92        plt.suptitle("Latent Space Interpolations", fontsize=14)
93        plt.tight_layout()
94        plt.savefig(save_path, dpi=150, bbox_inches="tight")
95        plt.close()
96
97    def latent_walk(
98        self,
99        start_z: torch.Tensor,
100        directions: int = 4,
101        steps_per_direction: int = 5,
102        step_size: float = 0.5,
103        num_sampling_steps: int = 50,
104        save_path: str = "latent_walk.png",
105    ) -> None:
106        """Random walk in latent space from a starting point."""
107        fig, axes = plt.subplots(directions, steps_per_direction * 2 + 1, figsize=(20, 8))
108
109        for d in range(directions):
110            # Random direction
111            direction = torch.randn_like(start_z)
112            direction = direction / direction.norm() * step_size
113
114            # Walk in both directions
115            samples = []
116            for i in range(-steps_per_direction, steps_per_direction + 1):
117                z = start_z + i * direction
118                sample = self.sampler.sample(z.shape, num_steps=num_sampling_steps)
119                samples.append(sample)
120
121            # Display
122            for col, sample in enumerate(samples):
123                img = (sample[0].permute(1, 2, 0).cpu().numpy() + 1) / 2
124                axes[d, col].imshow(img.clip(0, 1))
125                axes[d, col].axis("off")
126
127                # Mark center
128                if col == steps_per_direction:
129                    axes[d, col].patch.set_edgecolor("red")
130                    axes[d, col].patch.set_linewidth(3)
131
132        plt.suptitle("Latent Space Walk", fontsize=14)
133        plt.tight_layout()
134        plt.savefig(save_path, dpi=150, bbox_inches="tight")
135        plt.close()
136
137    def analyze_interpolation_smoothness(
138        self,
139        num_interpolations: int = 50,
140        num_pairs: int = 100,
141        num_steps: int = 50,
142    ) -> dict:
143        """Quantify interpolation smoothness using perceptual distance."""
144        from torchvision.models import vgg16, VGG16_Weights
145        import torch.nn as nn
146
147        # Load VGG for perceptual distance
148        vgg = vgg16(weights=VGG16_Weights.IMAGENET1K_V1).features[:16]
149        vgg = vgg.to(self.device).eval()
150
151        def perceptual_distance(x1, x2):
152            f1 = vgg(x1)
153            f2 = vgg(x2)
154            return nn.functional.mse_loss(f1, f2).item()
155
156        # Measure distances along interpolations
157        all_distances = []
158        all_acceleration = []  # Second derivative
159
160        for _ in range(num_pairs):
161            z1 = torch.randn(1, 3, 64, 64, device=self.device)
162            z2 = torch.randn(1, 3, 64, 64, device=self.device)
163
164            latents = self.spherical_interpolation(z1, z2, num_interpolations)
165
166            # Generate samples
167            samples = [
168                self.sampler.sample(z.shape, num_steps=num_steps)
169                for z in latents
170            ]
171
172            # Compute consecutive distances
173            distances = []
174            for i in range(len(samples) - 1):
175                d = perceptual_distance(samples[i], samples[i + 1])
176                distances.append(d)
177                all_distances.append(d)
178
179            # Compute acceleration (smoothness measure)
180            for i in range(len(distances) - 1):
181                accel = abs(distances[i + 1] - distances[i])
182                all_acceleration.append(accel)
183
184        return {
185            "mean_distance": np.mean(all_distances),
186            "std_distance": np.std(all_distances),
187            "max_distance": np.max(all_distances),
188            "mean_acceleration": np.mean(all_acceleration),
189            "std_acceleration": np.std(all_acceleration),
190            "smoothness_score": 1.0 / (1.0 + np.mean(all_acceleration)),
191        }

Failure Mode Analysis

Identifying Common Issues

Systematically categorize and analyze failure cases:

🐍python
1import torch
2import torch.nn as nn
3from typing import Dict, List, Tuple
4import numpy as np
5from collections import defaultdict
6import matplotlib.pyplot as plt
7from pathlib import Path
8
9class FailureModeAnalyzer:
10    """Analyze and categorize failure modes in generated samples."""
11
12    def __init__(self, output_dir: str):
13        self.output_dir = Path(output_dir)
14        self.output_dir.mkdir(parents=True, exist_ok=True)
15        self.failure_categories = defaultdict(list)
16
17    def detect_artifacts(
18        self,
19        samples: torch.Tensor,
20        threshold: float = 0.1,
21    ) -> Dict[str, List[int]]:
22        """Detect various types of artifacts in samples."""
23        artifacts = {
24            "high_frequency_noise": [],
25            "color_banding": [],
26            "value_clipping": [],
27            "blur": [],
28            "repeated_patterns": [],
29        }
30
31        for i, sample in enumerate(samples):
32            # High frequency noise detection (Laplacian variance)
33            gray = sample.mean(dim=0)
34            laplacian = torch.abs(
35                gray[2:, 1:-1] + gray[:-2, 1:-1] +
36                gray[1:-1, 2:] + gray[1:-1, :-2] -
37                4 * gray[1:-1, 1:-1]
38            )
39            if laplacian.var() > 0.5:
40                artifacts["high_frequency_noise"].append(i)
41
42            # Color banding (few unique values)
43            unique_values = len(torch.unique(sample))
44            if unique_values < sample.numel() * 0.1:
45                artifacts["color_banding"].append(i)
46
47            # Value clipping (too many values at boundaries)
48            clipped = (
49                (sample < -0.99).sum() + (sample > 0.99).sum()
50            ).item()
51            if clipped > sample.numel() * threshold:
52                artifacts["value_clipping"].append(i)
53
54            # Blur detection (low Laplacian variance)
55            if laplacian.var() < 0.01:
56                artifacts["blur"].append(i)
57
58            # Repeated patterns (autocorrelation)
59            for c in range(3):
60                channel = sample[c]
61                autocorr = self._compute_autocorrelation(channel)
62                if autocorr > 0.5:
63                    artifacts["repeated_patterns"].append(i)
64                    break
65
66        return artifacts
67
68    def _compute_autocorrelation(self, image: torch.Tensor) -> float:
69        """Compute spatial autocorrelation for pattern detection."""
70        h, w = image.shape
71        shifts = [(4, 0), (0, 4), (8, 0), (0, 8)]
72
73        max_corr = 0
74        for dy, dx in shifts:
75            if dy > 0:
76                shifted = image[dy:, :]
77                original = image[:-dy, :]
78            else:
79                shifted = image[:, dx:]
80                original = image[:, :-dx]
81
82            corr = torch.corrcoef(
83                torch.stack([original.flatten(), shifted.flatten()])
84            )[0, 1]
85
86            if not torch.isnan(corr):
87                max_corr = max(max_corr, corr.abs().item())
88
89        return max_corr
90
91    def analyze_diversity_failure(
92        self,
93        samples: torch.Tensor,
94        threshold: float = 0.9,
95    ) -> Dict[str, any]:
96        """Detect mode collapse or low diversity."""
97        # Compute pairwise similarities
98        samples_flat = samples.view(len(samples), -1)
99        samples_norm = samples_flat / samples_flat.norm(dim=1, keepdim=True)
100
101        similarities = torch.mm(samples_norm, samples_norm.t())
102
103        # Exclude diagonal
104        mask = ~torch.eye(len(samples), dtype=torch.bool, device=samples.device)
105        pairwise_sims = similarities[mask]
106
107        # Analysis
108        very_similar_pairs = (pairwise_sims > threshold).sum().item()
109        total_pairs = len(pairwise_sims)
110
111        # Find near-duplicate clusters
112        clusters = self._find_similar_clusters(similarities.cpu().numpy(), threshold)
113
114        return {
115            "mean_similarity": pairwise_sims.mean().item(),
116            "max_similarity": pairwise_sims.max().item(),
117            "duplicate_fraction": very_similar_pairs / total_pairs,
118            "num_clusters": len(clusters),
119            "largest_cluster_size": max(len(c) for c in clusters) if clusters else 0,
120            "unique_samples_fraction": len([c for c in clusters if len(c) == 1]) / len(samples),
121        }
122
123    def _find_similar_clusters(
124        self,
125        similarity_matrix: np.ndarray,
126        threshold: float,
127    ) -> List[List[int]]:
128        """Find clusters of similar samples."""
129        n = len(similarity_matrix)
130        visited = set()
131        clusters = []
132
133        for i in range(n):
134            if i in visited:
135                continue
136
137            cluster = [i]
138            visited.add(i)
139
140            for j in range(i + 1, n):
141                if j not in visited and similarity_matrix[i, j] > threshold:
142                    cluster.append(j)
143                    visited.add(j)
144
145            clusters.append(cluster)
146
147        return clusters
148
149    def generate_failure_report(
150        self,
151        samples: torch.Tensor,
152        save_path: str = "failure_report.txt",
153    ) -> Dict[str, any]:
154        """Generate comprehensive failure analysis report."""
155        # Detect artifacts
156        artifacts = self.detect_artifacts(samples)
157
158        # Analyze diversity
159        diversity = self.analyze_diversity_failure(samples)
160
161        # Compile report
162        report = {
163            "total_samples": len(samples),
164            "artifacts": {
165                k: len(v) for k, v in artifacts.items()
166            },
167            "diversity": diversity,
168        }
169
170        # Calculate overall quality score
171        artifact_count = sum(len(v) for v in artifacts.values())
172        artifact_rate = artifact_count / len(samples)
173        diversity_score = 1 - diversity["duplicate_fraction"]
174
175        report["quality_score"] = (1 - artifact_rate) * 0.5 + diversity_score * 0.5
176
177        # Write report
178        with open(self.output_dir / save_path, "w") as f:
179            f.write("=" * 60 + "\n")
180            f.write("FAILURE ANALYSIS REPORT\n")
181            f.write("=" * 60 + "\n\n")
182
183            f.write(f"Total samples analyzed: {report['total_samples']}\n\n")
184
185            f.write("ARTIFACTS:\n")
186            f.write("-" * 40 + "\n")
187            for artifact_type, count in report["artifacts"].items():
188                pct = count / len(samples) * 100
189                f.write(f"  {artifact_type}: {count} ({pct:.1f}%)\n")
190
191            f.write("\nDIVERSITY:\n")
192            f.write("-" * 40 + "\n")
193            f.write(f"  Mean pairwise similarity: {diversity['mean_similarity']:.3f}\n")
194            f.write(f"  Max pairwise similarity: {diversity['max_similarity']:.3f}\n")
195            f.write(f"  Near-duplicate fraction: {diversity['duplicate_fraction']:.3f}\n")
196            f.write(f"  Unique samples: {diversity['unique_samples_fraction']:.1%}\n")
197
198            f.write("\n" + "=" * 60 + "\n")
199            f.write(f"OVERALL QUALITY SCORE: {report['quality_score']:.3f}\n")
200            f.write("=" * 60 + "\n")
201
202        # Save examples of each failure type
203        self._save_failure_examples(samples, artifacts)
204
205        return report
206
207    def _save_failure_examples(
208        self,
209        samples: torch.Tensor,
210        artifacts: Dict[str, List[int]],
211    ) -> None:
212        """Save example images for each failure type."""
213        for artifact_type, indices in artifacts.items():
214            if len(indices) == 0:
215                continue
216
217            # Take up to 16 examples
218            example_indices = indices[:16]
219            examples = samples[example_indices]
220
221            fig, axes = plt.subplots(
222                (len(examples) + 3) // 4, 4,
223                figsize=(12, 3 * ((len(examples) + 3) // 4))
224            )
225            axes = axes.flatten() if len(examples) > 4 else [axes] if len(examples) == 1 else axes
226
227            for i, (idx, sample) in enumerate(zip(example_indices, examples)):
228                if i < len(axes):
229                    img = (sample.permute(1, 2, 0).cpu().numpy() + 1) / 2
230                    axes[i].imshow(img.clip(0, 1))
231                    axes[i].set_title(f"Sample {idx}", fontsize=8)
232                    axes[i].axis("off")
233
234            for j in range(len(examples), len(axes)):
235                axes[j].axis("off")
236
237            plt.suptitle(f"Failure Mode: {artifact_type}", fontsize=14)
238            plt.tight_layout()
239            plt.savefig(
240                self.output_dir / f"failure_{artifact_type}.png",
241                dpi=150, bbox_inches="tight"
242            )
243            plt.close()
Failure ModeCauseSolution
High-frequency noiseUndertrained, wrong LRTrain longer, adjust LR
Blurry outputsLow model capacityIncrease channels/depth
Repeated patternsMode collapseCheck diversity, EMA decay
Color artifactsNormalization issuesCheck data preprocessing
Clipped valuesWrong value rangeVerify [-1,1] normalization

Comprehensive Evaluation Report

🐍python
1import torch
2import json
3from datetime import datetime
4from pathlib import Path
5from typing import Dict, Any, Optional
6import matplotlib.pyplot as plt
7from dataclasses import dataclass, asdict
8
9@dataclass
10class EvaluationReport:
11    """Complete evaluation report combining all analyses."""
12    # Model info
13    model_name: str
14    checkpoint_path: str
15    timestamp: str
16
17    # Quantitative metrics
18    fid: float
19    inception_score_mean: float
20    inception_score_std: float
21    precision: float
22    recall: float
23
24    # Qualitative analysis
25    sample_statistics: Dict[str, float]
26    interpolation_smoothness: float
27    failure_analysis: Dict[str, Any]
28
29    # Configuration
30    num_samples: int
31    num_steps: int
32    image_size: int
33
34    def to_dict(self) -> dict:
35        return asdict(self)
36
37    def save(self, path: str):
38        with open(path, "w") as f:
39            json.dump(self.to_dict(), f, indent=2)
40
41
42class ComprehensiveEvaluator:
43    """Complete evaluation pipeline combining all analyses."""
44
45    def __init__(
46        self,
47        model,
48        sampler,
49        fid_calculator,
50        is_calculator,
51        output_dir: str,
52        device: str = "cuda",
53    ):
54        self.model = model.to(device).eval()
55        self.sampler = sampler
56        self.fid_calc = fid_calculator
57        self.is_calc = is_calculator
58        self.output_dir = Path(output_dir)
59        self.output_dir.mkdir(parents=True, exist_ok=True)
60        self.device = device
61
62        # Sub-analyzers
63        self.inspector = SampleInspector(model, sampler, str(output_dir / "samples"), device)
64        self.explorer = LatentExplorer(model, sampler, device)
65        self.failure_analyzer = FailureModeAnalyzer(str(output_dir / "failures"))
66
67    def evaluate(
68        self,
69        model_name: str,
70        checkpoint_path: str,
71        num_samples: int = 10000,
72        num_steps: int = 50,
73        image_size: int = 64,
74        compute_all: bool = True,
75    ) -> EvaluationReport:
76        """Run complete evaluation pipeline."""
77        print("=" * 60)
78        print("COMPREHENSIVE MODEL EVALUATION")
79        print("=" * 60)
80        print(f"Model: {model_name}")
81        print(f"Checkpoint: {checkpoint_path}")
82        print(f"Samples: {num_samples}")
83        print()
84
85        # Generate samples
86        print("Generating samples...")
87        samples = []
88        batch_size = 64
89        shape = (batch_size, 3, image_size, image_size)
90
91        for i in range(0, num_samples, batch_size):
92            batch = self.sampler.sample(shape, num_steps=num_steps)
93            samples.append(batch.cpu())
94
95        samples = torch.cat(samples, dim=0)[:num_samples]
96        print(f"Generated {len(samples)} samples")
97
98        # 1. Quantitative metrics
99        print("\nComputing quantitative metrics...")
100        fid = self.fid_calc.compute_fid(samples)
101        is_mean, is_std = self.is_calc.compute_inception_score(samples)
102        print(f"  FID: {fid:.2f}")
103        print(f"  IS: {is_mean:.2f} +/- {is_std:.2f}")
104
105        # Precision/Recall (simplified - need real features)
106        precision, recall = 0.0, 0.0  # Would compute with real data
107
108        # 2. Visual inspection
109        print("\nGenerating visual inspection grids...")
110        self.inspector.generate_random_grid(64, num_steps, image_size)
111        self.inspector.generate_seeded_grid(list(range(16)), num_steps, image_size)
112        self.inspector.compare_timesteps()
113
114        # 3. Sample statistics
115        print("\nAnalyzing sample statistics...")
116        sample_stats = self.inspector.analyze_sample_statistics(
117            num_samples=min(1000, num_samples),
118            num_steps=num_steps,
119        )
120        print(f"  Mean: {sample_stats['mean']:.4f}")
121        print(f"  Std: {sample_stats['std']:.4f}")
122        print(f"  Outlier fraction: {sample_stats['outlier_fraction']:.4f}")
123
124        # 4. Latent space exploration
125        if compute_all:
126            print("\nExploring latent space...")
127            self.explorer.generate_interpolation_grid(
128                num_pairs=4,
129                num_interpolations=8,
130                num_steps=num_steps,
131                save_path=str(self.output_dir / "interpolations.png"),
132            )
133
134            smoothness_analysis = self.explorer.analyze_interpolation_smoothness(
135                num_interpolations=20,
136                num_pairs=10,
137                num_steps=num_steps,
138            )
139            print(f"  Smoothness score: {smoothness_analysis['smoothness_score']:.4f}")
140        else:
141            smoothness_analysis = {"smoothness_score": 0.0}
142
143        # 5. Failure analysis
144        print("\nAnalyzing failure modes...")
145        failure_report = self.failure_analyzer.generate_failure_report(
146            samples[:1000],  # Analyze subset
147        )
148        print(f"  Quality score: {failure_report['quality_score']:.4f}")
149        for artifact, count in failure_report["artifacts"].items():
150            if count > 0:
151                print(f"  {artifact}: {count}")
152
153        # Compile report
154        report = EvaluationReport(
155            model_name=model_name,
156            checkpoint_path=checkpoint_path,
157            timestamp=datetime.now().isoformat(),
158            fid=fid,
159            inception_score_mean=is_mean,
160            inception_score_std=is_std,
161            precision=precision,
162            recall=recall,
163            sample_statistics=sample_stats,
164            interpolation_smoothness=smoothness_analysis["smoothness_score"],
165            failure_analysis=failure_report,
166            num_samples=num_samples,
167            num_steps=num_steps,
168            image_size=image_size,
169        )
170
171        # Save report
172        report_path = self.output_dir / f"evaluation_report_{model_name}.json"
173        report.save(str(report_path))
174        print(f"\nReport saved to {report_path}")
175
176        # Generate summary figure
177        self._generate_summary_figure(report)
178
179        return report
180
181    def _generate_summary_figure(self, report: EvaluationReport) -> None:
182        """Generate a visual summary of the evaluation."""
183        fig = plt.figure(figsize=(16, 10))
184
185        # Create grid layout
186        gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.3)
187
188        # Title
189        fig.suptitle(
190            f"Evaluation Report: {report.model_name}\n"
191            f"FID: {report.fid:.2f} | IS: {report.inception_score_mean:.2f}",
192            fontsize=14
193        )
194
195        # Load and display sample grid
196        ax1 = fig.add_subplot(gs[0:2, 0:2])
197        grid_path = self.output_dir / "samples" / "random_grid.png"
198        if grid_path.exists():
199            img = plt.imread(str(grid_path))
200            ax1.imshow(img)
201        ax1.set_title("Random Samples")
202        ax1.axis("off")
203
204        # Metrics bar chart
205        ax2 = fig.add_subplot(gs[0, 2])
206        metrics = ["FID (inv)", "IS", "Quality"]
207        values = [
208            100 / (report.fid + 1),  # Invert FID for visualization
209            report.inception_score_mean,
210            report.failure_analysis.get("quality_score", 0) * 10,
211        ]
212        ax2.bar(metrics, values, color=["blue", "green", "orange"])
213        ax2.set_title("Quality Metrics")
214
215        # Failure distribution pie chart
216        ax3 = fig.add_subplot(gs[0, 3])
217        artifacts = report.failure_analysis.get("artifacts", {})
218        if sum(artifacts.values()) > 0:
219            labels = [k for k, v in artifacts.items() if v > 0]
220            sizes = [v for v in artifacts.values() if v > 0]
221            ax3.pie(sizes, labels=labels, autopct="%1.1f%%")
222        ax3.set_title("Artifact Distribution")
223
224        # Sample statistics
225        ax4 = fig.add_subplot(gs[1, 2:])
226        stats = report.sample_statistics
227        text = (
228            f"Mean: {stats.get('mean', 0):.4f}\n"
229            f"Std: {stats.get('std', 0):.4f}\n"
230            f"Outliers: {stats.get('outlier_fraction', 0):.2%}\n"
231            f"Smoothness: {report.interpolation_smoothness:.4f}"
232        )
233        ax4.text(0.1, 0.5, text, fontsize=12, verticalalignment="center")
234        ax4.set_title("Statistics")
235        ax4.axis("off")
236
237        # Interpolation grid
238        ax5 = fig.add_subplot(gs[2, :])
239        interp_path = self.output_dir / "interpolations.png"
240        if interp_path.exists():
241            img = plt.imread(str(interp_path))
242            ax5.imshow(img)
243        ax5.set_title("Latent Interpolations")
244        ax5.axis("off")
245
246        plt.savefig(
247            self.output_dir / f"summary_{report.model_name}.png",
248            dpi=150, bbox_inches="tight"
249        )
250        plt.close()

Key Takeaways

  1. Visual inspection is essential: FID cannot capture all quality aspects. Always examine samples visually.
  2. Use consistent seeds: Generate samples from fixed seeds to track progress and compare models fairly.
  3. Explore the latent space: Smooth interpolations indicate good representations. Artifacts during interpolation reveal problems.
  4. Systematically categorize failures: Different failure modes have different root causes and solutions.
  5. Create comprehensive reports: Combine quantitative metrics with qualitative analysis for complete understanding.
Looking Ahead: The final section provides an interactive demonstration bringing together everything we have learned: model loading, sampling, and real-time generation.