Chapter 12
20 min read
Section 56 of 76

Evaluation Metrics

Generation and Evaluation

Learning Objectives

By the end of this section, you will be able to:

  1. Compute Frechet Inception Distance (FID) as the standard metric for evaluating generative models
  2. Calculate Inception Score (IS) for measuring sample quality and diversity
  3. Analyze precision and recall to understand the fidelity-diversity tradeoff
  4. Build a complete evaluation pipeline for benchmarking diffusion models

The Big Picture

Evaluating generative models is fundamentally different from evaluating discriminative models. There is no ground truth label to compare against. Instead, we measure how well the generated distribution matches the real data distribution using statistical distance measures.

The Evaluation Challenge: A good generative model should produce samples that are both high-quality (realistic) and diverse (covering all modes of the data). No single metric captures both aspects perfectly, so we use multiple complementary metrics.

Model Evaluation Metrics Comparison

Select models to compare:

Multi-Metric Radar Comparison

FID (lower=better)IS (higher=better)PrecisionRecallEfficiency
DDPM (1000 steps)
DDIM (50 steps)
DDIM (100 steps)
DPM-Solver (20 steps)
EDM

FID Score Comparison (lower is better)

3.2DDPM (1000 step...4.7DDIM (50 steps)3.5DDIM (100 steps...3.2DPM-Solver (20 ...2.0EDM1.22.33.54.7
ModelFIDISPrecisionRecallNFE
DDPM (1000 steps)
3.179.4683.0%58.0%1000
DDIM (50 steps)
4.679.1282.0%55.0%50
DDIM (100 steps)
3.519.3883.0%57.0%100
DPM-Solver (20 steps)
3.249.4183.0%57.0%20
EDM
1.979.8485.0%62.0%35

Metrics explained: FID (Frechet Inception Distance) measures image quality and diversity. IS (Inception Score) measures image quality and class distinctiveness. Precision measures fidelity to real data, while Recall measures mode coverage. NFE is the number of neural network forward passes needed for generation.

MetricMeasuresLower/Higher is BetterComputation Cost
FIDDistribution distanceLowerMedium (50k samples)
ISQuality + DiversityHigherLow (5-10k samples)
PrecisionSample quality/fidelityHigherMedium
RecallMode coverage/diversityHigherMedium
LPIPSPerceptual similarityHigher for diversityLow

Frechet Inception Distance (FID)

Understanding FID

FID measures the distance between the distribution of generated images and real images in the feature space of an Inception network. It compares the mean and covariance of the two distributions:

FID=μrμg2+Tr(Σr+Σg2ΣrΣg)\text{FID} = \|\mu_r - \mu_g\|^2 + \text{Tr}\left(\Sigma_r + \Sigma_g - 2\sqrt{\Sigma_r \Sigma_g}\right)

where μr,Σr\mu_r, \Sigma_r are the mean and covariance of real image features, and μg,Σg\mu_g, \Sigma_g are for generated images.

🐍python
1import torch
2import torch.nn as nn
3import numpy as np
4from scipy import linalg
5from torchvision.models import inception_v3, Inception_V3_Weights
6from torch.utils.data import DataLoader
7from tqdm import tqdm
8from typing import Optional, Tuple
9
10class FIDCalculator:
11    """Frechet Inception Distance calculator."""
12
13    def __init__(
14        self,
15        device: str = "cuda",
16        dims: int = 2048,
17    ):
18        self.device = device
19        self.dims = dims
20
21        # Load Inception v3
22        self.inception = inception_v3(
23            weights=Inception_V3_Weights.IMAGENET1K_V1,
24            transform_input=False,
25        )
26        # Remove final classification layer
27        self.inception.fc = nn.Identity()
28        self.inception = self.inception.to(device).eval()
29
30        # Preprocessing
31        self.resize = nn.Upsample(size=(299, 299), mode="bilinear", align_corners=False)
32
33    def _preprocess(self, images: torch.Tensor) -> torch.Tensor:
34        """Preprocess images for Inception."""
35        # Resize to 299x299
36        if images.shape[-1] != 299 or images.shape[-2] != 299:
37            images = self.resize(images)
38
39        # Normalize from [-1, 1] to ImageNet normalization
40        # First convert to [0, 1]
41        images = (images + 1) / 2
42
43        # Apply ImageNet normalization
44        mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)
45        std = torch.tensor([0.229, 0.224, 0.225], device=images.device)
46        images = (images - mean[None, :, None, None]) / std[None, :, None, None]
47
48        return images
49
50    @torch.no_grad()
51    def extract_features(
52        self,
53        images: torch.Tensor,
54    ) -> np.ndarray:
55        """Extract Inception features from a batch of images."""
56        images = self._preprocess(images.to(self.device))
57        features = self.inception(images)
58        return features.cpu().numpy()
59
60    @torch.no_grad()
61    def extract_features_from_loader(
62        self,
63        loader: DataLoader,
64        num_samples: Optional[int] = None,
65    ) -> np.ndarray:
66        """Extract features from a DataLoader."""
67        features_list = []
68        total = 0
69
70        for batch in tqdm(loader, desc="Extracting features"):
71            if isinstance(batch, (list, tuple)):
72                batch = batch[0]  # Handle (images, labels) format
73
74            features = self.extract_features(batch)
75            features_list.append(features)
76
77            total += batch.shape[0]
78            if num_samples and total >= num_samples:
79                break
80
81        features = np.concatenate(features_list, axis=0)
82        if num_samples:
83            features = features[:num_samples]
84
85        return features
86
87    @staticmethod
88    def compute_statistics(
89        features: np.ndarray,
90    ) -> Tuple[np.ndarray, np.ndarray]:
91        """Compute mean and covariance of features."""
92        mu = np.mean(features, axis=0)
93        sigma = np.cov(features, rowvar=False)
94        return mu, sigma
95
96    @staticmethod
97    def compute_fid_from_statistics(
98        mu1: np.ndarray,
99        sigma1: np.ndarray,
100        mu2: np.ndarray,
101        sigma2: np.ndarray,
102    ) -> float:
103        """Compute FID from precomputed statistics."""
104        # Compute mean difference
105        diff = mu1 - mu2
106        diff_squared = np.dot(diff, diff)
107
108        # Compute sqrt of product of covariances
109        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
110
111        # Handle numerical issues
112        if np.iscomplexobj(covmean):
113            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
114                m = np.max(np.abs(covmean.imag))
115                raise ValueError(f"Imaginary component {m}")
116            covmean = covmean.real
117
118        # Compute FID
119        tr_covmean = np.trace(covmean)
120        fid = diff_squared + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
121
122        return float(fid)
123
124    def compute_fid(
125        self,
126        real_loader: DataLoader,
127        generated_images: torch.Tensor,
128        num_samples: int = 50000,
129    ) -> float:
130        """Compute FID between real and generated images.
131
132        Args:
133            real_loader: DataLoader for real images
134            generated_images: Tensor of generated images
135            num_samples: Number of samples to use (affects accuracy)
136
137        Returns:
138            FID score (lower is better)
139        """
140        # Extract features
141        print("Extracting features from real images...")
142        real_features = self.extract_features_from_loader(
143            real_loader, num_samples
144        )
145
146        print("Extracting features from generated images...")
147        gen_features = []
148        batch_size = 64
149        for i in tqdm(range(0, len(generated_images), batch_size)):
150            batch = generated_images[i:i + batch_size]
151            gen_features.append(self.extract_features(batch))
152        gen_features = np.concatenate(gen_features, axis=0)
153
154        # Compute statistics
155        mu_real, sigma_real = self.compute_statistics(real_features)
156        mu_gen, sigma_gen = self.compute_statistics(gen_features)
157
158        # Compute FID
159        fid = self.compute_fid_from_statistics(
160            mu_real, sigma_real,
161            mu_gen, sigma_gen
162        )
163
164        return fid
165
166
167class PrecomputedFID:
168    """FID computation with cached real image statistics."""
169
170    def __init__(
171        self,
172        stats_path: Optional[str] = None,
173        device: str = "cuda",
174    ):
175        self.calculator = FIDCalculator(device=device)
176        self.real_mu = None
177        self.real_sigma = None
178
179        if stats_path:
180            self.load_statistics(stats_path)
181
182    def compute_and_cache_real_stats(
183        self,
184        real_loader: DataLoader,
185        num_samples: int = 50000,
186        save_path: Optional[str] = None,
187    ):
188        """Compute and optionally save real image statistics."""
189        features = self.calculator.extract_features_from_loader(
190            real_loader, num_samples
191        )
192        self.real_mu, self.real_sigma = self.calculator.compute_statistics(features)
193
194        if save_path:
195            np.savez(
196                save_path,
197                mu=self.real_mu,
198                sigma=self.real_sigma,
199            )
200            print(f"Saved statistics to {save_path}")
201
202    def load_statistics(self, path: str):
203        """Load precomputed statistics."""
204        data = np.load(path)
205        self.real_mu = data["mu"]
206        self.real_sigma = data["sigma"]
207        print(f"Loaded statistics from {path}")
208
209    def compute_fid(
210        self,
211        generated_images: torch.Tensor,
212    ) -> float:
213        """Compute FID using cached real statistics."""
214        if self.real_mu is None:
215            raise ValueError("Real statistics not loaded. Call compute_and_cache_real_stats first.")
216
217        # Extract generated features
218        gen_features = []
219        batch_size = 64
220        for i in range(0, len(generated_images), batch_size):
221            batch = generated_images[i:i + batch_size]
222            gen_features.append(self.calculator.extract_features(batch))
223        gen_features = np.concatenate(gen_features, axis=0)
224
225        mu_gen, sigma_gen = self.calculator.compute_statistics(gen_features)
226
227        return self.calculator.compute_fid_from_statistics(
228            self.real_mu, self.real_sigma,
229            mu_gen, sigma_gen
230        )

FID Sample Size

FID accuracy depends heavily on sample size. Use at least 10,000 samples for reasonable estimates and 50,000 for publication-quality results. FID values are not comparable across different sample sizes.
DatasetSOTA FID (Diffusion)Previous SOTA (GAN)
CIFAR-10 (32x32)1.97 (EDM)2.92 (StyleGAN2-ADA)
ImageNet 64x641.55 (EDM)2.10 (BigGAN-deep)
ImageNet 256x2561.79 (DiT)3.60 (BigGAN-deep)
LSUN Bedroom1.90 (LDM)2.35 (StyleGAN2)

Inception Score (IS)

Understanding IS

Inception Score measures both quality and diversity using a pre-trained Inception classifier. It computes the KL divergence between the conditional class distribution and the marginal class distribution:

IS=exp(Expg[DKL(p(yx)p(y))])\text{IS} = \exp\left(\mathbb{E}_{x \sim p_g} \left[D_{KL}(p(y|x) \| p(y))\right]\right)

A high IS means that each image is confidently classified (quality) and the overall class distribution is diverse (diversity).

🐍python
1import torch
2import torch.nn as nn
3import numpy as np
4from torchvision.models import inception_v3, Inception_V3_Weights
5from scipy.stats import entropy
6from torch.utils.data import DataLoader
7from tqdm import tqdm
8from typing import Tuple
9
10class InceptionScoreCalculator:
11    """Inception Score calculator."""
12
13    def __init__(
14        self,
15        device: str = "cuda",
16        splits: int = 10,
17    ):
18        self.device = device
19        self.splits = splits
20
21        # Load Inception v3 with classification head
22        self.inception = inception_v3(
23            weights=Inception_V3_Weights.IMAGENET1K_V1,
24            transform_input=False,
25        )
26        self.inception = self.inception.to(device).eval()
27
28        # Preprocessing
29        self.resize = nn.Upsample(size=(299, 299), mode="bilinear", align_corners=False)
30
31    def _preprocess(self, images: torch.Tensor) -> torch.Tensor:
32        """Preprocess images for Inception."""
33        if images.shape[-1] != 299:
34            images = self.resize(images)
35
36        # Normalize to [0, 1] then to ImageNet stats
37        images = (images + 1) / 2
38        mean = torch.tensor([0.485, 0.456, 0.406], device=images.device)
39        std = torch.tensor([0.229, 0.224, 0.225], device=images.device)
40        images = (images - mean[None, :, None, None]) / std[None, :, None, None]
41
42        return images
43
44    @torch.no_grad()
45    def get_predictions(
46        self,
47        images: torch.Tensor,
48    ) -> np.ndarray:
49        """Get softmax predictions from Inception."""
50        images = self._preprocess(images.to(self.device))
51        logits = self.inception(images)
52        probs = torch.softmax(logits, dim=1)
53        return probs.cpu().numpy()
54
55    def compute_inception_score(
56        self,
57        images: torch.Tensor,
58        batch_size: int = 64,
59    ) -> Tuple[float, float]:
60        """Compute Inception Score.
61
62        Args:
63            images: Generated images tensor
64            batch_size: Batch size for processing
65
66        Returns:
67            (mean IS, std IS) across splits
68        """
69        # Get all predictions
70        preds = []
71        for i in tqdm(range(0, len(images), batch_size), desc="Computing IS"):
72            batch = images[i:i + batch_size]
73            preds.append(self.get_predictions(batch))
74        preds = np.concatenate(preds, axis=0)
75
76        # Compute IS for each split
77        scores = []
78        split_size = len(preds) // self.splits
79
80        for i in range(self.splits):
81            start = i * split_size
82            end = start + split_size
83            split_preds = preds[start:end]
84
85            # p(y|x) for each image
86            py_given_x = split_preds
87
88            # p(y) = average of p(y|x) over all images
89            py = np.mean(py_given_x, axis=0, keepdims=True)
90
91            # KL divergence for each image
92            kl_divs = []
93            for p in py_given_x:
94                kl = entropy(p, py[0])
95                if not np.isnan(kl) and not np.isinf(kl):
96                    kl_divs.append(kl)
97
98            # IS = exp(mean KL)
99            if kl_divs:
100                scores.append(np.exp(np.mean(kl_divs)))
101
102        return float(np.mean(scores)), float(np.std(scores))
103
104
105def compute_is_from_loader(
106    calculator: InceptionScoreCalculator,
107    loader: DataLoader,
108    num_samples: int = 50000,
109) -> Tuple[float, float]:
110    """Compute IS from a DataLoader of generated images."""
111    all_images = []
112    total = 0
113
114    for batch in loader:
115        if isinstance(batch, (list, tuple)):
116            batch = batch[0]
117        all_images.append(batch)
118        total += batch.shape[0]
119        if total >= num_samples:
120            break
121
122    images = torch.cat(all_images, dim=0)[:num_samples]
123    return calculator.compute_inception_score(images)

IS Limitations

Inception Score has known limitations: it only works well for ImageNet-like images, ignores intra-class diversity, and can be gamed by mode dropping. Always use IS together with FID, never alone.

Precision and Recall

Understanding the Tradeoff

Precision and Recall for generative models capture the quality-diversity tradeoff that FID conflates into a single number:

  • Precision: What fraction of generated samples are realistic (fall within the real data manifold)?
  • Recall: What fraction of the real data distribution is covered by generated samples?
🐍python
1import torch
2import numpy as np
3from typing import Tuple
4from sklearn.neighbors import NearestNeighbors
5from tqdm import tqdm
6
7class PrecisionRecallCalculator:
8    """Improved Precision and Recall for generative models.
9
10    Based on "Improved Precision and Recall Metric for Assessing
11    Generative Models" (Kynkaanniemi et al., 2019).
12    """
13
14    def __init__(
15        self,
16        feature_extractor,  # e.g., FIDCalculator
17        k: int = 3,
18        device: str = "cuda",
19    ):
20        self.feature_extractor = feature_extractor
21        self.k = k
22        self.device = device
23
24    def compute_manifold(
25        self,
26        features: np.ndarray,
27    ) -> NearestNeighbors:
28        """Compute k-NN manifold representation."""
29        nn = NearestNeighbors(n_neighbors=self.k + 1, algorithm="auto")
30        nn.fit(features)
31        return nn
32
33    def compute_precision_recall(
34        self,
35        real_features: np.ndarray,
36        fake_features: np.ndarray,
37    ) -> Tuple[float, float]:
38        """Compute precision and recall.
39
40        Args:
41            real_features: Features of real images
42            fake_features: Features of generated images
43
44        Returns:
45            (precision, recall)
46        """
47        # Build manifolds
48        print("Building real manifold...")
49        real_nn = self.compute_manifold(real_features)
50
51        print("Building fake manifold...")
52        fake_nn = self.compute_manifold(fake_features)
53
54        # Get radii (distance to k-th nearest neighbor)
55        real_distances, _ = real_nn.kneighbors(real_features)
56        real_radii = real_distances[:, -1]  # k-th neighbor distance
57
58        fake_distances, _ = fake_nn.kneighbors(fake_features)
59        fake_radii = fake_distances[:, -1]
60
61        # Precision: fraction of fake samples in real manifold
62        print("Computing precision...")
63        fake_to_real_distances, _ = real_nn.kneighbors(fake_features)
64        fake_to_real_nearest = fake_to_real_distances[:, 0]
65
66        # A fake sample is in the real manifold if its distance to
67        # nearest real sample is <= that real sample&apos;s radius
68        _, real_indices = real_nn.kneighbors(fake_features, n_neighbors=1)
69        real_indices = real_indices.flatten()
70
71        precision_count = 0
72        for i, idx in enumerate(real_indices):
73            if fake_to_real_nearest[i] <= real_radii[idx]:
74                precision_count += 1
75        precision = precision_count / len(fake_features)
76
77        # Recall: fraction of real samples covered by fake manifold
78        print("Computing recall...")
79        real_to_fake_distances, fake_indices = fake_nn.kneighbors(
80            real_features, n_neighbors=1
81        )
82        real_to_fake_nearest = real_to_fake_distances.flatten()
83        fake_indices = fake_indices.flatten()
84
85        recall_count = 0
86        for i, idx in enumerate(fake_indices):
87            if real_to_fake_nearest[i] <= fake_radii[idx]:
88                recall_count += 1
89        recall = recall_count / len(real_features)
90
91        return precision, recall
92
93    def compute_density_coverage(
94        self,
95        real_features: np.ndarray,
96        fake_features: np.ndarray,
97    ) -> Tuple[float, float]:
98        """Compute Density and Coverage metrics.
99
100        Alternative to precision/recall that&apos;s less sensitive to outliers.
101        """
102        # Build real manifold
103        real_nn = self.compute_manifold(real_features)
104        real_distances, _ = real_nn.kneighbors(real_features)
105        real_radii = real_distances[:, -1]
106
107        # Density: average number of real samples in each fake sample&apos;s ball
108        fake_to_real_distances, real_indices = real_nn.kneighbors(fake_features)
109        fake_to_real_nearest = fake_to_real_distances[:, 0]
110
111        density_sum = 0
112        for i in range(len(fake_features)):
113            # Count real samples within the radius of the nearest real sample
114            nearest_real_idx = real_indices[i, 0]
115            radius = real_radii[nearest_real_idx]
116
117            count = np.sum(fake_to_real_distances[i] <= radius)
118            density_sum += count
119
120        density = density_sum / (self.k * len(fake_features))
121
122        # Coverage: fraction of real samples with at least one fake neighbor
123        covered = np.zeros(len(real_features), dtype=bool)
124        for i, idx in enumerate(real_indices[:, 0]):
125            if fake_to_real_nearest[i] <= real_radii[idx]:
126                covered[idx] = True
127
128        coverage = np.mean(covered)
129
130        return density, coverage
131
132
133def compute_full_precision_recall(
134    real_loader,
135    generated_images,
136    fid_calculator: FIDCalculator,
137    k: int = 3,
138    num_samples: int = 10000,
139) -> dict:
140    """Compute full precision/recall metrics."""
141    # Extract features
142    real_features = fid_calculator.extract_features_from_loader(
143        real_loader, num_samples
144    )
145
146    gen_features = []
147    batch_size = 64
148    for i in range(0, min(len(generated_images), num_samples), batch_size):
149        batch = generated_images[i:i + batch_size]
150        gen_features.append(fid_calculator.extract_features(batch))
151    gen_features = np.concatenate(gen_features, axis=0)
152
153    # Compute metrics
154    pr_calc = PrecisionRecallCalculator(fid_calculator, k=k)
155
156    precision, recall = pr_calc.compute_precision_recall(
157        real_features, gen_features
158    )
159
160    density, coverage = pr_calc.compute_density_coverage(
161        real_features, gen_features
162    )
163
164    return {
165        "precision": precision,
166        "recall": recall,
167        "density": density,
168        "coverage": coverage,
169        "f1": 2 * precision * recall / (precision + recall + 1e-8),
170    }
Model CharacteristicHigh PrecisionHigh Recall
Mode collapseYes (limited diversity)No (misses modes)
Good qualityYes (realistic samples)Maybe
Good diversityMaybeYes (covers all modes)
OverfittingYes (copying training data)Yes (covers training data)

Complete Evaluation Pipeline

🐍python
1import torch
2import numpy as np
3from pathlib import Path
4from dataclasses import dataclass
5from typing import Optional, Dict, Any
6import json
7from datetime import datetime
8
9@dataclass
10class EvaluationResults:
11    """Container for all evaluation metrics."""
12    fid: float
13    is_mean: float
14    is_std: float
15    precision: float
16    recall: float
17    density: float
18    coverage: float
19    num_samples: int
20    timestamp: str
21    model_name: str
22    additional_info: Dict[str, Any]
23
24    def to_dict(self) -> dict:
25        return {
26            "fid": self.fid,
27            "inception_score": {"mean": self.is_mean, "std": self.is_std},
28            "precision": self.precision,
29            "recall": self.recall,
30            "density": self.density,
31            "coverage": self.coverage,
32            "num_samples": self.num_samples,
33            "timestamp": self.timestamp,
34            "model_name": self.model_name,
35            **self.additional_info,
36        }
37
38    def save(self, path: str):
39        with open(path, "w") as f:
40            json.dump(self.to_dict(), f, indent=2)
41
42
43class ModelEvaluator:
44    """Complete evaluation pipeline for generative models."""
45
46    def __init__(
47        self,
48        real_stats_path: Optional[str] = None,
49        device: str = "cuda",
50        num_samples: int = 50000,
51    ):
52        self.device = device
53        self.num_samples = num_samples
54
55        # Initialize calculators
56        self.fid_calc = PrecomputedFID(stats_path=real_stats_path, device=device)
57        self.is_calc = InceptionScoreCalculator(device=device)
58        self.pr_calc = PrecisionRecallCalculator(self.fid_calc.calculator, device=device)
59
60    def precompute_real_stats(
61        self,
62        real_loader,
63        save_path: str,
64    ):
65        """Precompute and save real data statistics."""
66        print("Computing real data statistics...")
67        self.fid_calc.compute_and_cache_real_stats(
68            real_loader,
69            num_samples=self.num_samples,
70            save_path=save_path,
71        )
72
73        # Also save features for precision/recall
74        features = self.fid_calc.calculator.extract_features_from_loader(
75            real_loader, self.num_samples
76        )
77        np.save(save_path.replace(".npz", "_features.npy"), features)
78
79    def evaluate(
80        self,
81        generated_images: torch.Tensor,
82        model_name: str = "unknown",
83        compute_all: bool = True,
84        additional_info: Optional[Dict[str, Any]] = None,
85    ) -> EvaluationResults:
86        """Run full evaluation on generated images.
87
88        Args:
89            generated_images: Tensor of generated images
90            model_name: Name for logging
91            compute_all: Whether to compute all metrics
92            additional_info: Additional metadata to include
93
94        Returns:
95            EvaluationResults with all metrics
96        """
97        print(f"Evaluating {len(generated_images)} generated images...")
98
99        results = {
100            "fid": 0.0,
101            "is_mean": 0.0,
102            "is_std": 0.0,
103            "precision": 0.0,
104            "recall": 0.0,
105            "density": 0.0,
106            "coverage": 0.0,
107        }
108
109        # Compute FID
110        print("\nComputing FID...")
111        results["fid"] = self.fid_calc.compute_fid(generated_images)
112        print(f"FID: {results['fid']:.2f}")
113
114        # Compute Inception Score
115        print("\nComputing Inception Score...")
116        results["is_mean"], results["is_std"] = self.is_calc.compute_inception_score(
117            generated_images
118        )
119        print(f"IS: {results['is_mean']:.2f} +/- {results['is_std']:.2f}")
120
121        if compute_all:
122            # Compute Precision/Recall
123            print("\nComputing Precision/Recall...")
124
125            # Load cached real features
126            gen_features = []
127            batch_size = 64
128            for i in range(0, len(generated_images), batch_size):
129                batch = generated_images[i:i + batch_size]
130                gen_features.append(
131                    self.fid_calc.calculator.extract_features(batch)
132                )
133            gen_features = np.concatenate(gen_features, axis=0)
134
135            # Load real features (need to be precomputed)
136            real_features_path = self.fid_calc.calculator.stats_path.replace(
137                ".npz", "_features.npy"
138            )
139            if Path(real_features_path).exists():
140                real_features = np.load(real_features_path)
141
142                pr_results = self.pr_calc.compute_precision_recall(
143                    real_features[:len(gen_features)],
144                    gen_features
145                )
146                results["precision"], results["recall"] = pr_results
147
148                dc_results = self.pr_calc.compute_density_coverage(
149                    real_features[:len(gen_features)],
150                    gen_features
151                )
152                results["density"], results["coverage"] = dc_results
153
154                print(f"Precision: {results['precision']:.4f}")
155                print(f"Recall: {results['recall']:.4f}")
156                print(f"Density: {results['density']:.4f}")
157                print(f"Coverage: {results['coverage']:.4f}")
158
159        return EvaluationResults(
160            fid=results["fid"],
161            is_mean=results["is_mean"],
162            is_std=results["is_std"],
163            precision=results["precision"],
164            recall=results["recall"],
165            density=results["density"],
166            coverage=results["coverage"],
167            num_samples=len(generated_images),
168            timestamp=datetime.now().isoformat(),
169            model_name=model_name,
170            additional_info=additional_info or {},
171        )
172
173
174def evaluate_model_checkpoint(
175    checkpoint_path: str,
176    model_class,
177    sampler_class,
178    real_stats_path: str,
179    output_dir: str,
180    num_samples: int = 50000,
181    num_steps: int = 50,
182):
183    """Evaluate a saved model checkpoint."""
184    output_dir = Path(output_dir)
185    output_dir.mkdir(parents=True, exist_ok=True)
186
187    # Load model
188    print(f"Loading model from {checkpoint_path}")
189    checkpoint = torch.load(checkpoint_path)
190    model = model_class(**checkpoint["config"])
191    model.load_state_dict(checkpoint["model_state_dict"])
192    model = model.cuda().eval()
193
194    # Create sampler
195    sampler = sampler_class(model, checkpoint["alphas_cumprod"])
196
197    # Generate samples
198    print(f"Generating {num_samples} samples...")
199    generated = []
200    batch_size = 64
201    shape = (batch_size, 3, 64, 64)  # Adjust as needed
202
203    for i in tqdm(range(0, num_samples, batch_size)):
204        current_batch = min(batch_size, num_samples - i)
205        samples = sampler.sample(
206            (current_batch, 3, 64, 64),
207            num_steps=num_steps
208        )
209        generated.append(samples.cpu())
210
211    generated = torch.cat(generated, dim=0)
212
213    # Evaluate
214    evaluator = ModelEvaluator(
215        real_stats_path=real_stats_path,
216        num_samples=num_samples,
217    )
218
219    results = evaluator.evaluate(
220        generated,
221        model_name=Path(checkpoint_path).stem,
222        additional_info={
223            "checkpoint": checkpoint_path,
224            "num_steps": num_steps,
225        }
226    )
227
228    # Save results
229    results_path = output_dir / f"eval_{results.model_name}_{results.timestamp}.json"
230    results.save(str(results_path))
231    print(f"\nResults saved to {results_path}")
232
233    # Print summary
234    print("\n" + "=" * 50)
235    print("EVALUATION SUMMARY")
236    print("=" * 50)
237    print(f"Model: {results.model_name}")
238    print(f"FID: {results.fid:.2f}")
239    print(f"IS: {results.is_mean:.2f} +/- {results.is_std:.2f}")
240    print(f"Precision: {results.precision:.4f}")
241    print(f"Recall: {results.recall:.4f}")
242    print("=" * 50)
243
244    return results

Key Takeaways

  1. FID is the primary metric: Lower is better. Use 50k samples for reliable comparisons and always report sample count.
  2. IS measures quality + diversity: Higher is better, but only works well for ImageNet-like images.
  3. Precision and Recall separate concerns: Precision measures quality (are samples realistic?), Recall measures coverage (are all modes represented?).
  4. Cache real statistics: Precompute and save Inception features for faster evaluation iterations.
  5. Use multiple metrics: No single metric captures all aspects of generation quality. Report FID, IS, and Precision/Recall together.
Looking Ahead: Quantitative metrics only tell part of the story. The next section covers qualitative analysis techniques for understanding what your model has learned and identifying failure modes.