Learning Objectives
By the end of this section, you will be able to:
- Design production-ready APIs for diffusion model inference with proper request handling
- Deploy models using TorchServe and Triton with optimized configurations
- Implement dynamic batching to maximize GPU utilization
- Scale horizontally using Kubernetes and GPU orchestration
- Monitor and optimize production diffusion model services
Deployment Challenges
Serving diffusion models in production presents unique challenges compared to traditional ML models:
| Challenge | Description | Impact |
|---|---|---|
| Long inference time | 2-30 seconds per image | Requires async processing |
| High GPU memory | 6-14 GB per model | Limited concurrent requests |
| Variable workloads | Bursty traffic patterns | Autoscaling complexity |
| Large model files | 2-10 GB per model | Slow cold starts |
| User expectations | Real-time feedback | Progress streaming needed |
Key Insight: Unlike classification models that return in milliseconds, diffusion models require fundamentally different serving patterns - async processing, progress updates, and careful resource management.
Serving Frameworks
TorchServe
TorchServe is PyTorch's native serving solution, offering easy integration with PyTorch models and the diffusers library.
1# diffusion_handler.py - TorchServe Handler
2import torch
3from ts.torch_handler.base_handler import BaseHandler
4from diffusers import StableDiffusionXLPipeline, LCMScheduler
5import io
6import base64
7from PIL import Image
8
9class DiffusionHandler(BaseHandler):
10 def __init__(self):
11 super().__init__()
12 self.pipe = None
13
14 def initialize(self, context):
15 """Load model on startup."""
16 self.manifest = context.manifest
17 model_dir = context.system_properties.get("model_dir")
18
19 # Load optimized pipeline
20 self.pipe = StableDiffusionXLPipeline.from_pretrained(
21 model_dir,
22 torch_dtype=torch.float16,
23 use_safetensors=True,
24 )
25 self.pipe.to("cuda")
26
27 # Apply optimizations
28 self.pipe.enable_xformers_memory_efficient_attention()
29
30 # Optional: Use LCM for faster inference
31 # self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
32
33 # Warmup
34 self.pipe("warmup", num_inference_steps=2, output_type="latent")
35 torch.cuda.synchronize()
36
37 def preprocess(self, requests):
38 """Parse incoming requests."""
39 inputs = []
40 for req in requests:
41 data = req.get("data") or req.get("body")
42 if isinstance(data, (bytes, bytearray)):
43 data = data.decode("utf-8")
44 if isinstance(data, str):
45 import json
46 data = json.loads(data)
47 inputs.append(data)
48 return inputs
49
50 def inference(self, inputs):
51 """Run diffusion inference."""
52 results = []
53
54 for input_data in inputs:
55 prompt = input_data.get("prompt", "")
56 negative_prompt = input_data.get("negative_prompt", "")
57 num_steps = input_data.get("num_inference_steps", 20)
58 guidance_scale = input_data.get("guidance_scale", 7.5)
59 seed = input_data.get("seed", None)
60
61 generator = None
62 if seed is not None:
63 generator = torch.Generator("cuda").manual_seed(seed)
64
65 with torch.inference_mode():
66 image = self.pipe(
67 prompt=prompt,
68 negative_prompt=negative_prompt,
69 num_inference_steps=num_steps,
70 guidance_scale=guidance_scale,
71 generator=generator,
72 ).images[0]
73
74 results.append(image)
75
76 return results
77
78 def postprocess(self, outputs):
79 """Convert images to base64."""
80 results = []
81 for image in outputs:
82 buffer = io.BytesIO()
83 image.save(buffer, format="PNG")
84 img_base64 = base64.b64encode(buffer.getvalue()).decode()
85 results.append({"image": img_base64})
86 return results1# Package and deploy with TorchServe
2
3# 1. Create model archive
4torch-model-archiver \
5 --model-name stable-diffusion-xl \
6 --version 1.0 \
7 --handler diffusion_handler.py \
8 --extra-files "model_index.json,scheduler/*,tokenizer/*,unet/*,vae/*,text_encoder/*" \
9 --export-path model_store
10
11# 2. Start TorchServe
12torchserve \
13 --start \
14 --model-store model_store \
15 --models stable-diffusion-xl=stable-diffusion-xl.mar \
16 --ts-config config.properties
17
18# 3. Test endpoint
19curl -X POST http://localhost:8080/predictions/stable-diffusion-xl \
20 -H "Content-Type: application/json" \
21 -d '{"prompt": "A beautiful sunset over mountains"}'NVIDIA Triton Inference Server
Triton offers advanced features like dynamic batching, model ensembles, and multi-framework support. It's ideal for high-throughput production deployments.
1# model.py - Triton Python Backend
2import triton_python_backend_utils as pb_utils
3import torch
4import numpy as np
5from diffusers import StableDiffusionXLPipeline
6import json
7
8class TritonPythonModel:
9 def initialize(self, args):
10 """Load model during server startup."""
11 self.model_config = json.loads(args["model_config"])
12
13 # Load pipeline
14 self.pipe = StableDiffusionXLPipeline.from_pretrained(
15 "/models/sdxl-base",
16 torch_dtype=torch.float16,
17 )
18 self.pipe.to("cuda")
19 self.pipe.enable_xformers_memory_efficient_attention()
20
21 # Warmup
22 self.pipe("warmup", num_inference_steps=1, output_type="latent")
23
24 def execute(self, requests):
25 """Process inference requests."""
26 responses = []
27
28 for request in requests:
29 # Get input tensors
30 prompt = pb_utils.get_input_tensor_by_name(request, "prompt")
31 prompt = prompt.as_numpy()[0].decode("utf-8")
32
33 steps = pb_utils.get_input_tensor_by_name(request, "steps")
34 steps = int(steps.as_numpy()[0]) if steps else 20
35
36 # Generate image
37 with torch.inference_mode():
38 image = self.pipe(
39 prompt=prompt,
40 num_inference_steps=steps,
41 ).images[0]
42
43 # Convert to numpy
44 image_np = np.array(image)
45
46 # Create output tensor
47 output_tensor = pb_utils.Tensor("image", image_np)
48
49 # Create response
50 response = pb_utils.InferenceResponse(output_tensors=[output_tensor])
51 responses.append(response)
52
53 return responses
54
55 def finalize(self):
56 """Cleanup on shutdown."""
57 del self.pipe
58 torch.cuda.empty_cache()1# config.pbtxt - Triton Model Configuration
2name: "stable_diffusion"
3backend: "python"
4
5input [
6 {
7 name: "prompt"
8 data_type: TYPE_STRING
9 dims: [1]
10 },
11 {
12 name: "steps"
13 data_type: TYPE_INT32
14 dims: [1]
15 optional: true
16 }
17]
18
19output [
20 {
21 name: "image"
22 data_type: TYPE_UINT8
23 dims: [1024, 1024, 3]
24 }
25]
26
27instance_group [
28 {
29 count: 1
30 kind: KIND_GPU
31 gpus: [0]
32 }
33]
34
35# Enable dynamic batching
36dynamic_batching {
37 preferred_batch_size: [1, 2, 4]
38 max_queue_delay_microseconds: 100000
39}API Design
A production API for diffusion models needs to handle async processing, progress updates, and result retrieval. Here's a robust design:
1from fastapi import FastAPI, BackgroundTasks, HTTPException
2from pydantic import BaseModel, Field
3from typing import Optional
4import uuid
5import asyncio
6from enum import Enum
7import redis
8import json
9
10app = FastAPI(title="Diffusion API")
11redis_client = redis.Redis(host="localhost", port=6379)
12
13class JobStatus(str, Enum):
14 PENDING = "pending"
15 PROCESSING = "processing"
16 COMPLETED = "completed"
17 FAILED = "failed"
18
19class GenerationRequest(BaseModel):
20 prompt: str = Field(..., min_length=1, max_length=2000)
21 negative_prompt: Optional[str] = ""
22 num_inference_steps: int = Field(default=20, ge=1, le=100)
23 guidance_scale: float = Field(default=7.5, ge=1.0, le=20.0)
24 width: int = Field(default=1024, ge=512, le=2048)
25 height: int = Field(default=1024, ge=512, le=2048)
26 seed: Optional[int] = None
27
28class GenerationResponse(BaseModel):
29 job_id: str
30 status: JobStatus
31 queue_position: Optional[int] = None
32 estimated_time: Optional[float] = None
33
34class JobResult(BaseModel):
35 job_id: str
36 status: JobStatus
37 progress: Optional[float] = None
38 image_url: Optional[str] = None
39 error: Optional[str] = None
40
41# Endpoints
42@app.post("/generate", response_model=GenerationResponse)
43async def create_generation(
44 request: GenerationRequest,
45 background_tasks: BackgroundTasks
46):
47 """Submit a new generation job."""
48 job_id = str(uuid.uuid4())
49
50 # Store job in Redis
51 job_data = {
52 "status": JobStatus.PENDING,
53 "request": request.dict(),
54 "progress": 0,
55 }
56 redis_client.setex(f"job:{job_id}", 3600, json.dumps(job_data))
57
58 # Add to processing queue
59 redis_client.rpush("generation_queue", job_id)
60 queue_position = redis_client.llen("generation_queue")
61
62 # Estimate time (rough: 2s per step for SDXL)
63 estimated_time = request.num_inference_steps * 0.1 * queue_position
64
65 return GenerationResponse(
66 job_id=job_id,
67 status=JobStatus.PENDING,
68 queue_position=queue_position,
69 estimated_time=estimated_time,
70 )
71
72@app.get("/jobs/{job_id}", response_model=JobResult)
73async def get_job_status(job_id: str):
74 """Check job status and get results."""
75 job_data = redis_client.get(f"job:{job_id}")
76
77 if not job_data:
78 raise HTTPException(status_code=404, detail="Job not found")
79
80 job = json.loads(job_data)
81
82 return JobResult(
83 job_id=job_id,
84 status=job["status"],
85 progress=job.get("progress"),
86 image_url=job.get("image_url"),
87 error=job.get("error"),
88 )
89
90@app.delete("/jobs/{job_id}")
91async def cancel_job(job_id: str):
92 """Cancel a pending job."""
93 job_data = redis_client.get(f"job:{job_id}")
94
95 if not job_data:
96 raise HTTPException(status_code=404, detail="Job not found")
97
98 job = json.loads(job_data)
99
100 if job["status"] == JobStatus.PROCESSING:
101 raise HTTPException(status_code=400, detail="Cannot cancel processing job")
102
103 redis_client.delete(f"job:{job_id}")
104 redis_client.lrem("generation_queue", 1, job_id)
105
106 return {"message": "Job cancelled"}
107
108# WebSocket for real-time progress
109from fastapi import WebSocket
110
111@app.websocket("/ws/jobs/{job_id}")
112async def job_progress_websocket(websocket: WebSocket, job_id: str):
113 """Stream job progress updates."""
114 await websocket.accept()
115
116 try:
117 while True:
118 job_data = redis_client.get(f"job:{job_id}")
119 if not job_data:
120 await websocket.close(code=1000, reason="Job not found")
121 break
122
123 job = json.loads(job_data)
124 await websocket.send_json({
125 "status": job["status"],
126 "progress": job.get("progress", 0),
127 })
128
129 if job["status"] in [JobStatus.COMPLETED, JobStatus.FAILED]:
130 break
131
132 await asyncio.sleep(0.5)
133
134 except Exception:
135 passWorker Process
1# worker.py - Background worker for processing jobs
2import torch
3from diffusers import StableDiffusionXLPipeline
4from diffusers.callbacks import PipelineCallback
5import redis
6import json
7import boto3 # For S3 upload
8
9redis_client = redis.Redis(host="localhost", port=6379)
10s3_client = boto3.client("s3")
11
12class ProgressCallback(PipelineCallback):
13 def __init__(self, job_id, redis_client):
14 self.job_id = job_id
15 self.redis = redis_client
16
17 def __call__(self, pipe, step, timestep, callback_kwargs):
18 total_steps = pipe.num_inference_steps
19 progress = (step + 1) / total_steps
20
21 # Update progress in Redis
22 job_data = json.loads(self.redis.get(f"job:{self.job_id}"))
23 job_data["progress"] = progress
24 self.redis.setex(f"job:{self.job_id}", 3600, json.dumps(job_data))
25
26 return callback_kwargs
27
28def process_jobs():
29 """Main worker loop."""
30 # Load model once
31 pipe = StableDiffusionXLPipeline.from_pretrained(
32 "stabilityai/stable-diffusion-xl-base-1.0",
33 torch_dtype=torch.float16,
34 )
35 pipe.to("cuda")
36 pipe.enable_xformers_memory_efficient_attention()
37
38 print("Worker ready, waiting for jobs...")
39
40 while True:
41 # Block until job available
42 _, job_id = redis_client.blpop("generation_queue", timeout=0)
43 job_id = job_id.decode("utf-8")
44
45 try:
46 # Get job data
47 job_data = json.loads(redis_client.get(f"job:{job_id}"))
48 request = job_data["request"]
49
50 # Update status
51 job_data["status"] = "processing"
52 redis_client.setex(f"job:{job_id}", 3600, json.dumps(job_data))
53
54 # Generate image
55 generator = None
56 if request.get("seed"):
57 generator = torch.Generator("cuda").manual_seed(request["seed"])
58
59 callback = ProgressCallback(job_id, redis_client)
60
61 image = pipe(
62 prompt=request["prompt"],
63 negative_prompt=request.get("negative_prompt", ""),
64 num_inference_steps=request["num_inference_steps"],
65 guidance_scale=request["guidance_scale"],
66 width=request["width"],
67 height=request["height"],
68 generator=generator,
69 callback_on_step_end=callback,
70 ).images[0]
71
72 # Upload to S3
73 buffer = io.BytesIO()
74 image.save(buffer, format="PNG")
75 buffer.seek(0)
76
77 image_key = f"generations/{job_id}.png"
78 s3_client.upload_fileobj(
79 buffer,
80 "diffusion-outputs",
81 image_key,
82 ExtraArgs={"ContentType": "image/png"},
83 )
84
85 # Update job with result
86 job_data["status"] = "completed"
87 job_data["progress"] = 1.0
88 job_data["image_url"] = f"https://cdn.example.com/{image_key}"
89 redis_client.setex(f"job:{job_id}", 86400, json.dumps(job_data))
90
91 except Exception as e:
92 # Handle failure
93 job_data["status"] = "failed"
94 job_data["error"] = str(e)
95 redis_client.setex(f"job:{job_id}", 3600, json.dumps(job_data))
96
97if __name__ == "__main__":
98 process_jobs()Batching Strategies
Batching multiple requests together significantly improves GPU utilization and throughput. However, diffusion models present unique batching challenges.
Dynamic Batching
1import asyncio
2from dataclasses import dataclass
3from typing import List, Dict, Any
4import torch
5from collections import deque
6import time
7
8@dataclass
9class BatchRequest:
10 request_id: str
11 prompt: str
12 num_steps: int
13 future: asyncio.Future
14
15class DynamicBatcher:
16 """
17 Collects requests and processes them in batches.
18 Optimizes for GPU utilization while maintaining reasonable latency.
19 """
20
21 def __init__(
22 self,
23 pipe,
24 max_batch_size: int = 4,
25 max_wait_time: float = 0.5, # seconds
26 same_steps_only: bool = True, # Batch only same-step requests
27 ):
28 self.pipe = pipe
29 self.max_batch_size = max_batch_size
30 self.max_wait_time = max_wait_time
31 self.same_steps_only = same_steps_only
32
33 self.request_queue: deque[BatchRequest] = deque()
34 self.processing = False
35 self._lock = asyncio.Lock()
36
37 async def submit(self, prompt: str, num_steps: int = 20) -> Dict[str, Any]:
38 """Submit a request and wait for result."""
39 future = asyncio.Future()
40 request = BatchRequest(
41 request_id=str(uuid.uuid4()),
42 prompt=prompt,
43 num_steps=num_steps,
44 future=future,
45 )
46
47 async with self._lock:
48 self.request_queue.append(request)
49
50 # Start processing if not already running
51 if not self.processing:
52 asyncio.create_task(self._process_batches())
53
54 return await future
55
56 async def _process_batches(self):
57 """Main processing loop."""
58 self.processing = True
59
60 while True:
61 batch = await self._collect_batch()
62
63 if not batch:
64 self.processing = False
65 break
66
67 # Process batch
68 try:
69 results = await self._run_batch(batch)
70
71 # Deliver results
72 for request, result in zip(batch, results):
73 request.future.set_result(result)
74
75 except Exception as e:
76 # Propagate error to all requests
77 for request in batch:
78 request.future.set_exception(e)
79
80 async def _collect_batch(self) -> List[BatchRequest]:
81 """Collect requests into a batch."""
82 batch = []
83 start_time = time.time()
84 target_steps = None
85
86 while len(batch) < self.max_batch_size:
87 elapsed = time.time() - start_time
88
89 if elapsed >= self.max_wait_time and batch:
90 break
91
92 try:
93 # Wait for next request with timeout
94 timeout = max(0.01, self.max_wait_time - elapsed)
95 await asyncio.sleep(0.01)
96
97 if self.request_queue:
98 request = self.request_queue.popleft()
99
100 # Check if compatible with current batch
101 if self.same_steps_only and target_steps is not None:
102 if request.num_steps != target_steps:
103 # Put back and continue
104 self.request_queue.appendleft(request)
105 continue
106
107 batch.append(request)
108 target_steps = request.num_steps
109
110 except asyncio.TimeoutError:
111 break
112
113 return batch
114
115 async def _run_batch(self, batch: List[BatchRequest]) -> List[Dict]:
116 """Run inference on a batch."""
117 prompts = [r.prompt for r in batch]
118 num_steps = batch[0].num_steps
119
120 # Run in thread pool to not block event loop
121 loop = asyncio.get_event_loop()
122 images = await loop.run_in_executor(
123 None,
124 lambda: self.pipe(
125 prompt=prompts,
126 num_inference_steps=num_steps,
127 ).images
128 )
129
130 return [{"image": img} for img in images]
131
132# Usage
133batcher = DynamicBatcher(pipe, max_batch_size=4, max_wait_time=0.5)
134
135# In your API endpoint
136async def generate(prompt: str):
137 return await batcher.submit(prompt, num_steps=20)Batching Considerations
| Factor | Impact | Recommendation |
|---|---|---|
| Different step counts | Cannot batch together efficiently | Normalize to standard step counts (4, 8, 20, 50) |
| Different resolutions | Requires padding or separate batches | Offer fixed resolution tiers |
| Memory constraints | Larger batches need more VRAM | Profile and set max based on GPU |
| Latency requirements | Waiting for batch increases latency | Tune max_wait_time based on SLOs |
Scaling Infrastructure
For production workloads, you need to scale horizontally across multiple GPUs and nodes. Kubernetes is the standard platform for this.
Kubernetes Deployment
1# deployment.yaml
2apiVersion: apps/v1
3kind: Deployment
4metadata:
5 name: diffusion-worker
6 labels:
7 app: diffusion
8spec:
9 replicas: 3
10 selector:
11 matchLabels:
12 app: diffusion
13 template:
14 metadata:
15 labels:
16 app: diffusion
17 spec:
18 containers:
19 - name: worker
20 image: your-registry/diffusion-worker:latest
21 resources:
22 limits:
23 nvidia.com/gpu: 1
24 memory: "32Gi"
25 cpu: "8"
26 requests:
27 nvidia.com/gpu: 1
28 memory: "24Gi"
29 cpu: "4"
30 env:
31 - name: MODEL_PATH
32 value: "/models/sdxl-base"
33 - name: REDIS_URL
34 value: "redis://redis-master:6379"
35 volumeMounts:
36 - name: model-cache
37 mountPath: /models
38 - name: shm
39 mountPath: /dev/shm
40 volumes:
41 - name: model-cache
42 persistentVolumeClaim:
43 claimName: model-pvc
44 - name: shm
45 emptyDir:
46 medium: Memory
47 sizeLimit: "16Gi"
48 nodeSelector:
49 nvidia.com/gpu.product: "NVIDIA-A100-SXM4-80GB"
50 tolerations:
51 - key: "nvidia.com/gpu"
52 operator: "Exists"
53 effect: "NoSchedule"
54---
55# Horizontal Pod Autoscaler
56apiVersion: autoscaling/v2
57kind: HorizontalPodAutoscaler
58metadata:
59 name: diffusion-worker-hpa
60spec:
61 scaleTargetRef:
62 apiVersion: apps/v1
63 kind: Deployment
64 name: diffusion-worker
65 minReplicas: 2
66 maxReplicas: 10
67 metrics:
68 - type: External
69 external:
70 metric:
71 name: redis_queue_length
72 selector:
73 matchLabels:
74 queue: generation_queue
75 target:
76 type: AverageValue
77 averageValue: "5"
78 behavior:
79 scaleUp:
80 stabilizationWindowSeconds: 60
81 policies:
82 - type: Pods
83 value: 2
84 periodSeconds: 60
85 scaleDown:
86 stabilizationWindowSeconds: 300
87 policies:
88 - type: Pods
89 value: 1
90 periodSeconds: 120KEDA for Queue-Based Scaling
1# keda-scaler.yaml
2apiVersion: keda.sh/v1alpha1
3kind: ScaledObject
4metadata:
5 name: diffusion-worker-scaler
6spec:
7 scaleTargetRef:
8 name: diffusion-worker
9 minReplicaCount: 1
10 maxReplicaCount: 20
11 pollingInterval: 15
12 cooldownPeriod: 300
13 triggers:
14 - type: redis
15 metadata:
16 address: redis-master:6379
17 listName: generation_queue
18 listLength: "3" # Scale up when 3+ jobs per worker
19 - type: prometheus
20 metadata:
21 serverAddress: http://prometheus:9090
22 metricName: gpu_utilization
23 threshold: "70"
24 query: |
25 avg(nvidia_gpu_utilization{job="diffusion-worker"})GPU Scheduling Considerations
- GPU Provisioning Time: New GPU nodes take 2-5 minutes to provision in cloud environments. Use predictive scaling.
- Spot/Preemptible Instances: Can reduce costs by 60-70% but require handling interruptions gracefully.
- Multi-GPU Nodes: A100 nodes often have 8 GPUs. Run multiple workers per node.
- Model Caching: Use shared volumes or object storage to avoid downloading models on each scale-up.
Monitoring and Observability
Comprehensive monitoring is essential for maintaining SLOs and optimizing costs.
1# metrics.py - Prometheus metrics for diffusion service
2from prometheus_client import Counter, Histogram, Gauge, start_http_server
3import time
4
5# Request metrics
6REQUESTS_TOTAL = Counter(
7 "diffusion_requests_total",
8 "Total generation requests",
9 ["status", "model"],
10)
11
12GENERATION_DURATION = Histogram(
13 "diffusion_generation_duration_seconds",
14 "Time to generate image",
15 ["model", "steps"],
16 buckets=[1, 2, 5, 10, 20, 30, 60, 120],
17)
18
19QUEUE_SIZE = Gauge(
20 "diffusion_queue_size",
21 "Number of jobs in queue",
22)
23
24# GPU metrics
25GPU_MEMORY_USED = Gauge(
26 "gpu_memory_used_bytes",
27 "GPU memory in use",
28 ["gpu_id"],
29)
30
31GPU_UTILIZATION = Gauge(
32 "gpu_utilization_percent",
33 "GPU compute utilization",
34 ["gpu_id"],
35)
36
37# Batch metrics
38BATCH_SIZE = Histogram(
39 "diffusion_batch_size",
40 "Number of requests per batch",
41 buckets=[1, 2, 4, 8, 16],
42)
43
44def record_generation(model: str, steps: int, duration: float, success: bool):
45 """Record generation metrics."""
46 status = "success" if success else "error"
47 REQUESTS_TOTAL.labels(status=status, model=model).inc()
48
49 if success:
50 GENERATION_DURATION.labels(model=model, steps=str(steps)).observe(duration)
51
52def update_gpu_metrics():
53 """Update GPU metrics (call periodically)."""
54 import pynvml
55 pynvml.nvmlInit()
56
57 device_count = pynvml.nvmlDeviceGetCount()
58 for i in range(device_count):
59 handle = pynvml.nvmlDeviceGetHandleByIndex(i)
60
61 # Memory
62 mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle)
63 GPU_MEMORY_USED.labels(gpu_id=str(i)).set(mem_info.used)
64
65 # Utilization
66 util = pynvml.nvmlDeviceGetUtilizationRates(handle)
67 GPU_UTILIZATION.labels(gpu_id=str(i)).set(util.gpu)
68
69# Start metrics server
70start_http_server(8000)Key Metrics to Monitor
| Metric | Target | Alert Threshold |
|---|---|---|
| p99 Latency | < 30s | > 60s for 5 min |
| GPU Utilization | 70-85% | < 50% or > 95% |
| Queue Depth | < 10 per worker | > 50 total |
| Error Rate | < 1% | > 5% |
| Memory Usage | < 90% | > 95% |
1# Grafana dashboard query examples
2
3# Average generation time by model
4avg(rate(diffusion_generation_duration_seconds_sum[5m]))
5 / rate(diffusion_generation_duration_seconds_count[5m])
6 by (model)
7
8# Queue depth over time
9diffusion_queue_size
10
11# GPU utilization heatmap
12avg(gpu_utilization_percent) by (gpu_id)
13
14# Request rate
15rate(diffusion_requests_total[1m])
16
17# Error rate percentage
18sum(rate(diffusion_requests_total{status="error"}[5m]))
19 / sum(rate(diffusion_requests_total[5m])) * 100Cost Optimization
GPU costs dominate diffusion model serving expenses. Here are strategies to optimize:
Cost Breakdown
| GPU Type | On-Demand/hr | Spot/hr | Images/hr (20 steps) | Cost/1000 Images |
|---|---|---|---|---|
| A10G | $1.00 | $0.35 | ~600 | $1.67 (on-demand) |
| A100 40GB | $3.50 | $1.20 | ~1200 | $2.92 (on-demand) |
| A100 80GB | $4.50 | $1.50 | ~1400 | $3.21 (on-demand) |
| H100 | $8.00 | $2.80 | ~2500 | $3.20 (on-demand) |
Optimization Strategies
- Use Spot Instances: 60-70% cost reduction. Handle interruptions with checkpointing and queue-based processing.
- Right-size GPUs: Match GPU to workload. A10G is often sufficient for SD 1.5/2.x; A100 needed for SDXL.
- Maximize Utilization: Target 70-85% GPU utilization. Use batching and queue management.
- Reduce Steps: Use LCM/Turbo models for 4-step generation (10x throughput increase).
- Time-based Scaling: Scale down during off-peak hours based on historical patterns.
1# Cost-aware scheduler
2class CostAwareScheduler:
3 """Schedule jobs based on cost optimization."""
4
5 def __init__(self):
6 self.spot_workers = [] # Cheaper, may be interrupted
7 self.on_demand_workers = [] # Reliable, more expensive
8
9 def schedule(self, job, priority: str = "normal"):
10 """Route job to appropriate worker type."""
11 if priority == "high" or job.deadline_soon():
12 # Use on-demand for reliability
13 return self.on_demand_workers
14 else:
15 # Prefer spot for cost savings
16 if self.spot_workers:
17 return self.spot_workers
18 return self.on_demand_workers
19
20 def handle_spot_interruption(self, worker_id: str):
21 """Handle spot instance preemption."""
22 # Get current job
23 current_job = self.get_worker_job(worker_id)
24
25 if current_job:
26 # Re-queue the job
27 self.queue.push(current_job, priority="high")
28
29 # Remove worker from pool
30 self.spot_workers.remove(worker_id)Summary
Serving diffusion models in production requires careful architecture and operational practices:
- Async Processing: Use queue-based architectures for long-running generation tasks
- Framework Choice: TorchServe for simplicity, Triton for high-throughput production
- API Design: Implement job submission, status polling, and WebSocket progress updates
- Dynamic Batching: Maximize GPU utilization by batching compatible requests
- Kubernetes Scaling: Use HPA or KEDA for queue-based auto-scaling
- Comprehensive Monitoring: Track latency, GPU utilization, queue depth, and error rates
- Cost Optimization: Use spot instances, right-sized GPUs, and distilled models
Looking Ahead: In the next chapter, we'll explore the future of diffusion models, including video generation, 3D generation, and other cutting-edge applications.