Chapter 17
25 min read
Section 75 of 76

Beyond Images

The Future of Diffusion Models

Introduction

While diffusion models achieved their initial success in image generation, the underlying framework is remarkably general. The same principles - gradually adding noise and learning to reverse the process - can be applied to virtually any data type where we can define a meaningful noise process. This has led to an explosion of applications beyond images, including audio, molecules, human motion, and robot behavior.

The key insight is that diffusion models learn the data distribution without requiring explicit likelihood computation, making them applicable to complex, high-dimensional spaces where traditional generative models struggle. In this section, we explore how researchers have adapted diffusion models to diverse domains, each with unique challenges and representations.

DomainData TypeKey ChallengeNotable Models
AudioSpectrograms, waveformsTemporal coherence, perceptual qualityAudioLDM, MusicGen, Stable Audio
Molecules3D coordinates, graphsPhysical constraints, stabilityEDM, DiffDock, RFDiffusion
MotionJoint trajectoriesBiomechanical plausibilityMDM, MotionDiffuse, MoMask
RoboticsAction sequencesPhysical feasibility, real-timeDiffusion Policy, Diffuser
WeatherSpatial fieldsMulti-scale dynamicsGenCast, GraphCast

Audio Generation

Audio generation with diffusion models has become a major research area, with applications ranging from sound effects to music to speech. The key insight is that audio can be represented as spectrograms (2D images) or directly as waveforms, both amenable to diffusion.

AudioLDM: Text-to-Audio

AudioLDM (Liu et al., 2023) applies latent diffusion to audio generation. Similar to Stable Diffusion for images, AudioLDM operates in a compressed latent space learned by a VAE, making generation faster and more efficient.

The architecture processes mel-spectrograms through a VAE encoder, applies diffusion in latent space, and decodes back to spectrograms. A vocoder (like HiFi-GAN) then converts spectrograms to waveforms:

🐍python
1import torch
2import torch.nn as nn
3from typing import Optional
4
5class AudioLDM(nn.Module):
6    """AudioLDM for text-to-audio generation."""
7
8    def __init__(
9        self,
10        latent_dim: int = 8,
11        mel_channels: int = 80,
12        hidden_dim: int = 256,
13        num_layers: int = 12,
14    ):
15        super().__init__()
16
17        # VAE for mel-spectrogram compression
18        self.vae = MelVAE(
19            mel_channels=mel_channels,
20            latent_dim=latent_dim,
21        )
22
23        # Latent diffusion U-Net (1D convolutions for temporal data)
24        self.unet = AudioUNet(
25            in_channels=latent_dim,
26            out_channels=latent_dim,
27            hidden_dim=hidden_dim,
28            num_layers=num_layers,
29        )
30
31        # Text encoder (CLAP or T5)
32        self.text_encoder = CLAPTextEncoder()
33
34        # Vocoder for spectrogram to waveform
35        self.vocoder = HiFiGAN()
36
37    def encode_text(self, text: str) -> torch.Tensor:
38        """Encode text prompt to conditioning embeddings."""
39        return self.text_encoder(text)
40
41    def forward(
42        self,
43        mel_spec: torch.Tensor,  # [B, mel_channels, T]
44        t: torch.Tensor,
45        text_embeddings: torch.Tensor,
46    ) -> torch.Tensor:
47        # Encode to latent
48        latent = self.vae.encode(mel_spec)
49
50        # Apply diffusion model
51        noise_pred = self.unet(latent, t, text_embeddings)
52
53        return noise_pred
54
55    @torch.no_grad()
56    def generate(
57        self,
58        text: str,
59        duration: float = 5.0,  # seconds
60        num_inference_steps: int = 50,
61        guidance_scale: float = 3.0,
62    ) -> torch.Tensor:
63        """Generate audio from text prompt."""
64        device = next(self.parameters()).device
65
66        # Calculate latent dimensions
67        latent_length = int(duration * 100 / 4)  # 100 Hz mel, 4x compression
68        latent = torch.randn(1, self.latent_dim, latent_length, device=device)
69
70        # Encode text
71        text_emb = self.encode_text(text)
72        null_emb = self.encode_text("")
73
74        # Diffusion sampling with CFG
75        for i, t in enumerate(reversed(range(num_inference_steps))):
76            t_tensor = torch.full((1,), t, device=device)
77
78            # Classifier-free guidance
79            noise_cond = self.unet(latent, t_tensor, text_emb)
80            noise_uncond = self.unet(latent, t_tensor, null_emb)
81            noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
82
83            # DDPM update step
84            latent = self.ddpm_step(latent, noise_pred, t)
85
86        # Decode to mel-spectrogram
87        mel_spec = self.vae.decode(latent)
88
89        # Convert to waveform
90        waveform = self.vocoder(mel_spec)
91
92        return waveform
93
94
95class MelVAE(nn.Module):
96    """VAE for mel-spectrogram compression."""
97
98    def __init__(self, mel_channels: int = 80, latent_dim: int = 8):
99        super().__init__()
100
101        # Encoder: mel -> latent
102        self.encoder = nn.Sequential(
103            nn.Conv1d(mel_channels, 128, 3, 2, 1),
104            nn.ReLU(),
105            nn.Conv1d(128, 256, 3, 2, 1),
106            nn.ReLU(),
107            nn.Conv1d(256, latent_dim * 2, 3, 1, 1),  # mean and logvar
108        )
109
110        # Decoder: latent -> mel
111        self.decoder = nn.Sequential(
112            nn.ConvTranspose1d(latent_dim, 256, 4, 2, 1),
113            nn.ReLU(),
114            nn.ConvTranspose1d(256, 128, 4, 2, 1),
115            nn.ReLU(),
116            nn.Conv1d(128, mel_channels, 3, 1, 1),
117        )
118
119    def encode(self, mel: torch.Tensor) -> torch.Tensor:
120        h = self.encoder(mel)
121        mean, logvar = h.chunk(2, dim=1)
122        std = torch.exp(0.5 * logvar)
123        z = mean + std * torch.randn_like(std)
124        return z
125
126    def decode(self, z: torch.Tensor) -> torch.Tensor:
127        return self.decoder(z)

Music Generation

Music generation presents unique challenges beyond general audio: long-term structure, harmonic consistency, and rhythmic coherence. Several approaches have emerged:

  • MusicGen (Meta, 2023): Uses a transformer with delayed pattern for efficient multi-stream audio tokenization
  • Stable Audio (Stability AI, 2023): Latent diffusion with timing conditioning for variable-length generation
  • Riffusion: Fine-tuned Stable Diffusion on spectrograms for music generation
🐍python
1class MusicDiffusion(nn.Module):
2    """Diffusion model for music generation with structure conditioning."""
3
4    def __init__(
5        self,
6        audio_channels: int = 2,  # Stereo
7        hidden_dim: int = 512,
8        num_layers: int = 16,
9    ):
10        super().__init__()
11
12        # Waveform encoder (EnCodec-style)
13        self.encoder = AudioEncoder(
14            in_channels=audio_channels,
15            out_channels=hidden_dim // 4,
16            num_codebooks=4,
17        )
18
19        # Diffusion transformer
20        self.transformer = DiffusionTransformer(
21            hidden_dim=hidden_dim,
22            num_layers=num_layers,
23            num_heads=8,
24        )
25
26        # Timing embedding (for variable length)
27        self.timing_embed = nn.Sequential(
28            nn.Linear(2, hidden_dim),  # [start_time, total_duration]
29            nn.SiLU(),
30            nn.Linear(hidden_dim, hidden_dim),
31        )
32
33        # Music-specific conditioning
34        self.genre_embed = nn.Embedding(100, hidden_dim)  # 100 genres
35        self.tempo_embed = nn.Linear(1, hidden_dim)  # BPM
36        self.key_embed = nn.Embedding(24, hidden_dim)  # 12 keys x 2 modes
37
38    def forward(
39        self,
40        audio_tokens: torch.Tensor,  # [B, num_codebooks, T]
41        t: torch.Tensor,
42        text_embeddings: torch.Tensor,
43        timing: Optional[torch.Tensor] = None,
44        genre: Optional[torch.Tensor] = None,
45        tempo: Optional[torch.Tensor] = None,
46        key: Optional[torch.Tensor] = None,
47    ) -> torch.Tensor:
48        # Encode audio
49        x = self.encoder.encode_tokens(audio_tokens)
50
51        # Add timing conditioning
52        if timing is not None:
53            x = x + self.timing_embed(timing).unsqueeze(1)
54
55        # Add music-specific conditioning
56        if genre is not None:
57            x = x + self.genre_embed(genre).unsqueeze(1)
58        if tempo is not None:
59            x = x + self.tempo_embed(tempo).unsqueeze(1)
60        if key is not None:
61            x = x + self.key_embed(key).unsqueeze(1)
62
63        # Apply transformer
64        noise_pred = self.transformer(x, t, text_embeddings)
65
66        return noise_pred
67
68
69class AudioEncoder(nn.Module):
70    """Neural audio codec for music compression."""
71
72    def __init__(
73        self,
74        in_channels: int = 2,
75        out_channels: int = 128,
76        num_codebooks: int = 4,
77        codebook_size: int = 1024,
78    ):
79        super().__init__()
80        self.num_codebooks = num_codebooks
81
82        # Convolutional encoder
83        self.encoder = nn.Sequential(
84            nn.Conv1d(in_channels, 64, 7, 1, 3),
85            nn.ELU(),
86            ResBlock1D(64),
87            nn.Conv1d(64, 128, 4, 2, 1),  # Downsample
88            nn.ELU(),
89            ResBlock1D(128),
90            nn.Conv1d(128, 256, 4, 2, 1),
91            nn.ELU(),
92            ResBlock1D(256),
93            nn.Conv1d(256, out_channels * num_codebooks, 4, 2, 1),
94        )
95
96        # Residual vector quantizers
97        self.quantizers = nn.ModuleList([
98            VectorQuantizer(out_channels, codebook_size)
99            for _ in range(num_codebooks)
100        ])
101
102    def encode(self, audio: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
103        """Encode audio to discrete tokens."""
104        h = self.encoder(audio)
105        h = h.reshape(h.shape[0], self.num_codebooks, -1, h.shape[-1])
106
107        tokens = []
108        for i, quantizer in enumerate(self.quantizers):
109            token = quantizer.encode(h[:, i])
110            tokens.append(token)
111
112        return torch.stack(tokens, dim=1)  # [B, num_codebooks, T]

Speech Synthesis and Enhancement

Speech synthesis has also benefited from diffusion models, with applications in text-to-speech (TTS), voice conversion, and speech enhancement. Notable systems include:

  • Grad-TTS: Score-based TTS with duration prediction
  • DiffWave: Direct waveform generation with dilated convolutions
  • VoiceBox (Meta): Large-scale speech generation with in-context learning
🐍python
1class DiffWave(nn.Module):
2    """DiffWave for waveform generation."""
3
4    def __init__(
5        self,
6        residual_channels: int = 64,
7        dilation_cycle: int = 10,
8        num_layers: int = 30,
9    ):
10        super().__init__()
11
12        self.input_conv = nn.Conv1d(1, residual_channels, 1)
13
14        # WaveNet-style dilated convolutions
15        self.residual_layers = nn.ModuleList()
16        for i in range(num_layers):
17            dilation = 2 ** (i % dilation_cycle)
18            self.residual_layers.append(
19                ResidualBlock(
20                    residual_channels,
21                    dilation=dilation,
22                )
23            )
24
25        self.output_conv = nn.Sequential(
26            nn.Conv1d(residual_channels, residual_channels, 1),
27            nn.ReLU(),
28            nn.Conv1d(residual_channels, 1, 1),
29        )
30
31        # Diffusion embeddings
32        self.diffusion_embed = SinusoidalEmbedding(residual_channels)
33        self.diffusion_proj = nn.Linear(residual_channels, residual_channels)
34
35        # Conditioning (e.g., mel-spectrogram for vocoding)
36        self.cond_proj = nn.Conv1d(80, residual_channels, 1)
37
38    def forward(
39        self,
40        x: torch.Tensor,  # [B, 1, T] noisy waveform
41        t: torch.Tensor,  # [B] diffusion timesteps
42        condition: Optional[torch.Tensor] = None,  # [B, 80, T//256]
43    ) -> torch.Tensor:
44        # Embed diffusion timestep
45        t_emb = self.diffusion_embed(t)
46        t_emb = self.diffusion_proj(t_emb)
47
48        # Upsample condition if provided
49        if condition is not None:
50            condition = nn.functional.interpolate(
51                condition, size=x.shape[-1], mode="linear"
52            )
53            condition = self.cond_proj(condition)
54
55        # Process through residual layers
56        h = self.input_conv(x)
57        skip_connections = []
58
59        for layer in self.residual_layers:
60            h, skip = layer(h, t_emb, condition)
61            skip_connections.append(skip)
62
63        # Sum skip connections
64        out = sum(skip_connections) / len(skip_connections) ** 0.5
65
66        return self.output_conv(out)
67
68
69class ResidualBlock(nn.Module):
70    """WaveNet residual block with diffusion conditioning."""
71
72    def __init__(self, channels: int, dilation: int):
73        super().__init__()
74
75        self.dilated_conv = nn.Conv1d(
76            channels, 2 * channels, 3,
77            padding=dilation, dilation=dilation
78        )
79        self.diffusion_proj = nn.Linear(channels, channels)
80        self.cond_proj = nn.Conv1d(channels, 2 * channels, 1)
81        self.output_proj = nn.Conv1d(channels, 2 * channels, 1)
82
83    def forward(
84        self,
85        x: torch.Tensor,
86        t_emb: torch.Tensor,
87        condition: Optional[torch.Tensor],
88    ) -> tuple[torch.Tensor, torch.Tensor]:
89        h = self.dilated_conv(x)
90
91        # Add diffusion timestep
92        h = h + self.diffusion_proj(t_emb).unsqueeze(-1)
93
94        # Add condition
95        if condition is not None:
96            h = h + self.cond_proj(condition)
97
98        # Gated activation
99        gate, filter = h.chunk(2, dim=1)
100        h = torch.sigmoid(gate) * torch.tanh(filter)
101
102        # Output projection
103        out = self.output_proj(h)
104        residual, skip = out.chunk(2, dim=1)
105
106        return x + residual, skip

Molecular Generation

Molecular generation is a transformative application of diffusion models in drug discovery and materials science. Molecules can be represented as 3D point clouds (atom positions), graphs (atoms and bonds), or combinations of both. The key challenge is ensuring generated molecules satisfy chemical constraints.

Diffusion for Drug Discovery

E(3)-equivariant diffusion models generate molecules while respecting the symmetries of 3D space (rotations and translations). This is crucial because molecular properties are invariant to rigid transformations.

The forward process adds noise to atom positions in 3D:

q(xtx0)=N(xt;αˉtx0,(1αˉt)I)q(\mathbf{x}_t | \mathbf{x}_0) = \mathcal{N}(\mathbf{x}_t; \sqrt{\bar{\alpha}_t}\mathbf{x}_0, (1-\bar{\alpha}_t)\mathbf{I})

The model must be E(3)-equivariant, meaning predictions transform correctly under rotations and translations:

ϵθ(Rx+t)=Rϵθ(x)\epsilon_\theta(R\mathbf{x} + \mathbf{t}) = R\epsilon_\theta(\mathbf{x})
🐍python
1import torch
2import torch.nn as nn
3from e3nn import o3
4from e3nn.nn import FullyConnectedNet
5
6class E3DiffusionMolecule(nn.Module):
7    """E(3)-equivariant diffusion for molecular generation."""
8
9    def __init__(
10        self,
11        num_atom_types: int = 100,
12        hidden_dim: int = 128,
13        num_layers: int = 6,
14        max_radius: float = 5.0,
15    ):
16        super().__init__()
17        self.max_radius = max_radius
18
19        # Atom type embedding
20        self.atom_embed = nn.Embedding(num_atom_types, hidden_dim)
21
22        # E(3)-equivariant message passing layers
23        self.layers = nn.ModuleList([
24            EquivariantMessagePassingLayer(
25                hidden_dim=hidden_dim,
26                max_radius=max_radius,
27            )
28            for _ in range(num_layers)
29        ])
30
31        # Output heads
32        self.pos_head = nn.Linear(hidden_dim, 3)  # Position noise prediction
33        self.type_head = nn.Linear(hidden_dim, num_atom_types)  # Type prediction
34
35        # Time embedding
36        self.time_embed = nn.Sequential(
37            SinusoidalEmbedding(hidden_dim),
38            nn.Linear(hidden_dim, hidden_dim),
39            nn.SiLU(),
40        )
41
42    def forward(
43        self,
44        positions: torch.Tensor,  # [B, N, 3] atom positions
45        atom_types: torch.Tensor,  # [B, N] atom types
46        t: torch.Tensor,  # [B] timesteps
47        batch_idx: torch.Tensor,  # [B * N] batch indices
48    ) -> tuple[torch.Tensor, torch.Tensor]:
49        batch_size, num_atoms, _ = positions.shape
50
51        # Flatten for processing
52        positions_flat = positions.reshape(-1, 3)
53        atom_types_flat = atom_types.reshape(-1)
54
55        # Initial embeddings
56        h = self.atom_embed(atom_types_flat)
57
58        # Add time embedding
59        t_emb = self.time_embed(t)
60        h = h + t_emb[batch_idx]
61
62        # Build edge indices based on distance
63        edge_index, edge_vec = self.build_edges(positions_flat, batch_idx)
64
65        # Message passing
66        for layer in self.layers:
67            h, positions_flat = layer(h, positions_flat, edge_index, edge_vec)
68
69        # Predict noise
70        pos_noise = self.pos_head(h).reshape(batch_size, num_atoms, 3)
71        type_logits = self.type_head(h).reshape(batch_size, num_atoms, -1)
72
73        return pos_noise, type_logits
74
75    def build_edges(
76        self,
77        positions: torch.Tensor,
78        batch_idx: torch.Tensor,
79    ) -> tuple[torch.Tensor, torch.Tensor]:
80        """Build edge indices based on distance cutoff."""
81        # Compute pairwise distances within each batch
82        diff = positions.unsqueeze(0) - positions.unsqueeze(1)
83        dist = torch.norm(diff, dim=-1)
84
85        # Create edges for atoms within cutoff
86        mask = (dist < self.max_radius) & (dist > 0)
87        mask = mask & (batch_idx.unsqueeze(0) == batch_idx.unsqueeze(1))
88
89        edge_index = mask.nonzero().T
90        edge_vec = diff[edge_index[0], edge_index[1]]
91
92        return edge_index, edge_vec
93
94
95class EquivariantMessagePassingLayer(nn.Module):
96    """E(3)-equivariant message passing."""
97
98    def __init__(self, hidden_dim: int, max_radius: float):
99        super().__init__()
100
101        # Spherical harmonics for angular information
102        self.spherical_harmonics = o3.SphericalHarmonics(
103            irreps_out="1o",  # Vector representation
104            normalize=True,
105        )
106
107        # Message function
108        self.message_net = nn.Sequential(
109            nn.Linear(2 * hidden_dim + 16, hidden_dim),
110            nn.SiLU(),
111            nn.Linear(hidden_dim, hidden_dim),
112        )
113
114        # Radial basis for distance encoding
115        self.rbf = GaussianRadialBasis(max_radius=max_radius, num_basis=16)
116
117        # Update functions
118        self.update_h = nn.Sequential(
119            nn.Linear(2 * hidden_dim, hidden_dim),
120            nn.SiLU(),
121            nn.Linear(hidden_dim, hidden_dim),
122        )
123        self.update_pos = nn.Linear(hidden_dim, 1)
124
125    def forward(
126        self,
127        h: torch.Tensor,  # [N, hidden_dim]
128        positions: torch.Tensor,  # [N, 3]
129        edge_index: torch.Tensor,  # [2, E]
130        edge_vec: torch.Tensor,  # [E, 3]
131    ) -> tuple[torch.Tensor, torch.Tensor]:
132        src, dst = edge_index
133
134        # Compute edge features
135        edge_dist = torch.norm(edge_vec, dim=-1, keepdim=True)
136        edge_rbf = self.rbf(edge_dist.squeeze(-1))
137
138        # Message computation
139        h_src = h[src]
140        h_dst = h[dst]
141        message_input = torch.cat([h_src, h_dst, edge_rbf], dim=-1)
142        messages = self.message_net(message_input)
143
144        # Aggregate messages
145        aggregated = torch.zeros_like(h)
146        aggregated.index_add_(0, dst, messages)
147
148        # Update node features
149        h_new = self.update_h(torch.cat([h, aggregated], dim=-1))
150        h = h + h_new
151
152        # Update positions (equivariant)
153        pos_update_weight = self.update_pos(messages)
154        edge_vec_normalized = edge_vec / (edge_dist + 1e-8)
155        pos_updates = edge_vec_normalized * pos_update_weight
156
157        pos_aggregated = torch.zeros_like(positions)
158        pos_aggregated.index_add_(0, dst, pos_updates)
159        positions = positions + pos_aggregated
160
161        return h, positions

Protein Structure Generation

Protein structure prediction has been revolutionized by AI, with diffusion models now used for both structure prediction and de novo protein design. RFDiffusion (Watson et al., 2023) generates novel protein backbones that fold into desired shapes:

🐍python
1class ProteinDiffusion(nn.Module):
2    """Diffusion model for protein backbone generation."""
3
4    def __init__(
5        self,
6        num_residues: int = 100,
7        hidden_dim: int = 256,
8        num_layers: int = 12,
9    ):
10        super().__init__()
11
12        # Backbone is represented as frames (rotation + translation per residue)
13        self.frame_dim = 12  # 3x3 rotation + 3 translation
14
15        # SE(3)-equivariant transformer (IPA-style)
16        self.ipa_layers = nn.ModuleList([
17            InvariantPointAttention(
18                hidden_dim=hidden_dim,
19                num_heads=8,
20                num_query_points=8,
21                num_value_points=8,
22            )
23            for _ in range(num_layers)
24        ])
25
26        # Single representation
27        self.single_embed = nn.Linear(20, hidden_dim)  # 20 amino acids
28
29        # Pair representation
30        self.pair_embed = nn.Linear(hidden_dim, hidden_dim)
31
32        # Frame update
33        self.frame_update = nn.Linear(hidden_dim, 6)  # axis-angle + translation
34
35        # Time embedding
36        self.time_embed = SinusoidalEmbedding(hidden_dim)
37
38    def forward(
39        self,
40        frames: torch.Tensor,  # [B, L, 4, 4] backbone frames
41        sequence: torch.Tensor,  # [B, L] amino acid sequence
42        t: torch.Tensor,  # [B] timesteps
43    ) -> torch.Tensor:
44        batch_size, num_residues = sequence.shape
45
46        # Embed sequence
47        single = self.single_embed(
48            nn.functional.one_hot(sequence, 20).float()
49        )
50
51        # Add time embedding
52        t_emb = self.time_embed(t).unsqueeze(1)
53        single = single + t_emb
54
55        # Compute pair features from frames
56        pair = self.compute_pair_features(frames)
57        pair = self.pair_embed(pair)
58
59        # Apply IPA layers
60        for ipa in self.ipa_layers:
61            single = ipa(single, pair, frames)
62
63        # Predict frame updates
64        frame_update = self.frame_update(single)  # [B, L, 6]
65
66        # Convert to rotation and translation
67        axis_angle = frame_update[..., :3]
68        translation = frame_update[..., 3:]
69
70        return axis_angle, translation
71
72    def compute_pair_features(self, frames: torch.Tensor) -> torch.Tensor:
73        """Compute pairwise features from backbone frames."""
74        batch_size, num_residues = frames.shape[:2]
75
76        # Extract positions (last column of frame matrix)
77        positions = frames[..., :3, 3]  # [B, L, 3]
78
79        # Pairwise distances
80        diff = positions.unsqueeze(2) - positions.unsqueeze(1)  # [B, L, L, 3]
81        dist = torch.norm(diff, dim=-1, keepdim=True)  # [B, L, L, 1]
82
83        # Relative orientation features
84        rotations = frames[..., :3, :3]  # [B, L, 3, 3]
85        rel_rot = torch.einsum(
86            "bijk,bilk->bijl",
87            rotations, rotations
88        )  # [B, L, L, 3]
89
90        # Combine features
91        pair = torch.cat([dist, diff, rel_rot.flatten(-2)], dim=-1)
92
93        return pair

Motion Synthesis

Motion synthesis generates human motion sequences from various conditions like text descriptions, music, or action labels. The data is typically represented as sequences of joint positions or rotations over time.

Human Motion Generation

Human motion is typically represented using skeletal formats like SMPL (Skinned Multi-Person Linear Model), which parameterizes body pose with 7272 joint angles plus 33global translation parameters:

MRT×(J×3+3)\mathbf{M} \in \mathbb{R}^{T \times (J \times 3 + 3)}

where TT is the number of frames and JJ is the number of joints.

Motion Diffusion Model (MDM)

The Motion Diffusion Model (Tevet et al., 2023) applies diffusion to motion generation with text conditioning. Key design choices include:

  • Simple transformer architecture: No U-Net, just a transformer encoder
  • Geometric losses: Foot contact and velocity constraints for physical plausibility
  • Classifier-free guidance: For text-to-motion generation
🐍python
1import torch
2import torch.nn as nn
3
4class MotionDiffusionModel(nn.Module):
5    """Motion Diffusion Model for text-to-motion generation."""
6
7    def __init__(
8        self,
9        motion_dim: int = 263,  # HumanML3D representation
10        hidden_dim: int = 512,
11        num_layers: int = 8,
12        num_heads: int = 4,
13        max_length: int = 196,
14    ):
15        super().__init__()
16        self.motion_dim = motion_dim
17        self.max_length = max_length
18
19        # Motion embedding
20        self.motion_embed = nn.Linear(motion_dim, hidden_dim)
21
22        # Positional encoding
23        self.pos_embed = nn.Parameter(
24            torch.randn(1, max_length, hidden_dim) * 0.02
25        )
26
27        # Time embedding
28        self.time_embed = nn.Sequential(
29            SinusoidalEmbedding(hidden_dim),
30            nn.Linear(hidden_dim, hidden_dim),
31            nn.SiLU(),
32            nn.Linear(hidden_dim, hidden_dim),
33        )
34
35        # Text encoder (CLIP)
36        self.text_encoder = CLIPTextEncoder()
37        self.text_proj = nn.Linear(512, hidden_dim)  # CLIP dim -> hidden
38
39        # Transformer encoder
40        encoder_layer = nn.TransformerEncoderLayer(
41            d_model=hidden_dim,
42            nhead=num_heads,
43            dim_feedforward=hidden_dim * 4,
44            dropout=0.1,
45            activation="gelu",
46            batch_first=True,
47        )
48        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
49
50        # Output projection
51        self.output_proj = nn.Linear(hidden_dim, motion_dim)
52
53    def forward(
54        self,
55        motion: torch.Tensor,  # [B, T, motion_dim] noisy motion
56        t: torch.Tensor,  # [B] timesteps
57        text_embeddings: torch.Tensor,  # [B, hidden_dim] text condition
58        mask: Optional[torch.Tensor] = None,  # [B, T] length mask
59    ) -> torch.Tensor:
60        batch_size, seq_len, _ = motion.shape
61
62        # Embed motion
63        h = self.motion_embed(motion)  # [B, T, hidden_dim]
64
65        # Add positional encoding
66        h = h + self.pos_embed[:, :seq_len]
67
68        # Add time embedding
69        t_emb = self.time_embed(t).unsqueeze(1)  # [B, 1, hidden_dim]
70        h = h + t_emb
71
72        # Add text condition as first token
73        text_token = self.text_proj(text_embeddings).unsqueeze(1)
74        h = torch.cat([text_token, h], dim=1)
75
76        # Create attention mask
77        if mask is not None:
78            # Extend mask for text token
79            text_mask = torch.ones(batch_size, 1, device=mask.device)
80            mask = torch.cat([text_mask, mask], dim=1)
81
82        # Apply transformer
83        h = self.transformer(h, src_key_padding_mask=~mask if mask else None)
84
85        # Remove text token and project to motion
86        h = h[:, 1:]  # [B, T, hidden_dim]
87        noise_pred = self.output_proj(h)
88
89        return noise_pred
90
91
92def motion_loss(
93    pred: torch.Tensor,
94    target: torch.Tensor,
95    mask: torch.Tensor,
96    use_geometric_losses: bool = True,
97) -> dict[str, torch.Tensor]:
98    """Compute motion generation losses with geometric constraints."""
99    losses = {}
100
101    # Basic reconstruction loss
102    losses["recon"] = (mask.unsqueeze(-1) * (pred - target) ** 2).sum() / mask.sum()
103
104    if use_geometric_losses:
105        # Velocity loss (smooth motion)
106        pred_vel = pred[:, 1:] - pred[:, :-1]
107        target_vel = target[:, 1:] - target[:, :-1]
108        vel_mask = mask[:, 1:] & mask[:, :-1]
109        losses["velocity"] = (
110            vel_mask.unsqueeze(-1) * (pred_vel - target_vel) ** 2
111        ).sum() / vel_mask.sum()
112
113        # Foot contact loss (grounded feet when contact)
114        # Assuming last 4 dims are foot positions
115        foot_positions = pred[..., -4:]
116        foot_vel = torch.abs(foot_positions[:, 1:] - foot_positions[:, :-1])
117
118        # Binary foot contact from ground truth
119        foot_contact = target[..., -8:-4]  # Binary contact labels
120        losses["foot_contact"] = (
121            foot_contact[:, :-1] * foot_vel
122        ).mean()
123
124    return losses
125
126
127@torch.no_grad()
128def generate_motion(
129    model: MotionDiffusionModel,
130    text: str,
131    num_frames: int = 120,
132    num_inference_steps: int = 50,
133    guidance_scale: float = 2.5,
134) -> torch.Tensor:
135    """Generate motion from text description."""
136    device = next(model.parameters()).device
137
138    # Encode text
139    text_emb = model.text_encoder(text)
140    null_emb = model.text_encoder("")
141
142    # Initialize from noise
143    motion = torch.randn(1, num_frames, model.motion_dim, device=device)
144    mask = torch.ones(1, num_frames, device=device, dtype=torch.bool)
145
146    # DDPM sampling
147    betas = torch.linspace(1e-4, 0.02, 1000, device=device)
148    alphas = 1 - betas
149    alpha_bars = torch.cumprod(alphas, dim=0)
150
151    timesteps = torch.linspace(999, 0, num_inference_steps, dtype=torch.long)
152
153    for i, t in enumerate(timesteps):
154        t_batch = t.unsqueeze(0).to(device)
155
156        # Classifier-free guidance
157        noise_cond = model(motion, t_batch, text_emb, mask)
158        noise_uncond = model(motion, t_batch, null_emb, mask)
159        noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
160
161        # DDPM update
162        alpha_bar = alpha_bars[t]
163        alpha_bar_prev = alpha_bars[t - 1] if t > 0 else torch.tensor(1.0)
164        beta = betas[t]
165
166        # Predict x0
167        x0_pred = (motion - torch.sqrt(1 - alpha_bar) * noise_pred) / torch.sqrt(alpha_bar)
168        x0_pred = torch.clamp(x0_pred, -1, 1)
169
170        # Compute mean
171        mean = (
172            torch.sqrt(alpha_bar_prev) * beta / (1 - alpha_bar) * x0_pred +
173            torch.sqrt(alphas[t]) * (1 - alpha_bar_prev) / (1 - alpha_bar) * motion
174        )
175
176        # Add noise (except for last step)
177        if t > 0:
178            noise = torch.randn_like(motion)
179            variance = beta * (1 - alpha_bar_prev) / (1 - alpha_bar)
180            motion = mean + torch.sqrt(variance) * noise
181        else:
182            motion = mean
183
184    return motion

Robotics and Planning

Robotics represents a natural application for diffusion models, which can generate action sequences that are diverse, multimodal, and respect physical constraints. Unlike reinforcement learning, diffusion-based policies can model complex, multi-modal action distributions.

Diffusion Policy for Robot Learning

Diffusion Policy (Chi et al., 2023) learns robot manipulation skills from demonstrations by modeling the action distribution with diffusion. Key advantages include:

  • Multimodal actions: Can represent multiple valid solutions to a task
  • Temporal consistency: Generates coherent action sequences rather than single actions
  • Strong conditioning: Effective at learning from visual observations
🐍python
1import torch
2import torch.nn as nn
3from typing import Dict
4
5class DiffusionPolicy(nn.Module):
6    """Diffusion Policy for robot manipulation."""
7
8    def __init__(
9        self,
10        action_dim: int = 7,  # 6-DoF + gripper
11        obs_dim: int = 512,  # Visual observation embedding
12        action_horizon: int = 16,  # Number of actions to predict
13        obs_horizon: int = 2,  # Number of observation frames
14        hidden_dim: int = 256,
15    ):
16        super().__init__()
17        self.action_dim = action_dim
18        self.action_horizon = action_horizon
19        self.obs_horizon = obs_horizon
20
21        # Visual encoder (ResNet + spatial softmax)
22        self.visual_encoder = VisualEncoder(output_dim=obs_dim)
23
24        # 1D U-Net for action sequence
25        self.unet = ConditionalUNet1D(
26            input_dim=action_dim,
27            cond_dim=obs_dim * obs_horizon,
28            hidden_dim=hidden_dim,
29        )
30
31        # Noise schedule
32        self.register_buffer(
33            "betas",
34            torch.linspace(1e-4, 0.02, 100)
35        )
36        self.register_buffer("alphas", 1 - self.betas)
37        self.register_buffer("alpha_bars", torch.cumprod(self.alphas, dim=0))
38
39    def forward(
40        self,
41        actions: torch.Tensor,  # [B, T, action_dim]
42        t: torch.Tensor,  # [B]
43        obs: torch.Tensor,  # [B, obs_horizon, C, H, W]
44    ) -> torch.Tensor:
45        # Encode observations
46        batch_size = obs.shape[0]
47        obs_flat = obs.flatten(0, 1)  # [B * obs_horizon, C, H, W]
48        obs_embed = self.visual_encoder(obs_flat)
49        obs_embed = obs_embed.reshape(batch_size, -1)  # [B, obs_horizon * obs_dim]
50
51        # Predict noise
52        noise_pred = self.unet(actions, t, obs_embed)
53
54        return noise_pred
55
56    @torch.no_grad()
57    def predict_action(
58        self,
59        obs: torch.Tensor,  # [B, obs_horizon, C, H, W]
60        num_inference_steps: int = 10,
61    ) -> torch.Tensor:
62        """Predict action sequence from observations."""
63        device = obs.device
64        batch_size = obs.shape[0]
65
66        # Initialize from noise
67        actions = torch.randn(
68            batch_size, self.action_horizon, self.action_dim,
69            device=device
70        )
71
72        # DDPM sampling with subset of timesteps
73        timesteps = torch.linspace(99, 0, num_inference_steps, dtype=torch.long)
74
75        for t in timesteps:
76            t_batch = t.expand(batch_size).to(device)
77            noise_pred = self(actions, t_batch, obs)
78
79            # DDPM update
80            alpha_bar = self.alpha_bars[t]
81            alpha_bar_prev = self.alpha_bars[t - 1] if t > 0 else 1.0
82            beta = self.betas[t]
83
84            x0_pred = (actions - torch.sqrt(1 - alpha_bar) * noise_pred) / torch.sqrt(alpha_bar)
85            x0_pred = torch.clamp(x0_pred, -1, 1)
86
87            if t > 0:
88                noise = torch.randn_like(actions)
89                variance = beta * (1 - alpha_bar_prev) / (1 - alpha_bar)
90                actions = (
91                    torch.sqrt(alpha_bar_prev) * x0_pred +
92                    torch.sqrt(1 - alpha_bar_prev - variance) * noise_pred +
93                    torch.sqrt(variance) * noise
94                )
95            else:
96                actions = x0_pred
97
98        return actions
99
100
101class ConditionalUNet1D(nn.Module):
102    """1D U-Net for action sequence denoising."""
103
104    def __init__(
105        self,
106        input_dim: int,
107        cond_dim: int,
108        hidden_dim: int = 256,
109    ):
110        super().__init__()
111
112        # Time embedding
113        self.time_embed = nn.Sequential(
114            SinusoidalEmbedding(hidden_dim),
115            nn.Linear(hidden_dim, hidden_dim),
116            nn.Mish(),
117            nn.Linear(hidden_dim, hidden_dim),
118        )
119
120        # Condition projection
121        self.cond_proj = nn.Sequential(
122            nn.Linear(cond_dim, hidden_dim),
123            nn.Mish(),
124            nn.Linear(hidden_dim, hidden_dim),
125        )
126
127        # Encoder
128        self.encoder = nn.ModuleList([
129            Conv1DBlock(input_dim, hidden_dim // 2, hidden_dim),
130            Conv1DBlock(hidden_dim // 2, hidden_dim, hidden_dim),
131            Conv1DBlock(hidden_dim, hidden_dim * 2, hidden_dim),
132        ])
133
134        # Decoder
135        self.decoder = nn.ModuleList([
136            Conv1DBlock(hidden_dim * 2 + hidden_dim, hidden_dim, hidden_dim),
137            Conv1DBlock(hidden_dim + hidden_dim // 2, hidden_dim // 2, hidden_dim),
138            Conv1DBlock(hidden_dim // 2 + input_dim, hidden_dim // 2, hidden_dim),
139        ])
140
141        self.output = nn.Conv1d(hidden_dim // 2, input_dim, 1)
142
143    def forward(
144        self,
145        x: torch.Tensor,  # [B, T, input_dim]
146        t: torch.Tensor,  # [B]
147        cond: torch.Tensor,  # [B, cond_dim]
148    ) -> torch.Tensor:
149        x = x.permute(0, 2, 1)  # [B, input_dim, T]
150
151        # Embeddings
152        t_emb = self.time_embed(t)  # [B, hidden_dim]
153        cond_emb = self.cond_proj(cond)  # [B, hidden_dim]
154        global_emb = t_emb + cond_emb
155
156        # Encoder
157        skip_connections = [x]
158        h = x
159        for encoder in self.encoder:
160            h = encoder(h, global_emb)
161            skip_connections.append(h)
162
163        # Decoder
164        for i, decoder in enumerate(self.decoder):
165            skip = skip_connections[-(i + 2)]
166            h = torch.cat([h, skip], dim=1)
167            h = decoder(h, global_emb)
168
169        output = self.output(h).permute(0, 2, 1)  # [B, T, input_dim]
170        return output
171
172
173class Conv1DBlock(nn.Module):
174    """1D convolution block with conditioning."""
175
176    def __init__(self, in_ch: int, out_ch: int, cond_dim: int):
177        super().__init__()
178        self.conv1 = nn.Conv1d(in_ch, out_ch, 5, padding=2)
179        self.conv2 = nn.Conv1d(out_ch, out_ch, 5, padding=2)
180        self.norm1 = nn.GroupNorm(8, out_ch)
181        self.norm2 = nn.GroupNorm(8, out_ch)
182        self.cond_proj = nn.Linear(cond_dim, out_ch)
183        self.residual = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
184
185    def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
186        h = self.conv1(x)
187        h = self.norm1(h)
188        h = h + self.cond_proj(cond).unsqueeze(-1)
189        h = nn.functional.mish(h)
190        h = self.conv2(h)
191        h = self.norm2(h)
192        h = nn.functional.mish(h)
193        return h + self.residual(x)

Trajectory Planning with Diffusion

Diffuser (Janner et al., 2022) uses diffusion for trajectory planning, generating entire state-action sequences that optimize rewards while respecting dynamics:

🐍python
1class Diffuser(nn.Module):
2    """Diffusion-based trajectory planning."""
3
4    def __init__(
5        self,
6        state_dim: int,
7        action_dim: int,
8        horizon: int = 32,
9        hidden_dim: int = 256,
10    ):
11        super().__init__()
12        self.state_dim = state_dim
13        self.action_dim = action_dim
14        self.transition_dim = state_dim + action_dim
15        self.horizon = horizon
16
17        # Temporal U-Net
18        self.unet = TemporalUNet(
19            transition_dim=self.transition_dim,
20            hidden_dim=hidden_dim,
21        )
22
23        # Value function for guidance
24        self.value_function = nn.Sequential(
25            nn.Linear(state_dim, hidden_dim),
26            nn.Mish(),
27            nn.Linear(hidden_dim, hidden_dim),
28            nn.Mish(),
29            nn.Linear(hidden_dim, 1),
30        )
31
32    def forward(
33        self,
34        trajectories: torch.Tensor,  # [B, H, state_dim + action_dim]
35        t: torch.Tensor,
36    ) -> torch.Tensor:
37        return self.unet(trajectories, t)
38
39    @torch.no_grad()
40    def plan(
41        self,
42        start_state: torch.Tensor,  # [B, state_dim]
43        num_inference_steps: int = 20,
44        guidance_scale: float = 1.0,
45    ) -> torch.Tensor:
46        """Plan trajectory from start state with value guidance."""
47        device = start_state.device
48        batch_size = start_state.shape[0]
49
50        # Initialize random trajectory
51        trajectory = torch.randn(
52            batch_size, self.horizon, self.transition_dim,
53            device=device
54        )
55
56        # Fix start state
57        trajectory[:, 0, :self.state_dim] = start_state
58
59        timesteps = torch.linspace(99, 0, num_inference_steps, dtype=torch.long)
60
61        for t in timesteps:
62            t_batch = t.expand(batch_size).to(device)
63
64            # Enable gradients for guidance
65            trajectory.requires_grad_(True)
66
67            # Value guidance on terminal state
68            terminal_state = trajectory[:, -1, :self.state_dim]
69            value = self.value_function(terminal_state).sum()
70            grad = torch.autograd.grad(value, trajectory)[0]
71
72            trajectory = trajectory.detach()
73
74            # Diffusion update
75            noise_pred = self(trajectory, t_batch)
76            noise_pred = noise_pred - guidance_scale * grad
77
78            # DDPM step
79            trajectory = self.ddpm_step(trajectory, noise_pred, t)
80
81            # Re-apply constraints
82            trajectory[:, 0, :self.state_dim] = start_state
83
84        return trajectory

Other Emerging Domains

Weather Prediction

Generative weather models use diffusion to produce ensemble forecasts that capture uncertainty. Google's GenCast (2024) outperforms traditional numerical weather prediction for medium-range forecasts:

ModelTypeLead TimeKey Innovation
GenCastDiffusion15 daysEnsemble generation, uncertainty
GraphCastGNN10 daysGraph neural network on sphere
Pangu-WeatherTransformer7 days3D Earth-specific attention
FourCastNetFourier7 daysAdaptive Fourier Neural Operator

Materials Science

Diffusion models are being applied to crystal structure generationand materials discovery. Unlike molecules, crystals have periodic boundary conditions and must respect space group symmetries:

🐍python
1class CrystalDiffusion(nn.Module):
2    """Diffusion for periodic crystal structure generation."""
3
4    def __init__(
5        self,
6        num_atom_types: int = 100,
7        hidden_dim: int = 256,
8        num_layers: int = 6,
9    ):
10        super().__init__()
11
12        # Lattice parameters (a, b, c, alpha, beta, gamma)
13        self.lattice_dim = 6
14
15        # Atom features
16        self.atom_embed = nn.Embedding(num_atom_types, hidden_dim)
17
18        # Periodic-aware message passing
19        self.layers = nn.ModuleList([
20            PeriodicMessagePassing(hidden_dim)
21            for _ in range(num_layers)
22        ])
23
24        # Output heads
25        self.coord_head = nn.Linear(hidden_dim, 3)  # Fractional coordinates
26        self.lattice_head = nn.Linear(hidden_dim, self.lattice_dim)
27
28    def forward(
29        self,
30        frac_coords: torch.Tensor,  # [B, N, 3] fractional coordinates
31        atom_types: torch.Tensor,   # [B, N]
32        lattice: torch.Tensor,      # [B, 6] lattice parameters
33        t: torch.Tensor,
34    ) -> tuple[torch.Tensor, torch.Tensor]:
35        # Convert to Cartesian for message passing
36        cart_coords = frac_to_cart(frac_coords, lattice)
37
38        # Embed atoms
39        h = self.atom_embed(atom_types)
40
41        # Message passing with periodic images
42        for layer in self.layers:
43            h = layer(h, cart_coords, lattice)
44
45        # Predict noise
46        coord_noise = self.coord_head(h)
47        lattice_noise = self.lattice_head(h.mean(dim=1))
48
49        return coord_noise, lattice_noise
50
51
52def frac_to_cart(frac_coords: torch.Tensor, lattice: torch.Tensor) -> torch.Tensor:
53    """Convert fractional to Cartesian coordinates."""
54    # lattice: [B, 6] -> (a, b, c, alpha, beta, gamma)
55    a, b, c = lattice[:, 0:1], lattice[:, 1:2], lattice[:, 2:3]
56    alpha, beta, gamma = lattice[:, 3:4], lattice[:, 4:5], lattice[:, 5:6]
57
58    # Build transformation matrix
59    cos_alpha = torch.cos(alpha * torch.pi / 180)
60    cos_beta = torch.cos(beta * torch.pi / 180)
61    cos_gamma = torch.cos(gamma * torch.pi / 180)
62    sin_gamma = torch.sin(gamma * torch.pi / 180)
63
64    # Lattice vectors
65    volume = a * b * c * torch.sqrt(
66        1 - cos_alpha**2 - cos_beta**2 - cos_gamma**2 +
67        2 * cos_alpha * cos_beta * cos_gamma
68    )
69
70    matrix = torch.stack([
71        torch.cat([a, torch.zeros_like(a), torch.zeros_like(a)], dim=-1),
72        torch.cat([b * cos_gamma, b * sin_gamma, torch.zeros_like(b)], dim=-1),
73        torch.cat([
74            c * cos_beta,
75            c * (cos_alpha - cos_beta * cos_gamma) / sin_gamma,
76            volume / (a * b * sin_gamma)
77        ], dim=-1),
78    ], dim=1)  # [B, 3, 3]
79
80    return torch.einsum("bij,bnj->bni", matrix, frac_coords)

References

Key papers for diffusion models beyond images:

  • Liu et al. (2023). "AudioLDM: Text-to-Audio Generation with Latent Diffusion Models"
  • Copet et al. (2023). "Simple and Controllable Music Generation" (MusicGen)
  • Kong et al. (2021). "DiffWave: A Versatile Diffusion Model for Audio Synthesis"
  • Hoogeboom et al. (2022). "Equivariant Diffusion for Molecule Generation in 3D"
  • Watson et al. (2023). "De novo design of protein structure and function with RFdiffusion"
  • Corso et al. (2023). "DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking"
  • Tevet et al. (2023). "Human Motion Diffusion Model"
  • Chi et al. (2023). "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
  • Janner et al. (2022). "Planning with Diffusion for Flexible Behavior Synthesis"
  • Price et al. (2024). "GenCast: Diffusion-based ensemble forecasting for medium-range weather"

The Generality of Diffusion

The remarkable success of diffusion models across such diverse domains underscores the generality of the framework. The key insight is that any data type where we can define a meaningful noise process and measure distances is amenable to diffusion. As researchers continue to explore new domains, we can expect diffusion models to become a foundational tool across science and engineering.