Learning Objectives
By the end of this section, you will be able to:
- Compute Frechet Inception Distance (FID) as the standard metric for evaluating generative models
- Calculate Inception Score (IS) for measuring sample quality and diversity
- Analyze precision and recall to understand the fidelity-diversity tradeoff
- Build a complete evaluation pipeline for benchmarking diffusion models
The Big Picture
Evaluating generative models is fundamentally different from evaluating discriminative models. There is no ground truth label to compare against. Instead, we measure how well the generated distribution matches the real data distribution using statistical distance measures.
The Evaluation Challenge: A good generative model should produce samples that are both high-quality (realistic) and diverse (covering all modes of the data). No single metric captures both aspects perfectly, so we use multiple complementary metrics.
Model Evaluation Metrics Comparison
Select models to compare:
Multi-Metric Radar Comparison
FID Score Comparison (lower is better)
| Model | FID | IS | Precision | Recall | NFE |
|---|---|---|---|---|---|
DDPM (1000 steps) | 3.17 | 9.46 | 83.0% | 58.0% | 1000 |
DDIM (50 steps) | 4.67 | 9.12 | 82.0% | 55.0% | 50 |
DDIM (100 steps) | 3.51 | 9.38 | 83.0% | 57.0% | 100 |
DPM-Solver (20 steps) | 3.24 | 9.41 | 83.0% | 57.0% | 20 |
EDM | 1.97 | 9.84 | 85.0% | 62.0% | 35 |
Metrics explained: FID (Frechet Inception Distance) measures image quality and diversity. IS (Inception Score) measures image quality and class distinctiveness. Precision measures fidelity to real data, while Recall measures mode coverage. NFE is the number of neural network forward passes needed for generation.
| Metric | Measures | Lower/Higher is Better | Computation Cost |
|---|---|---|---|
| FID | Distribution distance | Lower | Medium (50k samples) |
| IS | Quality + Diversity | Higher | Low (5-10k samples) |
| Precision | Sample quality/fidelity | Higher | Medium |
| Recall | Mode coverage/diversity | Higher | Medium |
| LPIPS | Perceptual similarity | Higher for diversity | Low |
Frechet Inception Distance (FID)
Understanding FID
FID measures the distance between the distribution of generated images and real images in the feature space of an Inception network. It compares the mean and covariance of the two distributions:
where are the mean and covariance of real image features, and are for generated images.
1import torch
2import torch.nn as nn
3import numpy as np
4from scipy import linalg
5from torchvision.models import inception_v3, Inception_V3_Weights
6from torch.utils.data import DataLoader
7from tqdm import tqdm
8from typing import Optional, Tuple
9
10class FIDCalculator:
11 """Frechet Inception Distance calculator."""
12
13 def __init__(
14 self,
15 device: str = "cuda",
16 dims: int = 2048,
17 ):
18 self.device = device
19 self.dims = dims
20
21 # Load Inception v3
22 self.inception = inception_v3(
23 weights=Inception_V3_Weights.IMAGENET1K_V1,
24 transform_input=False,
25 )
26 # Remove final classification layer
27 self.inception.fc = nn.Identity()
28 self.inception = self.inception.to(device).eval()
29
30 # Preprocessing
31 self.resize = nn.Upsample(size=(299, 299), mode="bilinear", align_corners=False)
32
33 def _preprocess(self, images: torch.Tensor) -> torch.Tensor:
34 """Preprocess images for Inception."""
35 # Resize to 299x299
36 if images.shape[-1] != 299 or images.shape[-2] != 299:
37 images = self.resize(images)
38
39 # Normalize from [-1, 1] to ImageNet normalization
40 # First convert to [0, 1]
41 images = (images + 1) / 2
42
43 # Apply ImageNet normalization
44 mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)
45 std = torch.tensor([0.229, 0.224, 0.225], device=images.device)
46 images = (images - mean[None, :, None, None]) / std[None, :, None, None]
47
48 return images
49
50 @torch.no_grad()
51 def extract_features(
52 self,
53 images: torch.Tensor,
54 ) -> np.ndarray:
55 """Extract Inception features from a batch of images."""
56 images = self._preprocess(images.to(self.device))
57 features = self.inception(images)
58 return features.cpu().numpy()
59
60 @torch.no_grad()
61 def extract_features_from_loader(
62 self,
63 loader: DataLoader,
64 num_samples: Optional[int] = None,
65 ) -> np.ndarray:
66 """Extract features from a DataLoader."""
67 features_list = []
68 total = 0
69
70 for batch in tqdm(loader, desc="Extracting features"):
71 if isinstance(batch, (list, tuple)):
72 batch = batch[0] # Handle (images, labels) format
73
74 features = self.extract_features(batch)
75 features_list.append(features)
76
77 total += batch.shape[0]
78 if num_samples and total >= num_samples:
79 break
80
81 features = np.concatenate(features_list, axis=0)
82 if num_samples:
83 features = features[:num_samples]
84
85 return features
86
87 @staticmethod
88 def compute_statistics(
89 features: np.ndarray,
90 ) -> Tuple[np.ndarray, np.ndarray]:
91 """Compute mean and covariance of features."""
92 mu = np.mean(features, axis=0)
93 sigma = np.cov(features, rowvar=False)
94 return mu, sigma
95
96 @staticmethod
97 def compute_fid_from_statistics(
98 mu1: np.ndarray,
99 sigma1: np.ndarray,
100 mu2: np.ndarray,
101 sigma2: np.ndarray,
102 ) -> float:
103 """Compute FID from precomputed statistics."""
104 # Compute mean difference
105 diff = mu1 - mu2
106 diff_squared = np.dot(diff, diff)
107
108 # Compute sqrt of product of covariances
109 covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
110
111 # Handle numerical issues
112 if np.iscomplexobj(covmean):
113 if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
114 m = np.max(np.abs(covmean.imag))
115 raise ValueError(f"Imaginary component {m}")
116 covmean = covmean.real
117
118 # Compute FID
119 tr_covmean = np.trace(covmean)
120 fid = diff_squared + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
121
122 return float(fid)
123
124 def compute_fid(
125 self,
126 real_loader: DataLoader,
127 generated_images: torch.Tensor,
128 num_samples: int = 50000,
129 ) -> float:
130 """Compute FID between real and generated images.
131
132 Args:
133 real_loader: DataLoader for real images
134 generated_images: Tensor of generated images
135 num_samples: Number of samples to use (affects accuracy)
136
137 Returns:
138 FID score (lower is better)
139 """
140 # Extract features
141 print("Extracting features from real images...")
142 real_features = self.extract_features_from_loader(
143 real_loader, num_samples
144 )
145
146 print("Extracting features from generated images...")
147 gen_features = []
148 batch_size = 64
149 for i in tqdm(range(0, len(generated_images), batch_size)):
150 batch = generated_images[i:i + batch_size]
151 gen_features.append(self.extract_features(batch))
152 gen_features = np.concatenate(gen_features, axis=0)
153
154 # Compute statistics
155 mu_real, sigma_real = self.compute_statistics(real_features)
156 mu_gen, sigma_gen = self.compute_statistics(gen_features)
157
158 # Compute FID
159 fid = self.compute_fid_from_statistics(
160 mu_real, sigma_real,
161 mu_gen, sigma_gen
162 )
163
164 return fid
165
166
167class PrecomputedFID:
168 """FID computation with cached real image statistics."""
169
170 def __init__(
171 self,
172 stats_path: Optional[str] = None,
173 device: str = "cuda",
174 ):
175 self.calculator = FIDCalculator(device=device)
176 self.real_mu = None
177 self.real_sigma = None
178
179 if stats_path:
180 self.load_statistics(stats_path)
181
182 def compute_and_cache_real_stats(
183 self,
184 real_loader: DataLoader,
185 num_samples: int = 50000,
186 save_path: Optional[str] = None,
187 ):
188 """Compute and optionally save real image statistics."""
189 features = self.calculator.extract_features_from_loader(
190 real_loader, num_samples
191 )
192 self.real_mu, self.real_sigma = self.calculator.compute_statistics(features)
193
194 if save_path:
195 np.savez(
196 save_path,
197 mu=self.real_mu,
198 sigma=self.real_sigma,
199 )
200 print(f"Saved statistics to {save_path}")
201
202 def load_statistics(self, path: str):
203 """Load precomputed statistics."""
204 data = np.load(path)
205 self.real_mu = data["mu"]
206 self.real_sigma = data["sigma"]
207 print(f"Loaded statistics from {path}")
208
209 def compute_fid(
210 self,
211 generated_images: torch.Tensor,
212 ) -> float:
213 """Compute FID using cached real statistics."""
214 if self.real_mu is None:
215 raise ValueError("Real statistics not loaded. Call compute_and_cache_real_stats first.")
216
217 # Extract generated features
218 gen_features = []
219 batch_size = 64
220 for i in range(0, len(generated_images), batch_size):
221 batch = generated_images[i:i + batch_size]
222 gen_features.append(self.calculator.extract_features(batch))
223 gen_features = np.concatenate(gen_features, axis=0)
224
225 mu_gen, sigma_gen = self.calculator.compute_statistics(gen_features)
226
227 return self.calculator.compute_fid_from_statistics(
228 self.real_mu, self.real_sigma,
229 mu_gen, sigma_gen
230 )FID Sample Size
| Dataset | SOTA FID (Diffusion) | Previous SOTA (GAN) |
|---|---|---|
| CIFAR-10 (32x32) | 1.97 (EDM) | 2.92 (StyleGAN2-ADA) |
| ImageNet 64x64 | 1.55 (EDM) | 2.10 (BigGAN-deep) |
| ImageNet 256x256 | 1.79 (DiT) | 3.60 (BigGAN-deep) |
| LSUN Bedroom | 1.90 (LDM) | 2.35 (StyleGAN2) |
Inception Score (IS)
Understanding IS
Inception Score measures both quality and diversity using a pre-trained Inception classifier. It computes the KL divergence between the conditional class distribution and the marginal class distribution:
A high IS means that each image is confidently classified (quality) and the overall class distribution is diverse (diversity).
1import torch
2import torch.nn as nn
3import numpy as np
4from torchvision.models import inception_v3, Inception_V3_Weights
5from scipy.stats import entropy
6from torch.utils.data import DataLoader
7from tqdm import tqdm
8from typing import Tuple
9
10class InceptionScoreCalculator:
11 """Inception Score calculator."""
12
13 def __init__(
14 self,
15 device: str = "cuda",
16 splits: int = 10,
17 ):
18 self.device = device
19 self.splits = splits
20
21 # Load Inception v3 with classification head
22 self.inception = inception_v3(
23 weights=Inception_V3_Weights.IMAGENET1K_V1,
24 transform_input=False,
25 )
26 self.inception = self.inception.to(device).eval()
27
28 # Preprocessing
29 self.resize = nn.Upsample(size=(299, 299), mode="bilinear", align_corners=False)
30
31 def _preprocess(self, images: torch.Tensor) -> torch.Tensor:
32 """Preprocess images for Inception."""
33 if images.shape[-1] != 299:
34 images = self.resize(images)
35
36 # Normalize to [0, 1] then to ImageNet stats
37 images = (images + 1) / 2
38 mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)
39 std = torch.tensor([0.229, 0.224, 0.225], device=images.device)
40 images = (images - mean[None, :, None, None]) / std[None, :, None, None]
41
42 return images
43
44 @torch.no_grad()
45 def get_predictions(
46 self,
47 images: torch.Tensor,
48 ) -> np.ndarray:
49 """Get softmax predictions from Inception."""
50 images = self._preprocess(images.to(self.device))
51 logits = self.inception(images)
52 probs = torch.softmax(logits, dim=1)
53 return probs.cpu().numpy()
54
55 def compute_inception_score(
56 self,
57 images: torch.Tensor,
58 batch_size: int = 64,
59 ) -> Tuple[float, float]:
60 """Compute Inception Score.
61
62 Args:
63 images: Generated images tensor
64 batch_size: Batch size for processing
65
66 Returns:
67 (mean IS, std IS) across splits
68 """
69 # Get all predictions
70 preds = []
71 for i in tqdm(range(0, len(images), batch_size), desc="Computing IS"):
72 batch = images[i:i + batch_size]
73 preds.append(self.get_predictions(batch))
74 preds = np.concatenate(preds, axis=0)
75
76 # Compute IS for each split
77 scores = []
78 split_size = len(preds) // self.splits
79
80 for i in range(self.splits):
81 start = i * split_size
82 end = start + split_size
83 split_preds = preds[start:end]
84
85 # p(y|x) for each image
86 py_given_x = split_preds
87
88 # p(y) = average of p(y|x) over all images
89 py = np.mean(py_given_x, axis=0, keepdims=True)
90
91 # KL divergence for each image
92 kl_divs = []
93 for p in py_given_x:
94 kl = entropy(p, py[0])
95 if not np.isnan(kl) and not np.isinf(kl):
96 kl_divs.append(kl)
97
98 # IS = exp(mean KL)
99 if kl_divs:
100 scores.append(np.exp(np.mean(kl_divs)))
101
102 return float(np.mean(scores)), float(np.std(scores))
103
104
105def compute_is_from_loader(
106 calculator: InceptionScoreCalculator,
107 loader: DataLoader,
108 num_samples: int = 50000,
109) -> Tuple[float, float]:
110 """Compute IS from a DataLoader of generated images."""
111 all_images = []
112 total = 0
113
114 for batch in loader:
115 if isinstance(batch, (list, tuple)):
116 batch = batch[0]
117 all_images.append(batch)
118 total += batch.shape[0]
119 if total >= num_samples:
120 break
121
122 images = torch.cat(all_images, dim=0)[:num_samples]
123 return calculator.compute_inception_score(images)IS Limitations
Precision and Recall
Understanding the Tradeoff
Precision and Recall for generative models capture the quality-diversity tradeoff that FID conflates into a single number:
- Precision: What fraction of generated samples are realistic (fall within the real data manifold)?
- Recall: What fraction of the real data distribution is covered by generated samples?
1import torch
2import numpy as np
3from typing import Tuple
4from sklearn.neighbors import NearestNeighbors
5from tqdm import tqdm
6
7class PrecisionRecallCalculator:
8 """Improved Precision and Recall for generative models.
9
10 Based on "Improved Precision and Recall Metric for Assessing
11 Generative Models" (Kynkaanniemi et al., 2019).
12 """
13
14 def __init__(
15 self,
16 feature_extractor, # e.g., FIDCalculator
17 k: int = 3,
18 device: str = "cuda",
19 ):
20 self.feature_extractor = feature_extractor
21 self.k = k
22 self.device = device
23
24 def compute_manifold(
25 self,
26 features: np.ndarray,
27 ) -> NearestNeighbors:
28 """Compute k-NN manifold representation."""
29 nn = NearestNeighbors(n_neighbors=self.k + 1, algorithm="auto")
30 nn.fit(features)
31 return nn
32
33 def compute_precision_recall(
34 self,
35 real_features: np.ndarray,
36 fake_features: np.ndarray,
37 ) -> Tuple[float, float]:
38 """Compute precision and recall.
39
40 Args:
41 real_features: Features of real images
42 fake_features: Features of generated images
43
44 Returns:
45 (precision, recall)
46 """
47 # Build manifolds
48 print("Building real manifold...")
49 real_nn = self.compute_manifold(real_features)
50
51 print("Building fake manifold...")
52 fake_nn = self.compute_manifold(fake_features)
53
54 # Get radii (distance to k-th nearest neighbor)
55 real_distances, _ = real_nn.kneighbors(real_features)
56 real_radii = real_distances[:, -1] # k-th neighbor distance
57
58 fake_distances, _ = fake_nn.kneighbors(fake_features)
59 fake_radii = fake_distances[:, -1]
60
61 # Precision: fraction of fake samples in real manifold
62 print("Computing precision...")
63 fake_to_real_distances, _ = real_nn.kneighbors(fake_features)
64 fake_to_real_nearest = fake_to_real_distances[:, 0]
65
66 # A fake sample is in the real manifold if its distance to
67 # nearest real sample is <= that real sample's radius
68 _, real_indices = real_nn.kneighbors(fake_features, n_neighbors=1)
69 real_indices = real_indices.flatten()
70
71 precision_count = 0
72 for i, idx in enumerate(real_indices):
73 if fake_to_real_nearest[i] <= real_radii[idx]:
74 precision_count += 1
75 precision = precision_count / len(fake_features)
76
77 # Recall: fraction of real samples covered by fake manifold
78 print("Computing recall...")
79 real_to_fake_distances, fake_indices = fake_nn.kneighbors(
80 real_features, n_neighbors=1
81 )
82 real_to_fake_nearest = real_to_fake_distances.flatten()
83 fake_indices = fake_indices.flatten()
84
85 recall_count = 0
86 for i, idx in enumerate(fake_indices):
87 if real_to_fake_nearest[i] <= fake_radii[idx]:
88 recall_count += 1
89 recall = recall_count / len(real_features)
90
91 return precision, recall
92
93 def compute_density_coverage(
94 self,
95 real_features: np.ndarray,
96 fake_features: np.ndarray,
97 ) -> Tuple[float, float]:
98 """Compute Density and Coverage metrics.
99
100 Alternative to precision/recall that's less sensitive to outliers.
101 """
102 # Build real manifold
103 real_nn = self.compute_manifold(real_features)
104 real_distances, _ = real_nn.kneighbors(real_features)
105 real_radii = real_distances[:, -1]
106
107 # Density: average number of real samples in each fake sample's ball
108 fake_to_real_distances, real_indices = real_nn.kneighbors(fake_features)
109 fake_to_real_nearest = fake_to_real_distances[:, 0]
110
111 density_sum = 0
112 for i in range(len(fake_features)):
113 # Count real samples within the radius of the nearest real sample
114 nearest_real_idx = real_indices[i, 0]
115 radius = real_radii[nearest_real_idx]
116
117 count = np.sum(fake_to_real_distances[i] <= radius)
118 density_sum += count
119
120 density = density_sum / (self.k * len(fake_features))
121
122 # Coverage: fraction of real samples with at least one fake neighbor
123 covered = np.zeros(len(real_features), dtype=bool)
124 for i, idx in enumerate(real_indices[:, 0]):
125 if fake_to_real_nearest[i] <= real_radii[idx]:
126 covered[idx] = True
127
128 coverage = np.mean(covered)
129
130 return density, coverage
131
132
133def compute_full_precision_recall(
134 real_loader,
135 generated_images,
136 fid_calculator: FIDCalculator,
137 k: int = 3,
138 num_samples: int = 10000,
139) -> dict:
140 """Compute full precision/recall metrics."""
141 # Extract features
142 real_features = fid_calculator.extract_features_from_loader(
143 real_loader, num_samples
144 )
145
146 gen_features = []
147 batch_size = 64
148 for i in range(0, min(len(generated_images), num_samples), batch_size):
149 batch = generated_images[i:i + batch_size]
150 gen_features.append(fid_calculator.extract_features(batch))
151 gen_features = np.concatenate(gen_features, axis=0)
152
153 # Compute metrics
154 pr_calc = PrecisionRecallCalculator(fid_calculator, k=k)
155
156 precision, recall = pr_calc.compute_precision_recall(
157 real_features, gen_features
158 )
159
160 density, coverage = pr_calc.compute_density_coverage(
161 real_features, gen_features
162 )
163
164 return {
165 "precision": precision,
166 "recall": recall,
167 "density": density,
168 "coverage": coverage,
169 "f1": 2 * precision * recall / (precision + recall + 1e-8),
170 }| Model Characteristic | High Precision | High Recall |
|---|---|---|
| Mode collapse | Yes (limited diversity) | No (misses modes) |
| Good quality | Yes (realistic samples) | Maybe |
| Good diversity | Maybe | Yes (covers all modes) |
| Overfitting | Yes (copying training data) | Yes (covers training data) |
Complete Evaluation Pipeline
1import torch
2import numpy as np
3from pathlib import Path
4from dataclasses import dataclass
5from typing import Optional, Dict, Any
6import json
7from datetime import datetime
8
9@dataclass
10class EvaluationResults:
11 """Container for all evaluation metrics."""
12 fid: float
13 is_mean: float
14 is_std: float
15 precision: float
16 recall: float
17 density: float
18 coverage: float
19 num_samples: int
20 timestamp: str
21 model_name: str
22 additional_info: Dict[str, Any]
23
24 def to_dict(self) -> dict:
25 return {
26 "fid": self.fid,
27 "inception_score": {"mean": self.is_mean, "std": self.is_std},
28 "precision": self.precision,
29 "recall": self.recall,
30 "density": self.density,
31 "coverage": self.coverage,
32 "num_samples": self.num_samples,
33 "timestamp": self.timestamp,
34 "model_name": self.model_name,
35 **self.additional_info,
36 }
37
38 def save(self, path: str):
39 with open(path, "w") as f:
40 json.dump(self.to_dict(), f, indent=2)
41
42
43class ModelEvaluator:
44 """Complete evaluation pipeline for generative models."""
45
46 def __init__(
47 self,
48 real_stats_path: Optional[str] = None,
49 device: str = "cuda",
50 num_samples: int = 50000,
51 ):
52 self.device = device
53 self.num_samples = num_samples
54
55 # Initialize calculators
56 self.fid_calc = PrecomputedFID(stats_path=real_stats_path, device=device)
57 self.is_calc = InceptionScoreCalculator(device=device)
58 self.pr_calc = PrecisionRecallCalculator(self.fid_calc.calculator, device=device)
59
60 def precompute_real_stats(
61 self,
62 real_loader,
63 save_path: str,
64 ):
65 """Precompute and save real data statistics."""
66 print("Computing real data statistics...")
67 self.fid_calc.compute_and_cache_real_stats(
68 real_loader,
69 num_samples=self.num_samples,
70 save_path=save_path,
71 )
72
73 # Also save features for precision/recall
74 features = self.fid_calc.calculator.extract_features_from_loader(
75 real_loader, self.num_samples
76 )
77 np.save(save_path.replace(".npz", "_features.npy"), features)
78
79 def evaluate(
80 self,
81 generated_images: torch.Tensor,
82 model_name: str = "unknown",
83 compute_all: bool = True,
84 additional_info: Optional[Dict[str, Any]] = None,
85 ) -> EvaluationResults:
86 """Run full evaluation on generated images.
87
88 Args:
89 generated_images: Tensor of generated images
90 model_name: Name for logging
91 compute_all: Whether to compute all metrics
92 additional_info: Additional metadata to include
93
94 Returns:
95 EvaluationResults with all metrics
96 """
97 print(f"Evaluating {len(generated_images)} generated images...")
98
99 results = {
100 "fid": 0.0,
101 "is_mean": 0.0,
102 "is_std": 0.0,
103 "precision": 0.0,
104 "recall": 0.0,
105 "density": 0.0,
106 "coverage": 0.0,
107 }
108
109 # Compute FID
110 print("\nComputing FID...")
111 results["fid"] = self.fid_calc.compute_fid(generated_images)
112 print(f"FID: {results['fid']:.2f}")
113
114 # Compute Inception Score
115 print("\nComputing Inception Score...")
116 results["is_mean"], results["is_std"] = self.is_calc.compute_inception_score(
117 generated_images
118 )
119 print(f"IS: {results['is_mean']:.2f} +/- {results['is_std']:.2f}")
120
121 if compute_all:
122 # Compute Precision/Recall
123 print("\nComputing Precision/Recall...")
124
125 # Load cached real features
126 gen_features = []
127 batch_size = 64
128 for i in range(0, len(generated_images), batch_size):
129 batch = generated_images[i:i + batch_size]
130 gen_features.append(
131 self.fid_calc.calculator.extract_features(batch)
132 )
133 gen_features = np.concatenate(gen_features, axis=0)
134
135 # Load real features (need to be precomputed)
136 real_features_path = self.fid_calc.calculator.stats_path.replace(
137 ".npz", "_features.npy"
138 )
139 if Path(real_features_path).exists():
140 real_features = np.load(real_features_path)
141
142 pr_results = self.pr_calc.compute_precision_recall(
143 real_features[:len(gen_features)],
144 gen_features
145 )
146 results["precision"], results["recall"] = pr_results
147
148 dc_results = self.pr_calc.compute_density_coverage(
149 real_features[:len(gen_features)],
150 gen_features
151 )
152 results["density"], results["coverage"] = dc_results
153
154 print(f"Precision: {results['precision']:.4f}")
155 print(f"Recall: {results['recall']:.4f}")
156 print(f"Density: {results['density']:.4f}")
157 print(f"Coverage: {results['coverage']:.4f}")
158
159 return EvaluationResults(
160 fid=results["fid"],
161 is_mean=results["is_mean"],
162 is_std=results["is_std"],
163 precision=results["precision"],
164 recall=results["recall"],
165 density=results["density"],
166 coverage=results["coverage"],
167 num_samples=len(generated_images),
168 timestamp=datetime.now().isoformat(),
169 model_name=model_name,
170 additional_info=additional_info or {},
171 )
172
173
174def evaluate_model_checkpoint(
175 checkpoint_path: str,
176 model_class,
177 sampler_class,
178 real_stats_path: str,
179 output_dir: str,
180 num_samples: int = 50000,
181 num_steps: int = 50,
182):
183 """Evaluate a saved model checkpoint."""
184 output_dir = Path(output_dir)
185 output_dir.mkdir(parents=True, exist_ok=True)
186
187 # Load model
188 print(f"Loading model from {checkpoint_path}")
189 checkpoint = torch.load(checkpoint_path)
190 model = model_class(**checkpoint["config"])
191 model.load_state_dict(checkpoint["model_state_dict"])
192 model = model.cuda().eval()
193
194 # Create sampler
195 sampler = sampler_class(model, checkpoint["alphas_cumprod"])
196
197 # Generate samples
198 print(f"Generating {num_samples} samples...")
199 generated = []
200 batch_size = 64
201 shape = (batch_size, 3, 64, 64) # Adjust as needed
202
203 for i in tqdm(range(0, num_samples, batch_size)):
204 current_batch = min(batch_size, num_samples - i)
205 samples = sampler.sample(
206 (current_batch, 3, 64, 64),
207 num_steps=num_steps
208 )
209 generated.append(samples.cpu())
210
211 generated = torch.cat(generated, dim=0)
212
213 # Evaluate
214 evaluator = ModelEvaluator(
215 real_stats_path=real_stats_path,
216 num_samples=num_samples,
217 )
218
219 results = evaluator.evaluate(
220 generated,
221 model_name=Path(checkpoint_path).stem,
222 additional_info={
223 "checkpoint": checkpoint_path,
224 "num_steps": num_steps,
225 }
226 )
227
228 # Save results
229 results_path = output_dir / f"eval_{results.model_name}_{results.timestamp}.json"
230 results.save(str(results_path))
231 print(f"\nResults saved to {results_path}")
232
233 # Print summary
234 print("\n" + "=" * 50)
235 print("EVALUATION SUMMARY")
236 print("=" * 50)
237 print(f"Model: {results.model_name}")
238 print(f"FID: {results.fid:.2f}")
239 print(f"IS: {results.is_mean:.2f} +/- {results.is_std:.2f}")
240 print(f"Precision: {results.precision:.4f}")
241 print(f"Recall: {results.recall:.4f}")
242 print("=" * 50)
243
244 return resultsKey Takeaways
- FID is the primary metric: Lower is better. Use 50k samples for reliable comparisons and always report sample count.
- IS measures quality + diversity: Higher is better, but only works well for ImageNet-like images.
- Precision and Recall separate concerns: Precision measures quality (are samples realistic?), Recall measures coverage (are all modes represented?).
- Cache real statistics: Precompute and save Inception features for faster evaluation iterations.
- Use multiple metrics: No single metric captures all aspects of generation quality. Report FID, IS, and Precision/Recall together.
Looking Ahead: Quantitative metrics only tell part of the story. The next section covers qualitative analysis techniques for understanding what your model has learned and identifying failure modes.