Learning Objectives
By the end of this section, you will be able to:
- Understand why diffusion models are slow and identify the computational bottlenecks in the sampling process
- Implement knowledge distillation to train student models that require fewer sampling steps
- Apply progressive distillation to reduce sampling from 1000 steps to as few as 4 steps
- Understand consistency models and their ability to perform single-step generation
- Use Latent Consistency Models (LCM) for fast, high-quality image generation
The Speed Bottleneck
Diffusion models produce stunning results, but they have a fundamental speed problem. While a GAN generates an image in a single forward pass, a diffusion model requires hundreds to thousands of sequential denoising steps. Each step involves a full neural network forward pass through a large U-Net.
| Model Type | Steps Required | Typical Latency | GPU Memory |
|---|---|---|---|
| DDPM (Original) | 1000 | ~60 seconds | ~8 GB |
| DDIM (Accelerated) | 50-100 | ~3-6 seconds | ~8 GB |
| Stable Diffusion 1.5 | 20-50 | ~2-5 seconds | ~6 GB |
| LCM-LoRA | 4-8 | ~0.5-1 second | ~6 GB |
| SDXL Turbo | 1-4 | ~0.2-0.5 seconds | ~12 GB |
The Core Problem: Sequential dependencies in the denoising process prevent parallelization. Each step depends on the previous output , so we cannot compute multiple steps simultaneously.
Why Reducing Steps Is Hard
Naively using fewer DDPM steps dramatically degrades quality. The model learns the denoising distribution for small step sizes, and when we skip steps, we violate the assumptions the model was trained on.
The key insight is that we need to retrain or distill models specifically for fewer-step generation. This leads us to the family of distillation techniques.
Knowledge Distillation
Knowledge distillation transfers the knowledge from a slow, high-quality "teacher" model to a fast "student" model. For diffusion models, the teacher uses many steps while the student uses fewer steps.
The Basic Approach
Instead of training the student to predict noise directly from data, we train it to match the teacher's predictions:
1import torch
2import torch.nn as nn
3from diffusers import DDPMScheduler, UNet2DModel
4
5class DiffusionDistillation:
6 """Knowledge distillation for diffusion models."""
7
8 def __init__(self, teacher_model, student_model, teacher_steps=1000, student_steps=100):
9 self.teacher = teacher_model.eval() # Freeze teacher
10 self.student = student_model
11
12 self.teacher_scheduler = DDPMScheduler(num_train_timesteps=teacher_steps)
13 self.student_scheduler = DDPMScheduler(num_train_timesteps=student_steps)
14
15 # Map student steps to teacher steps
16 self.step_ratio = teacher_steps // student_steps
17
18 def distillation_loss(self, x_0, student_t):
19 """Compute distillation loss at student timestep."""
20 # Map student timestep to teacher timestep
21 teacher_t = student_t * self.step_ratio
22
23 # Add noise at teacher timestep
24 noise = torch.randn_like(x_0)
25 x_teacher_t = self.teacher_scheduler.add_noise(x_0, noise, teacher_t)
26
27 # Teacher prediction (no gradients needed)
28 with torch.no_grad():
29 teacher_pred = self.teacher(x_teacher_t, teacher_t).sample
30
31 # Teacher takes multiple steps to reach student timestep
32 x_target = x_teacher_t
33 for t in range(teacher_t, student_t, -1):
34 x_target = self.teacher_scheduler.step(
35 self.teacher(x_target, t).sample, t, x_target
36 ).prev_sample
37
38 # Add noise at student timestep to match conditions
39 x_student_t = self.student_scheduler.add_noise(x_0, noise, student_t)
40
41 # Student prediction
42 student_pred = self.student(x_student_t, student_t).sample
43
44 # Distillation loss: match teacher's denoised result
45 return nn.functional.mse_loss(student_pred, teacher_pred)Teacher-Student Gap
Progressive Distillation
Progressive distillation (Salimans & Ho, 2022) iteratively halves the number of required steps. Starting from a 1000-step model, we distill to 500 steps, then 250, 125, 64, 32, 16, 8, 4 steps.
The Algorithm
At each stage, the student learns to match the teacher's output aftertwo steps in a single step:
1class ProgressiveDistillation:
2 """Progressive distillation: halve steps at each stage."""
3
4 def __init__(self, model, initial_steps=1024):
5 self.model = model
6 self.current_steps = initial_steps
7
8 def halve_steps(self, dataloader, num_epochs=100):
9 """Distill to half the current number of steps."""
10 target_steps = self.current_steps // 2
11 print(f"Distilling: {self.current_steps} -> {target_steps} steps")
12
13 # Clone model: teacher = current model, student = trainable copy
14 teacher = copy.deepcopy(self.model).eval()
15 student = self.model # Train in place
16
17 optimizer = torch.optim.AdamW(student.parameters(), lr=1e-5)
18
19 for epoch in range(num_epochs):
20 for batch in dataloader:
21 x_0 = batch["images"].to(device)
22
23 # Sample timestep for student (even timesteps only)
24 student_t = torch.randint(0, target_steps, (x_0.shape[0],)) * 2
25
26 # Add noise
27 noise = torch.randn_like(x_0)
28 x_t = self.add_noise(x_0, noise, student_t)
29
30 # Teacher takes TWO steps
31 with torch.no_grad():
32 # First teacher step: t -> t-1
33 teacher_pred_1 = teacher(x_t, student_t).sample
34 x_t_minus_1 = self.denoise_step(x_t, teacher_pred_1, student_t)
35
36 # Second teacher step: t-1 -> t-2
37 teacher_pred_2 = teacher(x_t_minus_1, student_t - 1).sample
38 x_target = self.denoise_step(x_t_minus_1, teacher_pred_2, student_t - 1)
39
40 # Student takes ONE step to match teacher's two steps
41 student_pred = student(x_t, student_t).sample
42 x_student = self.denoise_step(x_t, student_pred, student_t, stride=2)
43
44 # Loss: match the denoised outputs
45 loss = F.mse_loss(x_student, x_target)
46
47 optimizer.zero_grad()
48 loss.backward()
49 optimizer.step()
50
51 self.current_steps = target_steps
52 return student
53
54 def distill_to_n_steps(self, dataloader, target_steps=4):
55 """Progressively distill until reaching target steps."""
56 while self.current_steps > target_steps:
57 self.halve_steps(dataloader)
58 return self.modelResults and Trade-offs
| Distillation Stage | Steps | FID | Quality Notes |
|---|---|---|---|
| Original DDPM | 1000 | ~3.0 | Reference quality |
| Stage 1 | 512 | ~3.1 | Minimal degradation |
| Stage 2 | 256 | ~3.3 | Still high quality |
| Stage 3 | 128 | ~3.8 | Slight softening |
| Stage 4 | 64 | ~4.5 | Noticeable but acceptable |
| Stage 5 | 32 | ~5.8 | Some detail loss |
| Stage 6 | 16 | ~7.5 | Clearly degraded |
| Stage 7 | 8 | ~10.2 | Significant quality loss |
| Stage 8 | 4 | ~15.0 | Usable for previews |
Consistency Models
Consistency Models (Song et al., 2023) take a fundamentally different approach. Instead of learning the step-by-step denoising process, they learn to map any point on the diffusion trajectorydirectly to the clean data.
The Consistency Property
The key idea is the self-consistency property: any two points on the same diffusion trajectory should map to the same clean image. Mathematically:
For any on the same trajectory:
At , this should equal (the clean data).
1class ConsistencyModel(nn.Module):
2 """Consistency model for single-step generation."""
3
4 def __init__(self, base_model):
5 super().__init__()
6 self.base_model = base_model # U-Net architecture
7
8 # Consistency function parameterization
9 self.c_skip = lambda t: 1 / (1 + t**2)
10 self.c_out = lambda t: t / torch.sqrt(1 + t**2)
11
12 def forward(self, x_t, t):
13 """Map noisy input directly to clean output."""
14 # Base model prediction
15 F_theta = self.base_model(x_t, t)
16
17 # Consistency parameterization ensures f(x_0, 0) = x_0
18 return self.c_skip(t) * x_t + self.c_out(t) * F_theta
19
20 @torch.no_grad()
21 def sample(self, noise, num_steps=1):
22 """Generate samples (can use 1 or more steps)."""
23 x = noise
24
25 if num_steps == 1:
26 # Single-step generation!
27 return self.forward(x, torch.ones(x.shape[0], device=x.device))
28
29 # Multi-step improves quality
30 timesteps = torch.linspace(1, 0.001, num_steps + 1)
31 for i in range(num_steps):
32 t = timesteps[i] * torch.ones(x.shape[0], device=x.device)
33
34 # Map to clean data estimate
35 x_0_hat = self.forward(x, t)
36
37 # Add back noise for next step (if not last)
38 if i < num_steps - 1:
39 t_next = timesteps[i + 1]
40 x = x_0_hat + t_next * torch.randn_like(x)
41
42 return x_0_hatTraining Consistency Models
There are two approaches to training consistency models:
- Consistency Distillation (CD): Distill from a pre-trained diffusion model by enforcing consistency along sampled trajectories
- Consistency Training (CT): Train from scratch using the consistency loss, without requiring a teacher model
1class ConsistencyDistillation:
2 """Distill a diffusion model into a consistency model."""
3
4 def __init__(self, teacher_diffusion, consistency_model):
5 self.teacher = teacher_diffusion.eval()
6 self.student = consistency_model
7 self.ema_student = copy.deepcopy(consistency_model)
8
9 def training_step(self, x_0):
10 """One training step of consistency distillation."""
11 batch_size = x_0.shape[0]
12
13 # Sample timesteps
14 t = torch.rand(batch_size, device=x_0.device) * 0.999 + 0.001
15 t_next = t - 1/1000 # One step earlier
16
17 # Add noise at timestep t
18 noise = torch.randn_like(x_0)
19 x_t = x_0 + t.view(-1, 1, 1, 1) * noise
20
21 # Teacher denoises one step: x_t -> x_{t-1}
22 with torch.no_grad():
23 teacher_pred = self.teacher(x_t, t).sample
24 x_t_next = x_t - (t - t_next).view(-1, 1, 1, 1) * teacher_pred
25
26 # Consistency loss: f(x_t, t) should equal f(x_{t-1}, t-1)
27 student_output_t = self.student(x_t, t)
28
29 with torch.no_grad():
30 # Use EMA model for target (stabilizes training)
31 target_output = self.ema_student(x_t_next, t_next)
32
33 loss = F.mse_loss(student_output_t, target_output)
34
35 return loss
36
37 def update_ema(self, decay=0.999):
38 """Update EMA of student model."""
39 for ema_param, param in zip(self.ema_student.parameters(),
40 self.student.parameters()):
41 ema_param.data.mul_(decay).add_(param.data, alpha=1 - decay)Latent Consistency Models
Latent Consistency Models (LCM) (Luo et al., 2023) combine the best of both worlds: they apply consistency model training in thelatent space of Stable Diffusion, enabling 4-step generation with minimal quality loss.
Key Innovations
- Latent Space Training: Work in the compressed latent space of a VAE, reducing computational cost
- Classifier-Free Guidance Distillation: Bake the guidance scale into the model, eliminating the need for negative prompts
- LCM-LoRA: Efficient fine-tuning using LoRA adapters, allowing quick adaptation to new base models
1from diffusers import DiffusionPipeline, LCMScheduler
2import torch
3
4# Load LCM-LoRA for Stable Diffusion XL
5pipe = DiffusionPipeline.from_pretrained(
6 "stabilityai/stable-diffusion-xl-base-1.0",
7 torch_dtype=torch.float16,
8 variant="fp16"
9)
10pipe.to("cuda")
11
12# Load LCM-LoRA weights
13pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
14
15# Use LCM scheduler (critical!)
16pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
17
18# Generate with just 4 steps!
19prompt = "A majestic lion in a savanna, golden hour lighting, photorealistic"
20
21image = pipe(
22 prompt=prompt,
23 num_inference_steps=4, # Only 4 steps!
24 guidance_scale=1.0, # LCM works best with guidance_scale=1
25 generator=torch.Generator("cuda").manual_seed(42)
26).images[0]
27
28# Compare inference times
29import time
30
31# Standard SDXL: 50 steps
32start = time.time()
33standard_image = pipe(prompt, num_inference_steps=50).images[0]
34standard_time = time.time() - start
35print(f"Standard SDXL (50 steps): {standard_time:.2f}s")
36
37# LCM: 4 steps
38pipe.load_lora_weights("latent-consistency/lcm-lora-sdxl")
39pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
40start = time.time()
41lcm_image = pipe(prompt, num_inference_steps=4, guidance_scale=1.0).images[0]
42lcm_time = time.time() - start
43print(f"LCM (4 steps): {lcm_time:.2f}s")
44print(f"Speedup: {standard_time / lcm_time:.1f}x")LCM-LoRA Flexibility
Adversarial Distillation
The latest breakthrough in distillation combines consistency training with adversarial objectives. Models like SDXL Turbo and SDXL Lightning achieve near-original quality in just 1-4 steps.
ADD: Adversarial Diffusion Distillation
SDXL Turbo (Sauer et al., 2023) introduces Adversarial Diffusion Distillation (ADD):
1class AdversarialDiffusionDistillation:
2 """
3 Adversarial Diffusion Distillation (ADD) for 1-4 step generation.
4 Combines:
5 1. Score distillation from teacher diffusion model
6 2. Adversarial loss from discriminator
7 """
8
9 def __init__(self, student, teacher, discriminator):
10 self.student = student
11 self.teacher = teacher.eval()
12 self.discriminator = discriminator
13
14 def training_step(self, x_0, text_embeddings):
15 """Combined distillation + adversarial training step."""
16 batch_size = x_0.shape[0]
17
18 # Sample timestep for student (small t for ADD)
19 student_t = torch.rand(batch_size) * 0.5 # t in [0, 0.5]
20
21 # Student generates from noise
22 noise = torch.randn_like(x_0)
23 x_student_t = x_0 + student_t.view(-1, 1, 1, 1) * noise
24
25 # Student prediction
26 student_pred = self.student(x_student_t, student_t, text_embeddings)
27 x_generated = x_student_t - student_t.view(-1, 1, 1, 1) * student_pred
28
29 # 1. Score Distillation Loss
30 with torch.no_grad():
31 # Teacher's noise prediction at same timestep
32 teacher_pred = self.teacher(x_student_t, student_t, text_embeddings)
33
34 score_loss = F.mse_loss(student_pred, teacher_pred)
35
36 # 2. Adversarial Loss
37 # Discriminator should classify generated as real
38 disc_fake = self.discriminator(x_generated, text_embeddings)
39 adv_loss = F.binary_cross_entropy_with_logits(
40 disc_fake, torch.ones_like(disc_fake)
41 )
42
43 # Combined loss
44 total_loss = score_loss + 0.1 * adv_loss
45
46 return total_loss, {
47 "score_loss": score_loss.item(),
48 "adv_loss": adv_loss.item()
49 }
50
51 def discriminator_step(self, x_real, x_fake, text_embeddings):
52 """Train discriminator to distinguish real from generated."""
53 disc_real = self.discriminator(x_real, text_embeddings)
54 disc_fake = self.discriminator(x_fake.detach(), text_embeddings)
55
56 loss_real = F.binary_cross_entropy_with_logits(
57 disc_real, torch.ones_like(disc_real)
58 )
59 loss_fake = F.binary_cross_entropy_with_logits(
60 disc_fake, torch.zeros_like(disc_fake)
61 )
62
63 return loss_real + loss_fakeUsing SDXL Turbo
1from diffusers import AutoPipelineForText2Image
2import torch
3
4# Load SDXL Turbo
5pipe = AutoPipelineForText2Image.from_pretrained(
6 "stabilityai/sdxl-turbo",
7 torch_dtype=torch.float16,
8 variant="fp16"
9)
10pipe.to("cuda")
11
12# Generate in just 1 step!
13prompt = "A cat wearing sunglasses on a beach"
14
15image = pipe(
16 prompt=prompt,
17 num_inference_steps=1, # Single step!
18 guidance_scale=0.0, # No CFG needed
19).images[0]
20
21# For slightly better quality, use 4 steps
22image_4step = pipe(
23 prompt=prompt,
24 num_inference_steps=4,
25 guidance_scale=0.0,
26).images[0]Trade-off Reminder: Turbo models sacrifice some fine-grained control (like precise prompt following) for speed. For production, benchmark on your specific use case to find the right balance.
Practical Implementation
Choosing the Right Approach
| Use Case | Recommended Approach | Steps | Notes |
|---|---|---|---|
| Real-time preview | SDXL Turbo / Lightning | 1-4 | Fastest, good for iterating |
| Production quality | LCM-LoRA | 4-8 | Best quality-speed trade-off |
| Fine-tuned models | LCM-LoRA + your LoRA | 4-8 | Combine with custom styles |
| Maximum quality | Standard DDIM | 20-50 | When speed isn't critical |
| Research/ablations | Progressive distillation | Variable | When you need control |
Benchmark Script
1import torch
2import time
3from diffusers import (
4 StableDiffusionXLPipeline,
5 LCMScheduler,
6 AutoPipelineForText2Image
7)
8
9def benchmark_inference(pipe, prompt, num_steps, num_runs=10):
10 """Benchmark inference latency."""
11 # Warmup
12 for _ in range(3):
13 pipe(prompt, num_inference_steps=num_steps, guidance_scale=1.0)
14
15 # Benchmark
16 torch.cuda.synchronize()
17 start = time.time()
18
19 for _ in range(num_runs):
20 pipe(prompt, num_inference_steps=num_steps, guidance_scale=1.0)
21 torch.cuda.synchronize()
22
23 elapsed = time.time() - start
24 return elapsed / num_runs
25
26# Test different configurations
27prompt = "A beautiful sunset over mountains, photorealistic"
28
29results = []
30
31# Standard SDXL
32pipe_std = StableDiffusionXLPipeline.from_pretrained(
33 "stabilityai/stable-diffusion-xl-base-1.0",
34 torch_dtype=torch.float16
35).to("cuda")
36
37for steps in [50, 25, 10]:
38 latency = benchmark_inference(pipe_std, prompt, steps)
39 results.append(("SDXL Standard", steps, latency))
40
41# LCM-LoRA
42pipe_lcm = StableDiffusionXLPipeline.from_pretrained(
43 "stabilityai/stable-diffusion-xl-base-1.0",
44 torch_dtype=torch.float16
45).to("cuda")
46pipe_lcm.load_lora_weights("latent-consistency/lcm-lora-sdxl")
47pipe_lcm.scheduler = LCMScheduler.from_config(pipe_lcm.scheduler.config)
48
49for steps in [8, 4, 2]:
50 latency = benchmark_inference(pipe_lcm, prompt, steps)
51 results.append(("LCM-LoRA", steps, latency))
52
53# SDXL Turbo
54pipe_turbo = AutoPipelineForText2Image.from_pretrained(
55 "stabilityai/sdxl-turbo",
56 torch_dtype=torch.float16
57).to("cuda")
58
59for steps in [4, 2, 1]:
60 latency = benchmark_inference(pipe_turbo, prompt, steps)
61 results.append(("SDXL Turbo", steps, latency))
62
63# Print results
64print("\nBenchmark Results (A100 GPU):")
65print("-" * 50)
66for model, steps, latency in results:
67 print(f"{model:20s} | {steps:3d} steps | {latency*1000:7.1f} ms")Summary
Model acceleration is crucial for deploying diffusion models in production. We covered several key techniques:
- Knowledge Distillation: Training a student model to mimic a teacher's multi-step behavior in fewer steps
- Progressive Distillation: Iteratively halving steps (1000 → 4) while maintaining reasonable quality
- Consistency Models: Learning direct mappings to clean data, enabling single-step generation
- LCM/LCM-LoRA: Latent space consistency training with efficient LoRA adapters for 4-step generation
- Adversarial Distillation: Combining score distillation with GAN-style training (SDXL Turbo) for 1-4 step generation
Looking Ahead: In the next section, we'll explore quantization and efficiency techniques that further reduce memory usage and computational cost, enabling deployment on consumer hardware.
Key Papers
- Progressive Distillation: Salimans & Ho (2022) - "Progressive Distillation for Fast Sampling of Diffusion Models"
- Consistency Models: Song et al. (2023) - "Consistency Models"
- LCM: Luo et al. (2023) - "Latent Consistency Models"
- SDXL Turbo: Sauer et al. (2023) - "Adversarial Diffusion Distillation"