Chapter 17
25 min read
Section 76 of 76

Open Problems and Research Directions

The Future of Diffusion Models

Introduction

Despite remarkable progress, diffusion models remain an active area of research with numerous open problems and exciting research directions. From fundamental questions about sampling efficiency and theoretical foundations to practical concerns around safety, controllability, and societal impact, there is much work to be done. This section surveys the most pressing challenges and promising research directions in the field.

Understanding these open problems is crucial for researchers and practitioners alike. For researchers, they represent opportunities for impactful contributions. For practitioners, awareness of current limitations helps set realistic expectations and informs design decisions.

Research AreaKey ChallengeCurrent StateImpact
Sampling SpeedReduce steps without quality loss4-step viable, 1-step emergingReal-time applications
ControllabilityFine-grained, composable controlControlNet, IP-Adapter work wellCreative workflows
TheoryUnderstand why diffusion worksLimited understandingBetter algorithms
SafetyPrevent misuse, detect fakesActive research areaTrust and authenticity
ArchitecturesBeyond U-NetDiT showing promiseScaling efficiency

Faster Sampling Methods

The iterative nature of diffusion sampling remains a fundamental limitation. While significant progress has been made (from 1000 steps to 4-8 steps for comparable quality), the goal of single-step generation with high quality remains elusive.

Towards Single-Step Generation

Current approaches to single-step generation include:

  • Consistency Models: Learn to map any point on the trajectory directly to data, enabling single-step generation
  • Adversarial Distillation: Add discriminator loss to guide the model toward realistic single-step outputs
  • Rectified Flow: Learn straighter trajectories that require fewer discretization steps

The fundamental trade-off is between mode coverage (generating diverse outputs) and single-step quality. Iterative refinement naturally improves sample quality, so reducing steps requires finding alternative mechanisms for this refinement.

🐍python
1import torch
2import torch.nn as nn
3
4class RectifiedFlowModel(nn.Module):
5    """Rectified Flow for straighter generation trajectories."""
6
7    def __init__(self, backbone: nn.Module):
8        super().__init__()
9        self.backbone = backbone
10
11    def get_velocity(
12        self,
13        x_t: torch.Tensor,
14        t: torch.Tensor,
15        condition: torch.Tensor | None = None,
16    ) -> torch.Tensor:
17        """Predict velocity field v(x_t, t)."""
18        return self.backbone(x_t, t, condition)
19
20    def forward_flow(
21        self,
22        x_0: torch.Tensor,
23        x_1: torch.Tensor,
24        t: torch.Tensor,
25    ) -> tuple[torch.Tensor, torch.Tensor]:
26        """Compute interpolation and target velocity."""
27        # Linear interpolation (rectified flow uses straight lines)
28        x_t = t * x_1 + (1 - t) * x_0
29
30        # Target velocity is the direction from x_0 to x_1
31        target_velocity = x_1 - x_0
32
33        return x_t, target_velocity
34
35    @torch.no_grad()
36    def sample(
37        self,
38        x_0: torch.Tensor,  # Starting point (noise)
39        num_steps: int = 1,  # Can be 1 for rectified flows!
40        condition: torch.Tensor | None = None,
41    ) -> torch.Tensor:
42        """Generate samples using ODE integration."""
43        x = x_0
44        dt = 1.0 / num_steps
45
46        for i in range(num_steps):
47            t = torch.full((x.shape[0],), i / num_steps, device=x.device)
48            v = self.get_velocity(x, t, condition)
49            x = x + v * dt
50
51        return x
52
53
54def train_rectified_flow(
55    model: RectifiedFlowModel,
56    dataloader: torch.utils.data.DataLoader,
57    num_epochs: int = 100,
58    lr: float = 1e-4,
59):
60    """Train rectified flow with reflow for straighter trajectories."""
61    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
62
63    for epoch in range(num_epochs):
64        for batch in dataloader:
65            x_1 = batch["images"].cuda()  # Data samples
66            x_0 = torch.randn_like(x_1)    # Noise samples
67
68            # Random timestep
69            t = torch.rand(x_1.shape[0], 1, 1, 1, device=x_1.device)
70
71            # Get interpolation and target
72            x_t, target_v = model.forward_flow(x_0, x_1, t)
73
74            # Predict velocity
75            pred_v = model.get_velocity(x_t, t.squeeze())
76
77            # Flow matching loss
78            loss = nn.functional.mse_loss(pred_v, target_v)
79
80            optimizer.zero_grad()
81            loss.backward()
82            optimizer.step()
83
84
85def reflow(
86    model: RectifiedFlowModel,
87    dataloader: torch.utils.data.DataLoader,
88    num_iterations: int = 3,
89):
90    """Iteratively straighten trajectories through reflow."""
91    for iteration in range(num_iterations):
92        # Generate paired data using current model
93        paired_data = []
94        for batch in dataloader:
95            x_1 = batch["images"].cuda()
96            x_0 = torch.randn_like(x_1)
97
98            # Generate with current model
99            with torch.no_grad():
100                x_1_pred = model.sample(x_0, num_steps=10)
101
102            paired_data.append((x_0, x_1_pred))
103
104        # Retrain model on paired data
105        # This straightens the trajectories
106        train_on_pairs(model, paired_data)
107
108        print(f"Reflow iteration {iteration + 1} complete")

Optimal Transport and Flow Matching

Flow matching provides an alternative perspective on diffusion that connects to optimal transport. Instead of defining a forward noising process and learning to reverse it, flow matching directly learns a vector field that transports noise to data:

dxtdt=vθ(xt,t)\frac{d\mathbf{x}_t}{dt} = v_\theta(\mathbf{x}_t, t)

The optimal transport formulation seeks the map with minimum transport cost:

minTExpdata[xT(z)2]\min_{T} \mathbb{E}_{\mathbf{x} \sim p_{\text{data}}} \left[ \|\mathbf{x} - T(\mathbf{z})\|^2 \right]

Recent work shows that learning straighter paths (closer to optimal transport) enables faster sampling while maintaining quality.


Enhanced Controllability

While current methods like ControlNet and IP-Adapter provide impressive control, there remain significant challenges in fine-grained control,compositionality, and semantic editing.

Fine-Grained Spatial Control

Current challenges in spatial control include:

  • Object-level control: Independently controlling multiple objects in a scene
  • Attribute binding: Ensuring attributes (colors, textures) are correctly associated with objects
  • Counting: Generating exactly N objects (diffusion models notoriously struggle with counting)
  • Negation: Reliably excluding specified concepts
🐍python
1class CompositionalController:
2    """Controller for compositional generation with multiple objects."""
3
4    def __init__(self, base_model: nn.Module):
5        self.base_model = base_model
6
7    def generate_compositional(
8        self,
9        object_descriptions: list[dict],
10        # Each dict: {"prompt": str, "bbox": (x1, y1, x2, y2), "attributes": dict}
11        background_prompt: str,
12        num_steps: int = 50,
13    ) -> torch.Tensor:
14        """Generate image with multiple controlled objects."""
15        device = next(self.base_model.parameters()).device
16
17        # Initialize latent
18        latent = torch.randn(1, 4, 64, 64, device=device)
19
20        # Encode prompts
21        bg_emb = self.encode_text(background_prompt)
22        obj_embs = [self.encode_text(obj["prompt"]) for obj in object_descriptions]
23
24        # Create spatial masks for each object
25        masks = []
26        for obj in object_descriptions:
27            mask = self.create_bbox_mask(obj["bbox"], size=(64, 64))
28            masks.append(mask.to(device))
29
30        # Sampling with compositional guidance
31        for t in reversed(range(num_steps)):
32            t_tensor = torch.full((1,), t, device=device)
33
34            # Predict noise for background
35            noise_bg = self.base_model.unet(latent, t_tensor, bg_emb)
36
37            # Predict noise for each object
38            noise_objs = []
39            for obj_emb in obj_embs:
40                noise_obj = self.base_model.unet(latent, t_tensor, obj_emb)
41                noise_objs.append(noise_obj)
42
43            # Compose predictions using masks
44            composed_noise = noise_bg.clone()
45            for mask, noise_obj in zip(masks, noise_objs):
46                # Smooth blending at boundaries
47                smooth_mask = self.smooth_mask(mask)
48                composed_noise = (
49                    composed_noise * (1 - smooth_mask) +
50                    noise_obj * smooth_mask
51                )
52
53            # DDPM step
54            latent = self.ddpm_step(latent, composed_noise, t)
55
56        return self.decode_latent(latent)
57
58    def create_bbox_mask(
59        self,
60        bbox: tuple[float, float, float, float],
61        size: tuple[int, int],
62    ) -> torch.Tensor:
63        """Create binary mask from bounding box."""
64        x1, y1, x2, y2 = bbox
65        h, w = size
66        mask = torch.zeros(1, 1, h, w)
67        mask[:, :, int(y1*h):int(y2*h), int(x1*w):int(x2*w)] = 1.0
68        return mask
69
70    def smooth_mask(self, mask: torch.Tensor, sigma: float = 2.0) -> torch.Tensor:
71        """Apply Gaussian blur for smooth blending."""
72        kernel_size = int(6 * sigma) | 1  # Make odd
73        return torchvision.transforms.functional.gaussian_blur(
74            mask, kernel_size=kernel_size, sigma=sigma
75        )

Semantic and Attribute Editing

Image editing with diffusion models remains challenging, particularly for:

  • Local edits: Changing only specific regions while preserving the rest
  • Attribute manipulation: Changing age, expression, or style while preserving identity
  • Object insertion/removal: Adding or removing objects coherently

Current approaches include inversion-based methods (DDIM inversion, null-text inversion), attention manipulation (prompt-to-prompt), and specialized fine-tuning. However, these often fail for complex edits or introduce artifacts.

🐍python
1class SemanticEditor:
2    """Semantic image editing using diffusion inversion."""
3
4    def __init__(self, model: nn.Module, num_inference_steps: int = 50):
5        self.model = model
6        self.num_inference_steps = num_inference_steps
7
8    def ddim_inversion(
9        self,
10        image: torch.Tensor,
11        prompt: str,
12    ) -> list[torch.Tensor]:
13        """Invert image to latent trajectory."""
14        # Encode image
15        latent = self.model.vae.encode(image).latent_dist.mean * 0.18215
16        text_emb = self.model.encode_text(prompt)
17
18        # Store trajectory
19        trajectory = [latent]
20
21        # Forward DDIM (noise to data direction reversed)
22        for t in range(self.num_inference_steps):
23            t_tensor = torch.full((1,), t, device=latent.device)
24
25            # Predict noise
26            noise_pred = self.model.unet(latent, t_tensor, text_emb)
27
28            # DDIM forward step
29            alpha_bar = self.model.scheduler.alphas_cumprod[t]
30            alpha_bar_next = self.model.scheduler.alphas_cumprod[t + 1]
31
32            # Deterministic forward step
33            latent = (
34                torch.sqrt(alpha_bar_next / alpha_bar) * latent +
35                torch.sqrt(1 - alpha_bar_next) * noise_pred -
36                torch.sqrt((1 - alpha_bar_next) * alpha_bar / alpha_bar_next) *
37                noise_pred
38            )
39
40            trajectory.append(latent)
41
42        return trajectory
43
44    def edit_with_attention_replacement(
45        self,
46        source_image: torch.Tensor,
47        source_prompt: str,
48        target_prompt: str,
49        edit_strength: float = 0.8,
50    ) -> torch.Tensor:
51        """Edit image by replacing attention maps."""
52        # Get source trajectory
53        source_trajectory = self.ddim_inversion(source_image, source_prompt)
54
55        # Encode prompts
56        source_emb = self.model.encode_text(source_prompt)
57        target_emb = self.model.encode_text(target_prompt)
58
59        # Start from inverted noise
60        latent = source_trajectory[-1]
61
62        # Sample with attention manipulation
63        for t in reversed(range(self.num_inference_steps)):
64            t_tensor = torch.full((1,), t, device=latent.device)
65
66            # Get source attention maps
67            with self.model.store_attention():
68                _ = self.model.unet(source_trajectory[t], t_tensor, source_emb)
69                source_attns = self.model.get_stored_attention()
70
71            # Generate with target prompt but inject source attention
72            # for structure preservation
73            edit_t = int(self.num_inference_steps * (1 - edit_strength))
74            if t > edit_t:
75                # Early steps: use source attention for structure
76                with self.model.inject_attention(source_attns):
77                    noise_pred = self.model.unet(latent, t_tensor, target_emb)
78            else:
79                # Later steps: use target attention for appearance
80                noise_pred = self.model.unet(latent, t_tensor, target_emb)
81
82            # DDPM step
83            latent = self.ddpm_step(latent, noise_pred, t)
84
85        return self.model.vae.decode(latent / 0.18215)

Theoretical Understanding

Despite their empirical success, our theoretical understandingof diffusion models remains limited. Several fundamental questions remain open:

Convergence and Sample Complexity

Key theoretical questions include:

  • Sample complexity: How many samples are needed to learn the score function to a given accuracy?
  • Convergence rates: How fast does the generated distribution converge to the true data distribution?
  • Discretization error: What is the error introduced by using finite sampling steps?

Recent theoretical work has established bounds on the convergence of diffusion models under various assumptions. For example, if the score function is learned with error ϵ\epsilon, the total variation distance between generated and true distributions scales as:

TV(pgen,pdata)CϵT\text{TV}(p_{\text{gen}}, p_{\text{data}}) \leq C \cdot \epsilon \cdot \sqrt{T}

where TT is the number of diffusion steps and CC depends on data properties.

Score Estimation Theory

The score function xlogpt(x)\nabla_x \log p_t(x) is central to diffusion models. Open questions include:

  • Curse of dimensionality: How does score estimation scale with dimension? Can we exploit structure to avoid exponential complexity?
  • Architecture choices: Why do U-Nets work well for score estimation? What inductive biases are important?
  • Low-density regions: Score estimation is unstable in low-density regions. How can we improve robustness?
🐍python
1def analyze_score_estimation_error(
2    model: nn.Module,
3    true_score: callable,
4    test_data: torch.Tensor,
5    timesteps: torch.Tensor,
6) -> dict:
7    """Analyze score estimation error across timesteps."""
8    errors = []
9
10    for t in timesteps:
11        # Add noise at timestep t
12        alpha_bar_t = get_alpha_bar(t)
13        noise = torch.randn_like(test_data)
14        x_t = torch.sqrt(alpha_bar_t) * test_data + torch.sqrt(1 - alpha_bar_t) * noise
15
16        # Compute true score (if available analytically)
17        true_s = true_score(x_t, t)
18
19        # Compute predicted score
20        with torch.no_grad():
21            pred_noise = model(x_t, t)
22            # Score is related to noise prediction
23            pred_s = -pred_noise / torch.sqrt(1 - alpha_bar_t)
24
25        # Compute error
26        mse = ((pred_s - true_s) ** 2).mean()
27        errors.append({"timestep": t.item(), "mse": mse.item()})
28
29    return {
30        "errors": errors,
31        "mean_error": sum(e["mse"] for e in errors) / len(errors),
32        "max_error_timestep": max(errors, key=lambda x: x["mse"]),
33    }

Safety and Ethics

The ability of diffusion models to generate highly realistic content raises significant safety and ethical concerns. These include misinformation through deepfakes, non-consensual imagery, copyright issues, and bias in generated content.

Content Authenticity and Detection

Detecting AI-generated content is an active research area. Current approaches include:

  • Statistical artifacts: AI-generated images may have detectable patterns in frequency domain or pixel statistics
  • Learned detectors: Train classifiers to distinguish real from generated content
  • Fingerprinting: Identify specific models based on their generation artifacts
🐍python
1import torch
2import torch.nn as nn
3from torchvision import transforms
4
5class AIImageDetector(nn.Module):
6    """Detector for AI-generated images."""
7
8    def __init__(
9        self,
10        backbone: str = "resnet50",
11        num_classes: int = 2,  # Real vs AI-generated
12    ):
13        super().__init__()
14
15        # Use pretrained backbone
16        self.backbone = torch.hub.load(
17            "pytorch/vision", backbone, pretrained=True
18        )
19
20        # Replace classifier
21        num_features = self.backbone.fc.in_features
22        self.backbone.fc = nn.Sequential(
23            nn.Linear(num_features, 256),
24            nn.ReLU(),
25            nn.Dropout(0.5),
26            nn.Linear(256, num_classes),
27        )
28
29        # Frequency analysis branch
30        self.freq_branch = nn.Sequential(
31            nn.Conv2d(3, 64, 3, padding=1),
32            nn.ReLU(),
33            nn.Conv2d(64, 64, 3, padding=1),
34            nn.AdaptiveAvgPool2d(1),
35            nn.Flatten(),
36            nn.Linear(64, 32),
37        )
38
39        # Combine branches
40        self.classifier = nn.Linear(num_classes + 32, num_classes)
41
42    def compute_frequency_features(self, x: torch.Tensor) -> torch.Tensor:
43        """Extract frequency domain features."""
44        # Convert to grayscale
45        gray = 0.299 * x[:, 0] + 0.587 * x[:, 1] + 0.114 * x[:, 2]
46
47        # Compute 2D FFT
48        fft = torch.fft.fft2(gray)
49        fft_shift = torch.fft.fftshift(fft)
50        magnitude = torch.abs(fft_shift)
51
52        # High-frequency components (often different for AI images)
53        return magnitude.unsqueeze(1).expand(-1, 3, -1, -1)
54
55    def forward(self, x: torch.Tensor) -> torch.Tensor:
56        # Backbone features
57        backbone_out = self.backbone(x)
58
59        # Frequency features
60        freq_features = self.compute_frequency_features(x)
61        freq_out = self.freq_branch(freq_features)
62
63        # Combine
64        combined = torch.cat([backbone_out, freq_out], dim=1)
65        return self.classifier(combined)
66
67
68def train_detector(
69    detector: AIImageDetector,
70    real_dataloader: torch.utils.data.DataLoader,
71    fake_dataloader: torch.utils.data.DataLoader,
72    num_epochs: int = 10,
73):
74    """Train AI image detector."""
75    optimizer = torch.optim.Adam(detector.parameters(), lr=1e-4)
76    criterion = nn.CrossEntropyLoss()
77
78    for epoch in range(num_epochs):
79        for real_batch, fake_batch in zip(real_dataloader, fake_dataloader):
80            real_images = real_batch["image"].cuda()
81            fake_images = fake_batch["image"].cuda()
82
83            # Create labels
84            real_labels = torch.zeros(real_images.shape[0], dtype=torch.long).cuda()
85            fake_labels = torch.ones(fake_images.shape[0], dtype=torch.long).cuda()
86
87            # Forward pass
88            images = torch.cat([real_images, fake_images])
89            labels = torch.cat([real_labels, fake_labels])
90
91            outputs = detector(images)
92            loss = criterion(outputs, labels)
93
94            # Backward
95            optimizer.zero_grad()
96            loss.backward()
97            optimizer.step()
98
99        # Evaluate accuracy
100        accuracy = evaluate_detector(detector, real_dataloader, fake_dataloader)
101        print(f"Epoch {epoch}, Accuracy: {accuracy:.2%}")

Watermarking Generated Content

Watermarking embeds imperceptible signals in generated content to enable later identification. Key challenges include:

  • Robustness: Watermarks should survive common transformations (compression, cropping, screenshots)
  • Invisibility: Watermarks should not affect visual quality
  • Capacity: Encode sufficient information (model ID, timestamp, user ID)
🐍python
1class DiffusionWatermarker:
2    """Watermarking for diffusion-generated images."""
3
4    def __init__(
5        self,
6        key: bytes,
7        message_length: int = 48,
8    ):
9        self.key = key
10        self.message_length = message_length
11
12        # Watermark encoder/decoder network
13        self.encoder = WatermarkEncoder(message_length)
14        self.decoder = WatermarkDecoder(message_length)
15
16    def embed_during_generation(
17        self,
18        model: nn.Module,
19        prompt: str,
20        message: str,  # Message to embed
21        num_steps: int = 50,
22    ) -> torch.Tensor:
23        """Embed watermark during diffusion sampling."""
24        device = next(model.parameters()).device
25
26        # Encode message to binary
27        binary_message = self.string_to_binary(message)
28
29        # Generate watermark pattern
30        watermark = self.encoder(binary_message.unsqueeze(0).to(device))
31
32        # Initialize latent
33        latent = torch.randn(1, 4, 64, 64, device=device)
34        text_emb = model.encode_text(prompt)
35
36        # Sample with watermark injection
37        for t in reversed(range(num_steps)):
38            t_tensor = torch.full((1,), t, device=device)
39            noise_pred = model.unet(latent, t_tensor, text_emb)
40
41            # Inject watermark in early steps (survives better)
42            if t > num_steps * 0.3:
43                # Add subtle watermark to latent
44                latent = latent + 0.01 * watermark
45
46            latent = self.ddpm_step(latent, noise_pred, t)
47
48        return model.vae.decode(latent / 0.18215)
49
50    def detect_watermark(self, image: torch.Tensor) -> dict:
51        """Detect and decode watermark from image."""
52        # Extract watermark
53        detected_bits = self.decoder(image)
54
55        # Threshold to binary
56        binary = (detected_bits > 0.5).float()
57
58        # Decode message
59        try:
60            message = self.binary_to_string(binary)
61            confidence = torch.sigmoid(detected_bits).mean().item()
62            return {
63                "detected": True,
64                "message": message,
65                "confidence": confidence,
66            }
67        except:
68            return {"detected": False, "confidence": 0.0}
69
70    def string_to_binary(self, s: str) -> torch.Tensor:
71        """Convert string to binary tensor."""
72        binary = "".join(format(ord(c), "08b") for c in s)
73        return torch.tensor([int(b) for b in binary], dtype=torch.float32)
74
75    def binary_to_string(self, binary: torch.Tensor) -> str:
76        """Convert binary tensor to string."""
77        bits = "".join(str(int(b)) for b in binary.squeeze())
78        chars = [chr(int(bits[i:i+8], 2)) for i in range(0, len(bits), 8)]
79        return "".join(chars)

Bias and Fairness

Diffusion models inherit and can amplify biases present in training data:

  • Demographic bias: Over or under-representation of certain groups
  • Stereotyping: Associating occupations, attributes, or behaviors with specific demographics
  • Geographic bias: Western-centric training data leads to limited diversity
Measuring Bias: Systematic evaluation requires generating large samples across demographic prompts and measuring representation statistics. Projects like DALL-E's fairness evaluations provide frameworks for such audits.

Emerging Architectures

While U-Net has been the dominant architecture, new designs are emerging that may offer better scaling properties and generation quality.

Diffusion Transformers (DiT)

Diffusion Transformers (Peebles & Xie, 2023) replace U-Net with a pure transformer architecture, showing better scaling behavior and enabling the use of techniques from language models:

🐍python
1import torch
2import torch.nn as nn
3
4class DiT(nn.Module):
5    """Diffusion Transformer architecture."""
6
7    def __init__(
8        self,
9        input_size: int = 32,  # Latent spatial size
10        patch_size: int = 2,
11        in_channels: int = 4,
12        hidden_size: int = 1152,
13        depth: int = 28,
14        num_heads: int = 16,
15        mlp_ratio: float = 4.0,
16        class_dropout_prob: float = 0.1,
17        num_classes: int = 1000,
18    ):
19        super().__init__()
20        self.input_size = input_size
21        self.patch_size = patch_size
22        self.num_patches = (input_size // patch_size) ** 2
23
24        # Patch embedding
25        self.patch_embed = nn.Conv2d(
26            in_channels, hidden_size,
27            kernel_size=patch_size, stride=patch_size
28        )
29
30        # Position embedding
31        self.pos_embed = nn.Parameter(
32            torch.zeros(1, self.num_patches, hidden_size)
33        )
34
35        # Time embedding (adaln modulation)
36        self.time_embed = nn.Sequential(
37            SinusoidalEmbedding(hidden_size),
38            nn.Linear(hidden_size, hidden_size * 4),
39            nn.SiLU(),
40            nn.Linear(hidden_size * 4, hidden_size * 4),
41        )
42
43        # Class embedding
44        self.class_embed = nn.Embedding(num_classes, hidden_size)
45
46        # Transformer blocks
47        self.blocks = nn.ModuleList([
48            DiTBlock(
49                hidden_size=hidden_size,
50                num_heads=num_heads,
51                mlp_ratio=mlp_ratio,
52            )
53            for _ in range(depth)
54        ])
55
56        # Final layer
57        self.final_norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
58        self.final_linear = nn.Linear(
59            hidden_size, patch_size ** 2 * in_channels
60        )
61
62        # AdaLN modulation for final layer
63        self.final_adaLN = nn.Sequential(
64            nn.SiLU(),
65            nn.Linear(hidden_size * 4, 2 * hidden_size),
66        )
67
68    def forward(
69        self,
70        x: torch.Tensor,      # [B, C, H, W] noisy latent
71        t: torch.Tensor,      # [B] timesteps
72        y: torch.Tensor,      # [B] class labels
73    ) -> torch.Tensor:
74        # Patchify and embed
75        x = self.patch_embed(x)  # [B, hidden, H/p, W/p]
76        x = x.flatten(2).transpose(1, 2)  # [B, num_patches, hidden]
77        x = x + self.pos_embed
78
79        # Conditioning embeddings
80        t_emb = self.time_embed(t)  # [B, hidden * 4]
81        y_emb = self.class_embed(y)  # [B, hidden]
82        c = t_emb + y_emb.unsqueeze(1)  # Combined condition
83
84        # Apply transformer blocks
85        for block in self.blocks:
86            x = block(x, c)
87
88        # Final layer with AdaLN
89        shift, scale = self.final_adaLN(c).chunk(2, dim=-1)
90        x = self.final_norm(x)
91        x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
92        x = self.final_linear(x)
93
94        # Unpatchify
95        x = self.unpatchify(x)
96
97        return x
98
99    def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
100        """Convert patch tokens back to image."""
101        p = self.patch_size
102        h = w = int(self.num_patches ** 0.5)
103        x = x.reshape(-1, h, w, p, p, self.in_channels)
104        x = torch.einsum("bhwpqc->bchpwq", x)
105        x = x.reshape(-1, self.in_channels, h * p, w * p)
106        return x
107
108
109class DiTBlock(nn.Module):
110    """DiT transformer block with AdaLN-Zero."""
111
112    def __init__(
113        self,
114        hidden_size: int,
115        num_heads: int,
116        mlp_ratio: float = 4.0,
117    ):
118        super().__init__()
119
120        # Pre-norm (no affine params - modulated by AdaLN)
121        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
122        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)
123
124        # Self-attention
125        self.attn = nn.MultiheadAttention(
126            hidden_size, num_heads, batch_first=True
127        )
128
129        # MLP
130        mlp_hidden = int(hidden_size * mlp_ratio)
131        self.mlp = nn.Sequential(
132            nn.Linear(hidden_size, mlp_hidden),
133            nn.GELU(approximate="tanh"),
134            nn.Linear(mlp_hidden, hidden_size),
135        )
136
137        # AdaLN modulation (produces scale, shift, gate for attn and mlp)
138        self.adaLN = nn.Sequential(
139            nn.SiLU(),
140            nn.Linear(hidden_size * 4, 6 * hidden_size),
141        )
142
143    def forward(
144        self,
145        x: torch.Tensor,
146        c: torch.Tensor,  # Conditioning (time + class)
147    ) -> torch.Tensor:
148        # Get modulation parameters
149        modulation = self.adaLN(c)
150        shift_attn, scale_attn, gate_attn, shift_mlp, scale_mlp, gate_mlp = (
151            modulation.chunk(6, dim=-1)
152        )
153
154        # Attention with AdaLN
155        h = self.norm1(x)
156        h = h * (1 + scale_attn.unsqueeze(1)) + shift_attn.unsqueeze(1)
157        h = self.attn(h, h, h)[0]
158        x = x + gate_attn.unsqueeze(1) * h
159
160        # MLP with AdaLN
161        h = self.norm2(x)
162        h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
163        h = self.mlp(h)
164        x = x + gate_mlp.unsqueeze(1) * h
165
166        return x

State Space Models for Diffusion

State space models (SSMs) like Mamba offer linear-time attention alternatives that could enable efficient processing of very long sequences, beneficial for high-resolution generation and video:

ArchitectureAttentionScalingBest For
U-NetLocal + GlobalO(n)Most current models
Transformer (DiT)QuadraticO(n²)Large-scale training
SSM (Mamba)LinearO(n)Long sequences, video
HybridMixedVariesEmerging approaches

Future Outlook

The field of diffusion models is evolving rapidly. Several trends suggest where research may head:

  1. Unified multimodal models: Models that handle text, images, audio, video, and 3D in a single framework
  2. World models: Using diffusion for learning world dynamics and planning in robotics and autonomous systems
  3. Scientific applications: Accelerating drug discovery, materials design, and protein engineering
  4. Real-time generation: Enabling interactive creative tools and responsive AI systems
  5. Personalization: Efficiently adapting models to individual preferences and styles
The Bigger Picture: Diffusion models represent more than just a generative modeling technique. They provide a framework for learning complex distributions through iterative refinement - a principle that may extend to many problems beyond content generation, including optimization, inference, and decision-making.

References

Key papers on open problems and future directions:

  • Song et al. (2023). "Consistency Models" - Single-step generation
  • Liu et al. (2023). "Flow Matching for Generative Modeling" - Optimal transport perspective
  • Peebles & Xie (2023). "Scalable Diffusion Models with Transformers" - DiT architecture
  • Lipman et al. (2023). "Flow Matching for Generative Modeling"
  • Karras et al. (2024). "Analyzing and Improving the Training Dynamics of Diffusion Models"
  • Zhang et al. (2023). "Adding Conditional Control to Text-to-Image Diffusion Models" (ControlNet)
  • Hertz et al. (2023). "Prompt-to-Prompt Image Editing with Cross-Attention Control"
  • Wen et al. (2024). "Tree-Ring Watermarks: Fingerprints for Diffusion Images that are Invisible and Robust"

Contributing to the Field

Diffusion models are an active area with many opportunities for impactful contributions. Whether improving sampling efficiency, enhancing controllability, developing safety mechanisms, or expanding to new domains, there is much important work to be done. We hope this book has provided the foundation needed to engage with and contribute to this exciting field.