Learning Objectives
By the end of this section, you will be able to:
- Understand numerical precision formats (FP32, FP16, BF16, INT8, INT4) and their trade-offs
- Apply FP16/BF16 inference to halve memory usage with minimal quality loss
- Implement INT8 quantization using post-training quantization (PTQ) and quantization-aware training (QAT)
- Apply model pruning techniques to reduce model size and computation
- Export models to ONNX for deployment on various hardware platforms
The Memory Bottleneck
Diffusion models like Stable Diffusion XL have billions of parameters, creating significant memory challenges. A single SDXL model in full precision requires:
| Component | Parameters | FP32 Size | FP16 Size |
|---|---|---|---|
| U-Net | 2.6B | 10.4 GB | 5.2 GB |
| VAE | 84M | 336 MB | 168 MB |
| Text Encoder (CLIP) | 123M | 492 MB | 246 MB |
| Text Encoder (OpenCLIP) | 694M | 2.8 GB | 1.4 GB |
| Total | ~3.5B | ~14 GB | ~7 GB |
The Challenge: Running SDXL at full precision requires 14+ GB of VRAM just for model weights, before accounting for activations, gradients, or batch processing. This exceeds the capacity of most consumer GPUs.
Memory usage during inference includes:
- Model weights: The parameters stored in GPU memory
- Activations: Intermediate tensors during forward pass (scales with batch size and resolution)
- Optimizer states: Not needed for inference, but 2-3x weight size for training
- KV cache: For attention computation (grows with sequence length)
Precision Formats
Modern deep learning uses various numerical formats, each with different trade-offs between precision, memory, and computational efficiency:
| Format | Bits | Range | Precision | Use Case |
|---|---|---|---|---|
| FP32 | 32 | ~10^38 | ~7 decimal digits | Training (default) |
| TF32 | 19 | ~10^38 | ~3 decimal digits | Ampere+ training |
| FP16 | 16 | ~65504 | ~3 decimal digits | Inference, mixed training |
| BF16 | 16 | ~10^38 | ~2 decimal digits | Training, inference |
| INT8 | 8 | -128 to 127 | Integer only | Quantized inference |
| INT4 | 4 | -8 to 7 | Integer only | Extreme compression |
Understanding the Formats
1import torch
2import numpy as np
3
4# FP32: Standard single precision
5fp32_val = torch.tensor(3.14159265358979, dtype=torch.float32)
6print(f"FP32: {fp32_val.item():.15f}") # 3.141592741012573
7
8# FP16: Half precision - smaller range, same relative precision
9fp16_val = torch.tensor(3.14159265358979, dtype=torch.float16)
10print(f"FP16: {fp16_val.item():.15f}") # 3.140625000000000
11
12# BF16: Brain Float - FP32 range, reduced precision
13bf16_val = torch.tensor(3.14159265358979, dtype=torch.bfloat16)
14print(f"BF16: {bf16_val.item():.15f}") # 3.140625000000000
15
16# Check memory reduction
17x_fp32 = torch.randn(1000, 1000, dtype=torch.float32)
18x_fp16 = x_fp32.half()
19x_bf16 = x_fp32.bfloat16()
20
21print(f"\nMemory usage:")
22print(f"FP32: {x_fp32.element_size() * x_fp32.numel() / 1e6:.1f} MB")
23print(f"FP16: {x_fp16.element_size() * x_fp16.numel() / 1e6:.1f} MB")
24print(f"BF16: {x_bf16.element_size() * x_bf16.numel() / 1e6:.1f} MB")FP16 vs BF16
FP16 and BF16 Inference
The simplest efficiency improvement is switching to half-precision inference. This cuts memory usage in half with typically negligible quality loss.
1from diffusers import StableDiffusionXLPipeline
2import torch
3
4# Load in FP16 (recommended for most cases)
5pipe_fp16 = StableDiffusionXLPipeline.from_pretrained(
6 "stabilityai/stable-diffusion-xl-base-1.0",
7 torch_dtype=torch.float16,
8 variant="fp16", # Load pre-converted FP16 weights
9 use_safetensors=True,
10)
11pipe_fp16.to("cuda")
12
13# For BF16 (better on newer GPUs like A100, H100)
14pipe_bf16 = StableDiffusionXLPipeline.from_pretrained(
15 "stabilityai/stable-diffusion-xl-base-1.0",
16 torch_dtype=torch.bfloat16,
17)
18pipe_bf16.to("cuda")
19
20# Memory comparison
21def get_gpu_memory():
22 return torch.cuda.memory_allocated() / 1e9
23
24print(f"FP16 memory: {get_gpu_memory():.2f} GB")
25
26# Enable attention optimizations (works with FP16/BF16)
27pipe_fp16.enable_xformers_memory_efficient_attention()
28# Or use PyTorch 2.0 native:
29# pipe_fp16.enable_attention_slicing()
30
31# Generate image
32image = pipe_fp16(
33 "A majestic eagle soaring over mountains",
34 num_inference_steps=30,
35).images[0]Automatic Mixed Precision (AMP)
For custom models, use PyTorch's AMP to automatically manage precision:
1import torch
2from torch.cuda.amp import autocast
3
4class DiffusionInference:
5 def __init__(self, model, scheduler):
6 self.model = model.half().cuda() # Convert model to FP16
7 self.scheduler = scheduler
8
9 @torch.no_grad()
10 def sample(self, noise, timesteps, condition=None):
11 x = noise
12
13 for t in timesteps:
14 # Autocast handles precision automatically
15 with autocast(dtype=torch.float16):
16 # Model forward pass in FP16
17 noise_pred = self.model(x, t, condition)
18
19 # Scheduler step (may need FP32 for numerical stability)
20 with autocast(enabled=False):
21 x = self.scheduler.step(
22 noise_pred.float(), t, x.float()
23 ).prev_sample.half()
24
25 return x
26
27# For critical operations, disable autocast
28def compute_loss(pred, target):
29 with autocast(enabled=False):
30 # Loss computation in FP32 for numerical stability
31 return torch.nn.functional.mse_loss(pred.float(), target.float())INT8 Quantization
INT8 quantization represents weights and activations using 8-bit integers instead of floating point, reducing memory by 4x compared to FP32 (2x compared to FP16).
Quantization Basics
The quantization formula maps floating-point values to integers:
Where is the scale factor and is the zero-point. Dequantization reverses this: .
1import torch
2
3def quantize_tensor(x, num_bits=8):
4 """Symmetric quantization to INT8."""
5 # Compute scale based on max absolute value
6 max_val = x.abs().max()
7 scale = max_val / (2**(num_bits-1) - 1)
8
9 # Quantize
10 q = torch.round(x / scale).clamp(
11 -(2**(num_bits-1)), 2**(num_bits-1) - 1
12 ).to(torch.int8)
13
14 return q, scale
15
16def dequantize_tensor(q, scale):
17 """Dequantize INT8 back to float."""
18 return q.float() * scale
19
20# Example
21x = torch.randn(1000, 1000)
22q, scale = quantize_tensor(x)
23x_reconstructed = dequantize_tensor(q, scale)
24
25# Measure error
26mse = torch.nn.functional.mse_loss(x, x_reconstructed)
27print(f"Quantization MSE: {mse.item():.6f}")
28print(f"Memory reduction: {x.element_size() / q.element_size()}x")Post-Training Quantization (PTQ)
PTQ quantizes a pre-trained model without retraining, using calibration data to determine optimal scale factors:
1from optimum.quanto import quantize, freeze, qint8
2from diffusers import StableDiffusionPipeline
3import torch
4
5# Load the model
6pipe = StableDiffusionPipeline.from_pretrained(
7 "runwayml/stable-diffusion-v1-5",
8 torch_dtype=torch.float16,
9)
10pipe.to("cuda")
11
12# Quantize the U-Net to INT8
13quantize(pipe.unet, weights=qint8, activations=qint8)
14
15# Calibrate with sample inputs (important for activation quantization)
16calibration_prompts = [
17 "A photo of a cat",
18 "A beautiful landscape",
19 "A portrait of a person",
20 "Abstract art with vibrant colors",
21]
22
23print("Calibrating quantized model...")
24for prompt in calibration_prompts:
25 with torch.no_grad():
26 _ = pipe(prompt, num_inference_steps=10)
27
28# Freeze the quantization parameters
29freeze(pipe.unet)
30
31# Now inference uses INT8
32image = pipe(
33 "A cyberpunk city at night",
34 num_inference_steps=30,
35).images[0]
36
37# Check memory usage
38print(f"Quantized model memory: {torch.cuda.memory_allocated()/1e9:.2f} GB")Using bitsandbytes for 8-bit
1from diffusers import StableDiffusionXLPipeline, BitsAndBytesConfig
2import torch
3
4# Configure 8-bit quantization
5quantization_config = BitsAndBytesConfig(
6 load_in_8bit=True,
7 llm_int8_enable_fp32_cpu_offload=False,
8)
9
10# Load with 8-bit quantization
11pipe = StableDiffusionXLPipeline.from_pretrained(
12 "stabilityai/stable-diffusion-xl-base-1.0",
13 quantization_config=quantization_config,
14 device_map="auto", # Automatic device placement
15)
16
17# Memory footprint is significantly reduced
18print(f"8-bit memory usage: {torch.cuda.memory_allocated()/1e9:.2f} GB")
19
20# Generate normally
21image = pipe(
22 "A serene Japanese garden in autumn",
23 num_inference_steps=30,
24).images[0]INT4 and Extreme Quantization
INT4 quantization pushes compression further, using only 4 bits per weight. This requires careful handling to maintain quality.
1from diffusers import StableDiffusionXLPipeline, BitsAndBytesConfig
2import torch
3
4# 4-bit quantization with NF4 (Neural Float 4)
5quantization_config = BitsAndBytesConfig(
6 load_in_4bit=True,
7 bnb_4bit_quant_type="nf4", # NF4 is optimized for neural network weights
8 bnb_4bit_use_double_quant=True, # Double quantization for better compression
9 bnb_4bit_compute_dtype=torch.float16, # Compute in FP16
10)
11
12pipe = StableDiffusionXLPipeline.from_pretrained(
13 "stabilityai/stable-diffusion-xl-base-1.0",
14 quantization_config=quantization_config,
15 device_map="auto",
16)
17
18print(f"4-bit memory usage: {torch.cuda.memory_allocated()/1e9:.2f} GB")
19
20# Quality may be slightly reduced
21image = pipe(
22 "A detailed oil painting of a sailing ship",
23 num_inference_steps=30,
24).images[0]Comparison of Quantization Levels
| Precision | Memory | Speed | Quality | Best For |
|---|---|---|---|---|
| FP32 | 14 GB | 1.0x | 100% | Training, debugging |
| FP16/BF16 | 7 GB | 1.2-1.5x | ~99.9% | Production default |
| INT8 | 3.5 GB | 1.5-2x | ~99% | Memory-constrained |
| INT4 | 1.8 GB | 2-3x | ~95-98% | Consumer GPUs |
Model Pruning
Pruning removes unnecessary weights or entire network components to reduce model size and computation. There are several approaches:
Unstructured Pruning
Removes individual weights below a threshold, creating sparse matrices:
1import torch
2import torch.nn.utils.prune as prune
3
4def prune_model(model, amount=0.3):
5 """Apply unstructured L1 pruning to all linear layers."""
6 for name, module in model.named_modules():
7 if isinstance(module, torch.nn.Linear):
8 prune.l1_unstructured(module, name='weight', amount=amount)
9 prune.remove(module, 'weight') # Make pruning permanent
10
11 return model
12
13def count_parameters(model):
14 """Count non-zero parameters."""
15 total = sum(p.numel() for p in model.parameters())
16 nonzero = sum((p != 0).sum().item() for p in model.parameters())
17 return total, nonzero, nonzero/total
18
19# Example: prune a U-Net
20unet = ... # Your U-Net model
21total, nonzero, ratio = count_parameters(unet)
22print(f"Before pruning: {total:,} params, {ratio:.1%} non-zero")
23
24pruned_unet = prune_model(unet, amount=0.3)
25total, nonzero, ratio = count_parameters(pruned_unet)
26print(f"After pruning: {total:,} params, {ratio:.1%} non-zero")Structured Pruning
Removes entire channels or attention heads, enabling actual speedups without special sparse hardware:
1import torch.nn.utils.prune as prune
2
3def structured_prune_conv(module, amount=0.2):
4 """Prune entire output channels based on L2 norm."""
5 prune.ln_structured(
6 module,
7 name='weight',
8 amount=amount,
9 n=2, # L2 norm
10 dim=0, # Prune output channels
11 )
12 return module
13
14def prune_attention_heads(attention_layer, heads_to_prune):
15 """Remove specific attention heads."""
16 # This requires architecture-aware modifications
17 num_heads = attention_layer.num_heads
18 head_dim = attention_layer.head_dim
19
20 # Create mask for heads to keep
21 mask = torch.ones(num_heads)
22 mask[heads_to_prune] = 0
23
24 # Apply to Q, K, V projections
25 with torch.no_grad():
26 for proj in [attention_layer.q_proj, attention_layer.k_proj, attention_layer.v_proj]:
27 weight = proj.weight.view(num_heads, head_dim, -1)
28 weight = weight * mask.view(-1, 1, 1)
29 proj.weight.data = weight.view(-1, weight.shape[-1])
30
31 return attention_layerKnowledge Distillation for Pruned Models
After pruning, fine-tune with knowledge distillation to recover quality:
1class PrunedModelDistillation:
2 """Recover quality of pruned model through distillation."""
3
4 def __init__(self, teacher_model, pruned_student):
5 self.teacher = teacher_model.eval()
6 self.student = pruned_student
7
8 def distillation_step(self, x, t, condition):
9 # Teacher output
10 with torch.no_grad():
11 teacher_output = self.teacher(x, t, condition)
12
13 # Student output
14 student_output = self.student(x, t, condition)
15
16 # Distillation loss
17 loss = torch.nn.functional.mse_loss(student_output, teacher_output)
18
19 # Optional: add original denoising loss
20 noise = ... # Original noise
21 noise_loss = torch.nn.functional.mse_loss(student_output, noise)
22
23 return 0.5 * loss + 0.5 * noise_lossONNX Export and Optimization
ONNX (Open Neural Network Exchange) enables deploying models on various hardware platforms and inference engines like TensorRT, OpenVINO, and Core ML.
1from optimum.exporters.onnx import main_export
2from optimum.onnxruntime import ORTStableDiffusionPipeline
3import torch
4
5# Export Stable Diffusion to ONNX
6main_export(
7 model_name_or_path="runwayml/stable-diffusion-v1-5",
8 output="sd-onnx/",
9 task="stable-diffusion",
10 opset=17, # ONNX opset version
11 fp16=True, # Export in FP16
12)
13
14# Load and run with ONNX Runtime
15pipe = ORTStableDiffusionPipeline.from_pretrained(
16 "sd-onnx/",
17 provider="CUDAExecutionProvider", # Or "TensorrtExecutionProvider"
18)
19
20image = pipe(
21 "A beautiful sunset over the ocean",
22 num_inference_steps=30,
23).images[0]TensorRT Optimization
For NVIDIA GPUs, TensorRT provides the best inference performance:
1# Install: pip install tensorrt torch-tensorrt
2
3import torch
4import torch_tensorrt
5
6# Convert PyTorch model to TensorRT
7def optimize_for_tensorrt(model, sample_input):
8 """Compile model with TensorRT."""
9 # Enable TensorRT optimizations
10 optimized_model = torch_tensorrt.compile(
11 model,
12 inputs=[sample_input],
13 enabled_precisions={torch.half}, # FP16 inference
14 workspace_size=1 << 30, # 1GB workspace
15 min_block_size=3, # Minimum ops to include
16 truncate_long_and_double=True,
17 )
18 return optimized_model
19
20# For diffusers, use built-in TensorRT support
21from diffusers import StableDiffusionPipeline
22import torch
23
24pipe = StableDiffusionPipeline.from_pretrained(
25 "runwayml/stable-diffusion-v1-5",
26 torch_dtype=torch.float16,
27)
28pipe.to("cuda")
29
30# Enable TensorRT (requires additional setup)
31# pipe.enable_tensorrt()
32
33# Or use torch.compile with inductor backend
34pipe.unet = torch.compile(
35 pipe.unet,
36 mode="reduce-overhead", # Optimize for latency
37 backend="inductor",
38)Platform-Specific Exports
| Platform | Format | Optimization | Use Case |
|---|---|---|---|
| NVIDIA GPU | TensorRT | Kernel fusion, FP16 | Servers, cloud |
| Intel CPU/GPU | OpenVINO | INT8, threading | Edge, on-prem |
| Apple Silicon | Core ML | Neural Engine | macOS, iOS |
| Mobile | ONNX + QNN | INT8, NPU | Android, iOS |
| Web | ONNX.js/WebGPU | WebGL shaders | Browser apps |
Practical Workflows
Complete Optimization Pipeline
1import torch
2from diffusers import StableDiffusionXLPipeline, BitsAndBytesConfig
3from optimum.quanto import quantize, freeze, qint8
4
5class OptimizedDiffusionPipeline:
6 """Production-ready optimized diffusion pipeline."""
7
8 def __init__(
9 self,
10 model_id: str,
11 precision: str = "fp16", # fp32, fp16, bf16, int8, int4
12 use_xformers: bool = True,
13 compile_mode: str = None, # None, "reduce-overhead", "max-autotune"
14 ):
15 self.precision = precision
16
17 # Load with appropriate precision
18 if precision == "int4":
19 config = BitsAndBytesConfig(
20 load_in_4bit=True,
21 bnb_4bit_quant_type="nf4",
22 bnb_4bit_compute_dtype=torch.float16,
23 )
24 self.pipe = StableDiffusionXLPipeline.from_pretrained(
25 model_id, quantization_config=config, device_map="auto"
26 )
27 elif precision == "int8":
28 config = BitsAndBytesConfig(load_in_8bit=True)
29 self.pipe = StableDiffusionXLPipeline.from_pretrained(
30 model_id, quantization_config=config, device_map="auto"
31 )
32 else:
33 dtype = {
34 "fp32": torch.float32,
35 "fp16": torch.float16,
36 "bf16": torch.bfloat16,
37 }[precision]
38 self.pipe = StableDiffusionXLPipeline.from_pretrained(
39 model_id, torch_dtype=dtype, variant="fp16"
40 )
41 self.pipe.to("cuda")
42
43 # Enable memory-efficient attention
44 if use_xformers:
45 try:
46 self.pipe.enable_xformers_memory_efficient_attention()
47 except Exception:
48 self.pipe.enable_attention_slicing()
49
50 # Compile for extra speed (PyTorch 2.0+)
51 if compile_mode:
52 self.pipe.unet = torch.compile(
53 self.pipe.unet, mode=compile_mode
54 )
55
56 # Warmup
57 self._warmup()
58
59 def _warmup(self):
60 """Warmup the pipeline to trigger compilation."""
61 _ = self.pipe(
62 "warmup", num_inference_steps=2, output_type="latent"
63 )
64 torch.cuda.synchronize()
65
66 def generate(self, prompt: str, **kwargs):
67 """Generate image with all optimizations."""
68 return self.pipe(prompt, **kwargs).images[0]
69
70 def benchmark(self, prompt: str, num_runs: int = 10):
71 """Benchmark inference latency."""
72 import time
73
74 # Warmup
75 for _ in range(3):
76 self.generate(prompt, num_inference_steps=4)
77
78 # Benchmark
79 torch.cuda.synchronize()
80 start = time.time()
81
82 for _ in range(num_runs):
83 self.generate(prompt, num_inference_steps=20)
84 torch.cuda.synchronize()
85
86 elapsed = (time.time() - start) / num_runs
87 return {
88 "latency_ms": elapsed * 1000,
89 "memory_gb": torch.cuda.max_memory_allocated() / 1e9,
90 }
91
92# Usage
93pipeline = OptimizedDiffusionPipeline(
94 "stabilityai/stable-diffusion-xl-base-1.0",
95 precision="fp16",
96 use_xformers=True,
97 compile_mode="reduce-overhead",
98)
99
100# Benchmark
101results = pipeline.benchmark("A cat sitting on a couch")
102print(f"Latency: {results['latency_ms']:.1f}ms")
103print(f"Memory: {results['memory_gb']:.2f}GB")Memory Optimization Tips
1from diffusers import StableDiffusionXLPipeline
2import torch
3
4pipe = StableDiffusionXLPipeline.from_pretrained(
5 "stabilityai/stable-diffusion-xl-base-1.0",
6 torch_dtype=torch.float16,
7 variant="fp16",
8)
9
10# 1. Enable sequential CPU offloading (saves most memory)
11pipe.enable_sequential_cpu_offload()
12
13# 2. Or model CPU offloading (less memory savings, faster)
14# pipe.enable_model_cpu_offload()
15
16# 3. Enable VAE slicing for high-res images
17pipe.enable_vae_slicing()
18
19# 4. Enable VAE tiling for very high-res
20pipe.enable_vae_tiling()
21
22# 5. Use attention slicing
23pipe.enable_attention_slicing(slice_size="auto")
24
25# Now can run on 4GB VRAM (slowly due to offloading)
26image = pipe(
27 "A beautiful landscape",
28 num_inference_steps=30,
29).images[0]Summary
Quantization and efficiency techniques are essential for deploying diffusion models in production:
- FP16/BF16: Simple 2x memory reduction with minimal quality loss - use as baseline for all deployments
- INT8 Quantization: 4x compression from FP32, requires calibration but maintains good quality
- INT4 Quantization: Extreme compression for memory-constrained devices, some quality trade-off
- Model Pruning: Remove unnecessary weights/heads, requires fine-tuning to recover quality
- ONNX/TensorRT: Platform-specific optimization for production deployment
Production Recommendation: Start with FP16 + xformers + torch.compile for the best latency-memory trade-off. Move to INT8 or INT4 only if memory constraints require it.
Looking Ahead: In the next section, we'll explore how to serve these optimized models at scale using TorchServe, Triton, and cloud infrastructure.