Introduction
While diffusion models achieved their initial success in image generation, the underlying framework is remarkably general. The same principles - gradually adding noise and learning to reverse the process - can be applied to virtually any data type where we can define a meaningful noise process. This has led to an explosion of applications beyond images, including audio, molecules, human motion, and robot behavior.
The key insight is that diffusion models learn the data distribution without requiring explicit likelihood computation, making them applicable to complex, high-dimensional spaces where traditional generative models struggle. In this section, we explore how researchers have adapted diffusion models to diverse domains, each with unique challenges and representations.
| Domain | Data Type | Key Challenge | Notable Models |
|---|---|---|---|
| Audio | Spectrograms, waveforms | Temporal coherence, perceptual quality | AudioLDM, MusicGen, Stable Audio |
| Molecules | 3D coordinates, graphs | Physical constraints, stability | EDM, DiffDock, RFDiffusion |
| Motion | Joint trajectories | Biomechanical plausibility | MDM, MotionDiffuse, MoMask |
| Robotics | Action sequences | Physical feasibility, real-time | Diffusion Policy, Diffuser |
| Weather | Spatial fields | Multi-scale dynamics | GenCast, GraphCast |
Audio Generation
Audio generation with diffusion models has become a major research area, with applications ranging from sound effects to music to speech. The key insight is that audio can be represented as spectrograms (2D images) or directly as waveforms, both amenable to diffusion.
AudioLDM: Text-to-Audio
AudioLDM (Liu et al., 2023) applies latent diffusion to audio generation. Similar to Stable Diffusion for images, AudioLDM operates in a compressed latent space learned by a VAE, making generation faster and more efficient.
The architecture processes mel-spectrograms through a VAE encoder, applies diffusion in latent space, and decodes back to spectrograms. A vocoder (like HiFi-GAN) then converts spectrograms to waveforms:
1import torch
2import torch.nn as nn
3from typing import Optional
4
5class AudioLDM(nn.Module):
6 """AudioLDM for text-to-audio generation."""
7
8 def __init__(
9 self,
10 latent_dim: int = 8,
11 mel_channels: int = 80,
12 hidden_dim: int = 256,
13 num_layers: int = 12,
14 ):
15 super().__init__()
16
17 # VAE for mel-spectrogram compression
18 self.vae = MelVAE(
19 mel_channels=mel_channels,
20 latent_dim=latent_dim,
21 )
22
23 # Latent diffusion U-Net (1D convolutions for temporal data)
24 self.unet = AudioUNet(
25 in_channels=latent_dim,
26 out_channels=latent_dim,
27 hidden_dim=hidden_dim,
28 num_layers=num_layers,
29 )
30
31 # Text encoder (CLAP or T5)
32 self.text_encoder = CLAPTextEncoder()
33
34 # Vocoder for spectrogram to waveform
35 self.vocoder = HiFiGAN()
36
37 def encode_text(self, text: str) -> torch.Tensor:
38 """Encode text prompt to conditioning embeddings."""
39 return self.text_encoder(text)
40
41 def forward(
42 self,
43 mel_spec: torch.Tensor, # [B, mel_channels, T]
44 t: torch.Tensor,
45 text_embeddings: torch.Tensor,
46 ) -> torch.Tensor:
47 # Encode to latent
48 latent = self.vae.encode(mel_spec)
49
50 # Apply diffusion model
51 noise_pred = self.unet(latent, t, text_embeddings)
52
53 return noise_pred
54
55 @torch.no_grad()
56 def generate(
57 self,
58 text: str,
59 duration: float = 5.0, # seconds
60 num_inference_steps: int = 50,
61 guidance_scale: float = 3.0,
62 ) -> torch.Tensor:
63 """Generate audio from text prompt."""
64 device = next(self.parameters()).device
65
66 # Calculate latent dimensions
67 latent_length = int(duration * 100 / 4) # 100 Hz mel, 4x compression
68 latent = torch.randn(1, self.latent_dim, latent_length, device=device)
69
70 # Encode text
71 text_emb = self.encode_text(text)
72 null_emb = self.encode_text("")
73
74 # Diffusion sampling with CFG
75 for i, t in enumerate(reversed(range(num_inference_steps))):
76 t_tensor = torch.full((1,), t, device=device)
77
78 # Classifier-free guidance
79 noise_cond = self.unet(latent, t_tensor, text_emb)
80 noise_uncond = self.unet(latent, t_tensor, null_emb)
81 noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
82
83 # DDPM update step
84 latent = self.ddpm_step(latent, noise_pred, t)
85
86 # Decode to mel-spectrogram
87 mel_spec = self.vae.decode(latent)
88
89 # Convert to waveform
90 waveform = self.vocoder(mel_spec)
91
92 return waveform
93
94
95class MelVAE(nn.Module):
96 """VAE for mel-spectrogram compression."""
97
98 def __init__(self, mel_channels: int = 80, latent_dim: int = 8):
99 super().__init__()
100
101 # Encoder: mel -> latent
102 self.encoder = nn.Sequential(
103 nn.Conv1d(mel_channels, 128, 3, 2, 1),
104 nn.ReLU(),
105 nn.Conv1d(128, 256, 3, 2, 1),
106 nn.ReLU(),
107 nn.Conv1d(256, latent_dim * 2, 3, 1, 1), # mean and logvar
108 )
109
110 # Decoder: latent -> mel
111 self.decoder = nn.Sequential(
112 nn.ConvTranspose1d(latent_dim, 256, 4, 2, 1),
113 nn.ReLU(),
114 nn.ConvTranspose1d(256, 128, 4, 2, 1),
115 nn.ReLU(),
116 nn.Conv1d(128, mel_channels, 3, 1, 1),
117 )
118
119 def encode(self, mel: torch.Tensor) -> torch.Tensor:
120 h = self.encoder(mel)
121 mean, logvar = h.chunk(2, dim=1)
122 std = torch.exp(0.5 * logvar)
123 z = mean + std * torch.randn_like(std)
124 return z
125
126 def decode(self, z: torch.Tensor) -> torch.Tensor:
127 return self.decoder(z)Music Generation
Music generation presents unique challenges beyond general audio: long-term structure, harmonic consistency, and rhythmic coherence. Several approaches have emerged:
- MusicGen (Meta, 2023): Uses a transformer with delayed pattern for efficient multi-stream audio tokenization
- Stable Audio (Stability AI, 2023): Latent diffusion with timing conditioning for variable-length generation
- Riffusion: Fine-tuned Stable Diffusion on spectrograms for music generation
1class MusicDiffusion(nn.Module):
2 """Diffusion model for music generation with structure conditioning."""
3
4 def __init__(
5 self,
6 audio_channels: int = 2, # Stereo
7 hidden_dim: int = 512,
8 num_layers: int = 16,
9 ):
10 super().__init__()
11
12 # Waveform encoder (EnCodec-style)
13 self.encoder = AudioEncoder(
14 in_channels=audio_channels,
15 out_channels=hidden_dim // 4,
16 num_codebooks=4,
17 )
18
19 # Diffusion transformer
20 self.transformer = DiffusionTransformer(
21 hidden_dim=hidden_dim,
22 num_layers=num_layers,
23 num_heads=8,
24 )
25
26 # Timing embedding (for variable length)
27 self.timing_embed = nn.Sequential(
28 nn.Linear(2, hidden_dim), # [start_time, total_duration]
29 nn.SiLU(),
30 nn.Linear(hidden_dim, hidden_dim),
31 )
32
33 # Music-specific conditioning
34 self.genre_embed = nn.Embedding(100, hidden_dim) # 100 genres
35 self.tempo_embed = nn.Linear(1, hidden_dim) # BPM
36 self.key_embed = nn.Embedding(24, hidden_dim) # 12 keys x 2 modes
37
38 def forward(
39 self,
40 audio_tokens: torch.Tensor, # [B, num_codebooks, T]
41 t: torch.Tensor,
42 text_embeddings: torch.Tensor,
43 timing: Optional[torch.Tensor] = None,
44 genre: Optional[torch.Tensor] = None,
45 tempo: Optional[torch.Tensor] = None,
46 key: Optional[torch.Tensor] = None,
47 ) -> torch.Tensor:
48 # Encode audio
49 x = self.encoder.encode_tokens(audio_tokens)
50
51 # Add timing conditioning
52 if timing is not None:
53 x = x + self.timing_embed(timing).unsqueeze(1)
54
55 # Add music-specific conditioning
56 if genre is not None:
57 x = x + self.genre_embed(genre).unsqueeze(1)
58 if tempo is not None:
59 x = x + self.tempo_embed(tempo).unsqueeze(1)
60 if key is not None:
61 x = x + self.key_embed(key).unsqueeze(1)
62
63 # Apply transformer
64 noise_pred = self.transformer(x, t, text_embeddings)
65
66 return noise_pred
67
68
69class AudioEncoder(nn.Module):
70 """Neural audio codec for music compression."""
71
72 def __init__(
73 self,
74 in_channels: int = 2,
75 out_channels: int = 128,
76 num_codebooks: int = 4,
77 codebook_size: int = 1024,
78 ):
79 super().__init__()
80 self.num_codebooks = num_codebooks
81
82 # Convolutional encoder
83 self.encoder = nn.Sequential(
84 nn.Conv1d(in_channels, 64, 7, 1, 3),
85 nn.ELU(),
86 ResBlock1D(64),
87 nn.Conv1d(64, 128, 4, 2, 1), # Downsample
88 nn.ELU(),
89 ResBlock1D(128),
90 nn.Conv1d(128, 256, 4, 2, 1),
91 nn.ELU(),
92 ResBlock1D(256),
93 nn.Conv1d(256, out_channels * num_codebooks, 4, 2, 1),
94 )
95
96 # Residual vector quantizers
97 self.quantizers = nn.ModuleList([
98 VectorQuantizer(out_channels, codebook_size)
99 for _ in range(num_codebooks)
100 ])
101
102 def encode(self, audio: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
103 """Encode audio to discrete tokens."""
104 h = self.encoder(audio)
105 h = h.reshape(h.shape[0], self.num_codebooks, -1, h.shape[-1])
106
107 tokens = []
108 for i, quantizer in enumerate(self.quantizers):
109 token = quantizer.encode(h[:, i])
110 tokens.append(token)
111
112 return torch.stack(tokens, dim=1) # [B, num_codebooks, T]Speech Synthesis and Enhancement
Speech synthesis has also benefited from diffusion models, with applications in text-to-speech (TTS), voice conversion, and speech enhancement. Notable systems include:
- Grad-TTS: Score-based TTS with duration prediction
- DiffWave: Direct waveform generation with dilated convolutions
- VoiceBox (Meta): Large-scale speech generation with in-context learning
1class DiffWave(nn.Module):
2 """DiffWave for waveform generation."""
3
4 def __init__(
5 self,
6 residual_channels: int = 64,
7 dilation_cycle: int = 10,
8 num_layers: int = 30,
9 ):
10 super().__init__()
11
12 self.input_conv = nn.Conv1d(1, residual_channels, 1)
13
14 # WaveNet-style dilated convolutions
15 self.residual_layers = nn.ModuleList()
16 for i in range(num_layers):
17 dilation = 2 ** (i % dilation_cycle)
18 self.residual_layers.append(
19 ResidualBlock(
20 residual_channels,
21 dilation=dilation,
22 )
23 )
24
25 self.output_conv = nn.Sequential(
26 nn.Conv1d(residual_channels, residual_channels, 1),
27 nn.ReLU(),
28 nn.Conv1d(residual_channels, 1, 1),
29 )
30
31 # Diffusion embeddings
32 self.diffusion_embed = SinusoidalEmbedding(residual_channels)
33 self.diffusion_proj = nn.Linear(residual_channels, residual_channels)
34
35 # Conditioning (e.g., mel-spectrogram for vocoding)
36 self.cond_proj = nn.Conv1d(80, residual_channels, 1)
37
38 def forward(
39 self,
40 x: torch.Tensor, # [B, 1, T] noisy waveform
41 t: torch.Tensor, # [B] diffusion timesteps
42 condition: Optional[torch.Tensor] = None, # [B, 80, T//256]
43 ) -> torch.Tensor:
44 # Embed diffusion timestep
45 t_emb = self.diffusion_embed(t)
46 t_emb = self.diffusion_proj(t_emb)
47
48 # Upsample condition if provided
49 if condition is not None:
50 condition = nn.functional.interpolate(
51 condition, size=x.shape[-1], mode="linear"
52 )
53 condition = self.cond_proj(condition)
54
55 # Process through residual layers
56 h = self.input_conv(x)
57 skip_connections = []
58
59 for layer in self.residual_layers:
60 h, skip = layer(h, t_emb, condition)
61 skip_connections.append(skip)
62
63 # Sum skip connections
64 out = sum(skip_connections) / len(skip_connections) ** 0.5
65
66 return self.output_conv(out)
67
68
69class ResidualBlock(nn.Module):
70 """WaveNet residual block with diffusion conditioning."""
71
72 def __init__(self, channels: int, dilation: int):
73 super().__init__()
74
75 self.dilated_conv = nn.Conv1d(
76 channels, 2 * channels, 3,
77 padding=dilation, dilation=dilation
78 )
79 self.diffusion_proj = nn.Linear(channels, channels)
80 self.cond_proj = nn.Conv1d(channels, 2 * channels, 1)
81 self.output_proj = nn.Conv1d(channels, 2 * channels, 1)
82
83 def forward(
84 self,
85 x: torch.Tensor,
86 t_emb: torch.Tensor,
87 condition: Optional[torch.Tensor],
88 ) -> tuple[torch.Tensor, torch.Tensor]:
89 h = self.dilated_conv(x)
90
91 # Add diffusion timestep
92 h = h + self.diffusion_proj(t_emb).unsqueeze(-1)
93
94 # Add condition
95 if condition is not None:
96 h = h + self.cond_proj(condition)
97
98 # Gated activation
99 gate, filter = h.chunk(2, dim=1)
100 h = torch.sigmoid(gate) * torch.tanh(filter)
101
102 # Output projection
103 out = self.output_proj(h)
104 residual, skip = out.chunk(2, dim=1)
105
106 return x + residual, skipMolecular Generation
Molecular generation is a transformative application of diffusion models in drug discovery and materials science. Molecules can be represented as 3D point clouds (atom positions), graphs (atoms and bonds), or combinations of both. The key challenge is ensuring generated molecules satisfy chemical constraints.
Diffusion for Drug Discovery
E(3)-equivariant diffusion models generate molecules while respecting the symmetries of 3D space (rotations and translations). This is crucial because molecular properties are invariant to rigid transformations.
The forward process adds noise to atom positions in 3D:
The model must be E(3)-equivariant, meaning predictions transform correctly under rotations and translations:
1import torch
2import torch.nn as nn
3from e3nn import o3
4from e3nn.nn import FullyConnectedNet
5
6class E3DiffusionMolecule(nn.Module):
7 """E(3)-equivariant diffusion for molecular generation."""
8
9 def __init__(
10 self,
11 num_atom_types: int = 100,
12 hidden_dim: int = 128,
13 num_layers: int = 6,
14 max_radius: float = 5.0,
15 ):
16 super().__init__()
17 self.max_radius = max_radius
18
19 # Atom type embedding
20 self.atom_embed = nn.Embedding(num_atom_types, hidden_dim)
21
22 # E(3)-equivariant message passing layers
23 self.layers = nn.ModuleList([
24 EquivariantMessagePassingLayer(
25 hidden_dim=hidden_dim,
26 max_radius=max_radius,
27 )
28 for _ in range(num_layers)
29 ])
30
31 # Output heads
32 self.pos_head = nn.Linear(hidden_dim, 3) # Position noise prediction
33 self.type_head = nn.Linear(hidden_dim, num_atom_types) # Type prediction
34
35 # Time embedding
36 self.time_embed = nn.Sequential(
37 SinusoidalEmbedding(hidden_dim),
38 nn.Linear(hidden_dim, hidden_dim),
39 nn.SiLU(),
40 )
41
42 def forward(
43 self,
44 positions: torch.Tensor, # [B, N, 3] atom positions
45 atom_types: torch.Tensor, # [B, N] atom types
46 t: torch.Tensor, # [B] timesteps
47 batch_idx: torch.Tensor, # [B * N] batch indices
48 ) -> tuple[torch.Tensor, torch.Tensor]:
49 batch_size, num_atoms, _ = positions.shape
50
51 # Flatten for processing
52 positions_flat = positions.reshape(-1, 3)
53 atom_types_flat = atom_types.reshape(-1)
54
55 # Initial embeddings
56 h = self.atom_embed(atom_types_flat)
57
58 # Add time embedding
59 t_emb = self.time_embed(t)
60 h = h + t_emb[batch_idx]
61
62 # Build edge indices based on distance
63 edge_index, edge_vec = self.build_edges(positions_flat, batch_idx)
64
65 # Message passing
66 for layer in self.layers:
67 h, positions_flat = layer(h, positions_flat, edge_index, edge_vec)
68
69 # Predict noise
70 pos_noise = self.pos_head(h).reshape(batch_size, num_atoms, 3)
71 type_logits = self.type_head(h).reshape(batch_size, num_atoms, -1)
72
73 return pos_noise, type_logits
74
75 def build_edges(
76 self,
77 positions: torch.Tensor,
78 batch_idx: torch.Tensor,
79 ) -> tuple[torch.Tensor, torch.Tensor]:
80 """Build edge indices based on distance cutoff."""
81 # Compute pairwise distances within each batch
82 diff = positions.unsqueeze(0) - positions.unsqueeze(1)
83 dist = torch.norm(diff, dim=-1)
84
85 # Create edges for atoms within cutoff
86 mask = (dist < self.max_radius) & (dist > 0)
87 mask = mask & (batch_idx.unsqueeze(0) == batch_idx.unsqueeze(1))
88
89 edge_index = mask.nonzero().T
90 edge_vec = diff[edge_index[0], edge_index[1]]
91
92 return edge_index, edge_vec
93
94
95class EquivariantMessagePassingLayer(nn.Module):
96 """E(3)-equivariant message passing."""
97
98 def __init__(self, hidden_dim: int, max_radius: float):
99 super().__init__()
100
101 # Spherical harmonics for angular information
102 self.spherical_harmonics = o3.SphericalHarmonics(
103 irreps_out="1o", # Vector representation
104 normalize=True,
105 )
106
107 # Message function
108 self.message_net = nn.Sequential(
109 nn.Linear(2 * hidden_dim + 16, hidden_dim),
110 nn.SiLU(),
111 nn.Linear(hidden_dim, hidden_dim),
112 )
113
114 # Radial basis for distance encoding
115 self.rbf = GaussianRadialBasis(max_radius=max_radius, num_basis=16)
116
117 # Update functions
118 self.update_h = nn.Sequential(
119 nn.Linear(2 * hidden_dim, hidden_dim),
120 nn.SiLU(),
121 nn.Linear(hidden_dim, hidden_dim),
122 )
123 self.update_pos = nn.Linear(hidden_dim, 1)
124
125 def forward(
126 self,
127 h: torch.Tensor, # [N, hidden_dim]
128 positions: torch.Tensor, # [N, 3]
129 edge_index: torch.Tensor, # [2, E]
130 edge_vec: torch.Tensor, # [E, 3]
131 ) -> tuple[torch.Tensor, torch.Tensor]:
132 src, dst = edge_index
133
134 # Compute edge features
135 edge_dist = torch.norm(edge_vec, dim=-1, keepdim=True)
136 edge_rbf = self.rbf(edge_dist.squeeze(-1))
137
138 # Message computation
139 h_src = h[src]
140 h_dst = h[dst]
141 message_input = torch.cat([h_src, h_dst, edge_rbf], dim=-1)
142 messages = self.message_net(message_input)
143
144 # Aggregate messages
145 aggregated = torch.zeros_like(h)
146 aggregated.index_add_(0, dst, messages)
147
148 # Update node features
149 h_new = self.update_h(torch.cat([h, aggregated], dim=-1))
150 h = h + h_new
151
152 # Update positions (equivariant)
153 pos_update_weight = self.update_pos(messages)
154 edge_vec_normalized = edge_vec / (edge_dist + 1e-8)
155 pos_updates = edge_vec_normalized * pos_update_weight
156
157 pos_aggregated = torch.zeros_like(positions)
158 pos_aggregated.index_add_(0, dst, pos_updates)
159 positions = positions + pos_aggregated
160
161 return h, positionsProtein Structure Generation
Protein structure prediction has been revolutionized by AI, with diffusion models now used for both structure prediction and de novo protein design. RFDiffusion (Watson et al., 2023) generates novel protein backbones that fold into desired shapes:
1class ProteinDiffusion(nn.Module):
2 """Diffusion model for protein backbone generation."""
3
4 def __init__(
5 self,
6 num_residues: int = 100,
7 hidden_dim: int = 256,
8 num_layers: int = 12,
9 ):
10 super().__init__()
11
12 # Backbone is represented as frames (rotation + translation per residue)
13 self.frame_dim = 12 # 3x3 rotation + 3 translation
14
15 # SE(3)-equivariant transformer (IPA-style)
16 self.ipa_layers = nn.ModuleList([
17 InvariantPointAttention(
18 hidden_dim=hidden_dim,
19 num_heads=8,
20 num_query_points=8,
21 num_value_points=8,
22 )
23 for _ in range(num_layers)
24 ])
25
26 # Single representation
27 self.single_embed = nn.Linear(20, hidden_dim) # 20 amino acids
28
29 # Pair representation
30 self.pair_embed = nn.Linear(hidden_dim, hidden_dim)
31
32 # Frame update
33 self.frame_update = nn.Linear(hidden_dim, 6) # axis-angle + translation
34
35 # Time embedding
36 self.time_embed = SinusoidalEmbedding(hidden_dim)
37
38 def forward(
39 self,
40 frames: torch.Tensor, # [B, L, 4, 4] backbone frames
41 sequence: torch.Tensor, # [B, L] amino acid sequence
42 t: torch.Tensor, # [B] timesteps
43 ) -> torch.Tensor:
44 batch_size, num_residues = sequence.shape
45
46 # Embed sequence
47 single = self.single_embed(
48 nn.functional.one_hot(sequence, 20).float()
49 )
50
51 # Add time embedding
52 t_emb = self.time_embed(t).unsqueeze(1)
53 single = single + t_emb
54
55 # Compute pair features from frames
56 pair = self.compute_pair_features(frames)
57 pair = self.pair_embed(pair)
58
59 # Apply IPA layers
60 for ipa in self.ipa_layers:
61 single = ipa(single, pair, frames)
62
63 # Predict frame updates
64 frame_update = self.frame_update(single) # [B, L, 6]
65
66 # Convert to rotation and translation
67 axis_angle = frame_update[..., :3]
68 translation = frame_update[..., 3:]
69
70 return axis_angle, translation
71
72 def compute_pair_features(self, frames: torch.Tensor) -> torch.Tensor:
73 """Compute pairwise features from backbone frames."""
74 batch_size, num_residues = frames.shape[:2]
75
76 # Extract positions (last column of frame matrix)
77 positions = frames[..., :3, 3] # [B, L, 3]
78
79 # Pairwise distances
80 diff = positions.unsqueeze(2) - positions.unsqueeze(1) # [B, L, L, 3]
81 dist = torch.norm(diff, dim=-1, keepdim=True) # [B, L, L, 1]
82
83 # Relative orientation features
84 rotations = frames[..., :3, :3] # [B, L, 3, 3]
85 rel_rot = torch.einsum(
86 "bijk,bilk->bijl",
87 rotations, rotations
88 ) # [B, L, L, 3]
89
90 # Combine features
91 pair = torch.cat([dist, diff, rel_rot.flatten(-2)], dim=-1)
92
93 return pairMotion Synthesis
Motion synthesis generates human motion sequences from various conditions like text descriptions, music, or action labels. The data is typically represented as sequences of joint positions or rotations over time.
Human Motion Generation
Human motion is typically represented using skeletal formats like SMPL (Skinned Multi-Person Linear Model), which parameterizes body pose with joint angles plus global translation parameters:
where is the number of frames and is the number of joints.
Motion Diffusion Model (MDM)
The Motion Diffusion Model (Tevet et al., 2023) applies diffusion to motion generation with text conditioning. Key design choices include:
- Simple transformer architecture: No U-Net, just a transformer encoder
- Geometric losses: Foot contact and velocity constraints for physical plausibility
- Classifier-free guidance: For text-to-motion generation
1import torch
2import torch.nn as nn
3
4class MotionDiffusionModel(nn.Module):
5 """Motion Diffusion Model for text-to-motion generation."""
6
7 def __init__(
8 self,
9 motion_dim: int = 263, # HumanML3D representation
10 hidden_dim: int = 512,
11 num_layers: int = 8,
12 num_heads: int = 4,
13 max_length: int = 196,
14 ):
15 super().__init__()
16 self.motion_dim = motion_dim
17 self.max_length = max_length
18
19 # Motion embedding
20 self.motion_embed = nn.Linear(motion_dim, hidden_dim)
21
22 # Positional encoding
23 self.pos_embed = nn.Parameter(
24 torch.randn(1, max_length, hidden_dim) * 0.02
25 )
26
27 # Time embedding
28 self.time_embed = nn.Sequential(
29 SinusoidalEmbedding(hidden_dim),
30 nn.Linear(hidden_dim, hidden_dim),
31 nn.SiLU(),
32 nn.Linear(hidden_dim, hidden_dim),
33 )
34
35 # Text encoder (CLIP)
36 self.text_encoder = CLIPTextEncoder()
37 self.text_proj = nn.Linear(512, hidden_dim) # CLIP dim -> hidden
38
39 # Transformer encoder
40 encoder_layer = nn.TransformerEncoderLayer(
41 d_model=hidden_dim,
42 nhead=num_heads,
43 dim_feedforward=hidden_dim * 4,
44 dropout=0.1,
45 activation="gelu",
46 batch_first=True,
47 )
48 self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
49
50 # Output projection
51 self.output_proj = nn.Linear(hidden_dim, motion_dim)
52
53 def forward(
54 self,
55 motion: torch.Tensor, # [B, T, motion_dim] noisy motion
56 t: torch.Tensor, # [B] timesteps
57 text_embeddings: torch.Tensor, # [B, hidden_dim] text condition
58 mask: Optional[torch.Tensor] = None, # [B, T] length mask
59 ) -> torch.Tensor:
60 batch_size, seq_len, _ = motion.shape
61
62 # Embed motion
63 h = self.motion_embed(motion) # [B, T, hidden_dim]
64
65 # Add positional encoding
66 h = h + self.pos_embed[:, :seq_len]
67
68 # Add time embedding
69 t_emb = self.time_embed(t).unsqueeze(1) # [B, 1, hidden_dim]
70 h = h + t_emb
71
72 # Add text condition as first token
73 text_token = self.text_proj(text_embeddings).unsqueeze(1)
74 h = torch.cat([text_token, h], dim=1)
75
76 # Create attention mask
77 if mask is not None:
78 # Extend mask for text token
79 text_mask = torch.ones(batch_size, 1, device=mask.device)
80 mask = torch.cat([text_mask, mask], dim=1)
81
82 # Apply transformer
83 h = self.transformer(h, src_key_padding_mask=~mask if mask else None)
84
85 # Remove text token and project to motion
86 h = h[:, 1:] # [B, T, hidden_dim]
87 noise_pred = self.output_proj(h)
88
89 return noise_pred
90
91
92def motion_loss(
93 pred: torch.Tensor,
94 target: torch.Tensor,
95 mask: torch.Tensor,
96 use_geometric_losses: bool = True,
97) -> dict[str, torch.Tensor]:
98 """Compute motion generation losses with geometric constraints."""
99 losses = {}
100
101 # Basic reconstruction loss
102 losses["recon"] = (mask.unsqueeze(-1) * (pred - target) ** 2).sum() / mask.sum()
103
104 if use_geometric_losses:
105 # Velocity loss (smooth motion)
106 pred_vel = pred[:, 1:] - pred[:, :-1]
107 target_vel = target[:, 1:] - target[:, :-1]
108 vel_mask = mask[:, 1:] & mask[:, :-1]
109 losses["velocity"] = (
110 vel_mask.unsqueeze(-1) * (pred_vel - target_vel) ** 2
111 ).sum() / vel_mask.sum()
112
113 # Foot contact loss (grounded feet when contact)
114 # Assuming last 4 dims are foot positions
115 foot_positions = pred[..., -4:]
116 foot_vel = torch.abs(foot_positions[:, 1:] - foot_positions[:, :-1])
117
118 # Binary foot contact from ground truth
119 foot_contact = target[..., -8:-4] # Binary contact labels
120 losses["foot_contact"] = (
121 foot_contact[:, :-1] * foot_vel
122 ).mean()
123
124 return losses
125
126
127@torch.no_grad()
128def generate_motion(
129 model: MotionDiffusionModel,
130 text: str,
131 num_frames: int = 120,
132 num_inference_steps: int = 50,
133 guidance_scale: float = 2.5,
134) -> torch.Tensor:
135 """Generate motion from text description."""
136 device = next(model.parameters()).device
137
138 # Encode text
139 text_emb = model.text_encoder(text)
140 null_emb = model.text_encoder("")
141
142 # Initialize from noise
143 motion = torch.randn(1, num_frames, model.motion_dim, device=device)
144 mask = torch.ones(1, num_frames, device=device, dtype=torch.bool)
145
146 # DDPM sampling
147 betas = torch.linspace(1e-4, 0.02, 1000, device=device)
148 alphas = 1 - betas
149 alpha_bars = torch.cumprod(alphas, dim=0)
150
151 timesteps = torch.linspace(999, 0, num_inference_steps, dtype=torch.long)
152
153 for i, t in enumerate(timesteps):
154 t_batch = t.unsqueeze(0).to(device)
155
156 # Classifier-free guidance
157 noise_cond = model(motion, t_batch, text_emb, mask)
158 noise_uncond = model(motion, t_batch, null_emb, mask)
159 noise_pred = noise_uncond + guidance_scale * (noise_cond - noise_uncond)
160
161 # DDPM update
162 alpha_bar = alpha_bars[t]
163 alpha_bar_prev = alpha_bars[t - 1] if t > 0 else torch.tensor(1.0)
164 beta = betas[t]
165
166 # Predict x0
167 x0_pred = (motion - torch.sqrt(1 - alpha_bar) * noise_pred) / torch.sqrt(alpha_bar)
168 x0_pred = torch.clamp(x0_pred, -1, 1)
169
170 # Compute mean
171 mean = (
172 torch.sqrt(alpha_bar_prev) * beta / (1 - alpha_bar) * x0_pred +
173 torch.sqrt(alphas[t]) * (1 - alpha_bar_prev) / (1 - alpha_bar) * motion
174 )
175
176 # Add noise (except for last step)
177 if t > 0:
178 noise = torch.randn_like(motion)
179 variance = beta * (1 - alpha_bar_prev) / (1 - alpha_bar)
180 motion = mean + torch.sqrt(variance) * noise
181 else:
182 motion = mean
183
184 return motionRobotics and Planning
Robotics represents a natural application for diffusion models, which can generate action sequences that are diverse, multimodal, and respect physical constraints. Unlike reinforcement learning, diffusion-based policies can model complex, multi-modal action distributions.
Diffusion Policy for Robot Learning
Diffusion Policy (Chi et al., 2023) learns robot manipulation skills from demonstrations by modeling the action distribution with diffusion. Key advantages include:
- Multimodal actions: Can represent multiple valid solutions to a task
- Temporal consistency: Generates coherent action sequences rather than single actions
- Strong conditioning: Effective at learning from visual observations
1import torch
2import torch.nn as nn
3from typing import Dict
4
5class DiffusionPolicy(nn.Module):
6 """Diffusion Policy for robot manipulation."""
7
8 def __init__(
9 self,
10 action_dim: int = 7, # 6-DoF + gripper
11 obs_dim: int = 512, # Visual observation embedding
12 action_horizon: int = 16, # Number of actions to predict
13 obs_horizon: int = 2, # Number of observation frames
14 hidden_dim: int = 256,
15 ):
16 super().__init__()
17 self.action_dim = action_dim
18 self.action_horizon = action_horizon
19 self.obs_horizon = obs_horizon
20
21 # Visual encoder (ResNet + spatial softmax)
22 self.visual_encoder = VisualEncoder(output_dim=obs_dim)
23
24 # 1D U-Net for action sequence
25 self.unet = ConditionalUNet1D(
26 input_dim=action_dim,
27 cond_dim=obs_dim * obs_horizon,
28 hidden_dim=hidden_dim,
29 )
30
31 # Noise schedule
32 self.register_buffer(
33 "betas",
34 torch.linspace(1e-4, 0.02, 100)
35 )
36 self.register_buffer("alphas", 1 - self.betas)
37 self.register_buffer("alpha_bars", torch.cumprod(self.alphas, dim=0))
38
39 def forward(
40 self,
41 actions: torch.Tensor, # [B, T, action_dim]
42 t: torch.Tensor, # [B]
43 obs: torch.Tensor, # [B, obs_horizon, C, H, W]
44 ) -> torch.Tensor:
45 # Encode observations
46 batch_size = obs.shape[0]
47 obs_flat = obs.flatten(0, 1) # [B * obs_horizon, C, H, W]
48 obs_embed = self.visual_encoder(obs_flat)
49 obs_embed = obs_embed.reshape(batch_size, -1) # [B, obs_horizon * obs_dim]
50
51 # Predict noise
52 noise_pred = self.unet(actions, t, obs_embed)
53
54 return noise_pred
55
56 @torch.no_grad()
57 def predict_action(
58 self,
59 obs: torch.Tensor, # [B, obs_horizon, C, H, W]
60 num_inference_steps: int = 10,
61 ) -> torch.Tensor:
62 """Predict action sequence from observations."""
63 device = obs.device
64 batch_size = obs.shape[0]
65
66 # Initialize from noise
67 actions = torch.randn(
68 batch_size, self.action_horizon, self.action_dim,
69 device=device
70 )
71
72 # DDPM sampling with subset of timesteps
73 timesteps = torch.linspace(99, 0, num_inference_steps, dtype=torch.long)
74
75 for t in timesteps:
76 t_batch = t.expand(batch_size).to(device)
77 noise_pred = self(actions, t_batch, obs)
78
79 # DDPM update
80 alpha_bar = self.alpha_bars[t]
81 alpha_bar_prev = self.alpha_bars[t - 1] if t > 0 else 1.0
82 beta = self.betas[t]
83
84 x0_pred = (actions - torch.sqrt(1 - alpha_bar) * noise_pred) / torch.sqrt(alpha_bar)
85 x0_pred = torch.clamp(x0_pred, -1, 1)
86
87 if t > 0:
88 noise = torch.randn_like(actions)
89 variance = beta * (1 - alpha_bar_prev) / (1 - alpha_bar)
90 actions = (
91 torch.sqrt(alpha_bar_prev) * x0_pred +
92 torch.sqrt(1 - alpha_bar_prev - variance) * noise_pred +
93 torch.sqrt(variance) * noise
94 )
95 else:
96 actions = x0_pred
97
98 return actions
99
100
101class ConditionalUNet1D(nn.Module):
102 """1D U-Net for action sequence denoising."""
103
104 def __init__(
105 self,
106 input_dim: int,
107 cond_dim: int,
108 hidden_dim: int = 256,
109 ):
110 super().__init__()
111
112 # Time embedding
113 self.time_embed = nn.Sequential(
114 SinusoidalEmbedding(hidden_dim),
115 nn.Linear(hidden_dim, hidden_dim),
116 nn.Mish(),
117 nn.Linear(hidden_dim, hidden_dim),
118 )
119
120 # Condition projection
121 self.cond_proj = nn.Sequential(
122 nn.Linear(cond_dim, hidden_dim),
123 nn.Mish(),
124 nn.Linear(hidden_dim, hidden_dim),
125 )
126
127 # Encoder
128 self.encoder = nn.ModuleList([
129 Conv1DBlock(input_dim, hidden_dim // 2, hidden_dim),
130 Conv1DBlock(hidden_dim // 2, hidden_dim, hidden_dim),
131 Conv1DBlock(hidden_dim, hidden_dim * 2, hidden_dim),
132 ])
133
134 # Decoder
135 self.decoder = nn.ModuleList([
136 Conv1DBlock(hidden_dim * 2 + hidden_dim, hidden_dim, hidden_dim),
137 Conv1DBlock(hidden_dim + hidden_dim // 2, hidden_dim // 2, hidden_dim),
138 Conv1DBlock(hidden_dim // 2 + input_dim, hidden_dim // 2, hidden_dim),
139 ])
140
141 self.output = nn.Conv1d(hidden_dim // 2, input_dim, 1)
142
143 def forward(
144 self,
145 x: torch.Tensor, # [B, T, input_dim]
146 t: torch.Tensor, # [B]
147 cond: torch.Tensor, # [B, cond_dim]
148 ) -> torch.Tensor:
149 x = x.permute(0, 2, 1) # [B, input_dim, T]
150
151 # Embeddings
152 t_emb = self.time_embed(t) # [B, hidden_dim]
153 cond_emb = self.cond_proj(cond) # [B, hidden_dim]
154 global_emb = t_emb + cond_emb
155
156 # Encoder
157 skip_connections = [x]
158 h = x
159 for encoder in self.encoder:
160 h = encoder(h, global_emb)
161 skip_connections.append(h)
162
163 # Decoder
164 for i, decoder in enumerate(self.decoder):
165 skip = skip_connections[-(i + 2)]
166 h = torch.cat([h, skip], dim=1)
167 h = decoder(h, global_emb)
168
169 output = self.output(h).permute(0, 2, 1) # [B, T, input_dim]
170 return output
171
172
173class Conv1DBlock(nn.Module):
174 """1D convolution block with conditioning."""
175
176 def __init__(self, in_ch: int, out_ch: int, cond_dim: int):
177 super().__init__()
178 self.conv1 = nn.Conv1d(in_ch, out_ch, 5, padding=2)
179 self.conv2 = nn.Conv1d(out_ch, out_ch, 5, padding=2)
180 self.norm1 = nn.GroupNorm(8, out_ch)
181 self.norm2 = nn.GroupNorm(8, out_ch)
182 self.cond_proj = nn.Linear(cond_dim, out_ch)
183 self.residual = nn.Conv1d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
184
185 def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
186 h = self.conv1(x)
187 h = self.norm1(h)
188 h = h + self.cond_proj(cond).unsqueeze(-1)
189 h = nn.functional.mish(h)
190 h = self.conv2(h)
191 h = self.norm2(h)
192 h = nn.functional.mish(h)
193 return h + self.residual(x)Trajectory Planning with Diffusion
Diffuser (Janner et al., 2022) uses diffusion for trajectory planning, generating entire state-action sequences that optimize rewards while respecting dynamics:
1class Diffuser(nn.Module):
2 """Diffusion-based trajectory planning."""
3
4 def __init__(
5 self,
6 state_dim: int,
7 action_dim: int,
8 horizon: int = 32,
9 hidden_dim: int = 256,
10 ):
11 super().__init__()
12 self.state_dim = state_dim
13 self.action_dim = action_dim
14 self.transition_dim = state_dim + action_dim
15 self.horizon = horizon
16
17 # Temporal U-Net
18 self.unet = TemporalUNet(
19 transition_dim=self.transition_dim,
20 hidden_dim=hidden_dim,
21 )
22
23 # Value function for guidance
24 self.value_function = nn.Sequential(
25 nn.Linear(state_dim, hidden_dim),
26 nn.Mish(),
27 nn.Linear(hidden_dim, hidden_dim),
28 nn.Mish(),
29 nn.Linear(hidden_dim, 1),
30 )
31
32 def forward(
33 self,
34 trajectories: torch.Tensor, # [B, H, state_dim + action_dim]
35 t: torch.Tensor,
36 ) -> torch.Tensor:
37 return self.unet(trajectories, t)
38
39 @torch.no_grad()
40 def plan(
41 self,
42 start_state: torch.Tensor, # [B, state_dim]
43 num_inference_steps: int = 20,
44 guidance_scale: float = 1.0,
45 ) -> torch.Tensor:
46 """Plan trajectory from start state with value guidance."""
47 device = start_state.device
48 batch_size = start_state.shape[0]
49
50 # Initialize random trajectory
51 trajectory = torch.randn(
52 batch_size, self.horizon, self.transition_dim,
53 device=device
54 )
55
56 # Fix start state
57 trajectory[:, 0, :self.state_dim] = start_state
58
59 timesteps = torch.linspace(99, 0, num_inference_steps, dtype=torch.long)
60
61 for t in timesteps:
62 t_batch = t.expand(batch_size).to(device)
63
64 # Enable gradients for guidance
65 trajectory.requires_grad_(True)
66
67 # Value guidance on terminal state
68 terminal_state = trajectory[:, -1, :self.state_dim]
69 value = self.value_function(terminal_state).sum()
70 grad = torch.autograd.grad(value, trajectory)[0]
71
72 trajectory = trajectory.detach()
73
74 # Diffusion update
75 noise_pred = self(trajectory, t_batch)
76 noise_pred = noise_pred - guidance_scale * grad
77
78 # DDPM step
79 trajectory = self.ddpm_step(trajectory, noise_pred, t)
80
81 # Re-apply constraints
82 trajectory[:, 0, :self.state_dim] = start_state
83
84 return trajectoryOther Emerging Domains
Weather Prediction
Generative weather models use diffusion to produce ensemble forecasts that capture uncertainty. Google's GenCast (2024) outperforms traditional numerical weather prediction for medium-range forecasts:
| Model | Type | Lead Time | Key Innovation |
|---|---|---|---|
| GenCast | Diffusion | 15 days | Ensemble generation, uncertainty |
| GraphCast | GNN | 10 days | Graph neural network on sphere |
| Pangu-Weather | Transformer | 7 days | 3D Earth-specific attention |
| FourCastNet | Fourier | 7 days | Adaptive Fourier Neural Operator |
Materials Science
Diffusion models are being applied to crystal structure generationand materials discovery. Unlike molecules, crystals have periodic boundary conditions and must respect space group symmetries:
1class CrystalDiffusion(nn.Module):
2 """Diffusion for periodic crystal structure generation."""
3
4 def __init__(
5 self,
6 num_atom_types: int = 100,
7 hidden_dim: int = 256,
8 num_layers: int = 6,
9 ):
10 super().__init__()
11
12 # Lattice parameters (a, b, c, alpha, beta, gamma)
13 self.lattice_dim = 6
14
15 # Atom features
16 self.atom_embed = nn.Embedding(num_atom_types, hidden_dim)
17
18 # Periodic-aware message passing
19 self.layers = nn.ModuleList([
20 PeriodicMessagePassing(hidden_dim)
21 for _ in range(num_layers)
22 ])
23
24 # Output heads
25 self.coord_head = nn.Linear(hidden_dim, 3) # Fractional coordinates
26 self.lattice_head = nn.Linear(hidden_dim, self.lattice_dim)
27
28 def forward(
29 self,
30 frac_coords: torch.Tensor, # [B, N, 3] fractional coordinates
31 atom_types: torch.Tensor, # [B, N]
32 lattice: torch.Tensor, # [B, 6] lattice parameters
33 t: torch.Tensor,
34 ) -> tuple[torch.Tensor, torch.Tensor]:
35 # Convert to Cartesian for message passing
36 cart_coords = frac_to_cart(frac_coords, lattice)
37
38 # Embed atoms
39 h = self.atom_embed(atom_types)
40
41 # Message passing with periodic images
42 for layer in self.layers:
43 h = layer(h, cart_coords, lattice)
44
45 # Predict noise
46 coord_noise = self.coord_head(h)
47 lattice_noise = self.lattice_head(h.mean(dim=1))
48
49 return coord_noise, lattice_noise
50
51
52def frac_to_cart(frac_coords: torch.Tensor, lattice: torch.Tensor) -> torch.Tensor:
53 """Convert fractional to Cartesian coordinates."""
54 # lattice: [B, 6] -> (a, b, c, alpha, beta, gamma)
55 a, b, c = lattice[:, 0:1], lattice[:, 1:2], lattice[:, 2:3]
56 alpha, beta, gamma = lattice[:, 3:4], lattice[:, 4:5], lattice[:, 5:6]
57
58 # Build transformation matrix
59 cos_alpha = torch.cos(alpha * torch.pi / 180)
60 cos_beta = torch.cos(beta * torch.pi / 180)
61 cos_gamma = torch.cos(gamma * torch.pi / 180)
62 sin_gamma = torch.sin(gamma * torch.pi / 180)
63
64 # Lattice vectors
65 volume = a * b * c * torch.sqrt(
66 1 - cos_alpha**2 - cos_beta**2 - cos_gamma**2 +
67 2 * cos_alpha * cos_beta * cos_gamma
68 )
69
70 matrix = torch.stack([
71 torch.cat([a, torch.zeros_like(a), torch.zeros_like(a)], dim=-1),
72 torch.cat([b * cos_gamma, b * sin_gamma, torch.zeros_like(b)], dim=-1),
73 torch.cat([
74 c * cos_beta,
75 c * (cos_alpha - cos_beta * cos_gamma) / sin_gamma,
76 volume / (a * b * sin_gamma)
77 ], dim=-1),
78 ], dim=1) # [B, 3, 3]
79
80 return torch.einsum("bij,bnj->bni", matrix, frac_coords)References
Key papers for diffusion models beyond images:
- Liu et al. (2023). "AudioLDM: Text-to-Audio Generation with Latent Diffusion Models"
- Copet et al. (2023). "Simple and Controllable Music Generation" (MusicGen)
- Kong et al. (2021). "DiffWave: A Versatile Diffusion Model for Audio Synthesis"
- Hoogeboom et al. (2022). "Equivariant Diffusion for Molecule Generation in 3D"
- Watson et al. (2023). "De novo design of protein structure and function with RFdiffusion"
- Corso et al. (2023). "DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking"
- Tevet et al. (2023). "Human Motion Diffusion Model"
- Chi et al. (2023). "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion"
- Janner et al. (2022). "Planning with Diffusion for Flexible Behavior Synthesis"
- Price et al. (2024). "GenCast: Diffusion-based ensemble forecasting for medium-range weather"