Introduction
3D generation represents one of the most exciting frontiers for diffusion models. While image and video generation operate in 2D pixel space, 3D generation must contend with fundamentally different representations: point clouds, meshes, neural radiance fields, and volumetric grids. The success of diffusion models in 2D has inspired a wave of research adapting these techniques to 3D, with methods like DreamFusion, Magic3D, and Zero-1-to-3 achieving remarkable results in text-to-3D and single-image-to-3D generation.
The core challenge in 3D generation is the choice of representation. Unlike images, which have a natural grid structure, 3D data can be represented in many ways, each with different trade-offs:
| Representation | Advantages | Challenges |
|---|---|---|
| Neural Radiance Fields | Photorealistic rendering, view-consistent | Slow training, requires many views |
| Point Clouds | Simple, memory-efficient | No surface connectivity, hole artifacts |
| Meshes | Industry-standard, efficient rendering | Fixed topology, hard to optimize |
| Voxels | Regular grid structure | Memory-intensive (cubic scaling) |
| Triplanes | Compact, learnable | Lower resolution than explicit methods |
The 3D Generation Revolution: In 2023-2024, text-to-3D generation progressed from research curiosity to practical tools. Systems like Wonder3D, Instant3D, and DMV3D can now generate 3D assets in seconds rather than hours, enabling real applications in gaming, e-commerce, and content creation.
3D Representations for Diffusion
The choice of 3D representation fundamentally affects how we can apply diffusion models. Each representation leads to different architectures, noise processes, and generation strategies.
Neural Radiance Fields (NeRF)
Neural Radiance Fields represent 3D scenes as continuous functions mapping 3D coordinates and viewing directions to colors and densities. Given a point and viewing direction , a NeRF outputs:
where is the emitted color and is the volume density. Images are rendered using volume rendering:
where is the accumulated transmittance.
1import torch
2import torch.nn as nn
3
4class NeRF(nn.Module):
5 """Basic NeRF architecture for 3D representation."""
6
7 def __init__(
8 self,
9 pos_dim: int = 3,
10 dir_dim: int = 3,
11 hidden_dim: int = 256,
12 num_layers: int = 8,
13 skip_layer: int = 4,
14 ):
15 super().__init__()
16
17 # Positional encoding frequencies
18 self.pos_freq = 10
19 self.dir_freq = 4
20
21 pos_encoded_dim = pos_dim + 2 * self.pos_freq * pos_dim
22 dir_encoded_dim = dir_dim + 2 * self.dir_freq * dir_dim
23
24 # Position MLP
25 self.pos_layers = nn.ModuleList()
26 in_dim = pos_encoded_dim
27 for i in range(num_layers):
28 if i == skip_layer:
29 in_dim += pos_encoded_dim # Skip connection
30 self.pos_layers.append(nn.Linear(in_dim, hidden_dim))
31 in_dim = hidden_dim
32
33 # Density head
34 self.density_head = nn.Linear(hidden_dim, 1)
35
36 # Color MLP (depends on viewing direction)
37 self.feature_layer = nn.Linear(hidden_dim, hidden_dim)
38 self.color_layers = nn.Sequential(
39 nn.Linear(hidden_dim + dir_encoded_dim, hidden_dim // 2),
40 nn.ReLU(),
41 nn.Linear(hidden_dim // 2, 3),
42 nn.Sigmoid(),
43 )
44
45 def positional_encoding(self, x: torch.Tensor, num_freq: int) -> torch.Tensor:
46 """Apply positional encoding to input coordinates."""
47 encodings = [x]
48 for i in range(num_freq):
49 freq = 2 ** i * torch.pi
50 encodings.extend([torch.sin(freq * x), torch.cos(freq * x)])
51 return torch.cat(encodings, dim=-1)
52
53 def forward(
54 self,
55 positions: torch.Tensor, # [B, 3]
56 directions: torch.Tensor, # [B, 3]
57 ) -> tuple[torch.Tensor, torch.Tensor]:
58 # Encode positions and directions
59 pos_encoded = self.positional_encoding(positions, self.pos_freq)
60 dir_encoded = self.positional_encoding(directions, self.dir_freq)
61
62 # Process through position MLP
63 h = pos_encoded
64 for i, layer in enumerate(self.pos_layers):
65 if i == 4: # Skip connection
66 h = torch.cat([h, pos_encoded], dim=-1)
67 h = torch.relu(layer(h))
68
69 # Predict density
70 density = torch.relu(self.density_head(h))
71
72 # Predict color (view-dependent)
73 features = self.feature_layer(h)
74 color_input = torch.cat([features, dir_encoded], dim=-1)
75 color = self.color_layers(color_input)
76
77 return color, density
78
79
80def volume_render(
81 nerf: NeRF,
82 rays_o: torch.Tensor, # [B, 3] ray origins
83 rays_d: torch.Tensor, # [B, 3] ray directions
84 near: float = 0.1,
85 far: float = 4.0,
86 num_samples: int = 64,
87) -> torch.Tensor:
88 """Volume rendering along rays."""
89 batch_size = rays_o.shape[0]
90 device = rays_o.device
91
92 # Sample points along rays
93 t_vals = torch.linspace(near, far, num_samples, device=device)
94 t_vals = t_vals.unsqueeze(0).expand(batch_size, -1) # [B, num_samples]
95
96 # Add noise for stratified sampling
97 mids = 0.5 * (t_vals[..., 1:] + t_vals[..., :-1])
98 upper = torch.cat([mids, t_vals[..., -1:]], dim=-1)
99 lower = torch.cat([t_vals[..., :1], mids], dim=-1)
100 t_rand = torch.rand_like(t_vals)
101 t_vals = lower + (upper - lower) * t_rand
102
103 # Compute 3D points
104 points = rays_o.unsqueeze(1) + t_vals.unsqueeze(-1) * rays_d.unsqueeze(1)
105 points = points.reshape(-1, 3) # [B * num_samples, 3]
106
107 # Expand directions
108 dirs = rays_d.unsqueeze(1).expand(-1, num_samples, -1).reshape(-1, 3)
109
110 # Query NeRF
111 colors, densities = nerf(points, dirs)
112 colors = colors.reshape(batch_size, num_samples, 3)
113 densities = densities.reshape(batch_size, num_samples, 1)
114
115 # Volume rendering
116 deltas = t_vals[..., 1:] - t_vals[..., :-1]
117 deltas = torch.cat([deltas, torch.full_like(deltas[..., :1], 1e10)], dim=-1)
118 deltas = deltas.unsqueeze(-1)
119
120 alpha = 1 - torch.exp(-densities * deltas)
121 transmittance = torch.cumprod(1 - alpha + 1e-10, dim=1)
122 transmittance = torch.cat([
123 torch.ones_like(transmittance[:, :1]),
124 transmittance[:, :-1]
125 ], dim=1)
126
127 weights = alpha * transmittance
128 rgb = (weights * colors).sum(dim=1)
129
130 return rgbPoint Cloud Diffusion
Point clouds represent 3D shapes as unordered sets of 3D coordinates. This representation is memory-efficient and directly compatible with sensors like LiDAR, making it attractive for diffusion models. The key insight is that point clouds can be treated as sets, and diffusion can be applied directly to the coordinates:
The forward diffusion process adds Gaussian noise to point coordinates:
1import torch
2import torch.nn as nn
3from typing import Optional
4
5class PointCloudDiffusion(nn.Module):
6 """Diffusion model for point cloud generation."""
7
8 def __init__(
9 self,
10 point_dim: int = 3,
11 hidden_dim: int = 256,
12 num_points: int = 2048,
13 num_layers: int = 6,
14 num_heads: int = 8,
15 time_embed_dim: int = 128,
16 ):
17 super().__init__()
18 self.num_points = num_points
19
20 # Time embedding
21 self.time_embed = nn.Sequential(
22 SinusoidalEmbedding(time_embed_dim),
23 nn.Linear(time_embed_dim, hidden_dim),
24 nn.SiLU(),
25 nn.Linear(hidden_dim, hidden_dim),
26 )
27
28 # Point embedding
29 self.point_embed = nn.Linear(point_dim, hidden_dim)
30
31 # Transformer encoder for set processing
32 encoder_layer = nn.TransformerEncoderLayer(
33 d_model=hidden_dim,
34 nhead=num_heads,
35 dim_feedforward=hidden_dim * 4,
36 dropout=0.1,
37 activation="gelu",
38 batch_first=True,
39 )
40 self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
41
42 # Output projection
43 self.output_proj = nn.Linear(hidden_dim, point_dim)
44
45 def forward(
46 self,
47 x: torch.Tensor, # [B, N, 3] noisy points
48 t: torch.Tensor, # [B] timesteps
49 condition: Optional[torch.Tensor] = None,
50 ) -> torch.Tensor:
51 batch_size, num_points, _ = x.shape
52
53 # Embed time
54 t_emb = self.time_embed(t) # [B, hidden_dim]
55 t_emb = t_emb.unsqueeze(1).expand(-1, num_points, -1)
56
57 # Embed points
58 h = self.point_embed(x) # [B, N, hidden_dim]
59 h = h + t_emb
60
61 # Add conditioning if provided
62 if condition is not None:
63 h = h + condition.unsqueeze(1)
64
65 # Process through transformer
66 h = self.transformer(h)
67
68 # Predict noise
69 noise_pred = self.output_proj(h) # [B, N, 3]
70
71 return noise_pred
72
73
74class SinusoidalEmbedding(nn.Module):
75 """Sinusoidal time embedding."""
76
77 def __init__(self, dim: int):
78 super().__init__()
79 self.dim = dim
80
81 def forward(self, t: torch.Tensor) -> torch.Tensor:
82 device = t.device
83 half_dim = self.dim // 2
84 emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
85 emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
86 emb = t.unsqueeze(-1) * emb.unsqueeze(0)
87 return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
88
89
90def train_point_cloud_diffusion(
91 model: PointCloudDiffusion,
92 dataloader: torch.utils.data.DataLoader,
93 num_epochs: int = 100,
94 lr: float = 1e-4,
95 num_timesteps: int = 1000,
96):
97 """Train point cloud diffusion model."""
98 optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
99 device = next(model.parameters()).device
100
101 # Noise schedule
102 betas = torch.linspace(1e-4, 0.02, num_timesteps, device=device)
103 alphas = 1.0 - betas
104 alpha_bars = torch.cumprod(alphas, dim=0)
105
106 for epoch in range(num_epochs):
107 for batch in dataloader:
108 points = batch["points"].to(device) # [B, N, 3]
109
110 # Sample timesteps
111 t = torch.randint(0, num_timesteps, (points.shape[0],), device=device)
112
113 # Sample noise
114 noise = torch.randn_like(points)
115
116 # Add noise to points
117 alpha_bar_t = alpha_bars[t].view(-1, 1, 1)
118 noisy_points = torch.sqrt(alpha_bar_t) * points + \
119 torch.sqrt(1 - alpha_bar_t) * noise
120
121 # Predict noise
122 noise_pred = model(noisy_points, t)
123
124 # Compute loss
125 loss = nn.functional.mse_loss(noise_pred, noise)
126
127 optimizer.zero_grad()
128 loss.backward()
129 optimizer.step()
130
131 print(f"Epoch {epoch}, Loss: {loss.item():.4f}")Mesh-Based Generation
Mesh generation is challenging because meshes have irregular topology: vertices connected by edges forming triangular faces. Recent work approaches this in several ways:
- Deform a template mesh: Start from a sphere or other template and learn vertex displacements
- Generate mesh tokens autoregressively: Treat vertices and faces as tokens, generate sequentially (MeshGPT, PolyGen)
- Extract mesh from implicit representation: Generate SDF or occupancy field, then apply Marching Cubes
1import torch
2import torch.nn as nn
3import mcubes # Marching cubes implementation
4
5class SDFDiffusion(nn.Module):
6 """Diffusion model for signed distance function generation."""
7
8 def __init__(
9 self,
10 resolution: int = 64,
11 hidden_dim: int = 256,
12 num_layers: int = 8,
13 ):
14 super().__init__()
15 self.resolution = resolution
16
17 # 3D U-Net for SDF prediction
18 self.unet = UNet3D(
19 in_channels=1,
20 out_channels=1,
21 hidden_dim=hidden_dim,
22 num_layers=num_layers,
23 )
24
25 # Time embedding
26 self.time_embed = nn.Sequential(
27 SinusoidalEmbedding(hidden_dim),
28 nn.Linear(hidden_dim, hidden_dim),
29 nn.SiLU(),
30 )
31
32 def forward(
33 self,
34 sdf: torch.Tensor, # [B, 1, D, H, W] noisy SDF grid
35 t: torch.Tensor, # [B] timesteps
36 ) -> torch.Tensor:
37 t_emb = self.time_embed(t)
38 return self.unet(sdf, t_emb)
39
40 def extract_mesh(
41 self,
42 sdf: torch.Tensor,
43 threshold: float = 0.0,
44 ) -> tuple[torch.Tensor, torch.Tensor]:
45 """Extract mesh from SDF using marching cubes."""
46 sdf_np = sdf[0, 0].cpu().numpy()
47
48 # Run marching cubes
49 vertices, triangles = mcubes.marching_cubes(sdf_np, threshold)
50
51 # Normalize vertices to [-1, 1]
52 vertices = vertices / self.resolution * 2 - 1
53
54 return torch.from_numpy(vertices), torch.from_numpy(triangles)
55
56
57class UNet3D(nn.Module):
58 """3D U-Net for volumetric processing."""
59
60 def __init__(
61 self,
62 in_channels: int,
63 out_channels: int,
64 hidden_dim: int,
65 num_layers: int,
66 ):
67 super().__init__()
68
69 # Encoder
70 self.encoders = nn.ModuleList([
71 nn.Conv3d(in_channels, hidden_dim, 3, padding=1),
72 *[
73 nn.Sequential(
74 nn.Conv3d(hidden_dim * (2 ** i), hidden_dim * (2 ** (i+1)), 3, 2, 1),
75 nn.GroupNorm(8, hidden_dim * (2 ** (i+1))),
76 nn.SiLU(),
77 )
78 for i in range(num_layers // 2)
79 ]
80 ])
81
82 # Decoder with skip connections
83 self.decoders = nn.ModuleList([
84 nn.Sequential(
85 nn.ConvTranspose3d(
86 hidden_dim * (2 ** (num_layers // 2 - i)),
87 hidden_dim * (2 ** (num_layers // 2 - i - 1)),
88 4, 2, 1
89 ),
90 nn.GroupNorm(8, hidden_dim * (2 ** (num_layers // 2 - i - 1))),
91 nn.SiLU(),
92 )
93 for i in range(num_layers // 2)
94 ])
95
96 self.output = nn.Conv3d(hidden_dim, out_channels, 3, padding=1)
97
98 def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
99 # Encoder
100 skip_connections = []
101 for encoder in self.encoders:
102 x = encoder(x)
103 skip_connections.append(x)
104
105 # Decoder with skip connections
106 for i, decoder in enumerate(self.decoders):
107 x = decoder(x)
108 if i < len(skip_connections) - 1:
109 x = x + skip_connections[-(i + 2)]
110
111 return self.output(x)Text-to-3D Generation
Text-to-3D generation bridges the gap between natural language descriptions and 3D content. The key innovation enabling this is Score Distillation Sampling (SDS), introduced in DreamFusion, which leverages pretrained 2D diffusion models to optimize 3D representations.
DreamFusion: Score Distillation Sampling
DreamFusion (Poole et al., 2022) introduced the idea of using a pretrained text-to-image diffusion model to guide the optimization of a 3D representation (specifically, a NeRF). The key insight is that we can compute gradients through the rendering process without backpropagating through the diffusion model itself.
Given a 3D representation parameterized by and a differentiable renderer , we render an image from a random camera pose:
The Score Distillation Sampling (SDS) loss is:
where is the noised rendering, is the text prompt, and is a weighting function.
1import torch
2import torch.nn as nn
3from diffusers import StableDiffusionPipeline
4
5class ScoreDistillationSampling:
6 """Score Distillation Sampling for text-to-3D generation."""
7
8 def __init__(
9 self,
10 model_name: str = "stabilityai/stable-diffusion-2-1",
11 guidance_scale: float = 100.0,
12 grad_clip: float = 1.0,
13 ):
14 self.pipe = StableDiffusionPipeline.from_pretrained(
15 model_name,
16 torch_dtype=torch.float16,
17 )
18 self.pipe.to("cuda")
19
20 # Freeze diffusion model
21 self.unet = self.pipe.unet
22 self.vae = self.pipe.vae
23 self.text_encoder = self.pipe.text_encoder
24 self.tokenizer = self.pipe.tokenizer
25
26 for param in self.unet.parameters():
27 param.requires_grad = False
28 for param in self.vae.parameters():
29 param.requires_grad = False
30 for param in self.text_encoder.parameters():
31 param.requires_grad = False
32
33 self.guidance_scale = guidance_scale
34 self.grad_clip = grad_clip
35
36 # Noise schedule
37 self.scheduler = self.pipe.scheduler
38 self.num_timesteps = self.scheduler.config.num_train_timesteps
39
40 @torch.no_grad()
41 def encode_text(self, prompt: str) -> torch.Tensor:
42 """Encode text prompt to embeddings."""
43 tokens = self.tokenizer(
44 prompt,
45 padding="max_length",
46 max_length=self.tokenizer.model_max_length,
47 return_tensors="pt",
48 ).input_ids.to("cuda")
49
50 return self.text_encoder(tokens)[0]
51
52 def compute_sds_loss(
53 self,
54 rendered_image: torch.Tensor, # [B, 3, H, W] in [-1, 1]
55 text_embeddings: torch.Tensor,
56 negative_embeddings: torch.Tensor,
57 t_range: tuple[float, float] = (0.02, 0.98),
58 ) -> torch.Tensor:
59 """Compute SDS loss for rendered images."""
60 batch_size = rendered_image.shape[0]
61 device = rendered_image.device
62
63 # Resize to latent resolution
64 rendered_image = nn.functional.interpolate(
65 rendered_image, (512, 512), mode="bilinear"
66 )
67
68 # Encode to latent space
69 with torch.no_grad():
70 latents = self.vae.encode(
71 rendered_image.half()
72 ).latent_dist.sample() * 0.18215
73
74 # Sample timestep
75 min_t = int(t_range[0] * self.num_timesteps)
76 max_t = int(t_range[1] * self.num_timesteps)
77 t = torch.randint(min_t, max_t, (batch_size,), device=device)
78
79 # Add noise
80 noise = torch.randn_like(latents)
81 noisy_latents = self.scheduler.add_noise(latents, noise, t)
82
83 # Predict noise with classifier-free guidance
84 latent_model_input = torch.cat([noisy_latents] * 2)
85 t_input = torch.cat([t] * 2)
86 text_input = torch.cat([negative_embeddings, text_embeddings])
87
88 with torch.no_grad():
89 noise_pred = self.unet(
90 latent_model_input.half(),
91 t_input,
92 encoder_hidden_states=text_input.half(),
93 ).sample
94
95 # Classifier-free guidance
96 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
97 noise_pred = noise_pred_uncond + self.guidance_scale * (
98 noise_pred_text - noise_pred_uncond
99 )
100
101 # SDS gradient (detach noise prediction)
102 w = (1 - self.scheduler.alphas_cumprod[t]).view(-1, 1, 1, 1).to(device)
103 grad = w * (noise_pred.float() - noise)
104
105 # Clip gradients
106 grad = torch.nan_to_num(grad)
107 grad = torch.clamp(grad, -self.grad_clip, self.grad_clip)
108
109 # Compute loss (gradient descent direction)
110 target = (latents - grad).detach()
111 loss = 0.5 * nn.functional.mse_loss(latents, target, reduction="sum") / batch_size
112
113 return loss
114
115
116def train_dreamfusion(
117 nerf: NeRF,
118 sds: ScoreDistillationSampling,
119 prompt: str,
120 num_iterations: int = 10000,
121 lr: float = 1e-3,
122):
123 """Train NeRF using DreamFusion approach."""
124 optimizer = torch.optim.Adam(nerf.parameters(), lr=lr)
125
126 # Encode prompts
127 text_embeddings = sds.encode_text(prompt)
128 negative_embeddings = sds.encode_text("")
129
130 for step in range(num_iterations):
131 # Sample random camera pose
132 camera = sample_random_camera()
133
134 # Render image from NeRF
135 rays_o, rays_d = get_rays(camera, image_size=64)
136 rendered_image = render_nerf(nerf, rays_o, rays_d) # [1, 3, 64, 64]
137
138 # Compute SDS loss
139 loss = sds.compute_sds_loss(
140 rendered_image,
141 text_embeddings,
142 negative_embeddings,
143 )
144
145 optimizer.zero_grad()
146 loss.backward()
147 optimizer.step()
148
149 if step % 100 == 0:
150 print(f"Step {step}, SDS Loss: {loss.item():.4f}")SDS Limitations
Magic3D: High-Resolution 3D
Magic3D (Lin et al., 2023) improves upon DreamFusion with a coarse-to-fine approach:
- Coarse stage: Optimize a low-resolution NeRF using SDS with the base diffusion model
- Fine stage: Convert to a textured mesh and optimize using SDS with a fine-tuned high-resolution diffusion model
This two-stage approach achieves 2x higher resolution and 8x faster generation compared to DreamFusion:
1class Magic3DPipeline:
2 """Two-stage coarse-to-fine 3D generation."""
3
4 def __init__(self):
5 # Coarse stage: NeRF with instant-NGP
6 self.coarse_nerf = InstantNGP(
7 resolution=64,
8 num_levels=16,
9 )
10
11 # Fine stage: DMTet mesh
12 self.fine_mesh = DMTet(
13 resolution=256,
14 grid_scale=1.0,
15 )
16
17 # Diffusion models
18 self.coarse_sds = ScoreDistillationSampling(
19 model_name="stabilityai/stable-diffusion-2-1",
20 )
21 self.fine_sds = ScoreDistillationSampling(
22 model_name="stabilityai/stable-diffusion-2-1-base",
23 )
24
25 def coarse_stage(
26 self,
27 prompt: str,
28 num_iterations: int = 5000,
29 ):
30 """Optimize coarse NeRF representation."""
31 optimizer = torch.optim.Adam(self.coarse_nerf.parameters(), lr=1e-2)
32
33 text_emb = self.coarse_sds.encode_text(prompt)
34 neg_emb = self.coarse_sds.encode_text("")
35
36 for step in range(num_iterations):
37 camera = sample_random_camera(radius_range=(1.5, 2.0))
38 image = self.coarse_nerf.render(camera, resolution=64)
39
40 loss = self.coarse_sds.compute_sds_loss(image, text_emb, neg_emb)
41
42 optimizer.zero_grad()
43 loss.backward()
44 optimizer.step()
45
46 return self.coarse_nerf
47
48 def fine_stage(
49 self,
50 prompt: str,
51 coarse_nerf: InstantNGP,
52 num_iterations: int = 3000,
53 ):
54 """Refine with textured mesh."""
55 # Initialize mesh from coarse NeRF
56 self.fine_mesh.initialize_from_nerf(coarse_nerf)
57
58 optimizer = torch.optim.Adam([
59 {"params": self.fine_mesh.vertices, "lr": 1e-4},
60 {"params": self.fine_mesh.texture, "lr": 1e-3},
61 ])
62
63 text_emb = self.fine_sds.encode_text(prompt)
64 neg_emb = self.fine_sds.encode_text("")
65
66 for step in range(num_iterations):
67 camera = sample_random_camera(radius_range=(1.5, 2.0))
68 image = self.fine_mesh.render(camera, resolution=512)
69
70 loss = self.fine_sds.compute_sds_loss(image, text_emb, neg_emb)
71
72 # Add mesh regularization
73 loss += 0.1 * self.fine_mesh.laplacian_smoothing()
74 loss += 0.01 * self.fine_mesh.normal_consistency()
75
76 optimizer.zero_grad()
77 loss.backward()
78 optimizer.step()
79
80 return self.fine_meshProlificDreamer: Variational Score Distillation
ProlificDreamer (Wang et al., 2023) addresses SDS's over-saturation by introducing Variational Score Distillation (VSD). Instead of using a fixed noise prediction, VSD trains a LoRA adapter on the diffusion model to match the distribution of rendered images:
where is a LoRA-adapted version of the diffusion model that learns to predict noise for the current 3D representation's renderings.
| Method | Quality | Speed | Key Innovation |
|---|---|---|---|
| DreamFusion | Good | ~1 hour | Score Distillation Sampling |
| Magic3D | High | ~40 min | Coarse-to-fine, mesh refinement |
| ProlificDreamer | Excellent | ~2 hours | Variational Score Distillation |
| DreamGaussian | Good | ~2 min | 3D Gaussian Splatting |
Multi-View Diffusion
Instead of optimizing a 3D representation from scratch, multi-view diffusion models generate consistent 2D views that can be lifted to 3D. This approach is often faster and produces more consistent results than SDS-based methods.
Zero-1-to-3: Novel View Synthesis
Zero-1-to-3 (Liu et al., 2023) fine-tunes Stable Diffusion to generate novel views of an object given a single input image and a camera transformation. The key is to condition the diffusion model on:
- The input image (via CLIP image embeddings)
- The relative camera transformation (azimuth, elevation, distance)
1import torch
2from diffusers import StableDiffusionPipeline
3from transformers import CLIPImageProcessor, CLIPVisionModel
4
5class Zero123:
6 """Zero-1-to-3 for novel view synthesis."""
7
8 def __init__(self, model_path: str):
9 # Load fine-tuned Stable Diffusion
10 self.pipe = StableDiffusionPipeline.from_pretrained(
11 model_path,
12 torch_dtype=torch.float16,
13 ).to("cuda")
14
15 # CLIP image encoder for conditioning
16 self.clip_processor = CLIPImageProcessor.from_pretrained(
17 "openai/clip-vit-large-patch14"
18 )
19 self.clip_encoder = CLIPVisionModel.from_pretrained(
20 "openai/clip-vit-large-patch14"
21 ).to("cuda")
22
23 def encode_image(self, image: torch.Tensor) -> torch.Tensor:
24 """Encode input image with CLIP."""
25 processed = self.clip_processor(images=image, return_tensors="pt")
26 embeddings = self.clip_encoder(
27 processed.pixel_values.to("cuda")
28 ).last_hidden_state
29 return embeddings
30
31 def encode_camera(
32 self,
33 azimuth: float, # Horizontal rotation in degrees
34 elevation: float, # Vertical rotation in degrees
35 distance: float, # Camera distance
36 ) -> torch.Tensor:
37 """Encode relative camera transformation."""
38 # Normalize to [-1, 1]
39 az_norm = azimuth / 180.0
40 el_norm = elevation / 90.0
41 dist_norm = (distance - 1.5) / 0.5 # Assume distance in [1, 2]
42
43 camera_emb = torch.tensor(
44 [az_norm, el_norm, dist_norm],
45 device="cuda",
46 dtype=torch.float16,
47 )
48 return camera_emb
49
50 @torch.no_grad()
51 def generate_novel_view(
52 self,
53 input_image: torch.Tensor,
54 azimuth: float,
55 elevation: float,
56 distance: float = 1.5,
57 num_inference_steps: int = 50,
58 ) -> torch.Tensor:
59 """Generate novel view given input image and camera pose."""
60 # Encode image
61 image_embeddings = self.encode_image(input_image)
62
63 # Encode camera
64 camera_embeddings = self.encode_camera(azimuth, elevation, distance)
65
66 # Combine embeddings (implementation detail varies)
67 conditioning = torch.cat([
68 image_embeddings,
69 camera_embeddings.unsqueeze(0).unsqueeze(0).expand(-1, 77, -1),
70 ], dim=-1)
71
72 # Generate
73 output = self.pipe(
74 prompt_embeds=conditioning,
75 num_inference_steps=num_inference_steps,
76 ).images[0]
77
78 return output
79
80
81def generate_multiview_set(
82 zero123: Zero123,
83 input_image: torch.Tensor,
84 num_views: int = 8,
85) -> list[torch.Tensor]:
86 """Generate consistent multi-view images."""
87 views = []
88 azimuths = torch.linspace(0, 360, num_views + 1)[:-1]
89
90 for az in azimuths:
91 view = zero123.generate_novel_view(
92 input_image,
93 azimuth=az.item(),
94 elevation=15.0, # Slight overhead view
95 )
96 views.append(view)
97
98 return viewsMVDream: Multi-View Consistent Generation
MVDream (Shi et al., 2023) generates multiple consistent views simultaneously by extending the diffusion model to output 4 views at once. This is achieved through:
- Multi-view attention: Cross-attention between views ensures consistency
- Camera conditioning: Each view is conditioned on its camera pose
- 3D self-attention: Features are exchanged between views at each layer
1class MultiViewAttention(nn.Module):
2 """Cross-attention between multiple views for consistency."""
3
4 def __init__(
5 self,
6 hidden_dim: int,
7 num_heads: int = 8,
8 num_views: int = 4,
9 ):
10 super().__init__()
11 self.num_views = num_views
12
13 self.self_attention = nn.MultiheadAttention(
14 hidden_dim, num_heads, batch_first=True
15 )
16 self.cross_view_attention = nn.MultiheadAttention(
17 hidden_dim, num_heads, batch_first=True
18 )
19 self.norm1 = nn.LayerNorm(hidden_dim)
20 self.norm2 = nn.LayerNorm(hidden_dim)
21
22 def forward(self, x: torch.Tensor) -> torch.Tensor:
23 # x: [B * num_views, seq_len, hidden_dim]
24 batch_times_views, seq_len, hidden_dim = x.shape
25 batch_size = batch_times_views // self.num_views
26
27 # Self-attention within each view
28 x = x + self.self_attention(
29 self.norm1(x), self.norm1(x), self.norm1(x)
30 )[0]
31
32 # Reshape for cross-view attention
33 x = x.view(batch_size, self.num_views, seq_len, hidden_dim)
34 x = x.permute(0, 2, 1, 3) # [B, seq_len, num_views, hidden_dim]
35 x = x.reshape(batch_size * seq_len, self.num_views, hidden_dim)
36
37 # Cross-view attention (each position attends to same position in other views)
38 x = x + self.cross_view_attention(
39 self.norm2(x), self.norm2(x), self.norm2(x)
40 )[0]
41
42 # Reshape back
43 x = x.view(batch_size, seq_len, self.num_views, hidden_dim)
44 x = x.permute(0, 2, 1, 3) # [B, num_views, seq_len, hidden_dim]
45 x = x.reshape(batch_times_views, seq_len, hidden_dim)
46
47 return x
48
49
50class MVDream(nn.Module):
51 """Multi-view diffusion model for 3D-consistent generation."""
52
53 def __init__(
54 self,
55 base_model: nn.Module,
56 num_views: int = 4,
57 ):
58 super().__init__()
59 self.base_model = base_model
60 self.num_views = num_views
61
62 # Add multi-view attention to each transformer block
63 self.mv_attentions = nn.ModuleList([
64 MultiViewAttention(
65 hidden_dim=block.hidden_dim,
66 num_views=num_views,
67 )
68 for block in self.base_model.transformer_blocks
69 ])
70
71 # Camera pose encoding
72 self.camera_encoder = nn.Sequential(
73 nn.Linear(12, 256), # 3x4 camera matrix
74 nn.SiLU(),
75 nn.Linear(256, 768), # Match text embedding dim
76 )
77
78 def forward(
79 self,
80 x: torch.Tensor, # [B, 4, C, H, W] - 4 views
81 t: torch.Tensor,
82 text_embeddings: torch.Tensor,
83 camera_poses: torch.Tensor, # [B, 4, 3, 4]
84 ) -> torch.Tensor:
85 batch_size = x.shape[0]
86
87 # Flatten views into batch
88 x = x.view(-1, *x.shape[2:]) # [B * 4, C, H, W]
89
90 # Encode camera poses
91 cam_emb = self.camera_encoder(
92 camera_poses.view(-1, 12)
93 ) # [B * 4, 768]
94
95 # Combine with text embeddings
96 text_embeddings = text_embeddings.unsqueeze(1).expand(-1, 4, -1, -1)
97 text_embeddings = text_embeddings.reshape(-1, *text_embeddings.shape[2:])
98 text_embeddings = text_embeddings + cam_emb.unsqueeze(1)
99
100 # Run through base model with multi-view attention
101 # (simplified - actual implementation modifies transformer blocks)
102 output = self.base_model(x, t, text_embeddings)
103
104 # Reshape back to views
105 output = output.view(batch_size, 4, *output.shape[1:])
106
107 return outputOne-2-3-45: Fast Single-Image to 3D
One-2-3-45 (Liu et al., 2023) combines multi-view generation with fast 3D reconstruction. Given a single image, it:
- Uses Zero-1-to-3 to generate multi-view images (45 degree increments)
- Applies a fast neural surface reconstruction network to lift views to 3D
- Achieves 3D generation in approximately 45 seconds
Speed vs Quality Trade-off: One-2-3-45++ further improves speed to under 10 seconds by using a feed-forward network instead of optimization, trading some quality for dramatic speed improvements.
3D-Native Diffusion Models
Rather than lifting 2D images to 3D, 3D-native diffusion modelsoperate directly on 3D representations, learning the data distribution in 3D space.
Triplane Diffusion
Triplanes provide a compact 3D representation by factorizing a 3D volume into three 2D feature planes (XY, XZ, YZ). Features at any 3D point are obtained by projecting onto each plane and aggregating:
1import torch
2import torch.nn as nn
3
4class TriplaneRepresentation(nn.Module):
5 """Triplane representation for efficient 3D encoding."""
6
7 def __init__(
8 self,
9 resolution: int = 256,
10 feature_dim: int = 32,
11 ):
12 super().__init__()
13 self.resolution = resolution
14 self.feature_dim = feature_dim
15
16 # Three feature planes: XY, XZ, YZ
17 self.planes = nn.ParameterList([
18 nn.Parameter(torch.randn(1, feature_dim, resolution, resolution) * 0.01)
19 for _ in range(3)
20 ])
21
22 # MLP decoder
23 self.decoder = nn.Sequential(
24 nn.Linear(feature_dim * 3, 128),
25 nn.ReLU(),
26 nn.Linear(128, 128),
27 nn.ReLU(),
28 nn.Linear(128, 4), # RGB + density
29 )
30
31 def sample_features(
32 self,
33 points: torch.Tensor, # [B, N, 3] in [-1, 1]
34 ) -> torch.Tensor:
35 batch_size, num_points, _ = points.shape
36
37 # Project to each plane
38 xy_coords = points[..., :2] # [B, N, 2]
39 xz_coords = points[..., [0, 2]]
40 yz_coords = points[..., [1, 2]]
41
42 # Grid sample from each plane
43 xy_features = nn.functional.grid_sample(
44 self.planes[0].expand(batch_size, -1, -1, -1),
45 xy_coords.view(batch_size, num_points, 1, 2),
46 mode="bilinear",
47 align_corners=True,
48 ).view(batch_size, self.feature_dim, num_points)
49
50 xz_features = nn.functional.grid_sample(
51 self.planes[1].expand(batch_size, -1, -1, -1),
52 xz_coords.view(batch_size, num_points, 1, 2),
53 mode="bilinear",
54 align_corners=True,
55 ).view(batch_size, self.feature_dim, num_points)
56
57 yz_features = nn.functional.grid_sample(
58 self.planes[2].expand(batch_size, -1, -1, -1),
59 yz_coords.view(batch_size, num_points, 1, 2),
60 mode="bilinear",
61 align_corners=True,
62 ).view(batch_size, self.feature_dim, num_points)
63
64 # Aggregate features
65 features = torch.cat([xy_features, xz_features, yz_features], dim=1)
66 features = features.permute(0, 2, 1) # [B, N, feature_dim * 3]
67
68 return features
69
70 def forward(self, points: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
71 features = self.sample_features(points)
72 output = self.decoder(features)
73
74 rgb = torch.sigmoid(output[..., :3])
75 density = torch.relu(output[..., 3:])
76
77 return rgb, density
78
79
80class TriplaneDiffusion(nn.Module):
81 """Diffusion model operating on triplane representations."""
82
83 def __init__(
84 self,
85 plane_resolution: int = 256,
86 feature_dim: int = 32,
87 ):
88 super().__init__()
89
90 # 2D U-Net for each plane (shared weights)
91 self.unet = UNet2D(
92 in_channels=feature_dim,
93 out_channels=feature_dim,
94 hidden_channels=128,
95 )
96
97 # Time embedding
98 self.time_embed = nn.Sequential(
99 SinusoidalEmbedding(256),
100 nn.Linear(256, 512),
101 nn.SiLU(),
102 nn.Linear(512, 512),
103 )
104
105 # Cross-plane attention for consistency
106 self.cross_plane_attention = nn.MultiheadAttention(
107 embed_dim=feature_dim,
108 num_heads=4,
109 batch_first=True,
110 )
111
112 def forward(
113 self,
114 planes: torch.Tensor, # [B, 3, C, H, W]
115 t: torch.Tensor,
116 condition: torch.Tensor | None = None,
117 ) -> torch.Tensor:
118 batch_size = planes.shape[0]
119 t_emb = self.time_embed(t)
120
121 # Process each plane
122 outputs = []
123 for i in range(3):
124 plane_out = self.unet(planes[:, i], t_emb, condition)
125 outputs.append(plane_out)
126
127 outputs = torch.stack(outputs, dim=1) # [B, 3, C, H, W]
128
129 # Cross-plane attention for consistency
130 h, w = outputs.shape[-2:]
131 flat_outputs = outputs.view(batch_size, 3, -1).permute(0, 2, 1)
132 attended = self.cross_plane_attention(
133 flat_outputs, flat_outputs, flat_outputs
134 )[0]
135 outputs = attended.permute(0, 2, 1).view(batch_size, 3, -1, h, w)
136
137 return outputsShape-E and Shap-E
OpenAI's Shap-E (Jun & Nichol, 2023) generates 3D assets by diffusing in a latent space of implicit neural representations. Key features include:
- Implicit function encoder: Maps 3D meshes/point clouds to latent codes
- Latent diffusion: Operates on compact latent representations
- Multi-modal conditioning: Supports both text and image conditioning
1from shap_e.diffusion.sample import sample_latents
2from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
3from shap_e.models.download import load_model, load_config
4from shap_e.util.notebooks import decode_latent_mesh
5
6def generate_3d_with_shap_e(
7 prompt: str,
8 batch_size: int = 1,
9 guidance_scale: float = 15.0,
10):
11 """Generate 3D mesh using Shap-E."""
12 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
14 # Load models
15 xm = load_model("transmitter", device=device)
16 model = load_model("text300M", device=device)
17 diffusion = diffusion_from_config(load_config("diffusion"))
18
19 # Generate latents
20 latents = sample_latents(
21 batch_size=batch_size,
22 model=model,
23 diffusion=diffusion,
24 guidance_scale=guidance_scale,
25 model_kwargs=dict(texts=[prompt] * batch_size),
26 progress=True,
27 clip_denoised=True,
28 use_fp16=True,
29 use_karras=True,
30 karras_steps=64,
31 )
32
33 # Decode to mesh
34 for i, latent in enumerate(latents):
35 mesh = decode_latent_mesh(xm, latent).tri_mesh()
36
37 # Export
38 with open(f"output_{i}.ply", "wb") as f:
39 mesh.write_ply(f)
40
41 return latents
42
43
44# Example usage
45latents = generate_3d_with_shap_e(
46 prompt="a chair that looks like an avocado"
47)Point-E: Fast Point Cloud Generation
Point-E (Nichol et al., 2022) generates colored point clouds in two stages:
- Text-to-image: Generate a synthetic view using GLIDE
- Image-to-point-cloud: Diffusion model generates point cloud conditioned on the image
While the quality is lower than optimization-based methods, Point-E generates 3D assets in approximately 1-2 minutes on a single GPU, making it practical for rapid prototyping.
Practical Implementation
Here's a complete example combining multiple techniques for practical 3D generation:
1import torch
2import torch.nn as nn
3from dataclasses import dataclass
4from typing import Optional
5
6@dataclass
7class Generation3DConfig:
8 """Configuration for 3D generation pipeline."""
9 method: str = "dreamfusion" # dreamfusion, magic3d, zero123
10 num_iterations: int = 10000
11 learning_rate: float = 1e-3
12 guidance_scale: float = 100.0
13 resolution: int = 64
14 export_format: str = "mesh" # mesh, point_cloud, nerf
15
16
17class Unified3DGenerator:
18 """Unified interface for 3D generation methods."""
19
20 def __init__(self, config: Generation3DConfig):
21 self.config = config
22 self.device = torch.device("cuda")
23
24 # Initialize based on method
25 if config.method == "dreamfusion":
26 self.generator = DreamFusionGenerator(config)
27 elif config.method == "magic3d":
28 self.generator = Magic3DGenerator(config)
29 elif config.method == "zero123":
30 self.generator = Zero123Generator(config)
31 else:
32 raise ValueError(f"Unknown method: {config.method}")
33
34 def generate(
35 self,
36 prompt: str,
37 image: Optional[torch.Tensor] = None,
38 callback: Optional[callable] = None,
39 ) -> dict:
40 """Generate 3D asset from text or image."""
41 result = self.generator.generate(
42 prompt=prompt,
43 image=image,
44 callback=callback,
45 )
46
47 # Export in requested format
48 if self.config.export_format == "mesh":
49 mesh = self.generator.export_mesh()
50 result["mesh"] = mesh
51 elif self.config.export_format == "point_cloud":
52 points = self.generator.export_point_cloud()
53 result["points"] = points
54
55 return result
56
57
58class DreamFusionGenerator:
59 """DreamFusion-style generation with SDS."""
60
61 def __init__(self, config: Generation3DConfig):
62 self.config = config
63 self.device = torch.device("cuda")
64
65 # Initialize NeRF
66 self.nerf = InstantNGP(
67 resolution=config.resolution,
68 bound=1.0,
69 ).to(self.device)
70
71 # Initialize SDS
72 self.sds = ScoreDistillationSampling(
73 guidance_scale=config.guidance_scale,
74 )
75
76 self.optimizer = torch.optim.Adam(
77 self.nerf.parameters(),
78 lr=config.learning_rate,
79 )
80
81 def generate(
82 self,
83 prompt: str,
84 image: Optional[torch.Tensor] = None,
85 callback: Optional[callable] = None,
86 ) -> dict:
87 # Encode text
88 text_emb = self.sds.encode_text(prompt)
89 neg_emb = self.sds.encode_text("")
90
91 losses = []
92
93 for step in range(self.config.num_iterations):
94 # Sample camera
95 camera = sample_orbit_camera(
96 radius=1.5,
97 elevation_range=(-30, 60),
98 )
99
100 # Render
101 image = self.nerf.render(camera)
102
103 # Compute SDS loss
104 loss = self.sds.compute_sds_loss(
105 image, text_emb, neg_emb,
106 t_range=(0.02, 0.98),
107 )
108
109 # Optimize
110 self.optimizer.zero_grad()
111 loss.backward()
112 self.optimizer.step()
113
114 losses.append(loss.item())
115
116 if callback and step % 100 == 0:
117 callback(step, loss.item(), image)
118
119 return {"losses": losses}
120
121 def export_mesh(self, resolution: int = 256) -> dict:
122 """Extract mesh using marching cubes."""
123 # Query density grid
124 grid = torch.linspace(-1, 1, resolution, device=self.device)
125 xx, yy, zz = torch.meshgrid(grid, grid, grid, indexing="ij")
126 points = torch.stack([xx, yy, zz], dim=-1).reshape(-1, 3)
127
128 # Evaluate density
129 with torch.no_grad():
130 _, density = self.nerf(
131 points,
132 torch.zeros_like(points),
133 )
134
135 density = density.reshape(resolution, resolution, resolution)
136
137 # Marching cubes
138 import mcubes
139 vertices, triangles = mcubes.marching_cubes(
140 density.cpu().numpy(),
141 threshold=10.0,
142 )
143
144 # Normalize vertices
145 vertices = vertices / resolution * 2 - 1
146
147 return {
148 "vertices": vertices,
149 "faces": triangles,
150 }
151
152
153# Example usage
154if __name__ == "__main__":
155 config = Generation3DConfig(
156 method="dreamfusion",
157 num_iterations=5000,
158 guidance_scale=100.0,
159 )
160
161 generator = Unified3DGenerator(config)
162
163 def progress_callback(step, loss, image):
164 print(f"Step {step}: Loss = {loss:.4f}")
165
166 result = generator.generate(
167 prompt="a detailed 3D model of a medieval castle",
168 callback=progress_callback,
169 )
170
171 # Export mesh
172 mesh = result["mesh"]
173 print(f"Generated mesh with {len(mesh['vertices'])} vertices")Practical Tips for 3D Generation:
- Start with lower resolution (64-128) for faster iteration, then increase
- Use view-dependent prompts: add "front view", "back view" to reduce Janus faces
- Combine methods: use fast Zero-1-to-3 for initial shape, refine with SDS
- Monitor from multiple viewpoints during optimization to catch consistency issues early
References
Key papers and resources for 3D diffusion models:
- Poole et al. (2022). "DreamFusion: Text-to-3D using 2D Diffusion" - Introduced Score Distillation Sampling
- Lin et al. (2023). "Magic3D: High-Resolution Text-to-3D Content Creation" - Coarse-to-fine optimization
- Wang et al. (2023). "ProlificDreamer: High-Fidelity and Diverse Text-to-3D Generation with Variational Score Distillation"
- Liu et al. (2023). "Zero-1-to-3: Zero-shot One Image to 3D Object" - Novel view synthesis with diffusion
- Shi et al. (2023). "MVDream: Multi-view Diffusion for 3D Generation" - Multi-view consistent generation
- Jun & Nichol (2023). "Shap-E: Generating Conditional 3D Implicit Functions" - Latent 3D diffusion
- Tang et al. (2023). "DreamGaussian: Generative Gaussian Splatting for Efficient 3D Content Creation"
- Mildenhall et al. (2020). "NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis"