Introduction
Modern image generation increasingly demands precise control over multiple aspects simultaneously: text describes content, images specify style, poses guide composition, and depth maps constrain spatial structure. Multi-modal conditioning addresses the challenge of combining these diverse signals into coherent generation guidance.
This section explores how different conditioning modalities can be combined, weighted, and scheduled throughout the diffusion process. We'll examine both the theoretical framework for multi-modal fusion and practical implementation strategies that enable fine-grained control over every aspect of generation.
The Multi-Modal Challenge
Combining multiple conditioning signals is not straightforward. Different modalities may conflict, compete for influence, or interact in unexpected ways. Consider these challenges:
| Challenge | Example | Consequence |
|---|---|---|
| Semantic Conflict | Text: "cat" + Image: dog photo | Model must decide which to prioritize |
| Scale Mismatch | Strong ControlNet + weak text | Structure dominates, content ignored |
| Temporal Interference | Style applied too early/late | Style affects structure instead of texture |
| Compositional Ambiguity | Multiple subjects in different poses | Unclear which pose maps to which subject |
| Representation Gap | CLIP text vs. face embeddings | Different embedding spaces, different semantics |
The fundamental question is: how do we combine conditioning signals from different modalities in a way that respects the user's intent while producing coherent outputs?
Design Principle: Multi-modal conditioning should be compositional(combining independent signals), controllable (adjustable influence per modality), and predictable (consistent behavior across generations).
Conditioning Modalities
Before discussing combination strategies, let's understand the characteristics of each major conditioning modality and how they influence the generation process.
Text Conditioning
Text provides semantic guidance that shapes content, style, and abstract concepts. It operates primarily through cross-attention, allowing spatial features to query relevant text tokens:
1class TextConditioner:
2 """
3 Text conditioning via CLIP or T5 embeddings.
4 Primary semantic signal for content and style.
5 """
6 def __init__(self, encoder_type: str = "clip"):
7 if encoder_type == "clip":
8 self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
9 self.encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
10 self.max_length = 77
11 else: # T5
12 self.tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
13 self.encoder = T5EncoderModel.from_pretrained("google/t5-v1_1-xxl")
14 self.max_length = 256
15
16 def encode(self, text: str) -> Dict[str, torch.Tensor]:
17 """Encode text to conditioning embeddings."""
18 tokens = self.tokenizer(
19 text,
20 max_length=self.max_length,
21 padding="max_length",
22 truncation=True,
23 return_tensors="pt"
24 )
25
26 with torch.no_grad():
27 outputs = self.encoder(**tokens)
28
29 return {
30 "embeddings": outputs.last_hidden_state, # [1, seq_len, dim]
31 "pooled": outputs.pooler_output if hasattr(outputs, "pooler_output") else None,
32 "attention_mask": tokens.attention_mask
33 }
34
35 @property
36 def influence_type(self) -> str:
37 return "semantic" # Affects content and style
38
39 @property
40 def injection_method(self) -> str:
41 return "cross_attention" # Through cross-attention layersText conditioning is global in nature: it influences the entire image rather than specific regions (unless using attention control techniques). This makes it ideal for specifying overall content and style but less precise for spatial details.
Image Conditioning
Image conditioning (via IP-Adapter, style references, or image variations) providesvisual reference signals that encode appearance, texture, and style in ways that text cannot capture:
1class ImageConditioner:
2 """
3 Image conditioning via CLIP vision or specialized encoders.
4 Provides style, texture, and appearance guidance.
5 """
6 def __init__(self, encoder_type: str = "clip_vision"):
7 self.encoder_type = encoder_type
8
9 if encoder_type == "clip_vision":
10 self.encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14")
11 self.processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
12 elif encoder_type == "face_id":
13 self.encoder = load_insightface_model()
14 elif encoder_type == "dino":
15 self.encoder = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14")
16
17 def encode(self, image: PIL.Image) -> Dict[str, torch.Tensor]:
18 """Encode image to conditioning embeddings."""
19 if self.encoder_type == "clip_vision":
20 inputs = self.processor(images=image, return_tensors="pt")
21 with torch.no_grad():
22 outputs = self.encoder(**inputs)
23 return {
24 "patch_embeddings": outputs.last_hidden_state[:, 1:], # Spatial
25 "global_embedding": outputs.pooler_output, # Global
26 }
27 elif self.encoder_type == "face_id":
28 face_emb = self.encoder.get_embedding(image)
29 return {"face_embedding": face_emb}
30
31 @property
32 def influence_type(self) -> str:
33 return "appearance" # Affects visual style and texture
34
35 @property
36 def injection_method(self) -> str:
37 return "decoupled_cross_attention" # Parallel to text attentionDifferent image encoders capture different aspects: CLIP captures semantic content, DINO captures visual structure, face encoders capture identity. Choosing the right encoder depends on what aspect of the reference image should transfer.
Spatial Conditioning
Spatial conditioning (ControlNet, T2I-Adapter, depth, pose) provides structural guidance that constrains the geometric layout of the output:
1class SpatialConditioner:
2 """
3 Spatial conditioning via control signals.
4 Provides structural and geometric constraints.
5 """
6 def __init__(self, condition_type: str):
7 self.condition_type = condition_type
8
9 # Preprocessors for different condition types
10 self.preprocessors = {
11 "canny": CannyDetector(),
12 "depth": DepthEstimator(),
13 "pose": OpenPoseDetector(),
14 "normal": NormalEstimator(),
15 "segmentation": SegmentationModel(),
16 }
17
18 # ControlNet models for each type
19 self.control_nets = {
20 ctype: ControlNetModel.from_pretrained(f"controlnet-{ctype}")
21 for ctype in self.preprocessors
22 }
23
24 def encode(self, image: PIL.Image) -> Dict[str, torch.Tensor]:
25 """Extract and encode spatial condition."""
26 # Extract condition map
27 condition_map = self.preprocessors[self.condition_type](image)
28
29 # Normalize to [0, 1]
30 condition_tensor = to_tensor(condition_map).unsqueeze(0)
31
32 return {
33 "condition_image": condition_tensor,
34 "condition_type": self.condition_type,
35 }
36
37 def get_control_output(
38 self,
39 condition: torch.Tensor,
40 sample: torch.Tensor,
41 timestep: torch.Tensor,
42 encoder_hidden_states: torch.Tensor
43 ) -> List[torch.Tensor]:
44 """Get ControlNet residuals for U-Net injection."""
45 controlnet = self.control_nets[self.condition_type]
46
47 down_block_residuals, mid_block_residual = controlnet(
48 sample,
49 timestep,
50 encoder_hidden_states=encoder_hidden_states,
51 controlnet_cond=condition,
52 return_dict=False
53 )
54
55 return down_block_residuals, mid_block_residual
56
57 @property
58 def influence_type(self) -> str:
59 return "structural" # Affects layout and geometry
60
61 @property
62 def injection_method(self) -> str:
63 return "residual_addition" # Added to U-Net featuresSpatial conditions are local: they provide pixel-aligned guidance that directly constrains output structure. This makes them powerful for precise control but requires careful balancing to avoid suppressing other conditioning signals.
Combining Multiple Conditions
With multiple conditioning modalities available, the key question becomes how to combine them effectively. Two primary strategies emerge: attention-based fusion and adapter composition.
Attention-Based Fusion
When multiple modalities inject through cross-attention (text, image, audio), their outputs can be combined through weighted addition or concatenation:
1class MultiModalCrossAttention(nn.Module):
2 """
3 Cross-attention that combines multiple conditioning modalities.
4 """
5 def __init__(self, hidden_dim: int, num_heads: int = 8):
6 super().__init__()
7 self.hidden_dim = hidden_dim
8 self.num_heads = num_heads
9
10 # Shared query projection
11 self.to_q = nn.Linear(hidden_dim, hidden_dim, bias=False)
12
13 # Per-modality KV projections
14 self.modality_kv = nn.ModuleDict({
15 "text": nn.ModuleDict({
16 "to_k": nn.Linear(hidden_dim, hidden_dim, bias=False),
17 "to_v": nn.Linear(hidden_dim, hidden_dim, bias=False),
18 }),
19 "image": nn.ModuleDict({
20 "to_k": nn.Linear(hidden_dim, hidden_dim, bias=False),
21 "to_v": nn.Linear(hidden_dim, hidden_dim, bias=False),
22 }),
23 "audio": nn.ModuleDict({
24 "to_k": nn.Linear(hidden_dim, hidden_dim, bias=False),
25 "to_v": nn.Linear(hidden_dim, hidden_dim, bias=False),
26 }),
27 })
28
29 self.to_out = nn.Linear(hidden_dim, hidden_dim)
30
31 def forward(
32 self,
33 hidden_states: torch.Tensor,
34 modality_embeddings: Dict[str, torch.Tensor],
35 modality_weights: Dict[str, float]
36 ) -> torch.Tensor:
37 """
38 Args:
39 hidden_states: [B, N, D] U-Net features
40 modality_embeddings: {"text": [B, T, D], "image": [B, I, D], ...}
41 modality_weights: {"text": 1.0, "image": 0.6, ...}
42 """
43 B, N, D = hidden_states.shape
44 q = self.to_q(hidden_states)
45
46 # Compute attention for each modality
47 modality_outputs = []
48 for modality, embeddings in modality_embeddings.items():
49 if modality not in self.modality_kv or embeddings is None:
50 continue
51
52 weight = modality_weights.get(modality, 1.0)
53 if weight == 0:
54 continue
55
56 k = self.modality_kv[modality]["to_k"](embeddings)
57 v = self.modality_kv[modality]["to_v"](embeddings)
58
59 # Scaled dot-product attention
60 attn_output = F.scaled_dot_product_attention(
61 q.view(B, N, self.num_heads, -1).transpose(1, 2),
62 k.view(B, -1, self.num_heads, D // self.num_heads).transpose(1, 2),
63 v.view(B, -1, self.num_heads, D // self.num_heads).transpose(1, 2)
64 ).transpose(1, 2).reshape(B, N, D)
65
66 modality_outputs.append(weight * attn_output)
67
68 # Combine modality outputs
69 combined = sum(modality_outputs) / len(modality_outputs)
70
71 return self.to_out(combined)This approach allows each modality to contribute independently, with weights controlling relative influence. The attention mechanism naturally handles semantic alignment between modalities.
Adapter Composition
When using adapters like ControlNet and IP-Adapter simultaneously, their outputs must be combined without interference:
1class MultiAdapterPipeline:
2 """
3 Pipeline that composes multiple conditioning adapters.
4 """
5 def __init__(
6 self,
7 pipe: StableDiffusionPipeline,
8 controlnets: Dict[str, ControlNetModel],
9 ip_adapter: Optional[IPAdapter] = None
10 ):
11 self.pipe = pipe
12 self.controlnets = controlnets
13 self.ip_adapter = ip_adapter
14
15 def __call__(
16 self,
17 prompt: str,
18 control_images: Dict[str, PIL.Image],
19 control_scales: Dict[str, float],
20 ip_image: Optional[PIL.Image] = None,
21 ip_scale: float = 0.6,
22 **kwargs
23 ):
24 """
25 Generate with multiple control signals.
26
27 Args:
28 prompt: Text prompt
29 control_images: {"canny": img, "depth": img, ...}
30 control_scales: {"canny": 0.8, "depth": 0.5, ...}
31 ip_image: Reference image for IP-Adapter
32 ip_scale: IP-Adapter influence weight
33 """
34 # Prepare text embeddings
35 text_embeddings = self.pipe._encode_prompt(prompt, ...)
36
37 # Prepare IP-Adapter embeddings
38 if self.ip_adapter and ip_image:
39 image_embeddings = self.ip_adapter.encode_image(ip_image)
40 self.ip_adapter.set_scale(ip_scale)
41 else:
42 image_embeddings = None
43
44 # Prepare control conditions
45 control_tensors = {}
46 for name, img in control_images.items():
47 control_tensors[name] = self.preprocess_control(img, name)
48
49 # Diffusion loop with multi-adapter injection
50 latents = self.prepare_latents(...)
51
52 for t in self.scheduler.timesteps:
53 # Accumulate ControlNet residuals
54 total_down_residuals = None
55 total_mid_residual = None
56
57 for name, condition in control_tensors.items():
58 scale = control_scales.get(name, 1.0)
59 if scale == 0:
60 continue
61
62 down_res, mid_res = self.controlnets[name](
63 latents, t, text_embeddings, condition
64 )
65
66 # Scale and accumulate
67 if total_down_residuals is None:
68 total_down_residuals = [r * scale for r in down_res]
69 total_mid_residual = mid_res * scale
70 else:
71 for i, r in enumerate(down_res):
72 total_down_residuals[i] += r * scale
73 total_mid_residual += mid_res * scale
74
75 # U-Net prediction with combined control
76 noise_pred = self.pipe.unet(
77 latents, t,
78 encoder_hidden_states=text_embeddings,
79 ip_adapter_image_embeds=image_embeddings,
80 down_block_additional_residuals=total_down_residuals,
81 mid_block_additional_residual=total_mid_residual
82 ).sample
83
84 # Scheduler step
85 latents = self.scheduler.step(noise_pred, t, latents).prev_sample
86
87 return self.decode_latents(latents)Multi-Adapter Composition Strategy:
- ControlNet residuals are additive: multiple ControlNets can sum their structural guidance
- IP-Adapter operates through attention: independent of ControlNet pathway
- Per-adapter scaling: each adapter's influence can be tuned independently
- Accumulation vs. averaging: sum preserves all signals, averaging normalizes intensity
Dynamic Conditioning Weights
Static weights for each modality work for simple cases, but sophisticated control often requires dynamic weighting that varies spatially or temporally.
Per-Layer Weight Control
Different U-Net layers respond to conditioning differently. Early layers establish structure while later layers refine details. Per-layer control exploits this:
1class LayerWiseControlNet:
2 """
3 ControlNet with per-layer weight scheduling.
4 """
5 def __init__(self, controlnet: ControlNetModel, num_down_blocks: int = 12):
6 self.controlnet = controlnet
7 self.num_down_blocks = num_down_blocks
8
9 # Default: uniform weights
10 self.layer_weights = [1.0] * num_down_blocks
11
12 def set_layer_weights(self, weights: List[float]):
13 """Set per-layer control weights."""
14 assert len(weights) == self.num_down_blocks
15 self.layer_weights = weights
16
17 def set_layer_weights_from_schedule(self, schedule: str):
18 """Set weights from named schedule."""
19 if schedule == "structure_only":
20 # Early layers strong, late layers weak
21 self.layer_weights = [1.0, 1.0, 1.0, 0.8, 0.6, 0.4, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0]
22 elif schedule == "style_only":
23 # Late layers strong, early layers weak
24 self.layer_weights = [0.0, 0.0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0]
25 elif schedule == "balanced":
26 # Uniform across layers
27 self.layer_weights = [1.0] * 12
28 elif schedule == "detail_focus":
29 # Emphasize middle layers
30 self.layer_weights = [0.3, 0.5, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.5, 0.3, 0.1]
31
32 def __call__(
33 self,
34 sample: torch.Tensor,
35 timestep: torch.Tensor,
36 encoder_hidden_states: torch.Tensor,
37 controlnet_cond: torch.Tensor
38 ) -> Tuple[List[torch.Tensor], torch.Tensor]:
39 """Get layer-weighted residuals."""
40 down_residuals, mid_residual = self.controlnet(
41 sample, timestep, encoder_hidden_states, controlnet_cond
42 )
43
44 # Apply per-layer weights
45 weighted_residuals = [
46 residual * weight
47 for residual, weight in zip(down_residuals, self.layer_weights)
48 ]
49
50 return weighted_residuals, mid_residualThis enables fine control over how structural guidance propagates through the network. Use "structure_only" when you want the control signal to affect layout without influencing fine details.
Timestep-Based Scheduling
Different timesteps in the diffusion process serve different purposes: early steps (high noise) establish structure, late steps (low noise) refine details. Scheduling conditioning weights across timesteps enables powerful compositional effects:
1class TimestepConditionScheduler:
2 """
3 Schedule conditioning weights across diffusion timesteps.
4 """
5 def __init__(self, num_timesteps: int = 50):
6 self.num_timesteps = num_timesteps
7
8 # Per-modality schedules
9 self.schedules: Dict[str, Callable[[float], float]] = {}
10
11 def add_schedule(self, modality: str, schedule_fn: Callable[[float], float]):
12 """
13 Add a schedule function for a modality.
14
15 Args:
16 modality: Name of the modality (e.g., "controlnet", "ip_adapter")
17 schedule_fn: Function that takes normalized timestep [0, 1] and returns weight
18 """
19 self.schedules[modality] = schedule_fn
20
21 def get_weight(self, modality: str, step: int) -> float:
22 """Get weight for modality at current step."""
23 if modality not in self.schedules:
24 return 1.0
25
26 # Normalize step to [0, 1] where 0 = start, 1 = end
27 t_normalized = step / self.num_timesteps
28 return self.schedules[modality](t_normalized)
29
30 @staticmethod
31 def cosine_decay(start: float = 1.0, end: float = 0.0) -> Callable[[float], float]:
32 """Smooth cosine decay from start to end."""
33 def schedule(t: float) -> float:
34 return end + (start - end) * (1 + math.cos(math.pi * t)) / 2
35 return schedule
36
37 @staticmethod
38 def linear_decay(start: float = 1.0, end: float = 0.0) -> Callable[[float], float]:
39 """Linear interpolation from start to end."""
40 def schedule(t: float) -> float:
41 return start + (end - start) * t
42 return schedule
43
44 @staticmethod
45 def step_function(threshold: float = 0.5, before: float = 1.0, after: float = 0.0) -> Callable[[float], float]:
46 """Step function that switches at threshold."""
47 def schedule(t: float) -> float:
48 return before if t < threshold else after
49 return schedule
50
51
52# Example: Structure early, style late
53scheduler = TimestepConditionScheduler(num_timesteps=50)
54
55# ControlNet strong at start, fades out
56scheduler.add_schedule("controlnet", TimestepConditionScheduler.cosine_decay(1.0, 0.2))
57
58# IP-Adapter weak at start, strengthens at end
59scheduler.add_schedule("ip_adapter", TimestepConditionScheduler.cosine_decay(0.2, 1.0))
60
61# Text stays constant
62scheduler.add_schedule("text", lambda t: 1.0)Intuition: Early diffusion steps make coarse decisions (composition, layout) while late steps make fine decisions (texture, details). Schedule structure-focused conditions to fade out early and style-focused conditions to strengthen late for optimal separation.
Unified Conditioning Frameworks
Rather than treating each modality as a separate add-on, unified frameworks attempt to create a single conditioning interface that handles arbitrary combinations.
Composer Architecture
The Composer approach unifies conditioning by projecting all modalities into a shared representation space:
1class UnifiedConditioner(nn.Module):
2 """
3 Unified conditioning that handles arbitrary modality combinations.
4 Projects all conditions to a shared representation space.
5 """
6 def __init__(
7 self,
8 unified_dim: int = 1024,
9 max_condition_tokens: int = 256
10 ):
11 super().__init__()
12 self.unified_dim = unified_dim
13 self.max_tokens = max_condition_tokens
14
15 # Modality-specific encoders and projectors
16 self.modality_encoders = nn.ModuleDict()
17 self.modality_projectors = nn.ModuleDict()
18
19 # Unified position encoding for combined sequence
20 self.position_encoding = nn.Embedding(max_condition_tokens, unified_dim)
21
22 # Modality type encoding
23 self.modality_encoding = nn.Embedding(16, unified_dim) # Up to 16 modality types
24
25 # Cross-modal transformer for fusion
26 self.fusion_transformer = nn.TransformerEncoder(
27 nn.TransformerEncoderLayer(
28 d_model=unified_dim,
29 nhead=16,
30 dim_feedforward=unified_dim * 4,
31 batch_first=True
32 ),
33 num_layers=4
34 )
35
36 def register_modality(
37 self,
38 name: str,
39 encoder: nn.Module,
40 projector: nn.Module,
41 modality_id: int
42 ):
43 """Register a new conditioning modality."""
44 self.modality_encoders[name] = encoder
45 self.modality_projectors[name] = projector
46 self.modality_ids[name] = modality_id
47
48 def forward(
49 self,
50 conditions: Dict[str, Any],
51 condition_weights: Optional[Dict[str, float]] = None
52 ) -> torch.Tensor:
53 """
54 Process and combine arbitrary conditions.
55
56 Args:
57 conditions: {"text": "prompt", "image": pil_image, "depth": tensor, ...}
58 condition_weights: Optional per-modality weights
59
60 Returns:
61 unified_embeddings: [B, N, unified_dim] combined condition embeddings
62 """
63 if condition_weights is None:
64 condition_weights = {k: 1.0 for k in conditions}
65
66 all_embeddings = []
67 position_offset = 0
68
69 for modality_name, condition in conditions.items():
70 if condition is None:
71 continue
72
73 weight = condition_weights.get(modality_name, 1.0)
74 if weight == 0:
75 continue
76
77 # Encode with modality-specific encoder
78 encoded = self.modality_encoders[modality_name](condition)
79
80 # Project to unified space
81 projected = self.modality_projectors[modality_name](encoded)
82 B, N, D = projected.shape
83
84 # Add position encoding
85 positions = torch.arange(position_offset, position_offset + N, device=projected.device)
86 projected = projected + self.position_encoding(positions)
87
88 # Add modality encoding
89 modality_id = torch.tensor([self.modality_ids[modality_name]], device=projected.device)
90 projected = projected + self.modality_encoding(modality_id)
91
92 # Apply weight
93 projected = projected * weight
94
95 all_embeddings.append(projected)
96 position_offset += N
97
98 # Concatenate all modality embeddings
99 combined = torch.cat(all_embeddings, dim=1) # [B, total_tokens, D]
100
101 # Cross-modal fusion
102 fused = self.fusion_transformer(combined)
103
104 return fusedThis architecture treats all conditions uniformly, allowing the fusion transformer to learn optimal cross-modal relationships. New modalities can be added by simply registering their encoder and projector.
Practical Multi-Modal Workflows
Understanding the theory is one thing; applying it effectively requires practical experience. Here are proven workflows for common multi-modal scenarios:
- Character in Pose with Style
- Text: Describe the character and scene
- IP-Adapter (scale 0.5): Reference image for character appearance
- ControlNet-Pose (scale 0.8): Control the character's pose
- Schedule: Pose strong early, IP-Adapter strengthens mid-way
- Architectural Visualization
- Text: "Modern architecture, daylight, photorealistic"
- ControlNet-Depth (scale 1.0): 3D model depth render
- IP-Adapter (scale 0.4): Reference for material/lighting style
- Schedule: Depth constant, style fades in during late steps
- Product Variation
- Text: Describe desired changes
- IP-Adapter (scale 0.8): Original product image
- img2img with strength 0.5: Preserve product identity
- No ControlNet: Allow flexibility in composition
Implementation
Here's a complete implementation of a multi-modal generation pipeline:
1class MultiModalDiffusionPipeline:
2 """
3 Complete multi-modal generation pipeline.
4 Supports text, image, and spatial conditioning with dynamic scheduling.
5 """
6 def __init__(
7 self,
8 base_model: str = "stabilityai/stable-diffusion-2-1",
9 controlnet_models: Optional[Dict[str, str]] = None,
10 ip_adapter_path: Optional[str] = None
11 ):
12 # Load base pipeline
13 self.pipe = StableDiffusionPipeline.from_pretrained(
14 base_model,
15 torch_dtype=torch.float16
16 ).to("cuda")
17
18 # Load ControlNets
19 self.controlnets = {}
20 if controlnet_models:
21 for name, path in controlnet_models.items():
22 self.controlnets[name] = ControlNetModel.from_pretrained(
23 path, torch_dtype=torch.float16
24 ).to("cuda")
25
26 # Load IP-Adapter
27 self.ip_adapter = None
28 if ip_adapter_path:
29 self.ip_adapter = IPAdapter(self.pipe, ip_adapter_path)
30
31 # Initialize scheduler
32 self.condition_scheduler = TimestepConditionScheduler()
33
34 def configure_schedules(
35 self,
36 controlnet_schedule: str = "constant",
37 ip_adapter_schedule: str = "constant"
38 ):
39 """Configure conditioning schedules."""
40 schedules = {
41 "constant": lambda t: 1.0,
42 "early": TimestepConditionScheduler.cosine_decay(1.0, 0.0),
43 "late": TimestepConditionScheduler.cosine_decay(0.0, 1.0),
44 "mid": lambda t: math.sin(math.pi * t),
45 }
46
47 self.condition_scheduler.add_schedule(
48 "controlnet", schedules[controlnet_schedule]
49 )
50 self.condition_scheduler.add_schedule(
51 "ip_adapter", schedules[ip_adapter_schedule]
52 )
53
54 @torch.no_grad()
55 def generate(
56 self,
57 prompt: str,
58 negative_prompt: str = "",
59 control_images: Optional[Dict[str, PIL.Image]] = None,
60 control_scales: Optional[Dict[str, float]] = None,
61 ip_image: Optional[PIL.Image] = None,
62 ip_scale: float = 0.6,
63 guidance_scale: float = 7.5,
64 num_inference_steps: int = 50,
65 height: int = 512,
66 width: int = 512,
67 use_scheduling: bool = True
68 ) -> PIL.Image:
69 """
70 Generate image with multi-modal conditioning.
71 """
72 # Encode text
73 text_embeddings = self._encode_prompt(prompt, negative_prompt)
74
75 # Encode IP-Adapter image
76 if self.ip_adapter and ip_image:
77 ip_embeddings = self.ip_adapter.encode_image(ip_image)
78 else:
79 ip_embeddings = None
80
81 # Prepare control conditions
82 control_tensors = {}
83 if control_images:
84 for name, img in control_images.items():
85 if name in self.controlnets:
86 control_tensors[name] = self._prepare_control(img, height, width)
87
88 control_scales = control_scales or {}
89
90 # Prepare latents
91 latents = torch.randn(
92 1, 4, height // 8, width // 8,
93 device="cuda", dtype=torch.float16
94 )
95 latents = latents * self.pipe.scheduler.init_noise_sigma
96
97 # Diffusion loop
98 for step, t in enumerate(self.pipe.scheduler.timesteps):
99 # Get scheduled weights
100 if use_scheduling:
101 cn_weight_mult = self.condition_scheduler.get_weight("controlnet", step)
102 ip_weight = ip_scale * self.condition_scheduler.get_weight("ip_adapter", step)
103 else:
104 cn_weight_mult = 1.0
105 ip_weight = ip_scale
106
107 # Set IP-Adapter scale for this step
108 if self.ip_adapter:
109 self.ip_adapter.set_scale(ip_weight)
110
111 # Compute ControlNet residuals
112 down_residuals = None
113 mid_residual = None
114
115 for name, condition in control_tensors.items():
116 base_scale = control_scales.get(name, 1.0)
117 effective_scale = base_scale * cn_weight_mult
118
119 d_res, m_res = self.controlnets[name](
120 latents, t, text_embeddings[1], # Uncond for CFG
121 condition
122 )
123
124 if down_residuals is None:
125 down_residuals = [r * effective_scale for r in d_res]
126 mid_residual = m_res * effective_scale
127 else:
128 for i, r in enumerate(d_res):
129 down_residuals[i] += r * effective_scale
130 mid_residual += m_res * effective_scale
131
132 # Expand latents for CFG
133 latent_model_input = torch.cat([latents] * 2)
134
135 # U-Net prediction
136 noise_pred = self.pipe.unet(
137 latent_model_input,
138 t,
139 encoder_hidden_states=text_embeddings,
140 ip_adapter_image_embeds=ip_embeddings,
141 down_block_additional_residuals=down_residuals,
142 mid_block_additional_residual=mid_residual
143 ).sample
144
145 # CFG
146 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
147 noise_pred = noise_pred_uncond + guidance_scale * (
148 noise_pred_text - noise_pred_uncond
149 )
150
151 # Scheduler step
152 latents = self.pipe.scheduler.step(noise_pred, t, latents).prev_sample
153
154 # Decode
155 image = self.pipe.vae.decode(latents / 0.18215).sample
156 image = (image / 2 + 0.5).clamp(0, 1)
157 image = transforms.ToPILImage()(image[0])
158
159 return imageSummary
Multi-modal conditioning enables unprecedented control over diffusion model generation by combining text, image, and spatial signals. Key takeaways:
- Modality characteristics matter: text provides semantic guidance, images specify appearance, and spatial conditions constrain structure. Understanding these roles enables effective combination.
- Decoupled pathways prevent interference: separate cross-attention for different modalities keeps their influences independent and controllable.
- Timestep scheduling separates concerns: structural conditions early, stylistic conditions late creates clean compositional effects.
- Per-layer control adds precision: different U-Net depths respond differently to conditioning, enabling fine-grained influence control.
- Unified frameworks scale: projecting all modalities to a shared space enables arbitrary combinations and cross-modal learning.
In the next chapter, we'll shift perspectives to examine diffusion models through the lens of score-based generative modeling and stochastic differential equations, unifying the discrete-step formulation with continuous-time dynamics.