Chapter 16
25 min read
Section 70 of 76

Model Acceleration

Optimization and Deployment

Learning Objectives

By the end of this section, you will be able to:

  1. Understand why diffusion models are slow and identify the computational bottlenecks in the sampling process
  2. Implement knowledge distillation to train student models that require fewer sampling steps
  3. Apply progressive distillation to reduce sampling from 1000 steps to as few as 4 steps
  4. Understand consistency models and their ability to perform single-step generation
  5. Use Latent Consistency Models (LCM) for fast, high-quality image generation

The Speed Bottleneck

Diffusion models produce stunning results, but they have a fundamental speed problem. While a GAN generates an image in a single forward pass, a diffusion model requires hundreds to thousands of sequential denoising steps. Each step involves a full neural network forward pass through a large U-Net.

Model TypeSteps RequiredTypical LatencyGPU Memory
DDPM (Original)1000~60 seconds~8 GB
DDIM (Accelerated)50-100~3-6 seconds~8 GB
Stable Diffusion 1.520-50~2-5 seconds~6 GB
LCM-LoRA4-8~0.5-1 second~6 GB
SDXL Turbo1-4~0.2-0.5 seconds~12 GB
The Core Problem: Sequential dependencies in the denoising process prevent parallelization. Each step xt1x_{t-1} depends on the previous output xtx_t, so we cannot compute multiple steps simultaneously.

Why Reducing Steps Is Hard

Naively using fewer DDPM steps dramatically degrades quality. The model learns the denoising distribution for small step sizes, and when we skip steps, we violate the assumptions the model was trained on.

The key insight is that we need to retrain or distill models specifically for fewer-step generation. This leads us to the family of distillation techniques.


Knowledge Distillation

Knowledge distillation transfers the knowledge from a slow, high-quality "teacher" model to a fast "student" model. For diffusion models, the teacher uses many steps while the student uses fewer steps.

The Basic Approach

Instead of training the student to predict noise directly from data, we train it to match the teacher's predictions:

🐍python
1import torch
2import torch.nn as nn
3from diffusers import DDPMScheduler, UNet2DModel
4
5class DiffusionDistillation:
6    """Knowledge distillation for diffusion models."""
7
8    def __init__(self, teacher_model, student_model, teacher_steps=1000, student_steps=100):
9        self.teacher = teacher_model.eval()  # Freeze teacher
10        self.student = student_model
11
12        self.teacher_scheduler = DDPMScheduler(num_train_timesteps=teacher_steps)
13        self.student_scheduler = DDPMScheduler(num_train_timesteps=student_steps)
14
15        # Map student steps to teacher steps
16        self.step_ratio = teacher_steps // student_steps
17
18    def distillation_loss(self, x_0, student_t):
19        """Compute distillation loss at student timestep."""
20        # Map student timestep to teacher timestep
21        teacher_t = student_t * self.step_ratio
22
23        # Add noise at teacher timestep
24        noise = torch.randn_like(x_0)
25        x_teacher_t = self.teacher_scheduler.add_noise(x_0, noise, teacher_t)
26
27        # Teacher prediction (no gradients needed)
28        with torch.no_grad():
29            teacher_pred = self.teacher(x_teacher_t, teacher_t).sample
30
31            # Teacher takes multiple steps to reach student timestep
32            x_target = x_teacher_t
33            for t in range(teacher_t, student_t, -1):
34                x_target = self.teacher_scheduler.step(
35                    self.teacher(x_target, t).sample, t, x_target
36                ).prev_sample
37
38        # Add noise at student timestep to match conditions
39        x_student_t = self.student_scheduler.add_noise(x_0, noise, student_t)
40
41        # Student prediction
42        student_pred = self.student(x_student_t, student_t).sample
43
44        # Distillation loss: match teacher's denoised result
45        return nn.functional.mse_loss(student_pred, teacher_pred)

Teacher-Student Gap

The key insight is that the student learns to "jump ahead" by multiple steps at once. Instead of learning the precise single-step denoising, it learns the aggregate effect of multiple teacher steps.

Progressive Distillation

Progressive distillation (Salimans & Ho, 2022) iteratively halves the number of required steps. Starting from a 1000-step model, we distill to 500 steps, then 250, 125, 64, 32, 16, 8, 4 steps.

The Algorithm

At each stage, the student learns to match the teacher's output aftertwo steps in a single step:

🐍python
1class ProgressiveDistillation:
2    """Progressive distillation: halve steps at each stage."""
3
4    def __init__(self, model, initial_steps=1024):
5        self.model = model
6        self.current_steps = initial_steps
7
8    def halve_steps(self, dataloader, num_epochs=100):
9        """Distill to half the current number of steps."""
10        target_steps = self.current_steps // 2
11        print(f"Distilling: {self.current_steps} -> {target_steps} steps")
12
13        # Clone model: teacher = current model, student = trainable copy
14        teacher = copy.deepcopy(self.model).eval()
15        student = self.model  # Train in place
16
17        optimizer = torch.optim.AdamW(student.parameters(), lr=1e-5)
18
19        for epoch in range(num_epochs):
20            for batch in dataloader:
21                x_0 = batch["images"].to(device)
22
23                # Sample timestep for student (even timesteps only)
24                student_t = torch.randint(0, target_steps, (x_0.shape[0],)) * 2
25
26                # Add noise
27                noise = torch.randn_like(x_0)
28                x_t = self.add_noise(x_0, noise, student_t)
29
30                # Teacher takes TWO steps
31                with torch.no_grad():
32                    # First teacher step: t -> t-1
33                    teacher_pred_1 = teacher(x_t, student_t).sample
34                    x_t_minus_1 = self.denoise_step(x_t, teacher_pred_1, student_t)
35
36                    # Second teacher step: t-1 -> t-2
37                    teacher_pred_2 = teacher(x_t_minus_1, student_t - 1).sample
38                    x_target = self.denoise_step(x_t_minus_1, teacher_pred_2, student_t - 1)
39
40                # Student takes ONE step to match teacher's two steps
41                student_pred = student(x_t, student_t).sample
42                x_student = self.denoise_step(x_t, student_pred, student_t, stride=2)
43
44                # Loss: match the denoised outputs
45                loss = F.mse_loss(x_student, x_target)
46
47                optimizer.zero_grad()
48                loss.backward()
49                optimizer.step()
50
51        self.current_steps = target_steps
52        return student
53
54    def distill_to_n_steps(self, dataloader, target_steps=4):
55        """Progressively distill until reaching target steps."""
56        while self.current_steps > target_steps:
57            self.halve_steps(dataloader)
58        return self.model

Results and Trade-offs

Distillation StageStepsFIDQuality Notes
Original DDPM1000~3.0Reference quality
Stage 1512~3.1Minimal degradation
Stage 2256~3.3Still high quality
Stage 3128~3.8Slight softening
Stage 464~4.5Noticeable but acceptable
Stage 532~5.8Some detail loss
Stage 616~7.5Clearly degraded
Stage 78~10.2Significant quality loss
Stage 84~15.0Usable for previews

Consistency Models

Consistency Models (Song et al., 2023) take a fundamentally different approach. Instead of learning the step-by-step denoising process, they learn to map any point on the diffusion trajectorydirectly to the clean data.

The Consistency Property

The key idea is the self-consistency property: any two points on the same diffusion trajectory should map to the same clean image. Mathematically:

For any t,t[0,T]t, t' \in [0, T] on the same trajectory:

fθ(xt,t)=fθ(xt,t)f_\theta(x_t, t) = f_\theta(x_{t'}, t')

At t=0t=0, this should equal x0x_0 (the clean data).

🐍python
1class ConsistencyModel(nn.Module):
2    """Consistency model for single-step generation."""
3
4    def __init__(self, base_model):
5        super().__init__()
6        self.base_model = base_model  # U-Net architecture
7
8        # Consistency function parameterization
9        self.c_skip = lambda t: 1 / (1 + t**2)
10        self.c_out = lambda t: t / torch.sqrt(1 + t**2)
11
12    def forward(self, x_t, t):
13        """Map noisy input directly to clean output."""
14        # Base model prediction
15        F_theta = self.base_model(x_t, t)
16
17        # Consistency parameterization ensures f(x_0, 0) = x_0
18        return self.c_skip(t) * x_t + self.c_out(t) * F_theta
19
20    @torch.no_grad()
21    def sample(self, noise, num_steps=1):
22        """Generate samples (can use 1 or more steps)."""
23        x = noise
24
25        if num_steps == 1:
26            # Single-step generation!
27            return self.forward(x, torch.ones(x.shape[0], device=x.device))
28
29        # Multi-step improves quality
30        timesteps = torch.linspace(1, 0.001, num_steps + 1)
31        for i in range(num_steps):
32            t = timesteps[i] * torch.ones(x.shape[0], device=x.device)
33
34            # Map to clean data estimate
35            x_0_hat = self.forward(x, t)
36
37            # Add back noise for next step (if not last)
38            if i < num_steps - 1:
39                t_next = timesteps[i + 1]
40                x = x_0_hat + t_next * torch.randn_like(x)
41
42        return x_0_hat

Training Consistency Models

There are two approaches to training consistency models:

  1. Consistency Distillation (CD): Distill from a pre-trained diffusion model by enforcing consistency along sampled trajectories
  2. Consistency Training (CT): Train from scratch using the consistency loss, without requiring a teacher model
🐍python
1class ConsistencyDistillation:
2    """Distill a diffusion model into a consistency model."""
3
4    def __init__(self, teacher_diffusion, consistency_model):
5        self.teacher = teacher_diffusion.eval()
6        self.student = consistency_model
7        self.ema_student = copy.deepcopy(consistency_model)
8
9    def training_step(self, x_0):
10        """One training step of consistency distillation."""
11        batch_size = x_0.shape[0]
12
13        # Sample timesteps
14        t = torch.rand(batch_size, device=x_0.device) * 0.999 + 0.001
15        t_next = t - 1/1000  # One step earlier
16
17        # Add noise at timestep t
18        noise = torch.randn_like(x_0)
19        x_t = x_0 + t.view(-1, 1, 1, 1) * noise
20
21        # Teacher denoises one step: x_t -> x_{t-1}
22        with torch.no_grad():
23            teacher_pred = self.teacher(x_t, t).sample
24            x_t_next = x_t - (t - t_next).view(-1, 1, 1, 1) * teacher_pred
25
26        # Consistency loss: f(x_t, t) should equal f(x_{t-1}, t-1)
27        student_output_t = self.student(x_t, t)
28
29        with torch.no_grad():
30            # Use EMA model for target (stabilizes training)
31            target_output = self.ema_student(x_t_next, t_next)
32
33        loss = F.mse_loss(student_output_t, target_output)
34
35        return loss
36
37    def update_ema(self, decay=0.999):
38        """Update EMA of student model."""
39        for ema_param, param in zip(self.ema_student.parameters(),
40                                     self.student.parameters()):
41            ema_param.data.mul_(decay).add_(param.data, alpha=1 - decay)

Latent Consistency Models

Latent Consistency Models (LCM) (Luo et al., 2023) combine the best of both worlds: they apply consistency model training in thelatent space of Stable Diffusion, enabling 4-step generation with minimal quality loss.

Key Innovations

  • Latent Space Training: Work in the compressed latent space of a VAE, reducing computational cost
  • Classifier-Free Guidance Distillation: Bake the guidance scale into the model, eliminating the need for negative prompts
  • LCM-LoRA: Efficient fine-tuning using LoRA adapters, allowing quick adaptation to new base models
🐍python
1from diffusers import DiffusionPipeline, LCMScheduler
2import torch
3
4# Load LCM-LoRA for Stable Diffusion XL
5pipe = DiffusionPipeline.from_pretrained(
6    "stabilityai/stable-diffusion-xl-base-1.0",
7    torch_dtype=torch.float16,
8    variant="fp16"
9)
10pipe.to("cuda")
11
12# Load LCM-LoRA weights
13pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
14
15# Use LCM scheduler (critical!)
16pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
17
18# Generate with just 4 steps!
19prompt = "A majestic lion in a savanna, golden hour lighting, photorealistic"
20
21image = pipe(
22    prompt=prompt,
23    num_inference_steps=4,      # Only 4 steps!
24    guidance_scale=1.0,         # LCM works best with guidance_scale=1
25    generator=torch.Generator("cuda").manual_seed(42)
26).images[0]
27
28# Compare inference times
29import time
30
31# Standard SDXL: 50 steps
32start = time.time()
33standard_image = pipe(prompt, num_inference_steps=50).images[0]
34standard_time = time.time() - start
35print(f"Standard SDXL (50 steps): {standard_time:.2f}s")
36
37# LCM: 4 steps
38pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
39pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
40start = time.time()
41lcm_image = pipe(prompt, num_inference_steps=4, guidance_scale=1.0).images[0]
42lcm_time = time.time() - start
43print(f"LCM (4 steps): {lcm_time:.2f}s")
44print(f"Speedup: {standard_time / lcm_time:.1f}x")

LCM-LoRA Flexibility

LCM-LoRA can be applied to custom fine-tuned models! If you have a DreamBooth or LoRA-trained model, you can combine it with LCM-LoRA for fast inference on your custom style or subject.

Adversarial Distillation

The latest breakthrough in distillation combines consistency training with adversarial objectives. Models like SDXL Turbo and SDXL Lightning achieve near-original quality in just 1-4 steps.

ADD: Adversarial Diffusion Distillation

SDXL Turbo (Sauer et al., 2023) introduces Adversarial Diffusion Distillation (ADD):

🐍python
1class AdversarialDiffusionDistillation:
2    """
3    Adversarial Diffusion Distillation (ADD) for 1-4 step generation.
4    Combines:
5    1. Score distillation from teacher diffusion model
6    2. Adversarial loss from discriminator
7    """
8
9    def __init__(self, student, teacher, discriminator):
10        self.student = student
11        self.teacher = teacher.eval()
12        self.discriminator = discriminator
13
14    def training_step(self, x_0, text_embeddings):
15        """Combined distillation + adversarial training step."""
16        batch_size = x_0.shape[0]
17
18        # Sample timestep for student (small t for ADD)
19        student_t = torch.rand(batch_size) * 0.5  # t in [0, 0.5]
20
21        # Student generates from noise
22        noise = torch.randn_like(x_0)
23        x_student_t = x_0 + student_t.view(-1, 1, 1, 1) * noise
24
25        # Student prediction
26        student_pred = self.student(x_student_t, student_t, text_embeddings)
27        x_generated = x_student_t - student_t.view(-1, 1, 1, 1) * student_pred
28
29        # 1. Score Distillation Loss
30        with torch.no_grad():
31            # Teacher's noise prediction at same timestep
32            teacher_pred = self.teacher(x_student_t, student_t, text_embeddings)
33
34        score_loss = F.mse_loss(student_pred, teacher_pred)
35
36        # 2. Adversarial Loss
37        # Discriminator should classify generated as real
38        disc_fake = self.discriminator(x_generated, text_embeddings)
39        adv_loss = F.binary_cross_entropy_with_logits(
40            disc_fake, torch.ones_like(disc_fake)
41        )
42
43        # Combined loss
44        total_loss = score_loss + 0.1 * adv_loss
45
46        return total_loss, {
47            "score_loss": score_loss.item(),
48            "adv_loss": adv_loss.item()
49        }
50
51    def discriminator_step(self, x_real, x_fake, text_embeddings):
52        """Train discriminator to distinguish real from generated."""
53        disc_real = self.discriminator(x_real, text_embeddings)
54        disc_fake = self.discriminator(x_fake.detach(), text_embeddings)
55
56        loss_real = F.binary_cross_entropy_with_logits(
57            disc_real, torch.ones_like(disc_real)
58        )
59        loss_fake = F.binary_cross_entropy_with_logits(
60            disc_fake, torch.zeros_like(disc_fake)
61        )
62
63        return loss_real + loss_fake

Using SDXL Turbo

🐍python
1from diffusers import AutoPipelineForText2Image
2import torch
3
4# Load SDXL Turbo
5pipe = AutoPipelineForText2Image.from_pretrained(
6    "stabilityai/sdxl-turbo",
7    torch_dtype=torch.float16,
8    variant="fp16"
9)
10pipe.to("cuda")
11
12# Generate in just 1 step!
13prompt = "A cat wearing sunglasses on a beach"
14
15image = pipe(
16    prompt=prompt,
17    num_inference_steps=1,      # Single step!
18    guidance_scale=0.0,         # No CFG needed
19).images[0]
20
21# For slightly better quality, use 4 steps
22image_4step = pipe(
23    prompt=prompt,
24    num_inference_steps=4,
25    guidance_scale=0.0,
26).images[0]
Trade-off Reminder: Turbo models sacrifice some fine-grained control (like precise prompt following) for speed. For production, benchmark on your specific use case to find the right balance.

Practical Implementation

Choosing the Right Approach

Use CaseRecommended ApproachStepsNotes
Real-time previewSDXL Turbo / Lightning1-4Fastest, good for iterating
Production qualityLCM-LoRA4-8Best quality-speed trade-off
Fine-tuned modelsLCM-LoRA + your LoRA4-8Combine with custom styles
Maximum qualityStandard DDIM20-50When speed isn't critical
Research/ablationsProgressive distillationVariableWhen you need control

Benchmark Script

🐍python
1import torch
2import time
3from diffusers import (
4    StableDiffusionXLPipeline,
5    LCMScheduler,
6    AutoPipelineForText2Image
7)
8
9def benchmark_inference(pipe, prompt, num_steps, num_runs=10):
10    """Benchmark inference latency."""
11    # Warmup
12    for _ in range(3):
13        pipe(prompt, num_inference_steps=num_steps, guidance_scale=1.0)
14
15    # Benchmark
16    torch.cuda.synchronize()
17    start = time.time()
18
19    for _ in range(num_runs):
20        pipe(prompt, num_inference_steps=num_steps, guidance_scale=1.0)
21        torch.cuda.synchronize()
22
23    elapsed = time.time() - start
24    return elapsed / num_runs
25
26# Test different configurations
27prompt = "A beautiful sunset over mountains, photorealistic"
28
29results = []
30
31# Standard SDXL
32pipe_std = StableDiffusionXLPipeline.from_pretrained(
33    "stabilityai/stable-diffusion-xl-base-1.0",
34    torch_dtype=torch.float16
35).to("cuda")
36
37for steps in [50, 25, 10]:
38    latency = benchmark_inference(pipe_std, prompt, steps)
39    results.append(("SDXL Standard", steps, latency))
40
41# LCM-LoRA
42pipe_lcm = StableDiffusionXLPipeline.from_pretrained(
43    "stabilityai/stable-diffusion-xl-base-1.0",
44    torch_dtype=torch.float16
45).to("cuda")
46pipe_lcm.load_lora_weights("latent-consistency/lcm-lora-sdxl")
47pipe_lcm.scheduler = LCMScheduler.from_config(pipe_lcm.scheduler.config)
48
49for steps in [8, 4, 2]:
50    latency = benchmark_inference(pipe_lcm, prompt, steps)
51    results.append(("LCM-LoRA", steps, latency))
52
53# SDXL Turbo
54pipe_turbo = AutoPipelineForText2Image.from_pretrained(
55    "stabilityai/sdxl-turbo",
56    torch_dtype=torch.float16
57).to("cuda")
58
59for steps in [4, 2, 1]:
60    latency = benchmark_inference(pipe_turbo, prompt, steps)
61    results.append(("SDXL Turbo", steps, latency))
62
63# Print results
64print("\nBenchmark Results (A100 GPU):")
65print("-" * 50)
66for model, steps, latency in results:
67    print(f"{model:20s} | {steps:3d} steps | {latency*1000:7.1f} ms")

Summary

Model acceleration is crucial for deploying diffusion models in production. We covered several key techniques:

  1. Knowledge Distillation: Training a student model to mimic a teacher's multi-step behavior in fewer steps
  2. Progressive Distillation: Iteratively halving steps (1000 → 4) while maintaining reasonable quality
  3. Consistency Models: Learning direct mappings to clean data, enabling single-step generation
  4. LCM/LCM-LoRA: Latent space consistency training with efficient LoRA adapters for 4-step generation
  5. Adversarial Distillation: Combining score distillation with GAN-style training (SDXL Turbo) for 1-4 step generation
Looking Ahead: In the next section, we'll explore quantization and efficiency techniques that further reduce memory usage and computational cost, enabling deployment on consumer hardware.

Key Papers

  • Progressive Distillation: Salimans & Ho (2022) - "Progressive Distillation for Fast Sampling of Diffusion Models"
  • Consistency Models: Song et al. (2023) - "Consistency Models"
  • LCM: Luo et al. (2023) - "Latent Consistency Models"
  • SDXL Turbo: Sauer et al. (2023) - "Adversarial Diffusion Distillation"