Chapter 14
20 min read
Section 66 of 76

Multi-Modal Conditioning

Advanced Conditioning Techniques

Introduction

Modern image generation increasingly demands precise control over multiple aspects simultaneously: text describes content, images specify style, poses guide composition, and depth maps constrain spatial structure. Multi-modal conditioning addresses the challenge of combining these diverse signals into coherent generation guidance.

This section explores how different conditioning modalities can be combined, weighted, and scheduled throughout the diffusion process. We'll examine both the theoretical framework for multi-modal fusion and practical implementation strategies that enable fine-grained control over every aspect of generation.


The Multi-Modal Challenge

Combining multiple conditioning signals is not straightforward. Different modalities may conflict, compete for influence, or interact in unexpected ways. Consider these challenges:

ChallengeExampleConsequence
Semantic ConflictText: "cat" + Image: dog photoModel must decide which to prioritize
Scale MismatchStrong ControlNet + weak textStructure dominates, content ignored
Temporal InterferenceStyle applied too early/lateStyle affects structure instead of texture
Compositional AmbiguityMultiple subjects in different posesUnclear which pose maps to which subject
Representation GapCLIP text vs. face embeddingsDifferent embedding spaces, different semantics

The fundamental question is: how do we combine conditioning signals from different modalities in a way that respects the user's intent while producing coherent outputs?

Design Principle: Multi-modal conditioning should be compositional(combining independent signals), controllable (adjustable influence per modality), and predictable (consistent behavior across generations).

Conditioning Modalities

Before discussing combination strategies, let's understand the characteristics of each major conditioning modality and how they influence the generation process.

Text Conditioning

Text provides semantic guidance that shapes content, style, and abstract concepts. It operates primarily through cross-attention, allowing spatial features to query relevant text tokens:

🐍python
1class TextConditioner:
2    """
3    Text conditioning via CLIP or T5 embeddings.
4    Primary semantic signal for content and style.
5    """
6    def __init__(self, encoder_type: str = "clip"):
7        if encoder_type == "clip":
8            self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
9            self.encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
10            self.max_length = 77
11        else:  # T5
12            self.tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-xxl")
13            self.encoder = T5EncoderModel.from_pretrained("google/t5-v1_1-xxl")
14            self.max_length = 256
15
16    def encode(self, text: str) -> Dict[str, torch.Tensor]:
17        """Encode text to conditioning embeddings."""
18        tokens = self.tokenizer(
19            text,
20            max_length=self.max_length,
21            padding="max_length",
22            truncation=True,
23            return_tensors="pt"
24        )
25
26        with torch.no_grad():
27            outputs = self.encoder(**tokens)
28
29        return {
30            "embeddings": outputs.last_hidden_state,  # [1, seq_len, dim]
31            "pooled": outputs.pooler_output if hasattr(outputs, "pooler_output") else None,
32            "attention_mask": tokens.attention_mask
33        }
34
35    @property
36    def influence_type(self) -> str:
37        return "semantic"  # Affects content and style
38
39    @property
40    def injection_method(self) -> str:
41        return "cross_attention"  # Through cross-attention layers

Text conditioning is global in nature: it influences the entire image rather than specific regions (unless using attention control techniques). This makes it ideal for specifying overall content and style but less precise for spatial details.

Image Conditioning

Image conditioning (via IP-Adapter, style references, or image variations) providesvisual reference signals that encode appearance, texture, and style in ways that text cannot capture:

🐍python
1class ImageConditioner:
2    """
3    Image conditioning via CLIP vision or specialized encoders.
4    Provides style, texture, and appearance guidance.
5    """
6    def __init__(self, encoder_type: str = "clip_vision"):
7        self.encoder_type = encoder_type
8
9        if encoder_type == "clip_vision":
10            self.encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14")
11            self.processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
12        elif encoder_type == "face_id":
13            self.encoder = load_insightface_model()
14        elif encoder_type == "dino":
15            self.encoder = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14")
16
17    def encode(self, image: PIL.Image) -> Dict[str, torch.Tensor]:
18        """Encode image to conditioning embeddings."""
19        if self.encoder_type == "clip_vision":
20            inputs = self.processor(images=image, return_tensors="pt")
21            with torch.no_grad():
22                outputs = self.encoder(**inputs)
23            return {
24                "patch_embeddings": outputs.last_hidden_state[:, 1:],  # Spatial
25                "global_embedding": outputs.pooler_output,  # Global
26            }
27        elif self.encoder_type == "face_id":
28            face_emb = self.encoder.get_embedding(image)
29            return {"face_embedding": face_emb}
30
31    @property
32    def influence_type(self) -> str:
33        return "appearance"  # Affects visual style and texture
34
35    @property
36    def injection_method(self) -> str:
37        return "decoupled_cross_attention"  # Parallel to text attention

Different image encoders capture different aspects: CLIP captures semantic content, DINO captures visual structure, face encoders capture identity. Choosing the right encoder depends on what aspect of the reference image should transfer.

Spatial Conditioning

Spatial conditioning (ControlNet, T2I-Adapter, depth, pose) provides structural guidance that constrains the geometric layout of the output:

🐍python
1class SpatialConditioner:
2    """
3    Spatial conditioning via control signals.
4    Provides structural and geometric constraints.
5    """
6    def __init__(self, condition_type: str):
7        self.condition_type = condition_type
8
9        # Preprocessors for different condition types
10        self.preprocessors = {
11            "canny": CannyDetector(),
12            "depth": DepthEstimator(),
13            "pose": OpenPoseDetector(),
14            "normal": NormalEstimator(),
15            "segmentation": SegmentationModel(),
16        }
17
18        # ControlNet models for each type
19        self.control_nets = {
20            ctype: ControlNetModel.from_pretrained(f"controlnet-{ctype}")
21            for ctype in self.preprocessors
22        }
23
24    def encode(self, image: PIL.Image) -> Dict[str, torch.Tensor]:
25        """Extract and encode spatial condition."""
26        # Extract condition map
27        condition_map = self.preprocessors[self.condition_type](image)
28
29        # Normalize to [0, 1]
30        condition_tensor = to_tensor(condition_map).unsqueeze(0)
31
32        return {
33            "condition_image": condition_tensor,
34            "condition_type": self.condition_type,
35        }
36
37    def get_control_output(
38        self,
39        condition: torch.Tensor,
40        sample: torch.Tensor,
41        timestep: torch.Tensor,
42        encoder_hidden_states: torch.Tensor
43    ) -> List[torch.Tensor]:
44        """Get ControlNet residuals for U-Net injection."""
45        controlnet = self.control_nets[self.condition_type]
46
47        down_block_residuals, mid_block_residual = controlnet(
48            sample,
49            timestep,
50            encoder_hidden_states=encoder_hidden_states,
51            controlnet_cond=condition,
52            return_dict=False
53        )
54
55        return down_block_residuals, mid_block_residual
56
57    @property
58    def influence_type(self) -> str:
59        return "structural"  # Affects layout and geometry
60
61    @property
62    def injection_method(self) -> str:
63        return "residual_addition"  # Added to U-Net features

Spatial conditions are local: they provide pixel-aligned guidance that directly constrains output structure. This makes them powerful for precise control but requires careful balancing to avoid suppressing other conditioning signals.


Combining Multiple Conditions

With multiple conditioning modalities available, the key question becomes how to combine them effectively. Two primary strategies emerge: attention-based fusion and adapter composition.

Attention-Based Fusion

When multiple modalities inject through cross-attention (text, image, audio), their outputs can be combined through weighted addition or concatenation:

🐍python
1class MultiModalCrossAttention(nn.Module):
2    """
3    Cross-attention that combines multiple conditioning modalities.
4    """
5    def __init__(self, hidden_dim: int, num_heads: int = 8):
6        super().__init__()
7        self.hidden_dim = hidden_dim
8        self.num_heads = num_heads
9
10        # Shared query projection
11        self.to_q = nn.Linear(hidden_dim, hidden_dim, bias=False)
12
13        # Per-modality KV projections
14        self.modality_kv = nn.ModuleDict({
15            "text": nn.ModuleDict({
16                "to_k": nn.Linear(hidden_dim, hidden_dim, bias=False),
17                "to_v": nn.Linear(hidden_dim, hidden_dim, bias=False),
18            }),
19            "image": nn.ModuleDict({
20                "to_k": nn.Linear(hidden_dim, hidden_dim, bias=False),
21                "to_v": nn.Linear(hidden_dim, hidden_dim, bias=False),
22            }),
23            "audio": nn.ModuleDict({
24                "to_k": nn.Linear(hidden_dim, hidden_dim, bias=False),
25                "to_v": nn.Linear(hidden_dim, hidden_dim, bias=False),
26            }),
27        })
28
29        self.to_out = nn.Linear(hidden_dim, hidden_dim)
30
31    def forward(
32        self,
33        hidden_states: torch.Tensor,
34        modality_embeddings: Dict[str, torch.Tensor],
35        modality_weights: Dict[str, float]
36    ) -> torch.Tensor:
37        """
38        Args:
39            hidden_states: [B, N, D] U-Net features
40            modality_embeddings: {"text": [B, T, D], "image": [B, I, D], ...}
41            modality_weights: {"text": 1.0, "image": 0.6, ...}
42        """
43        B, N, D = hidden_states.shape
44        q = self.to_q(hidden_states)
45
46        # Compute attention for each modality
47        modality_outputs = []
48        for modality, embeddings in modality_embeddings.items():
49            if modality not in self.modality_kv or embeddings is None:
50                continue
51
52            weight = modality_weights.get(modality, 1.0)
53            if weight == 0:
54                continue
55
56            k = self.modality_kv[modality]["to_k"](embeddings)
57            v = self.modality_kv[modality]["to_v"](embeddings)
58
59            # Scaled dot-product attention
60            attn_output = F.scaled_dot_product_attention(
61                q.view(B, N, self.num_heads, -1).transpose(1, 2),
62                k.view(B, -1, self.num_heads, D // self.num_heads).transpose(1, 2),
63                v.view(B, -1, self.num_heads, D // self.num_heads).transpose(1, 2)
64            ).transpose(1, 2).reshape(B, N, D)
65
66            modality_outputs.append(weight * attn_output)
67
68        # Combine modality outputs
69        combined = sum(modality_outputs) / len(modality_outputs)
70
71        return self.to_out(combined)

This approach allows each modality to contribute independently, with weights controlling relative influence. The attention mechanism naturally handles semantic alignment between modalities.

Adapter Composition

When using adapters like ControlNet and IP-Adapter simultaneously, their outputs must be combined without interference:

🐍python
1class MultiAdapterPipeline:
2    """
3    Pipeline that composes multiple conditioning adapters.
4    """
5    def __init__(
6        self,
7        pipe: StableDiffusionPipeline,
8        controlnets: Dict[str, ControlNetModel],
9        ip_adapter: Optional[IPAdapter] = None
10    ):
11        self.pipe = pipe
12        self.controlnets = controlnets
13        self.ip_adapter = ip_adapter
14
15    def __call__(
16        self,
17        prompt: str,
18        control_images: Dict[str, PIL.Image],
19        control_scales: Dict[str, float],
20        ip_image: Optional[PIL.Image] = None,
21        ip_scale: float = 0.6,
22        **kwargs
23    ):
24        """
25        Generate with multiple control signals.
26
27        Args:
28            prompt: Text prompt
29            control_images: {"canny": img, "depth": img, ...}
30            control_scales: {"canny": 0.8, "depth": 0.5, ...}
31            ip_image: Reference image for IP-Adapter
32            ip_scale: IP-Adapter influence weight
33        """
34        # Prepare text embeddings
35        text_embeddings = self.pipe._encode_prompt(prompt, ...)
36
37        # Prepare IP-Adapter embeddings
38        if self.ip_adapter and ip_image:
39            image_embeddings = self.ip_adapter.encode_image(ip_image)
40            self.ip_adapter.set_scale(ip_scale)
41        else:
42            image_embeddings = None
43
44        # Prepare control conditions
45        control_tensors = {}
46        for name, img in control_images.items():
47            control_tensors[name] = self.preprocess_control(img, name)
48
49        # Diffusion loop with multi-adapter injection
50        latents = self.prepare_latents(...)
51
52        for t in self.scheduler.timesteps:
53            # Accumulate ControlNet residuals
54            total_down_residuals = None
55            total_mid_residual = None
56
57            for name, condition in control_tensors.items():
58                scale = control_scales.get(name, 1.0)
59                if scale == 0:
60                    continue
61
62                down_res, mid_res = self.controlnets[name](
63                    latents, t, text_embeddings, condition
64                )
65
66                # Scale and accumulate
67                if total_down_residuals is None:
68                    total_down_residuals = [r * scale for r in down_res]
69                    total_mid_residual = mid_res * scale
70                else:
71                    for i, r in enumerate(down_res):
72                        total_down_residuals[i] += r * scale
73                    total_mid_residual += mid_res * scale
74
75            # U-Net prediction with combined control
76            noise_pred = self.pipe.unet(
77                latents, t,
78                encoder_hidden_states=text_embeddings,
79                ip_adapter_image_embeds=image_embeddings,
80                down_block_additional_residuals=total_down_residuals,
81                mid_block_additional_residual=total_mid_residual
82            ).sample
83
84            # Scheduler step
85            latents = self.scheduler.step(noise_pred, t, latents).prev_sample
86
87        return self.decode_latents(latents)
Multi-Adapter Composition Strategy:
  • ControlNet residuals are additive: multiple ControlNets can sum their structural guidance
  • IP-Adapter operates through attention: independent of ControlNet pathway
  • Per-adapter scaling: each adapter's influence can be tuned independently
  • Accumulation vs. averaging: sum preserves all signals, averaging normalizes intensity

Dynamic Conditioning Weights

Static weights for each modality work for simple cases, but sophisticated control often requires dynamic weighting that varies spatially or temporally.

Per-Layer Weight Control

Different U-Net layers respond to conditioning differently. Early layers establish structure while later layers refine details. Per-layer control exploits this:

🐍python
1class LayerWiseControlNet:
2    """
3    ControlNet with per-layer weight scheduling.
4    """
5    def __init__(self, controlnet: ControlNetModel, num_down_blocks: int = 12):
6        self.controlnet = controlnet
7        self.num_down_blocks = num_down_blocks
8
9        # Default: uniform weights
10        self.layer_weights = [1.0] * num_down_blocks
11
12    def set_layer_weights(self, weights: List[float]):
13        """Set per-layer control weights."""
14        assert len(weights) == self.num_down_blocks
15        self.layer_weights = weights
16
17    def set_layer_weights_from_schedule(self, schedule: str):
18        """Set weights from named schedule."""
19        if schedule == "structure_only":
20            # Early layers strong, late layers weak
21            self.layer_weights = [1.0, 1.0, 1.0, 0.8, 0.6, 0.4, 0.2, 0.1, 0.0, 0.0, 0.0, 0.0]
22        elif schedule == "style_only":
23            # Late layers strong, early layers weak
24            self.layer_weights = [0.0, 0.0, 0.1, 0.2, 0.4, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0]
25        elif schedule == "balanced":
26            # Uniform across layers
27            self.layer_weights = [1.0] * 12
28        elif schedule == "detail_focus":
29            # Emphasize middle layers
30            self.layer_weights = [0.3, 0.5, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0, 0.8, 0.5, 0.3, 0.1]
31
32    def __call__(
33        self,
34        sample: torch.Tensor,
35        timestep: torch.Tensor,
36        encoder_hidden_states: torch.Tensor,
37        controlnet_cond: torch.Tensor
38    ) -> Tuple[List[torch.Tensor], torch.Tensor]:
39        """Get layer-weighted residuals."""
40        down_residuals, mid_residual = self.controlnet(
41            sample, timestep, encoder_hidden_states, controlnet_cond
42        )
43
44        # Apply per-layer weights
45        weighted_residuals = [
46            residual * weight
47            for residual, weight in zip(down_residuals, self.layer_weights)
48        ]
49
50        return weighted_residuals, mid_residual

This enables fine control over how structural guidance propagates through the network. Use "structure_only" when you want the control signal to affect layout without influencing fine details.

Timestep-Based Scheduling

Different timesteps in the diffusion process serve different purposes: early steps (high noise) establish structure, late steps (low noise) refine details. Scheduling conditioning weights across timesteps enables powerful compositional effects:

🐍python
1class TimestepConditionScheduler:
2    """
3    Schedule conditioning weights across diffusion timesteps.
4    """
5    def __init__(self, num_timesteps: int = 50):
6        self.num_timesteps = num_timesteps
7
8        # Per-modality schedules
9        self.schedules: Dict[str, Callable[[float], float]] = {}
10
11    def add_schedule(self, modality: str, schedule_fn: Callable[[float], float]):
12        """
13        Add a schedule function for a modality.
14
15        Args:
16            modality: Name of the modality (e.g., "controlnet", "ip_adapter")
17            schedule_fn: Function that takes normalized timestep [0, 1] and returns weight
18        """
19        self.schedules[modality] = schedule_fn
20
21    def get_weight(self, modality: str, step: int) -> float:
22        """Get weight for modality at current step."""
23        if modality not in self.schedules:
24            return 1.0
25
26        # Normalize step to [0, 1] where 0 = start, 1 = end
27        t_normalized = step / self.num_timesteps
28        return self.schedules[modality](t_normalized)
29
30    @staticmethod
31    def cosine_decay(start: float = 1.0, end: float = 0.0) -> Callable[[float], float]:
32        """Smooth cosine decay from start to end."""
33        def schedule(t: float) -> float:
34            return end + (start - end) * (1 + math.cos(math.pi * t)) / 2
35        return schedule
36
37    @staticmethod
38    def linear_decay(start: float = 1.0, end: float = 0.0) -> Callable[[float], float]:
39        """Linear interpolation from start to end."""
40        def schedule(t: float) -> float:
41            return start + (end - start) * t
42        return schedule
43
44    @staticmethod
45    def step_function(threshold: float = 0.5, before: float = 1.0, after: float = 0.0) -> Callable[[float], float]:
46        """Step function that switches at threshold."""
47        def schedule(t: float) -> float:
48            return before if t < threshold else after
49        return schedule
50
51
52# Example: Structure early, style late
53scheduler = TimestepConditionScheduler(num_timesteps=50)
54
55# ControlNet strong at start, fades out
56scheduler.add_schedule("controlnet", TimestepConditionScheduler.cosine_decay(1.0, 0.2))
57
58# IP-Adapter weak at start, strengthens at end
59scheduler.add_schedule("ip_adapter", TimestepConditionScheduler.cosine_decay(0.2, 1.0))
60
61# Text stays constant
62scheduler.add_schedule("text", lambda t: 1.0)
Intuition: Early diffusion steps make coarse decisions (composition, layout) while late steps make fine decisions (texture, details). Schedule structure-focused conditions to fade out early and style-focused conditions to strengthen late for optimal separation.

Unified Conditioning Frameworks

Rather than treating each modality as a separate add-on, unified frameworks attempt to create a single conditioning interface that handles arbitrary combinations.

Composer Architecture

The Composer approach unifies conditioning by projecting all modalities into a shared representation space:

🐍python
1class UnifiedConditioner(nn.Module):
2    """
3    Unified conditioning that handles arbitrary modality combinations.
4    Projects all conditions to a shared representation space.
5    """
6    def __init__(
7        self,
8        unified_dim: int = 1024,
9        max_condition_tokens: int = 256
10    ):
11        super().__init__()
12        self.unified_dim = unified_dim
13        self.max_tokens = max_condition_tokens
14
15        # Modality-specific encoders and projectors
16        self.modality_encoders = nn.ModuleDict()
17        self.modality_projectors = nn.ModuleDict()
18
19        # Unified position encoding for combined sequence
20        self.position_encoding = nn.Embedding(max_condition_tokens, unified_dim)
21
22        # Modality type encoding
23        self.modality_encoding = nn.Embedding(16, unified_dim)  # Up to 16 modality types
24
25        # Cross-modal transformer for fusion
26        self.fusion_transformer = nn.TransformerEncoder(
27            nn.TransformerEncoderLayer(
28                d_model=unified_dim,
29                nhead=16,
30                dim_feedforward=unified_dim * 4,
31                batch_first=True
32            ),
33            num_layers=4
34        )
35
36    def register_modality(
37        self,
38        name: str,
39        encoder: nn.Module,
40        projector: nn.Module,
41        modality_id: int
42    ):
43        """Register a new conditioning modality."""
44        self.modality_encoders[name] = encoder
45        self.modality_projectors[name] = projector
46        self.modality_ids[name] = modality_id
47
48    def forward(
49        self,
50        conditions: Dict[str, Any],
51        condition_weights: Optional[Dict[str, float]] = None
52    ) -> torch.Tensor:
53        """
54        Process and combine arbitrary conditions.
55
56        Args:
57            conditions: {"text": "prompt", "image": pil_image, "depth": tensor, ...}
58            condition_weights: Optional per-modality weights
59
60        Returns:
61            unified_embeddings: [B, N, unified_dim] combined condition embeddings
62        """
63        if condition_weights is None:
64            condition_weights = {k: 1.0 for k in conditions}
65
66        all_embeddings = []
67        position_offset = 0
68
69        for modality_name, condition in conditions.items():
70            if condition is None:
71                continue
72
73            weight = condition_weights.get(modality_name, 1.0)
74            if weight == 0:
75                continue
76
77            # Encode with modality-specific encoder
78            encoded = self.modality_encoders[modality_name](condition)
79
80            # Project to unified space
81            projected = self.modality_projectors[modality_name](encoded)
82            B, N, D = projected.shape
83
84            # Add position encoding
85            positions = torch.arange(position_offset, position_offset + N, device=projected.device)
86            projected = projected + self.position_encoding(positions)
87
88            # Add modality encoding
89            modality_id = torch.tensor([self.modality_ids[modality_name]], device=projected.device)
90            projected = projected + self.modality_encoding(modality_id)
91
92            # Apply weight
93            projected = projected * weight
94
95            all_embeddings.append(projected)
96            position_offset += N
97
98        # Concatenate all modality embeddings
99        combined = torch.cat(all_embeddings, dim=1)  # [B, total_tokens, D]
100
101        # Cross-modal fusion
102        fused = self.fusion_transformer(combined)
103
104        return fused

This architecture treats all conditions uniformly, allowing the fusion transformer to learn optimal cross-modal relationships. New modalities can be added by simply registering their encoder and projector.


Practical Multi-Modal Workflows

Understanding the theory is one thing; applying it effectively requires practical experience. Here are proven workflows for common multi-modal scenarios:

  1. Character in Pose with Style
    • Text: Describe the character and scene
    • IP-Adapter (scale 0.5): Reference image for character appearance
    • ControlNet-Pose (scale 0.8): Control the character's pose
    • Schedule: Pose strong early, IP-Adapter strengthens mid-way
  2. Architectural Visualization
    • Text: "Modern architecture, daylight, photorealistic"
    • ControlNet-Depth (scale 1.0): 3D model depth render
    • IP-Adapter (scale 0.4): Reference for material/lighting style
    • Schedule: Depth constant, style fades in during late steps
  3. Product Variation
    • Text: Describe desired changes
    • IP-Adapter (scale 0.8): Original product image
    • img2img with strength 0.5: Preserve product identity
    • No ControlNet: Allow flexibility in composition

Implementation

Here's a complete implementation of a multi-modal generation pipeline:

🐍python
1class MultiModalDiffusionPipeline:
2    """
3    Complete multi-modal generation pipeline.
4    Supports text, image, and spatial conditioning with dynamic scheduling.
5    """
6    def __init__(
7        self,
8        base_model: str = "stabilityai/stable-diffusion-2-1",
9        controlnet_models: Optional[Dict[str, str]] = None,
10        ip_adapter_path: Optional[str] = None
11    ):
12        # Load base pipeline
13        self.pipe = StableDiffusionPipeline.from_pretrained(
14            base_model,
15            torch_dtype=torch.float16
16        ).to("cuda")
17
18        # Load ControlNets
19        self.controlnets = {}
20        if controlnet_models:
21            for name, path in controlnet_models.items():
22                self.controlnets[name] = ControlNetModel.from_pretrained(
23                    path, torch_dtype=torch.float16
24                ).to("cuda")
25
26        # Load IP-Adapter
27        self.ip_adapter = None
28        if ip_adapter_path:
29            self.ip_adapter = IPAdapter(self.pipe, ip_adapter_path)
30
31        # Initialize scheduler
32        self.condition_scheduler = TimestepConditionScheduler()
33
34    def configure_schedules(
35        self,
36        controlnet_schedule: str = "constant",
37        ip_adapter_schedule: str = "constant"
38    ):
39        """Configure conditioning schedules."""
40        schedules = {
41            "constant": lambda t: 1.0,
42            "early": TimestepConditionScheduler.cosine_decay(1.0, 0.0),
43            "late": TimestepConditionScheduler.cosine_decay(0.0, 1.0),
44            "mid": lambda t: math.sin(math.pi * t),
45        }
46
47        self.condition_scheduler.add_schedule(
48            "controlnet", schedules[controlnet_schedule]
49        )
50        self.condition_scheduler.add_schedule(
51            "ip_adapter", schedules[ip_adapter_schedule]
52        )
53
54    @torch.no_grad()
55    def generate(
56        self,
57        prompt: str,
58        negative_prompt: str = "",
59        control_images: Optional[Dict[str, PIL.Image]] = None,
60        control_scales: Optional[Dict[str, float]] = None,
61        ip_image: Optional[PIL.Image] = None,
62        ip_scale: float = 0.6,
63        guidance_scale: float = 7.5,
64        num_inference_steps: int = 50,
65        height: int = 512,
66        width: int = 512,
67        use_scheduling: bool = True
68    ) -> PIL.Image:
69        """
70        Generate image with multi-modal conditioning.
71        """
72        # Encode text
73        text_embeddings = self._encode_prompt(prompt, negative_prompt)
74
75        # Encode IP-Adapter image
76        if self.ip_adapter and ip_image:
77            ip_embeddings = self.ip_adapter.encode_image(ip_image)
78        else:
79            ip_embeddings = None
80
81        # Prepare control conditions
82        control_tensors = {}
83        if control_images:
84            for name, img in control_images.items():
85                if name in self.controlnets:
86                    control_tensors[name] = self._prepare_control(img, height, width)
87
88        control_scales = control_scales or {}
89
90        # Prepare latents
91        latents = torch.randn(
92            1, 4, height // 8, width // 8,
93            device="cuda", dtype=torch.float16
94        )
95        latents = latents * self.pipe.scheduler.init_noise_sigma
96
97        # Diffusion loop
98        for step, t in enumerate(self.pipe.scheduler.timesteps):
99            # Get scheduled weights
100            if use_scheduling:
101                cn_weight_mult = self.condition_scheduler.get_weight("controlnet", step)
102                ip_weight = ip_scale * self.condition_scheduler.get_weight("ip_adapter", step)
103            else:
104                cn_weight_mult = 1.0
105                ip_weight = ip_scale
106
107            # Set IP-Adapter scale for this step
108            if self.ip_adapter:
109                self.ip_adapter.set_scale(ip_weight)
110
111            # Compute ControlNet residuals
112            down_residuals = None
113            mid_residual = None
114
115            for name, condition in control_tensors.items():
116                base_scale = control_scales.get(name, 1.0)
117                effective_scale = base_scale * cn_weight_mult
118
119                d_res, m_res = self.controlnets[name](
120                    latents, t, text_embeddings[1],  # Uncond for CFG
121                    condition
122                )
123
124                if down_residuals is None:
125                    down_residuals = [r * effective_scale for r in d_res]
126                    mid_residual = m_res * effective_scale
127                else:
128                    for i, r in enumerate(d_res):
129                        down_residuals[i] += r * effective_scale
130                    mid_residual += m_res * effective_scale
131
132            # Expand latents for CFG
133            latent_model_input = torch.cat([latents] * 2)
134
135            # U-Net prediction
136            noise_pred = self.pipe.unet(
137                latent_model_input,
138                t,
139                encoder_hidden_states=text_embeddings,
140                ip_adapter_image_embeds=ip_embeddings,
141                down_block_additional_residuals=down_residuals,
142                mid_block_additional_residual=mid_residual
143            ).sample
144
145            # CFG
146            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
147            noise_pred = noise_pred_uncond + guidance_scale * (
148                noise_pred_text - noise_pred_uncond
149            )
150
151            # Scheduler step
152            latents = self.pipe.scheduler.step(noise_pred, t, latents).prev_sample
153
154        # Decode
155        image = self.pipe.vae.decode(latents / 0.18215).sample
156        image = (image / 2 + 0.5).clamp(0, 1)
157        image = transforms.ToPILImage()(image[0])
158
159        return image

Summary

Multi-modal conditioning enables unprecedented control over diffusion model generation by combining text, image, and spatial signals. Key takeaways:

  • Modality characteristics matter: text provides semantic guidance, images specify appearance, and spatial conditions constrain structure. Understanding these roles enables effective combination.
  • Decoupled pathways prevent interference: separate cross-attention for different modalities keeps their influences independent and controllable.
  • Timestep scheduling separates concerns: structural conditions early, stylistic conditions late creates clean compositional effects.
  • Per-layer control adds precision: different U-Net depths respond differently to conditioning, enabling fine-grained influence control.
  • Unified frameworks scale: projecting all modalities to a shared space enables arbitrary combinations and cross-modal learning.

In the next chapter, we'll shift perspectives to examine diffusion models through the lens of score-based generative modeling and stochastic differential equations, unifying the discrete-step formulation with continuous-time dynamics.