Chapter 11
32 min read
Section 75 of 179

U-Net for Semantic Segmentation

CNN Architectures

Why classification CNNs can't segment

Every CNN we built in §§11.1–11.7 was designed to throw spatial information away. Pooling layers halve resolution; the final global-average-pool flattens whatever survives into a single channel-wise vector; the classifier head produces one number per class. That is the right design when the answer is one label for the whole image — "cat", "1000-class ImageNet index 281", or "pneumonia: yes/no".

Segmentation needs the opposite. The output is the same size as the input, with one label per pixel. A 572×572 cell-microscopy image must come back as a 572×572 mask of cell or not-cell. We need a network whose bottom layers throw away resolution to gain receptive field, then whose top layers recover that resolution while preserving the high-level reasoning.

The resolution-recovery problem. Classification CNNs are downsamplers. Segmentation networks must be down-then-up samplers, with a way for the upsampling path to see the high-resolution detail that the downsampling path threw away. That last clause — how the up-path sees high-res detail — is what U-Net solves.

Three failure modes a pure encoder-decoder (no skip connections) hits in practice, all of which U-Net's skip pattern repairs:

  • Boundary blur. By the time information reaches the bottleneck (28×28 in the original U-Net), every cell membrane has been smeared across multiple feature-map pixels. The decoder cannot recover edges it does not have.
  • Small-object loss. A neuron at the bottleneck has receptive field of ~140 pixels; objects smaller than that are aliased into the closest large-scale feature.
  • Localization vs context tradeoff. The decoder needs both: what is this(semantic, comes from the bottleneck) and exactly where (spatial, comes from the encoder's shallow layers). Without skips, only the first survives.

Architecture: symmetric encoder-decoder

U-Net is built from three pieces that reappear unchanged across the entire segmentation lineage we cover in §§09–11:

  1. Contracting path (encoder). Four stages of (3×3 conv → ReLU) ×2 → 2×2 max-pool. Channels double at each stage (64 → 128 → 256 → 512); spatial dimensions roughly halve at each pooling step.
  2. Bottleneck. Two more 3×3 convs at 10241024 channels. This is the lowest spatial resolution but the deepest semantic representation.
  3. Expansive path (decoder). Four mirrored stages of up-conv (2×2 transposed conv) → concatenate with the matching encoder feature map → (3×3 conv → ReLU) ×2. Channels halve, spatial dimensions roughly double, at each stage.

The decoder is symmetric to the encoder, and at every level the encoder's output is copied across to the decoder via a skip connection. We unpack what "copied across" means in skip-connections below.

Naming. The 3×3 convolutions in the original paper are unpadded, so each one shrinks H/W by 2. That is why a 572×572 input becomes a 388×388 mask — not a bug; an architectural choice that lets every output pixel see a fully valid receptive field. Modern reimplementations (and our PyTorch code in pytorch-implementation) typically use padded 3×3 convs so input and output sizes match.

Shapes the original U-Net produces, level by level:

StageChannelsH × WNotes
Input1572 × 572Single-channel cell-microscopy image
Encoder block 164568 × 568(3×3 conv unpadded) × 2
After max-pool 164284 × 284Halve spatial
Encoder block 2128280 × 280(3×3 conv unpadded) × 2
After max-pool 2128140 × 140
Encoder block 3256136 × 136
After max-pool 325668 × 68
Encoder block 451264 × 64
After max-pool 451232 × 32
Bottleneck102428 × 28Lowest resolution, deepest semantics
Up-conv → concat → conv51252 × 52Skip from encoder block 4 (cropped 64→52)
Up-conv → concat → conv256100 × 100Skip from encoder block 3 (cropped 136→100)
Up-conv → concat → conv128196 × 196Skip from encoder block 2 (cropped 280→196)
Up-conv → concat → conv64388 × 388Skip from encoder block 1 (cropped 568→388)
1×1 conv → softmax2388 × 3882 classes: cell / background

Total parameters: ~31 million. The decoder is roughly half the cost of the encoder because, although the channel counts mirror, the up-conv replaces a max-pool, which has no parameters.

The diagram below renders that same shape table as a 3D scene. Encoder blocks descend on the left with channels growing (blue), the bottleneck sits at the bottom (purple), decoder blocks ascend on the right with channels shrinking (amber). The dashed amber arcs arching over the top are the four skip connections — each encoder level feeds its mirror decoder level. Drag to rotate, scroll to zoom, click a layer to inspect.

Loading 3D U-Net architecture

Shape flow (interactive)

The diagram below is the same U-Net you saw in the table, drawn as a graph. Hover any block to see exact channels, spatial size, and parameters. Click a dashed skip line to see the cropping arithmetic that feeds the corresponding decoder level.

Loading U-Net shape flow

Three-letter mnemonic for what each color means: encoder is blue (down), bottleneck is purple (deep), decoder is brown (up). The dashed arrows are the skip connections; their existence is what separates U-Net from a generic encoder-decoder.


Skip connections: concat vs add

ResNet (§11.5) uses skip connections that add the encoder feature into the decoder feature: y=F(x)+xy = F(x) + x. U-Net does something different — it concatenates along the channel axis:

y=Conv(concat(up(d),crop(s))),Cout=Cup+Cskipy = \text{Conv}\big(\text{concat}(\text{up}(d), \text{crop}(s))\big), \quad C_{\text{out}} = C_{\text{up}} + C_{\text{skip}}

Why concat instead of add? The two paths carry different kinds of information. The encoder skip ss carries high-resolution spatial detail (where is the boundary?). The decoder up-conv up(d)\text{up}(d) carries semantic context (is this region cell or background?). Adding them forces a single channel slot to hold both, weighted equally, with no way for the network to decide how to combine them. Concatenating preserves both as separate channels, and the next 3×3 conv learns the optimal fusion as a regular weight matrix.

The cropping step. Because U-Net's convolutions are unpadded, the encoder feature map at level \ell is slightly larger than the decoder feature map at the same level. We center-crop the encoder side by (HskipHup)/2(H_{\text{skip}} - H_{\text{up}})/2 pixels per side before concatenating. The center here matters: the network only ever sees a fully valid receptive field for every output pixel.

For a quick mental contrast with ResNet: ResNet skips solve a training problem (gradient flow) by adding identity. U-Net skips solve a representation problem (high-res detail in the decoder) by concatenating learned features. The mechanisms look similar but address different bottlenecks. Cross-reference: §11.5 (ResNet add-skip derivation) vs §10.6 (transposed convolution arithmetic).


Skip-connection ablation playground

Words about "each skip carries different information" do not land until you see the network's output break in different ways. The widget below trains a small U-Net on a microscopy sample, then disables one skip at a time and re-renders the predicted mask. Toggle any single skip OFF.

Loading skip-ablation playground

What you should see. Disabling skip 1 (the highest-resolution, 64-channel skip) makes cell boundaries blurrier; the network knows where cells roughly are but smears the membrane. Disabling skip 4 (the deep, 512-channel skip) leaves boundaries sharp where cells are detected but causes the network to miss whole cells — the deep skip carried the "is there a cell here at all" signal.

Asset attribution. The cell-microscopy input is a public-domain ImageJ sample (originally blobs.gif from the imagej.net image library; substituted for the originally-targeted Spindle.tif which 404'd on all known mirrors). The shown mask predictions are produced by a tiny U-Net trained on that single image for ~200 epochs — intentionally illustrative, not production quality. Regeneration script: scripts/generate_segmentation_assets.py.

Python from scratch: one decoder up-block

Before reaching for PyTorch, let's build the single mechanism that distinguishes U-Net from a generic encoder-decoder — one decoder up-block — using only NumPy. We use tiny tensors so every value can be printed and verified. The block does three things in order: center-crop the skip, concatenate along the channel axis, and convolve.

Up-block: crop + concat + 3×3 conv (NumPy)
🐍unet_up_block.py
1import numpy as np

NumPy is a numerical computing library for Python. It provides ndarray — a fast, memory-efficient N-dimensional array type. All math in this file (slicing, concatenation, element-wise multiply, sum) runs as optimized C code under the hood, not slow Python loops. We alias it as 'np' by convention so we can write np.array(), np.concatenate(), np.sum(), etc.

EXECUTION STATE
numpy = Library for numerical computing — provides ndarray, linear algebra, random numbers, and mathematical functions
as np = Creates alias 'np' so we write np.array() instead of numpy.array() — universal Python convention
5d = np.array([...]) — decoder feature map

Defines the decoder feature map d. Shape (C_up=2, H=4, W=4). Channel 0 is a 4×4 checkerboard (alternating 1.0 / 0.0); channel 1 is uniform 0.5. These simulate what the transposed convolution in the up-block has already produced — the decoder's 'semantic' side that carries high-level context but at recovered resolution.

EXECUTION STATE
📚 np.array() = NumPy constructor: converts a nested Python list into a contiguous C array. Type is inferred as float64 here (because the literals are floats). Returns an ndarray object with .shape, .dtype, and full NumPy math support.
⬇ input: nested list — 3 levels deep = Outer list (length 2) = 2 channels. Each channel is a list of 4 rows. Each row is a list of 4 floats. np.array infers shape as (2, 4, 4).
d — shape (2, 4, 4) =
ch0 (checkerboard):
1.0  0.0  1.0  0.0
0.0  1.0  0.0  1.0
1.0  0.0  1.0  0.0
0.0  1.0  0.0  1.0

ch1 (uniform 0.5):
0.5  0.5  0.5  0.5
0.5  0.5  0.5  0.5
0.5  0.5  0.5  0.5
0.5  0.5  0.5  0.5
C_up = 2 = Number of channels coming from the previous decoder level. In the real U-Net this would be 512 (level 1) → 256 → 128 → 64. We use 2 so every number is printable.
18s = np.array([...]) — encoder skip feature map

Defines the encoder skip feature map s at this decoder level. Shape (C_skip=1, H=6, W=6). It is 6×6 rather than 4×4 because the encoder's unpadded 3×3 convolutions shrink width/height by 2 less than the decoder's matching level — the skip is always slightly bigger. The values form a hollow square (a frame of 1s with 0s inside), standing in for the high-resolution spatial boundary detail the encoder captured.

EXECUTION STATE
s — shape (1, 6, 6) =
ch0 (hollow square):
0.0  0.0  0.0  0.0  0.0  0.0
0.0  1.0  1.0  1.0  1.0  0.0
0.0  1.0  0.0  0.0  1.0  0.0
0.0  1.0  0.0  0.0  1.0  0.0
0.0  1.0  1.0  1.0  1.0  0.0
0.0  0.0  0.0  0.0  0.0  0.0
C_skip = 1 = Number of skip channels. In the real U-Net: 64 / 128 / 256 / 512 depending on the level. We use 1 for clarity.
6×6 vs 4×4 = The mismatch arises because each unpadded 3×3 conv in the encoder shrinks H/W by 2. The skip is captured BEFORE the max-pool, so it is bigger than the decoder's output from the up-conv. The center-crop (Step 1 below) removes the outer border.
28H_skip, W_skip = s.shape[1], s.shape[2]

Read the spatial dimensions of the skip tensor. s.shape is (1, 6, 6); axis 0 is channels, axis 1 is height, axis 2 is width. We unpack axes 1 and 2 into two named variables so the crop arithmetic below is readable.

EXECUTION STATE
s.shape = (1, 6, 6) — (C_skip=1, H_skip=6, W_skip=6). NumPy .shape is a tuple; indexing with [1] gives the height.
H_skip = 6 — height of the encoder skip map
W_skip = 6 — width of the encoder skip map
29H_dec, W_dec = d.shape[1], d.shape[2]

Read the spatial dimensions of the decoder feature map. d.shape is (2, 4, 4). Same pattern as the line above: axis 1 = height, axis 2 = width.

EXECUTION STATE
d.shape = (2, 4, 4) — (C_up=2, H_dec=4, W_dec=4)
H_dec = 4 — height of the decoder feature map
W_dec = 4 — width of the decoder feature map
30off_h = (H_skip - H_dec) // 2

Compute the vertical crop offset. (6 - 4) = 2 pixels of excess; divide by 2 to center the crop. The integer floor division // is safe here because H_skip - H_dec is always even in U-Net (each unpadded 3×3 conv removes exactly 2, and both paths share the same number of such convolutions per level).

EXECUTION STATE
H_skip - H_dec = 6 - 4 = 2 — 2 rows need to be discarded (1 from top, 1 from bottom)
📚 // (floor division) = Integer division that rounds toward negative infinity. 2 // 2 = 1. For positive operands it is the same as int(a/b). Used instead of / to get an int directly — NumPy slice indices must be integers.
off_h = 1 — start the crop 1 row from the top
31off_w = (W_skip - W_dec) // 2

Same calculation on the width axis. (6 - 4) // 2 = 1. The crop will start 1 column from the left.

EXECUTION STATE
off_w = 1 — start the crop 1 column from the left
32s_cropped = s[:, off_h:off_h + H_dec, off_w:off_w + W_dec]

Center-crop the skip tensor to the decoder's spatial size. NumPy slice notation: axis 0 is : (keep all channels), axis 1 is 1:1+4 = 1:5 (rows 1–4 inclusive), axis 2 is 1:1+4 = 1:5 (columns 1–4). The outer ring of zeros in the 6×6 map is discarded; what remains is the inner 4×4 hollow-square pattern.

EXECUTION STATE
📚 NumPy slicing [:, a:b, c:d] = Extracts a sub-array without copying data (returns a view). : on axis 0 keeps all 1 channels unchanged. a:b on axis 1 keeps rows from index a up to (not including) b. c:d on axis 2 keeps columns from index c up to (not including) d. Example: arr[1:5] on a length-6 axis returns 4 elements: indices 1,2,3,4.
⬇ slice on axis 1 (height): 1:5 = off_h=1, off_h+H_dec=5 → rows 1,2,3,4. Discards row 0 (all zeros) and row 5 (all zeros).
⬇ slice on axis 2 (width): 1:5 = off_w=1, off_w+W_dec=5 → columns 1,2,3,4. Discards col 0 (all zeros) and col 5 (all zeros).
⬆ s_cropped — shape (1, 4, 4) =
ch0 (inner hollow square):
1.0  1.0  1.0  1.0
1.0  0.0  0.0  1.0
1.0  0.0  0.0  1.0
1.0  1.0  1.0  1.0
33print("Cropped skip shape:", s_cropped.shape)

Verification print. Running this line emits: Cropped skip shape: (1, 4, 4). Confirms the crop produced exactly the spatial size we need to concatenate with d (which is also 4×4).

EXECUTION STATE
s_cropped.shape = (1, 4, 4) ✓ — matches (C_skip=1, H_dec=4, W_dec=4)
36fused = np.concatenate([d, s_cropped], axis=0)

Concatenate the decoder feature map and the cropped skip along the channel axis. This is the defining operation of a U-Net up-block. The two tensors each have spatial size (4, 4) but different channel counts: d has C_up=2 and s_cropped has C_skip=1. axis=0 stacks them along the channel dimension, yielding C_up + C_skip = 3 channels total.

EXECUTION STATE
📚 np.concatenate(arrays, axis) = NumPy function: joins a sequence of arrays along an existing axis. All arrays must have the same shape on every axis EXCEPT the concatenation axis. Example: np.concatenate([a(2,4,4), b(1,4,4)], axis=0) → shape (3,4,4). Contrast with np.stack() which creates a NEW axis.
⬇ arg 1: [d, s_cropped] = A Python list of two arrays. d has shape (2,4,4); s_cropped has shape (1,4,4). Both have the same H=4, W=4 on axes 1 and 2 — required for concatenation.
⬇ arg 2: axis=0 = Concatenate along the first axis (channels). axis=1 would stack rows (wrong — would give 8×4); axis=2 would stack columns (wrong — would give 4×8). axis=0 is correct because channels are axis 0.
⬆ fused — shape (3, 4, 4) =
ch0 (decoder ch0, checkerboard):
1.0  0.0  1.0  0.0
0.0  1.0  0.0  1.0
1.0  0.0  1.0  0.0
0.0  1.0  0.0  1.0

ch1 (decoder ch1, uniform 0.5):
0.5  0.5  0.5  0.5
0.5  0.5  0.5  0.5
0.5  0.5  0.5  0.5
0.5  0.5  0.5  0.5

ch2 (skip, hollow square):
1.0  1.0  1.0  1.0
1.0  0.0  0.0  1.0
1.0  0.0  0.0  1.0
1.0  1.0  1.0  1.0
→ why concat not add? = Adding would force a single channel slot to hold both semantic context (decoder) and spatial detail (skip), weighted equally. Concatenating preserves them as separate channels; the next 3×3 conv then learns the optimal mixture as a regular weight matrix.
43W = np.zeros((1, 3, 3, 3))

Initialize the convolution weight tensor to all zeros. Shape (out_C=1, in_C=3, kH=3, kW=3). We have 1 output channel, 3 input channels (matching fused), and a 3×3 kernel. np.zeros creates a float64 array of zeros; we then fill the three 3×3 slices manually on the next three lines.

EXECUTION STATE
📚 np.zeros(shape) = NumPy function: allocates a new array filled with 0.0. shape is a tuple of ints. Default dtype is float64. Example: np.zeros((2,3)) → [[0.0, 0.0, 0.0],[0.0, 0.0, 0.0]]
W — shape (1, 3, 3, 3) = out_C=1: one output feature map. in_C=3: three input channels (matches fused channels). kH=3, kW=3: 3×3 spatial kernel. Total learnable weights: 1×3×3×3 = 27 floats.
44W[0, 0] = 1.0 — decoder ch0 weight

Set the 3×3 kernel for output channel 0, input channel 0 (decoder checkerboard channel) to all 1.0. NumPy broadcasts the scalar 1.0 to fill the entire 3×3 slice. This means every element of the decoder checkerboard patch contributes its full value to the output.

EXECUTION STATE
W[0, 0] =
3×3 slice at out_ch=0, in_ch=0:
1.0  1.0  1.0
1.0  1.0  1.0
1.0  1.0  1.0
(all ones — sums the 3×3 patch of decoder ch0)
45W[0, 1] = 0.5 — decoder ch1 weight

Set the 3×3 kernel for input channel 1 (decoder uniform-0.5 channel) to all 0.5. Every element of that patch contributes half its value. Since ch1 is uniform 0.5, this effectively contributes 0.5 × 0.5 × 9 = 2.25 to each output position.

EXECUTION STATE
W[0, 1] =
3×3 slice at out_ch=0, in_ch=1:
0.5  0.5  0.5
0.5  0.5  0.5
0.5  0.5  0.5
(half-weight — decoder uniform channel contributes 50%)
46W[0, 2] = 2.0 — skip ch2 weight

Set the 3×3 kernel for input channel 2 (the skip channel) to all 2.0. This simulates a learned network decision to 'pay double attention to the skip'. In practice, the network would learn an asymmetric kernel but we use a uniform value here for clear arithmetic.

EXECUTION STATE
W[0, 2] =
3×3 slice at out_ch=0, in_ch=2:
2.0  2.0  2.0
2.0  2.0  2.0
2.0  2.0  2.0
(double-weight — skip channel contributes 2×)
47b = np.zeros(1)

Bias vector for the one output channel, initialized to 0. Shape (1,). In a trained network this is a learned scalar added to every spatial position of output channel 0 after the weighted sum. With b=0, y[0,i,j] = patch · W[0] + 0 = patch · W[0].

EXECUTION STATE
b = array([0.]) — shape (1,), one bias per output channel
49H_out = fused.shape[1] - 2

Compute output height for an unpadded 3×3 convolution. A 3×3 kernel centered at row i requires rows i-1 and i+1 to exist. So the first valid center is row 1 and the last is row H-2, giving H - 2 valid positions. Here 4 - 2 = 2. See §10.3 for the full convolution-arithmetic derivation: floor((H_in - kH) / stride) + 1 = (4-3)/1 + 1 = 2.

EXECUTION STATE
fused.shape[1] = 4 — height of the fused tensor
H_out = 2 — two output rows (one for top-left/top-right, one for bottom-left/bottom-right)
50W_out = fused.shape[2] - 2

Same calculation on the width axis. fused.shape[2] = 4, so W_out = 4 - 2 = 2. The output is a 2×2 spatial map with 1 channel: shape (1, 2, 2).

EXECUTION STATE
W_out = 2 — two output columns
51y = np.zeros((1, H_out, W_out))

Allocate the output tensor, shape (out_C=1, H_out=2, W_out=2), filled with zeros. We fill each element y[0, i, j] inside the nested for-loop below.

EXECUTION STATE
y — shape (1, 2, 2) =
[[0.0  0.0]
 [0.0  0.0]]
(all zeros — will be filled by the conv loop)
52for i in range(H_out):

Outer loop over the 2 valid output rows. i is the row index into y; the corresponding input patch starts at row i of fused and spans rows i through i+2 (inclusive). With H_out=2, this loop runs twice: i=0 (top row of output) and i=1 (bottom row).

LOOP TRACE · 2 iterations
i=0 (top output row)
patch rows in fused = fused[:, 0:3, :] — rows 0,1,2 of the 4-row fused tensor
i=1 (bottom output row)
patch rows in fused = fused[:, 1:4, :] — rows 1,2,3 of the 4-row fused tensor
53for j in range(W_out):

Inner loop over the 2 valid output columns. Combined with the outer loop, this produces 2×2 = 4 output positions total. For each (i,j) pair the code extracts a 3×3 patch from fused and applies the kernel.

LOOP TRACE · 4 iterations
i=0, j=0 → y[0,0,0]
patch[0] (decoder ch0) = [[1,0,1],[0,1,0],[1,0,1]]
patch[1] (decoder ch1) = [[.5,.5,.5],[.5,.5,.5],[.5,.5,.5]]
patch[2] (skip ch2) = [[1,1,1],[1,0,0],[1,0,0]]
ch0 ×1.0 = sum([[1,0,1],[0,1,0],[1,0,1]]×1.0) = 5.0
ch1 ×0.5 = sum([[.5,.5,.5],[.5,.5,.5],[.5,.5,.5]]×0.5) = 2.25
ch2 ×2.0 = sum([[1,1,1],[1,0,0],[1,0,0]]×2.0) = 5×2 = 10.0
y[0,0,0] = 5.0 + 2.25 + 10.0 = 17.25
i=0, j=1 → y[0,0,1]
patch[0] (decoder ch0) = [[0,1,0],[1,0,1],[0,1,0]]
patch[2] (skip ch2) = [[1,1,1],[0,0,1],[0,0,1]]
ch0 ×1.0 = sum = 4.0
ch1 ×0.5 = 2.25 (uniform patch, same as always)
ch2 ×2.0 = sum([[1,1,1],[0,0,1],[0,0,1]]×2.0) = 5×2 = 10.0
y[0,0,1] = 4.0 + 2.25 + 10.0 = 16.25
i=1, j=0 → y[0,1,0]
patch[0] (decoder ch0) = [[0,1,0],[1,0,1],[0,1,0]]
patch[2] (skip ch2) = [[1,0,0],[1,0,0],[1,1,1]]
ch0 ×1.0 = sum = 4.0
ch2 ×2.0 = sum([[1,0,0],[1,0,0],[1,1,1]]×2.0) = 5×2 = 10.0
y[0,1,0] = 4.0 + 2.25 + 10.0 = 16.25
i=1, j=1 → y[0,1,1]
patch[0] (decoder ch0) = [[1,0,1],[0,1,0],[1,0,1]]
patch[2] (skip ch2) = [[0,0,1],[0,0,1],[1,1,1]]
ch0 ×1.0 = sum = 5.0
ch2 ×2.0 = sum([[0,0,1],[0,0,1],[1,1,1]]×2.0) = 5×2 = 10.0
y[0,1,1] = 5.0 + 2.25 + 10.0 = 17.25
55patch = fused[:, i:i + 3, j:j + 3]

Extract the 3×3 spatial patch at position (i, j) across all 3 input channels. fused has shape (3, 4, 4); slicing [:, i:i+3, j:j+3] gives a (3, 3, 3) view — no data is copied. The first axis : keeps all 3 channels; the next two slices select the 3×3 spatial window.

EXECUTION STATE
📚 NumPy slice [:, i:i+3, j:j+3] = : on axis 0 = keep all channels. i:i+3 on axis 1 = 3 consecutive rows starting at i. j:j+3 on axis 2 = 3 consecutive columns starting at j. Result is a VIEW into fused — no copy, just a pointer.
patch — shape (3, 3, 3) = 3 channels × 3 rows × 3 cols = 27 values. ch0=decoder, ch1=decoder, ch2=skip.
56y[0, i, j] = np.sum(patch * W[0]) + b[0]

The convolution at position (i,j). Three steps: (1) element-wise multiply patch (3×3×3) with the weight tensor W[0] (3×3×3) — same shape, 27 multiplications; (2) sum all 27 products to a scalar — this is the dot product over all input channels and spatial positions; (3) add bias b[0]=0. This is the convolution formula written explicitly without any library magic.

EXECUTION STATE
📚 patch * W[0] = NumPy element-wise multiplication (*). patch and W[0] are both (3,3,3) — same shape, so no broadcasting needed. Returns a (3,3,3) array of products. This is NOT matrix multiply (@); it multiplies corresponding elements.
📚 np.sum() = NumPy function: sums all elements in the array into a single scalar. np.sum(arr) with no axis argument collapses all 27 elements of the (3,3,3) result. Equivalent to arr.flatten().sum().
⬇ y[0, 0, 0] (first output pixel) = ch0 contribution: sum(patch[0] × 1.0) = 5.0 ch1 contribution: sum(patch[1] × 0.5) = 2.25 ch2 contribution: sum(patch[2] × 2.0) = 10.0 ───────────────────────── y[0,0,0] = 5.0 + 2.25 + 10.0 + 0 = 17.25
⬆ full output y[0] (2×2) =
17.25  16.25
16.25  17.25

Corner positions (i=0,j=0 and i=1,j=1) score higher because their decoder patches contain 5 checkerboard 1s vs 4 for the edge positions. The skip contribution (10.0) is constant across all four positions because all 5 non-zero border pixels contribute equally from any 3×3 window at the corners.
58print("Output shape:", y.shape)

Verification print. Emits: Output shape: (1, 2, 2). Confirms the unpadded 3×3 conv produced a (1, 2, 2) tensor from a (3, 4, 4) input — exactly the shape arithmetic from §10.3.

EXECUTION STATE
y.shape = (1, 2, 2) — 1 output channel, 2×2 spatial map ✓
59print("Output:\n", y[0])

Print the 2×2 output map. y[0] indexes the first (only) output channel, giving a plain 2×2 NumPy array. Expected output: [[17.25 16.25] [16.25 17.25]].

EXECUTION STATE
y[0] — shape (2, 2) =
17.25  16.25
16.25  17.25

The two distinct values reflect the two distinct decoder patch patterns (5 checkerboard 1s for corners vs 4 for edges). The skip contribution (10.0) is the same for all four positions.
35 lines without explanation
1import numpy as np
2
3# Tiny tensors so every value is printable.
4# Decoder side after a 2x2 stride-2 transposed conv: shape (C_up, 4, 4)
5d = np.array([
6    [[1.0, 0.0, 1.0, 0.0],
7     [0.0, 1.0, 0.0, 1.0],
8     [1.0, 0.0, 1.0, 0.0],
9     [0.0, 1.0, 0.0, 1.0]],   # channel 0 = checkerboard
10    [[0.5, 0.5, 0.5, 0.5],
11     [0.5, 0.5, 0.5, 0.5],
12     [0.5, 0.5, 0.5, 0.5],
13     [0.5, 0.5, 0.5, 0.5]],   # channel 1 = uniform 0.5
14])  # shape (2, 4, 4)
15
16# Encoder skip side at the SAME level. With unpadded convs the skip is bigger.
17# We use a 6x6 skip with 1 channel that we will center-crop to 4x4.
18s = np.array([
19    [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
20     [0.0, 1.0, 1.0, 1.0, 1.0, 0.0],
21     [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
22     [0.0, 1.0, 0.0, 0.0, 1.0, 0.0],
23     [0.0, 1.0, 1.0, 1.0, 1.0, 0.0],
24     [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
25])  # shape (1, 6, 6)
26
27# --- Step 1: center-crop the skip to the decoder's spatial size ---
28H_skip, W_skip = s.shape[1], s.shape[2]
29H_dec,  W_dec  = d.shape[1], d.shape[2]
30off_h = (H_skip - H_dec) // 2
31off_w = (W_skip - W_dec) // 2
32s_cropped = s[:, off_h:off_h + H_dec, off_w:off_w + W_dec]
33print("Cropped skip shape:", s_cropped.shape)
34
35# --- Step 2: concatenate along the channel axis ---
36fused = np.concatenate([d, s_cropped], axis=0)
37print("Fused shape:", fused.shape)
38print("Fused channel 0 (was decoder ch0):\n", fused[0])
39print("Fused channel 2 (was skip):\n", fused[2])
40
41# --- Step 3: a single 3x3 unpadded conv with hand-set weights ---
42# Output channels = 1 for clarity. Weights have shape (out_C, in_C, kH, kW).
43W = np.zeros((1, 3, 3, 3))
44W[0, 0] = 1.0  # decoder ch0 contributes a 3x3 sum-of-1
45W[0, 1] = 0.5  # decoder ch1 contributes half
46W[0, 2] = 2.0  # skip contributes double
47b = np.zeros(1)
48
49H_out = fused.shape[1] - 2   # unpadded 3x3
50W_out = fused.shape[2] - 2
51y = np.zeros((1, H_out, W_out))
52for i in range(H_out):
53    for j in range(W_out):
54        # 3x3 patch across all 3 input channels
55        patch = fused[:, i:i + 3, j:j + 3]    # shape (3, 3, 3)
56        y[0, i, j] = np.sum(patch * W[0]) + b[0]
57
58print("Output shape:", y.shape)
59print("Output:\n", y[0])

Run it and you should see Cropped skip shape: (1, 4, 4), Fused shape: (3, 4, 4), and a (1, 2, 2) output. We did the crop, the concat, and the 3×3 conv ourselves — that's the entire mechanism. PyTorch in pytorch-implementation stacks this same building block four times to get the full decoder.


PyTorch: full UNet module

The PyTorch implementation builds the same crop-concat-conv mechanism we just did in NumPy, repeated 4 times in the decoder, with two small modernizations: padded 3×3 convolutions (so input and output spatial sizes match) and bilinear upsample as the default decoder up-step (cheaper than ConvTranspose, immune to the checkerboard artifact derived in §10.6).

We split the network into four files-worth of code: DoubleConv (the atom), Down (encoder stage), Up (decoder stage), and UNet (full module). Each is self-contained and composes into the next.

1. DoubleConv — the atom

DoubleConv: two padded 3×3 convs with BN+ReLU
🐍unet/double_conv.py
4class DoubleConv(nn.Module)

Defines a reusable building block. nn.Module is PyTorch's base class for any layer / sub-network — see §4.5 (Autograd) and §5.1 (nn.Module) for full background.

11def __init__(self, in_channels: int, out_channels: int)

Constructor takes the two channel counts. in_channels is what flows in; out_channels is what flows out. The network's overall channel-doubling pattern (1→64→128→256→512→1024) is set by the caller, not by this class.

EXECUTION STATE
in_channels = e.g. 64 (matches the previous block's out)
out_channels = e.g. 128 (the new channel count for this stage)
14📚 nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)

📚 PyTorch's 2D convolution layer. With padding=1 and kernel_size=3, output H/W = input H/W. bias=False because the next layer is BatchNorm, which has its own bias (β) — having both is redundant.

EXAMPLE
in=(1, 64, 32, 32), out=(1, 128, 32, 32)
15📚 nn.BatchNorm2d(out_channels)

📚 Per-channel batch normalization (cross-ref §5.5). Stabilizes training; allows higher learning rates. Its β (shift) parameter replaces the conv's bias term — that's why the conv above sets bias=False.

16📚 nn.ReLU(inplace=True)

📚 ReLU activation, in-place to save memory. inplace=True overwrites the input tensor; safe here because the activation is the last consumer of the previous tensor.

17Second 3×3 conv (mirror of line 14)

Same conv as line 14 but in_channels=out_channels now. The 'double' in DoubleConv refers to this pair.

23def forward(self, x: torch.Tensor) -> torch.Tensor

Forward pass. Just runs the Sequential module. PyTorch handles the gradient bookkeeping automatically.

EXECUTION STATE
x =
input tensor of shape (B, in_channels, H, W)
27Sanity check

Lines 27–29 verify the shape arithmetic: (1, 64, 32, 32) → (1, 128, 32, 32). Channels change; spatial size is preserved.

21 lines without explanation
1import torch
2import torch.nn as nn
3
4class DoubleConv(nn.Module):
5    """Two padded 3x3 convolutions, each followed by BatchNorm and ReLU.
6
7    This is the atom from which both encoder and decoder are built.
8    Padded convs (padding=1) keep H and W constant so output size = input size.
9    """
10
11    def __init__(self, in_channels: int, out_channels: int):
12        super().__init__()
13        self.conv = nn.Sequential(
14            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
15            nn.BatchNorm2d(out_channels),
16            nn.ReLU(inplace=True),
17            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
18            nn.BatchNorm2d(out_channels),
19            nn.ReLU(inplace=True),
20        )
21
22    def forward(self, x: torch.Tensor) -> torch.Tensor:
23        return self.conv(x)
24
25# Quick sanity check
26x = torch.randn(1, 64, 32, 32)
27block = DoubleConv(64, 128)
28y = block(x)
29print("DoubleConv: in=", tuple(x.shape), "→ out=", tuple(y.shape))

2. Down — encoder stage

Down: max-pool then DoubleConv
🐍unet/down.py
1class Down(nn.Module)

An encoder stage. Each U-Net encoder has 4 of these.

9def __init__(self, in_channels, out_channels)

in_channels = the previous stage's out_channels. out_channels = roughly 2× in_channels by U-Net's doubling convention.

11📚 nn.MaxPool2d(kernel_size=2)

📚 2×2 max-pool with stride 2 (default). Halves H and W. Channels are unchanged. See §10.5 for the full pooling derivation.

EXAMPLE
in=(1, 64, 32, 32) → out=(1, 64, 16, 16)
12DoubleConv(in_channels, out_channels)

Then the same 2-conv block. After this, the stage produced (1, out_channels, H/2, W/2).

16forward

Just hands x to the Sequential. The Sequential applies max-pool then DoubleConv in order.

17 lines without explanation
1class Down(nn.Module):
2    """Encoder stage: 2x2 max-pool then DoubleConv.
3
4    H and W halve at the max-pool. Channels double inside DoubleConv
5    (caller passes the right counts).
6    """
7
8    def __init__(self, in_channels: int, out_channels: int):
9        super().__init__()
10        self.down = nn.Sequential(
11            nn.MaxPool2d(kernel_size=2),
12            DoubleConv(in_channels, out_channels),
13        )
14
15    def forward(self, x: torch.Tensor) -> torch.Tensor:
16        return self.down(x)
17
18# Sanity
19x = torch.randn(1, 64, 32, 32)
20stage = Down(64, 128)
21y = stage(x)
22print("Down: in=", tuple(x.shape), "→ out=", tuple(y.shape))   # expect (1, 128, 16, 16)

3. Up — decoder stage

Up: upsample → concat with skip → DoubleConv
🐍unet/up.py
3class Up(nn.Module)

Decoder stage. The thing you saw in the shape-flow diagram drawn as one of the four right-side blocks.

10def __init__(self, in_channels, out_channels, bilinear=True)

bilinear=True means use a non-learned bilinear upsample (cheaper, no checkerboard). bilinear=False switches to a learned 2×2 transposed conv (the original paper's choice; see §10.6 for the full derivation and the checkerboard-artifact discussion).

EXECUTION STATE
bilinear = True (default)
12📚 nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

📚 Non-learned interpolation. mode='bilinear' uses 4-corner-weighted interpolation; align_corners=True puts pixel centers at the corners (matters for exact size matching). Doubles H and W.

15📚 nn.ConvTranspose2d(in_channels//2, in_channels//2, kernel_size=2, stride=2)

📚 Learned upsampling — see §10.6 for the full derivation. Stride=2 with kernel=2 doubles H and W exactly. Halves channels in this U-Net wiring (in_channels//2). The // is integer division.

19def forward(self, x, skip)

Two inputs: x is the decoder feature map coming from below, skip is the encoder feature map at the matching level (delivered by the parent UNet via the encoder's saved activations).

EXECUTION STATE
x =
shape (B, in_channels, H/2, W/2)
skip =
shape (B, in_channels//2, H, W)
22diff_h = skip.size(2) - x.size(2)

Off-by-one safety net. With even input sizes everywhere, diff_h and diff_w are 0. With odd sizes they can be 1. We pad x to match skip's spatial size before concat — a different fix than the original paper's center-crop, but achieves the same goal of making both sides concat-compatible.

EXECUTION STATE
diff_h = typically 0; 1 if input H was odd somewhere
diff_w = typically 0
24📚 F.pad(x, [left, right, top, bottom])

📚 Pads tensor with zeros on the spatial axes. The 4-element list is in the order PyTorch expects for (W_left, W_right, H_top, H_bottom). Splitting odd diffs as `d//2` and `d - d//2` gives a balanced pad.

26📚 torch.cat([skip, x], dim=1)

📚 Channel-axis concatenation. dim=1 is the channel axis (PyTorch's NCHW layout). After this, fused has channels = skip.channels + x.channels.

27return self.conv(x)

DoubleConv on the concatenated tensor. The 3×3 convs inside learn how to fuse the two information streams.

26 lines without explanation
1import torch.nn.functional as F
2
3class Up(nn.Module):
4    """Decoder stage: up-conv (or upsample), concat with skip, then DoubleConv.
5
6    With padded convs and bilinear upsample, input/output sizes match
7    cleanly without the cropping trick from the original paper.
8    """
9
10    def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True):
11        super().__init__()
12        if bilinear:
13            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
14            self.conv = DoubleConv(in_channels, out_channels)
15        else:
16            self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
17            self.conv = DoubleConv(in_channels, out_channels)
18
19    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
20        x = self.up(x)
21        # Pad if upsample produced an odd off-by-one (rare with even input sizes)
22        diff_h = skip.size(2) - x.size(2)
23        diff_w = skip.size(3) - x.size(3)
24        x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2,
25                      diff_h // 2, diff_h - diff_h // 2])
26        x = torch.cat([skip, x], dim=1)   # concat along channel axis
27        return self.conv(x)
28
29# Sanity
30x   = torch.randn(1, 256, 16, 16)   # decoder input
31sk  = torch.randn(1, 128, 32, 32)   # skip from matching encoder level
32stage = Up(256, 128, bilinear=True)
33y = stage(x, sk)
34print("Up: x=", tuple(x.shape), " skip=", tuple(sk.shape), " → out=", tuple(y.shape))
35# expected: (1, 128, 32, 32)

4. UNet — full module

UNet: end-to-end network
🐍unet/unet.py
1class UNet(nn.Module)

The whole network. Composes 5 encoder-side blocks (inc + 4 Downs) and 4 decoder-side blocks (Ups), then a 1×1 conv to map to n_classes output channels.

14def __init__(self, n_channels=1, n_classes=1, bilinear=True)

n_channels: input channels (1 for grayscale microscopy, 3 for RGB satellite imagery, etc.). n_classes: output channels — 1 for binary (cell vs background), C for C-way semantic segmentation. bilinear: see Up class above.

EXECUTION STATE
n_channels = 1 (grayscale microscopy)
n_classes = 1 (binary segmentation)
bilinear = True (cheaper, no checkerboard)
17self.inc = DoubleConv(n_channels, 64)

First block runs at input resolution; produces 64 channels. Doesn't down-sample yet.

18self.down1 = Down(64, 128)

First encoder stage. Output: (B, 128, H/2, W/2). Channels double; spatial halves.

22factor = 2 if bilinear else 1

Trick to keep parameter counts and shapes consistent across the bilinear and ConvTranspose code paths. With bilinear upsampling, the bottleneck channel count is halved (1024 // 2 = 512) so the up-conv has the right input shape.

EXECUTION STATE
factor = 2 (with bilinear=True)
23self.down4 = Down(512, 1024 // factor)

Bottleneck output: (B, 512, H/16, W/16) when bilinear=True; (B, 1024, H/16, W/16) when bilinear=False.

25self.up1 = Up(1024, 512 // factor, bilinear)

First decoder stage. Up's first arg is the SUM of decoder + skip channel counts. With bilinear=True: x5 has 512 ch, x4 has 512 ch → after upsample concat we have 1024 ch, conv produces 256 ch. With bilinear=False: x5 has 1024 ch → conv-transpose halves to 512, concat with x4's 512 → 1024 ch, conv produces 512 ch.

30self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

1×1 conv = per-pixel linear projection. Maps 64 final channels to n_classes logit channels. No activation — caller applies sigmoid (binary) or softmax (multi-class) inside the loss function.

33def forward(self, x)

Save every encoder activation (x1..x5) so they can feed back into the decoder via skip connections. This is the only place the U-shape topology shows up in code — the decoder Up calls take TWO arguments, the second being the matching saved activation.

EXECUTION STATE
x1 =
(B, 64, H, W) — feeds into up4 as skip
x2 =
(B, 128, H/2, W/2) — feeds into up3 as skip
x3 =
(B, 256, H/4, W/4) — feeds into up2 as skip
x4 =
(B, 512, H/8, W/8) — feeds into up1 as skip
x5 =
(B, 1024 or 512, H/16, W/16) — bottleneck
47Param count

For n_channels=1, n_classes=1, bilinear=True you should see ~7.7M parameters (less than the original paper's 31M because we use bilinear upsample instead of learned ConvTranspose, and our channel counts in the decoder are halved by `factor`).

40 lines without explanation
1class UNet(nn.Module):
2    """Full U-Net for binary semantic segmentation.
3
4    Args
5    ----
6    n_channels : input channels (1 for grayscale, 3 for RGB)
7    n_classes  : output channels (number of segmentation classes,
8                 including any background class). For binary cell/background
9                 segmentation this is typically 1 (logit per pixel) with
10                 a sigmoid + BCE loss.
11    bilinear   : True → bilinear upsample (cheaper); False → learned ConvTranspose
12    """
13
14    def __init__(self, n_channels: int = 1, n_classes: int = 1, bilinear: bool = True):
15        super().__init__()
16        # encoder
17        self.inc   = DoubleConv(n_channels, 64)
18        self.down1 = Down(64, 128)
19        self.down2 = Down(128, 256)
20        self.down3 = Down(256, 512)
21        factor = 2 if bilinear else 1
22        self.down4 = Down(512, 1024 // factor)
23        # decoder
24        self.up1 = Up(1024, 512 // factor, bilinear)
25        self.up2 = Up(512, 256 // factor, bilinear)
26        self.up3 = Up(256, 128 // factor, bilinear)
27        self.up4 = Up(128, 64, bilinear)
28        # final 1x1 conv to n_classes channels
29        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)
30
31    def forward(self, x: torch.Tensor) -> torch.Tensor:
32        x1 = self.inc(x)        # (B,  64, H,    W)
33        x2 = self.down1(x1)     # (B, 128, H/2,  W/2)
34        x3 = self.down2(x2)     # (B, 256, H/4,  W/4)
35        x4 = self.down3(x3)     # (B, 512, H/8,  W/8)
36        x5 = self.down4(x4)     # (B, 1024 or 512, H/16, W/16)  ← bottleneck
37
38        x  = self.up1(x5, x4)   # (B, 512 or 256, H/8,  W/8)
39        x  = self.up2(x,  x3)   # (B, 256 or 128, H/4,  W/4)
40        x  = self.up3(x,  x2)   # (B, 128 or  64, H/2,  W/2)
41        x  = self.up4(x,  x1)   # (B,  64,        H,    W)
42        return self.outc(x)     # (B, n_classes,  H,    W)
43
44# Quick smoke test
45model = UNet(n_channels=1, n_classes=1, bilinear=True)
46x = torch.randn(2, 1, 256, 256)
47y = model(x)
48print("UNet: in=", tuple(x.shape), "→ logits=", tuple(y.shape))
49n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
50print(f"Trainable params: {n_params / 1e6:.2f}M")

Loss functions for segmentation

Segmentation has a class-imbalance problem out of the gate: the foreground (cells, roads, organs) usually occupies less than 10% of the pixels. Plain pixel-wise BCE / cross-entropy is heavily biased toward predicting all-background. Three losses solve this; pick by problem.

LossDefinition (informal)When to useCross-ref
Dice1 − 2|p ∩ g| / (|p| + |g|) (overlap-based)Heavy class imbalance, e.g. tumor < 1% of pixels§5.4 (loss-functions)
BCE + DicePixel-wise BCE plus the Dice termMost binary segmentation tasks; pairs pixel-level signal with shape-level signal§5.4
Focal Tversky(1 − T)^γ where Tversky T trades FP and FNVery small structures (vessels, ducts) where missing them is much worse than over-segmentingSalehi 2017 (arXiv:1706.05721)

For binary cell segmentation we'll use BCE + Dice with equal weight. The Dice term:

LDice=12ipigi+εipi+igi+ε\mathcal{L}_{\text{Dice}} = 1 - \frac{2 \sum_i p_i g_i + \varepsilon}{\sum_i p_i + \sum_i g_i + \varepsilon}

where pi[0,1]p_i \in [0, 1] is the predicted probability for pixel ii, gi{0,1}g_i \in \{0, 1\} is the ground-truth label, and ε\varepsilon is a small constant (e.g. 1) that makes the loss well-defined when both p and g are zero everywhere.


Training recipe

The original U-Net's training recipe — still a good default for biomedical sets:

  1. Patch sampling. Train on overlapping patches (e.g. 388×388 outputs from 572×572 inputs) instead of full images. Smaller memory, more diversity per epoch.
  2. Heavy augmentation. Random flips, 90° rotations, and elastic deformation — small smooth random vector fields applied to both the input and the mask. Critical for biomedical sets where labeled data is scarce.
  3. Boundary weight maps. The original paper precomputes a per-pixel weight w(x)w(\mathbf{x}) that assigns higher loss to pixels near cell boundaries (separating touching cells is the hardest sub-problem). Modern setups often drop this in favor of a Dice + boundary-loss combination, but it remains a clean inductive prior.
  4. Optimizer. Adam with lr=104\text{lr}=10^{-4}, batch size 4–16 (memory-bound). Learning-rate cosine decay or step decay both work; warmup is rarely worth it at this scale.
  5. Metrics. Intersection-over-Union (IoU\text{IoU}) and Dice coefficient on a held-out validation set. For multi-class semantic segmentation: mean IoU (mIoU) averaged over classes, with explicit per-class breakdowns to catch class-imbalance failures.

Real-world applications

U-Net's reach is the broadest of any architecture in this chapter. The same encoder- decoder + skip pattern appears, often unchanged, across domains where the input and the desired output are both image-shaped:

DomainTaskWhy U-Net wins here
Biomedical (microscopy)Cell, nucleus, organelle segmentation in EM and fluorescence imagesOriginal domain. Scarce labels + heavy augmentation + per-boundary weights.
Biomedical (radiology)Tumor / organ delineation in MRI, CT, ultrasound; 3D U-Net for volumetric scans3D U-Net (Çiçek 2016) extends every 2D op to 3D. Used in nnU-Net's clinical pipelines.
Biomedical (ophthalmology)Retinal-vessel segmentation in OCT and fundus imagesTiny vessels need high-res skip 1 to resolve; deep skip provides shape prior.
Satellite (mapping)Building footprint, road extraction (Inria Aerial, SpaceNet)Wide receptive field via deep skip; high-res first skip preserves building corners.
Agriculture (remote sensing)Crop-field delineation in Sentinel-2 multi-spectral imageryn_channels = 13 (one per spectral band) is just a config knob in our UNet class.
Autonomous driving (semantic)Drivable-surface and lane masking when SOTA precision is not requiredDeepLab (§10) typically wins here, but U-Net remains a fast baseline.
AR / mobileReal-time portrait segmentation (Pixel phones, Zoom backgrounds)Trimmed U-Net runs at 30+ FPS on phones; bilinear upsample + lightweight backbone.
Industrial QASurface-defect detection on steel, glass, fabric production linesPre-trained on ImageNet then fine-tuned on a few hundred labeled defects.
Diffusion (generative)U-Net is the denoising backbone for Stable Diffusion and DDPMSame encoder-decoder + skip pattern, augmented with self-attention; see Ch 23.

3D U-Net: from pixels to voxels

Every CT or MRI scan is a volumetric image — a stack of 2D slices that, taken together, form a 3D voxel grid. A liver tumour does not live on a single slice; it spans, e.g., 12 contiguous slices in an abdominal CT, with shape, edges, and local invasion only fully visible when those slices are reasoned about together. The 2D U-Net we just built operates on each slice independently. It throws away the third dimension before it ever convolves.

The cross-slice context problem. A 2D U-Net seeing a single 512×512 axial slice of a CT cannot tell whether a bright spot at (x,y)(x, y) is a tumour (which would persist across the next 5 slices) or a vessel cross-section (which would shift on the next slice). The fix is structural, not training-time: replace every 2D op with its 3D counterpart so a single forward pass sees a 3D voxel neighbourhood at every position.

Three failure modes a slice-by-slice 2D U-Net hits in practice, all of which 3D U-Net fixes:

  • Through-plane discontinuity. Predicted masks are sharp inside a slice but flicker between adjacent slices because the network never had a chance to enforce cross-slice consistency.
  • Anisotropic context loss. A 3×3 in-plane receptive field sees ~3 mm of tissue (depending on resolution) but the kernel sees 0 mm in the through-plane direction. A small, elongated structure visible across 4 slices can be entirely missed.
  • 3D shape priors discarded. Anatomical structures have characteristic 3D shapes (a kidney is a bean; a vertebra has a recognisable arch). 2D processing cannot use those priors.

The 3D U-Net was introduced by Çiçek et al. 2016 (MICCAI) for exactly this reason. It is structurally identical to the 2D U-Net you just built — same encoder-decoder, same concat skips — with every 2D primitive replaced by its 3D analogue. Two contemporary papers explored adjacent points in the design space: Milletari et al. 2016 (V-Net) added residual blocks and introduced the now-ubiquitous soft Dice loss; the fully-automated nnU-Net pipeline (Isensee et al. 2021, Nature Methods) made 3D U-Net the default winner on ~23 public benchmarks.


3D U-Net architecture (Çiçek 2016)

The recipe, taken verbatim from §2 of the paper. The encoder ("analysis path") and decoder ("synthesis path") each have 4 resolution steps; every 2D op is swapped one-for-one with its 3D analogue.

2D op (earlier in this section)3D op (Çiçek 2016, §2)What changes
3×3 conv (padded) + ReLU3×3×3 conv (padded) + ReLUKernel goes from 9 to 27 weights → 3× params per filter
3×3 conv + BN + ReLU (twice per stage)3×3×3 conv + BN + ReLU (twice per stage)BatchNorm3d normalises over (B,D,H,W); same idea
2×2 max-pool (stride 2)2×2×2 max-pool (stride 2)Halves D, H, W → 8× voxel reduction per stage
2×2 transposed conv (stride 2)2×2×2 transposed conv (stride 2)Doubles D, H, W exactly
Skip: encoder → decoder (concat axis 1)Same — concat along channel axisIdentical mechanism, one more spatial axis
Final 1×1 conv → n_classesFinal 1×1×1 conv → n_classesSame per-voxel linear projection

One subtle but important departure from the 2D U-Net: channels double before the max-pool, not after. The paper credits this to Szegedy et al. 2015 (Rethinking Inception) to avoid representational bottlenecks. A pre-pool doubling means the conv at the higher resolution gets to use the larger channel count where it sees the most context.

Exact shapes reported in the paper for the Xenopus-kidney experiments:

StageChannelsSpatial (voxels)Notes
Input3132 × 132 × 1163-channel confocal microscopy (Tomato-Lectin / DAPI / Beta-Catenin)
Encoder L1 (after 2× conv)32 → 64132³ → 124³ etc.Doubling-before-pool: ch goes 32→64 here, BEFORE max-pool
After max-pool 164halved each axis2×2×2 stride 2 → 8× voxel reduction
Encoder L264 → 128shrinks by 4 voxels per axis from unpadded 3³ convstwo convs
Encoder L3128 → 256shrinkstwo convs
Encoder L4 (deepest analysis)256 → 512shrinkstwo convs
Bottleneck512smallestdeepest semantic representation
Decoder L4 → L1 (mirror)512 → 256 → 128 → 64doubles each axis per stage2×2×2 up-conv → concat skip → 2× 3³ conv
Output344 × 44 × 283 labels: inside-tubule / tubule / background. Receptive field 155×155×180 µm³

Total parameters: 19,069,95519{,}069{,}955 (exact figure from Çiçek 2016 §2). Batch size in the paper: 1 — volumetric activations are too large for larger batches on 2016 hardware, and this remains the norm today. BatchNorm placement: before each ReLU. With batch size 1 the running statistics are computed per-sample (effectively InstanceNorm); modern reimplementations (e.g. nnU-Net) often prefer GroupNorm\text{GroupNorm} for this reason.


Volumetric flow (interactive 3D)

Every box below is a feature-map volume. The encoder column on the left shrinks the spatial cube while the channel count grows; the decoder column on the right does the reverse; dashed lines are the skip connections that copy across each level. Drag to orbit the camera; hover the level buttons to highlight a matching pair on both sides.

Loading 3D volume flow

The shown spatial sizes (132 → 66 → 33 → 16 → 8) are didactic clean halves. The Çiçek 2016 paper uses unpadded 3³ convolutions, so the actual encoder outputs shrink by 4 voxels per axis at each stage (not exact halves) and the final output is 44×44×28 rather than 132×132×116. Modern padded reimplementations recover the clean halving you see here. The shape table in the previous section gives the paper's exact numbers.


Python from scratch: one 3D up-block

Just like the 2D version, the cleanest way to internalise a 3D U-Net is to build the single new operation — the volumetric up-block — from scratch in NumPy on tiny printable tensors. The block does three things in order: concatenate the skip along the channel axis, apply a 3×3×3 convolution, and (in a real network) the upsample-by-2 that produced the decoder feature map. Below we focus on the concat + 3³ conv, which is the part that actually changes vs 2D.

3D up-block: concat + 3×3×3 conv (NumPy)
🐍unet3d_up_block.py
1import numpy as np

NumPy is the only library this file needs. Volumetric tensors are just 4-D ndarrays — one channel axis plus three spatial axes (D, H, W). Every operation we use (np.empty, np.arange, np.tile, np.concatenate, slicing, np.sum, np.zeros) generalises seamlessly from 2-D to 3-D — that is the whole point of using NumPy here, before reaching for nn.Conv3d.

EXECUTION STATE
numpy = Library for numerical computing. ndarrays support any rank (1-D, 2-D, 3-D, 4-D, …); the slicing and broadcasting rules are identical regardless of rank.
as np = Standard alias. We write np.concatenate, np.zeros, np.sum.
5d = np.empty((2, 4, 4, 4)) — decoder feature map

Allocate a 4-D array with shape (C_up=2, D=4, H=4, W=4) — i.e. 2 channels of a 4×4×4 voxel volume. That is 2 × 64 = 128 floats. np.empty allocates the buffer without initialising the values; we fill them on the next two lines.

EXECUTION STATE
📚 np.empty(shape) = NumPy function: allocates an uninitialised array of the given shape. Faster than np.zeros because no fill pass; values are whatever was already in memory. We use it because we are about to overwrite every element.
⬇ shape: (2, 4, 4, 4) = (C_up=2, D=4, H=4, W=4). Axis 0 = channels. Axes 1–3 = volumetric spatial axes. Total elements = 2 × 4 × 4 × 4 = 128. dtype defaults to float64.
→ vs 2-D = The 2-D version one section earlier used (C, H, W). The 3-D version adds D (depth) at axis 1, pushing H to axis 2 and W to axis 3. Every later slice grows by one colon for the new axis.
7d[0] = np.tile(np.arange(4).reshape(4,1,1), (1,4,4))

Fill channel 0 with a depth-gradient: every voxel at depth d gets value d. d=0 plane is all 0.0, d=1 plane is all 1.0, d=2 is all 2.0, d=3 is all 3.0. We use this so the convolution output below visibly varies with depth — exactly the cross-slice context that 2-D processing throws away.

EXECUTION STATE
📚 np.arange(4, dtype=float) = Returns [0., 1., 2., 3.] — floats, length 4. Like Python range() but as a NumPy array. dtype=float forces float64 instead of int64 so the later multiply with float weights stays numeric.
📚 .reshape(4, 1, 1) = Reshape (4,) → (4, 1, 1). The two new size-1 axes are placeholders that broadcasting will expand. The data layout in memory is unchanged — only the shape metadata changes.
📚 np.tile(arr, (1, 4, 4)) = Repeats arr along each axis by the given factors. (4,1,1) tiled by (1,4,4) becomes (4,4,4) — the depth axis is left alone (factor 1) and each row/column is copied 4 times. Result: full 4×4×4 volume where d[0][k,:,:] = k.
⬆ d[0] — shape (4, 4, 4) =
depth=0:                depth=1:                depth=2:                depth=3:
0.0 0.0 0.0 0.0      1.0 1.0 1.0 1.0      2.0 2.0 2.0 2.0      3.0 3.0 3.0 3.0
0.0 0.0 0.0 0.0      1.0 1.0 1.0 1.0      2.0 2.0 2.0 2.0      3.0 3.0 3.0 3.0
0.0 0.0 0.0 0.0      1.0 1.0 1.0 1.0      2.0 2.0 2.0 2.0      3.0 3.0 3.0 3.0
0.0 0.0 0.0 0.0      1.0 1.0 1.0 1.0      2.0 2.0 2.0 2.0      3.0 3.0 3.0 3.0
9d[1] = 0.5 — uniform context channel

Fill channel 1 with the constant 0.5 across all 64 voxels. Broadcasting a scalar into a (4,4,4) slice fills every position. This channel carries no spatial detail; in a real network it would be a 'this region is foreground-ish' semantic signal. With a uniform value the conv output for ch1 is the same at every position — easy to verify.

EXECUTION STATE
broadcasting scalar → array = Assigning a Python float to a NumPy slice fills every element of that slice. d[1].shape = (4,4,4); after this line every one of those 64 entries is 0.5.
⬆ d[1] — shape (4, 4, 4), all 0.5 =
Every depth slice is the same 4×4 plane of 0.5s:
0.5 0.5 0.5 0.5
0.5 0.5 0.5 0.5
0.5 0.5 0.5 0.5
0.5 0.5 0.5 0.5
10print("d shape:", d.shape)

Verification print. Emits: d shape: (2, 4, 4, 4). One channels axis + three spatial axes — the canonical PyTorch/NumPy 3-D layout (without batch).

EXECUTION STATE
d.shape = (2, 4, 4, 4) — (C_up=2, D=4, H=4, W=4) ✓
11print("d[0] depth-slice 0 (all 0.0):", d[0, 0])

Print one 2-D slice of the 4-D tensor. d[0, 0] is channel 0 at depth 0 — a (4, 4) plane. Indexing a 4-D array with two ints reduces it to a 2-D array; equivalent to d[0][0]. Value: a 4×4 plane of zeros.

EXECUTION STATE
d[0, 0] =
Channel 0, depth 0:
0.0  0.0  0.0  0.0
0.0  0.0  0.0  0.0
0.0  0.0  0.0  0.0
0.0  0.0  0.0  0.0
12print("d[0] depth-slice 1 (all 1.0):", d[0, 1])

Channel 0 at depth 1 — a 4×4 plane of 1.0s. Together with the previous print, this confirms the depth-gradient: every voxel at depth d has value d.

EXECUTION STATE
d[0, 1] =
Channel 0, depth 1:
1.0  1.0  1.0  1.0
1.0  1.0  1.0  1.0
1.0  1.0  1.0  1.0
1.0  1.0  1.0  1.0
16s = np.ones((1, 4, 4, 4)) — encoder skip

The encoder skip at the matching level. Single channel of a 4×4×4 volume, uniform value 1.0. We use the modern padded-conv version of the up-block, so the skip and the decoder feature map already have the same spatial size — no center-crop needed (contrast with the 2-D unpadded version above, which had to crop 6×6 → 4×4).

EXECUTION STATE
📚 np.ones(shape) = Allocates and fills an array with 1.0. Shape arg is a tuple. Default dtype is float64. np.ones((1,4,4,4)) ≡ np.full((1,4,4,4), 1.0).
⬆ s — shape (1, 4, 4, 4) =
1 channel × 64 voxels = 64 ones. Every depth slice is the same 4×4 plane:
1.0  1.0  1.0  1.0
1.0  1.0  1.0  1.0
1.0  1.0  1.0  1.0
1.0  1.0  1.0  1.0
→ why uniform? = Constant inputs make the convolution output trivially verifiable per-position. In a real network the skip would carry the encoder's high-resolution detail at this level — e.g. a tumour boundary in 3-D.
17print("s shape:", s.shape)

Verification print. Emits: s shape: (1, 4, 4, 4). Spatial axes match d (D=H=W=4); channel count differs (C_skip=1, C_up=2). Concatenation along axis 0 will yield 2+1 = 3 channels.

EXECUTION STATE
s.shape = (1, 4, 4, 4) ✓
20fused = np.concatenate([d, s], axis=0)

The defining U-Net move, lifted to 3-D: stack the decoder feature map and the encoder skip along the channel axis. axis=0 is channels in our (C, D, H, W) layout. d has 2 channels, s has 1; the result has 2+1 = 3. The three spatial axes are unchanged because they match exactly on both inputs.

EXECUTION STATE
📚 np.concatenate(arrays, axis) = Joins arrays along an existing axis. All inputs must agree on every axis EXCEPT the concat axis. Generalises to any rank — the same call works for 2-D and 3-D feature maps without code changes.
⬇ arg 1: [d, s] = Two arrays. d.shape=(2,4,4,4); s.shape=(1,4,4,4). Spatial axes (D,H,W)=(4,4,4) match — required for axis=0 concat.
⬇ arg 2: axis=0 = Concatenate along the FIRST axis (channels). axis=1 would stack along depth (would produce 8-deep volume — wrong). axis=2 along height, axis=3 along width — also wrong. axis=0 is the only correct choice.
⬆ fused — shape (3, 4, 4, 4) = 3 channels: ch0=d[0] (depth-gradient), ch1=d[1] (uniform 0.5), ch2=s[0] (uniform 1.0). 192 floats total.
21print("fused shape:", fused.shape)

Verification print. Emits: fused shape: (3, 4, 4, 4). The next 3-D conv will see 3 input channels and a 4×4×4 voxel volume.

EXECUTION STATE
fused.shape = (3, 4, 4, 4) — 3 channels, 4×4×4 spatial ✓
25W = np.zeros((1, 3, 3, 3, 3)) — 3-D conv weight tensor

Allocate the convolution weight tensor for a 3-D conv. Shape (out_C=1, in_C=3, kD=3, kH=3, kW=3). The two extra axes vs 2-D — kD and one of kH/kW — turn the kernel from a 3×3 patch into a 3×3×3 cube. A 3×3 kernel had 9 weights per channel; a 3×3×3 kernel has 27. That cubic factor is why 3-D convs are roughly 3× more parameters than 2-D for the same channel count, and why GPU memory is the dominant constraint in 3-D segmentation.

EXECUTION STATE
📚 np.zeros((1, 3, 3, 3, 3)) = Allocates a 5-D float64 array of all zeros. Total elements 1×3×3×3×3 = 81. We will fill three 3×3×3 slices (one per input channel) on the next three lines.
⬇ shape (1, 3, 3, 3, 3) = out_C=1: one output feature map. in_C=3: matches fused channels. kD=3, kH=3, kW=3: 3×3×3 spatial cube. Total weights: 1×3×27 = 81 floats. (Compare 2-D conv with same in/out: 1×3×9 = 27 — exactly 3× fewer.)
26W[0, 0] = 1.0 — depth-gradient channel weight

Set the 3×3×3 kernel for output channel 0, input channel 0 (the depth-gradient ch) to all 1.0. Broadcasting a scalar fills the 27-entry slice. Every voxel of the depth-gradient patch contributes its value to the output unchanged.

EXECUTION STATE
W[0, 0] — shape (3, 3, 3) = All 27 entries are 1.0. The kernel acts as a 'sum the 3³ patch' operator on the depth-gradient channel.
27W[0, 1] = 0.5 — uniform channel weight

3×3×3 kernel for input channel 1 (the uniform-0.5 channel) is all 0.5. Each entry contributes 0.5 × 0.5 = 0.25 to the output. With 27 patch entries the total contribution per output position is 27 × 0.25 = 6.75 — constant, since input ch1 is uniform.

EXECUTION STATE
W[0, 1] — shape (3, 3, 3) = All 27 entries are 0.5. Combined with the uniform input value 0.5, every output position gets contribution 0.5 × 0.5 × 27 = 6.75.
28W[0, 2] = 2.0 — skip channel weight

3×3×3 kernel for input channel 2 (the skip) is all 2.0. With the skip value uniform 1.0, every patch entry contributes 1.0 × 2.0 = 2.0; total contribution per output position = 27 × 2.0 = 54.0. Asymmetric weighting (skip > decoder) simulates a learned 'pay extra attention to the skip' decision.

EXECUTION STATE
W[0, 2] — shape (3, 3, 3) = All 27 entries are 2.0. Contribution per output position = 1.0 × 2.0 × 27 = 54.0 (constant — input is uniform).
29b = np.zeros(1)

Bias vector for the single output channel, initialised to 0. Shape (1,). In a trained network this is one learned scalar added to every voxel of the output.

EXECUTION STATE
b = array([0.]) — one bias per output channel
32D_out = fused.shape[1] - 2

Output depth for an unpadded 3-D conv. With kernel size 3 and stride 1, the depth axis loses 2 voxels (one at each end). 4 - 2 = 2 valid depth positions. Convolution arithmetic is identical to the 2-D case — see §10.3 — but applies independently to every spatial axis.

EXECUTION STATE
fused.shape[1] = 4 — depth of the fused tensor
D_out = 2 — two output depth slices
33H_out = fused.shape[2] - 2

Output height. 4 - 2 = 2.

EXECUTION STATE
H_out = 2
34W_out = fused.shape[3] - 2

Output width. 4 - 2 = 2.

EXECUTION STATE
W_out = 2
35y = np.zeros((1, D_out, H_out, W_out))

Allocate the output tensor: 1 channel × 2 × 2 × 2 = 8 output voxels. We fill each via the triple-nested loop below.

EXECUTION STATE
y — shape (1, 2, 2, 2) = Eight zeros laid out as two 2×2 depth slices.
36for od in range(D_out): — outer depth loop

Iterate over the 2 valid output depth positions. For each od, the input patch covers depths od, od+1, od+2 of the fused tensor. This is the new dimension vs 2-D — each output value now sees three input depth slices, capturing cross-slice context.

LOOP TRACE · 2 iterations
od=0
depth slices in patch = fused[:, 0:3, :, :] — depths 0, 1, 2 of fused
ch0 patch values per depth = depth 0: all 0.0 depth 1: all 1.0 depth 2: all 2.0 sum across (3×3=9 entries per depth × 3 depths) = 9×(0+1+2) = 27
od=1
depth slices in patch = fused[:, 1:4, :, :] — depths 1, 2, 3 of fused
ch0 patch values per depth = depth 1: all 1.0 depth 2: all 2.0 depth 3: all 3.0 sum = 9×(1+2+3) = 54
37for oh in range(H_out): — middle height loop

Iterate over the 2 valid output height positions. Combined with the outer loop we now have 4 (od, oh) pairs.

LOOP TRACE · 2 iterations
oh=0
patch rows = rows 0,1,2
oh=1
patch rows = rows 1,2,3
38for ow in range(W_out): — inner width loop

Iterate over the 2 valid output width positions. Triple-nested loop produces 2×2×2 = 8 output positions total. Each iteration extracts a 3×3×3 patch from the (3, 4, 4, 4) fused tensor and convolves with the (3, 3, 3, 3) weight slice.

LOOP TRACE · 8 iterations
od=0, oh=0, ow=0 → y[0,0,0,0]
ch0 patch sum × 1.0 = 9×(0+1+2) = 27.0
ch1 patch sum × 0.5 = 0.5×0.5×27 = 6.75
ch2 patch sum × 2.0 = 1.0×2.0×27 = 54.0
y[0,0,0,0] = 27.0 + 6.75 + 54.0 = 87.75
od=0, oh=0, ow=1 → y[0,0,0,1]
all three channel contributions identical = Inputs invariant in W direction; same 27 / 6.75 / 54 totals
y[0,0,0,1] = 87.75
od=0, oh=1, ow=0 → y[0,0,1,0]
still no W-direction variation = Inputs uniform in (H, W) within each depth
y[0,0,1,0] = 87.75
od=0, oh=1, ow=1 → y[0,0,1,1]
y[0,0,1,1] = 87.75
od=1, oh=0, ow=0 → y[0,1,0,0]
ch0 patch sum × 1.0 = 9×(1+2+3) = 54.0 ← changed!
ch1 patch sum × 0.5 = 6.75 (uniform in d)
ch2 patch sum × 2.0 = 54.0 (uniform in d)
y[0,1,0,0] = 54.0 + 6.75 + 54.0 = 114.75
od=1, oh=0, ow=1 → y[0,1,0,1]
y[0,1,0,1] = 114.75
od=1, oh=1, ow=0 → y[0,1,1,0]
y[0,1,1,0] = 114.75
od=1, oh=1, ow=1 → y[0,1,1,1]
y[0,1,1,1] = 114.75
40patch = fused[:, od:od+3, oh:oh+3, ow:ow+3]

Extract a 3×3×3 patch across all 3 input channels. Slicing returns a (3, 3, 3, 3) view — no data is copied. Compared with 2-D, we now have one extra spatial colon-slice on axis 1 (depth). This single line is the only structural difference between a 2-D and 3-D conv-loop body.

EXECUTION STATE
📚 NumPy 4-D slice = : on axis 0 = all channels. od:od+3 on axis 1 (depth) = 3 depth slices. oh:oh+3 on axis 2 (height) = 3 rows. ow:ow+3 on axis 3 (width) = 3 cols. Returns a VIEW into fused.
patch — shape (3, 3, 3, 3) = 3 channels × 3 depths × 3 rows × 3 cols = 81 values per output position.
41y[0, od, oh, ow] = np.sum(patch * W[0]) + b[0]

The 3-D convolution at one output voxel. (1) element-wise multiply patch (3×3×3×3) with W[0] (3×3×3×3) — same shape, 81 multiplications; (2) sum all 81 products to a scalar; (3) add bias b[0]=0. This is the convolution formula written explicitly without any library magic — and it is exactly the 2-D formula with one extra spatial axis under the sum.

EXECUTION STATE
📚 patch * W[0] = Element-wise multiply two (3,3,3,3) arrays. Returns a (3,3,3,3) array of products. Not @ (matmul). Identical operation to the 2-D case but over 81 entries instead of 27.
📚 np.sum(...) = Collapses all 81 elements to one scalar. With no axis argument, sums over EVERY axis. Equivalent to .flatten().sum().
⬇ y[0, 0, 0, 0] = ch0 contribution: 0+1+2 across 9 entries each = 27.0 ch1 contribution: 0.5×0.5×27 = 6.75 ch2 contribution: 1.0×2.0×27 = 54.0 ────────────────────── y[0,0,0,0] = 27.0 + 6.75 + 54.0 = 87.75
⬆ full output y[0] =
depth=0 (od=0):
87.75   87.75
87.75   87.75

depth=1 (od=1):
114.75  114.75
114.75  114.75

Values vary with depth — exactly because the conv kernel saw three different input depth slices. A 2-D conv applied slice-by-slice could NEVER produce this variation; that's the whole point.
43print("y shape:", y.shape)

Verification print. Emits: y shape: (1, 2, 2, 2). One output channel, 2×2×2 spatial volume — exactly what the unpadded shape arithmetic (4 − 3 + 1 = 2 along each spatial axis) predicts.

EXECUTION STATE
y.shape = (1, 2, 2, 2) ✓
44print("y[0] depth-slice 0:", y[0, 0])

Print the first depth slice of the output channel. y[0, 0] is shape (2, 2). Expected: every entry 87.75.

EXECUTION STATE
y[0, 0] =
Shape (2, 2):
87.75   87.75
87.75   87.75
45print("y[0] depth-slice 1:", y[0, 1])

Print the second depth slice. Every entry 114.75. The +27.0 jump between depth slices comes entirely from the depth-gradient channel — which is exactly the cross-slice context that motivates 3-D U-Net in the first place.

EXECUTION STATE
y[0, 1] =
Shape (2, 2):
114.75  114.75
114.75  114.75

Difference between depth-0 and depth-1 outputs = 114.75 - 87.75 = 27.0 — exactly the 3³ patch sum of the depth-gradient (each kd contributes 9 voxels, depth values shift by +1 across the patch).
18 lines without explanation
1import numpy as np
2
3# ─── Tiny volumetric tensors so every value is printable ───────────────────
4# Decoder side after a 2x2x2 stride-2 transposed conv: shape (C_up, D, H, W)
5# C_up = 2 channels, D=H=W=4 voxels along each spatial axis.
6d = np.empty((2, 4, 4, 4))
7# ch0 = depth-gradient: value at depth d is just d itself.
8d[0] = np.tile(np.arange(4, dtype=float).reshape(4, 1, 1), (1, 4, 4))
9# ch1 = uniform 0.5 (carries no spatial information; pure semantic context).
10d[1] = 0.5
11print("d shape:", d.shape)
12print("d[0] depth-slice 0 (all 0.0):\n", d[0, 0])
13print("d[0] depth-slice 1 (all 1.0):\n", d[0, 1])
14
15# Encoder skip at SAME level. Modern padded version: same spatial size as d.
16# Single channel, uniform 1.0 — stands in for high-resolution boundary detail.
17s = np.ones((1, 4, 4, 4))
18print("s shape:", s.shape)
19
20# --- Step 1: concatenate along the channel axis (axis 0) ---
21fused = np.concatenate([d, s], axis=0)
22print("fused shape:", fused.shape)   # (3, 4, 4, 4)
23
24# --- Step 2: hand-set 3D conv weights ---
25# Kernel shape (out_C, in_C, kD, kH, kW) — one extra spatial axis vs 2D.
26W = np.zeros((1, 3, 3, 3, 3))
27W[0, 0] = 1.0   # decoder ch0 (depth-gradient) full weight
28W[0, 1] = 0.5   # decoder ch1 (uniform) half weight
29W[0, 2] = 2.0   # skip channel double weight
30b = np.zeros(1)
31
32# --- Step 3: unpadded 3D convolution. Output spatial = 4 - 3 + 1 = 2. ---
33D_out = fused.shape[1] - 2
34H_out = fused.shape[2] - 2
35W_out = fused.shape[3] - 2
36y = np.zeros((1, D_out, H_out, W_out))
37for od in range(D_out):
38    for oh in range(H_out):
39        for ow in range(W_out):
40            # 3x3x3 patch across all 3 input channels = (3, 3, 3, 3)
41            patch = fused[:, od:od+3, oh:oh+3, ow:ow+3]
42            y[0, od, oh, ow] = np.sum(patch * W[0]) + b[0]
43
44print("y shape:", y.shape)
45print("y[0] depth-slice 0:\n", y[0, 0])
46print("y[0] depth-slice 1:\n", y[0, 1])

The pay-off is the very last printed slice. Depth slice 0 is uniform 87.75; depth slice 1 is uniform 114.75. The +27.0 jump comes from the 3³ kernel summing three different input depth values along the depth axis (0+1+2 vs 1+2+3). A 2D conv applied slice-by-slice could not produce that variation — it would output the same value at every position on every slice. That Δ = 27 is the literal numerical signature of cross-slice context.


PyTorch: full UNet3D module

Now in PyTorch, with the same four-file decomposition we used for the 2D version — DoubleConv3D, Down3D, Up3D, UNet3D. Every line is the 3D counterpart of a line you have already read in the 2D PyTorch implementation. Read them side-by-side: the diff is the smallest possible while still capturing volumetric context.

1. DoubleConv3D — the volumetric atom

DoubleConv3D: two padded 3×3×3 convs with BN+ReLU
🐍unet3d/double_conv.py
1import torch

PyTorch core. Provides torch.Tensor, .randn(), and the autograd engine that records every op in this file for backprop.

2import torch.nn as nn

PyTorch's neural-network module zoo. Exposes Conv3d, BatchNorm3d, ReLU, MaxPool3d, ConvTranspose3d — every 3-D op we need. Each is the direct 3-D analogue of its 2-D counterpart used earlier in this section.

4class DoubleConv3D(nn.Module)

The atomic building block of UNet3D. Same 'two convs + BN + ReLU' pattern as 2-D, with every op replaced by its 3-D version. nn.Module hooks the layer into PyTorch's parameter, .to(device), .train()/.eval(), and serialisation systems.

12def __init__(self, in_channels: int, out_channels: int)

Same signature as the 2-D version. Caller controls the channel-doubling pattern (e.g. 32→64→128→256 for the encoder).

EXECUTION STATE
in_channels = e.g. 32 (output of the previous stage)
out_channels = e.g. 64 (this stage's channel count)
15📚 nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False)

📚 PyTorch's 3-D convolution. With kernel_size=3 and padding=1, output spatial size = input spatial size in every axis. Each filter has in_channels × 3 × 3 × 3 = 27·in_channels weights — three times the 2-D Conv2d count for the same channels. bias=False because BatchNorm3d below has its own learnable bias.

EXAMPLE
in=(1, 32, 16, 16, 16)  →  out=(1, 64, 16, 16, 16)
param count for this layer = 32·64·27 = 55,296
16📚 nn.BatchNorm3d(out_channels)

📚 Per-channel batch normalisation across (B, D, H, W). One γ and one β per channel; running mean/var tracked across batches. Identical role to BatchNorm2d but the running statistics are computed across the larger (B·D·H·W) volumetric population. Note the 3-D U-Net paper uses batch=1; with batch=1 BN's batch-mean is the per-sample mean, behaving like InstanceNorm. Modern reimplementations often switch to GroupNorm when batch=1.

17📚 nn.ReLU(inplace=True)

📚 ReLU activation, in-place to save GPU memory. Memory matters more in 3-D than in 2-D: a (1, 64, 64, 64, 64) activation is 64× larger than a (1, 64, 64, 64) 2-D activation.

18Second 3×3×3 conv (mirror of line 15)

Same conv as line 15 but in_channels=out_channels now. The 'double' in DoubleConv3D refers to this pair — two convs at this resolution before the next stage.

24def forward(self, x)

Forward pass. Passes x through the Sequential. Autograd records every op so backward() works.

EXECUTION STATE
x =
input tensor of shape (B, in_channels, D, H, W) — 5-D vs 2-D's 4-D
28Sanity tensor (1, 32, 16, 16, 16)

1 batch × 32 channels × 16³ voxels. 16³ × 32 × 4 bytes ≈ 0.5 MB — tiny, just to verify shapes. Real BraTS patches are typically (4, 128, 128, 128) ≈ 32 MB per sample.

30block(x)

Run DoubleConv3D forward. Output shape: (1, 64, 16, 16, 16) — channels doubled (32→64), every spatial axis preserved (padding=1 + kernel 3 keeps size).

31print("DoubleConv3D: in=...")

Verification print. Expect: DoubleConv3D: in= (1, 32, 16, 16, 16) → out= (1, 64, 16, 16, 16).

18 lines without explanation
1import torch
2import torch.nn as nn
3
4class DoubleConv3D(nn.Module):
5    """Two padded 3x3x3 convolutions, each followed by BatchNorm3d and ReLU.
6
7    Direct 3-D analogue of the 2-D DoubleConv from earlier in this section:
8    every Conv2d → Conv3d, every BatchNorm2d → BatchNorm3d.
9    Padding=1 keeps D, H, W unchanged so output size = input size.
10    """
11
12    def __init__(self, in_channels: int, out_channels: int):
13        super().__init__()
14        self.conv = nn.Sequential(
15            nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
16            nn.BatchNorm3d(out_channels),
17            nn.ReLU(inplace=True),
18            nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
19            nn.BatchNorm3d(out_channels),
20            nn.ReLU(inplace=True),
21        )
22
23    def forward(self, x: torch.Tensor) -> torch.Tensor:
24        return self.conv(x)
25
26# Sanity check on a tiny volumetric tensor
27x = torch.randn(1, 32, 16, 16, 16)
28block = DoubleConv3D(32, 64)
29y = block(x)
30print("DoubleConv3D: in=", tuple(x.shape), "→ out=", tuple(y.shape))

2. Down3D — encoder stage

Down3D: 2³ max-pool then DoubleConv3D
🐍unet3d/down.py
1class Down3D(nn.Module)

One encoder stage. The full 3-D U-Net has 4 of these (per the Çiçek 2016 paper, §2 Network Architecture).

9def __init__(self, in_channels, out_channels)

in_channels = previous stage's out_channels (e.g. 32). out_channels = roughly 2× in_channels following the doubling rule (e.g. 64).

EXECUTION STATE
in_channels = 32 (previous stage)
out_channels = 64 (channel-doubling per Çiçek 2016 §2)
11📚 nn.MaxPool3d(kernel_size=2)

📚 3-D max-pool with stride=2 (default = kernel_size). Halves D, H, AND W (each by factor 2). Channels are unchanged. Compared to 2-D MaxPool2d this throws away 7 of every 8 voxels (vs 3 of every 4 pixels in 2-D) — much more aggressive downsampling.

EXAMPLE
in=(1, 32, 16, 16, 16) → out=(1, 32, 8, 8, 8) — voxel count 4096 → 512 (8× reduction)
12DoubleConv3D(in_channels, out_channels)

After the pool, run DoubleConv3D. Final stage output: (B, out_channels, D/2, H/2, W/2).

16forward

Just hands x to the Sequential. The Sequential applies max-pool then DoubleConv3D in order.

19Sanity tensor (1, 32, 16, 16, 16)

Same input as DoubleConv3D's sanity tensor. Output should have 2× the channels and half the spatial extent in every axis.

21Stage forward

Expect: Down3D: in= (1, 32, 16, 16, 16) → out= (1, 64, 8, 8, 8). Voxel count drops 16³=4096 → 8³=512 (8× reduction); channels go 32 → 64 (2× increase). Net activation memory drops by 4×, freeing GPU memory for the deeper stages.

16 lines without explanation
1class Down3D(nn.Module):
2    """Encoder stage: 2x2x2 max-pool then DoubleConv3D.
3
4    D, H, and W all halve at the max-pool. Channels double inside DoubleConv3D
5    (caller passes the right counts, following the 'double-before-pool' rule
6    suggested in the original 3D U-Net paper).
7    """
8
9    def __init__(self, in_channels: int, out_channels: int):
10        super().__init__()
11        self.down = nn.Sequential(
12            nn.MaxPool3d(kernel_size=2),
13            DoubleConv3D(in_channels, out_channels),
14        )
15
16    def forward(self, x: torch.Tensor) -> torch.Tensor:
17        return self.down(x)
18
19# Sanity
20x = torch.randn(1, 32, 16, 16, 16)
21stage = Down3D(32, 64)
22y = stage(x)
23print("Down3D: in=", tuple(x.shape), "→ out=", tuple(y.shape))   # expect (1, 64, 8, 8, 8)

3. Up3D — decoder stage

Up3D: 3D upsample → concat skip → DoubleConv3D
🐍unet3d/up.py
1import torch.nn.functional as F

Functional API for stateless ops like F.pad. Distinct from torch.nn (modules with state).

3class Up3D(nn.Module)

Decoder stage. Three sub-operations in order: upsample, concat skip, double-conv. The whole thing is the 3-D analogue of Up from the 2-D section earlier.

8def __init__(self, in_channels, out_channels, trilinear=True)

trilinear=True: cheap non-learned interpolation, immune to checkerboard artefacts. trilinear=False: learned 2×2×2 transposed conv as in the original Çiçek 2016 paper.

EXECUTION STATE
trilinear = True (default)
10📚 nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)

📚 Non-learned 3-D interpolation. mode='trilinear' uses 8-corner-weighted interpolation (vs 4-corner bilinear in 2-D — one weighted neighbour per voxel-cube vertex). Doubles every spatial axis. align_corners=True puts voxel centres at the cube corners; relevant when chaining with skip connections that need exact alignment.

13📚 nn.ConvTranspose3d(in_channels//2, in_channels//2, kernel_size=2, stride=2)

📚 Learned 3-D upsample. Stride=2 with kernel=2 doubles D, H, W exactly. Original 3-D U-Net paper uses this. Risk: 3-D checkerboard artefacts (cube-shaped, not just square) — see §10.6 for the 2-D version of the analysis; the 3-D extension is mathematically identical along each axis.

19def forward(self, x, skip)

Two inputs: x is the decoder feature map from below; skip is the encoder feature map at the matching level (passed in by the parent UNet3D).

EXECUTION STATE
x =
(B, in_channels, D/2, H/2, W/2) — the deeper-resolution decoder map
skip =
(B, in_channels//2, D, H, W) — encoder skip at this level
21diff_d, diff_h, diff_w — three off-by-one safety nets

Three diffs vs the 2-D version's two. With even input sizes everywhere, all diffs are 0. With odd sizes any of them can be 1. We pad x before concat — same pattern as 2-D, just one more axis.

EXECUTION STATE
diff_d = 0 (typical) or 1 (if input D was odd)
diff_h = 0 typical
diff_w = 0 typical
24📚 F.pad(x, [w_left, w_right, h_top, h_bottom, d_front, d_back])

📚 Pads tensor with zeros. PyTorch's pad-list is in REVERSE axis order: width first, then height, then depth. Six entries for 3-D vs four for 2-D. Splitting odd diffs as d//2 / d − d//2 keeps the padding symmetric.

27📚 torch.cat([skip, x], dim=1)

📚 Channel-axis concatenation. dim=1 is channels in PyTorch's NCDHW layout (N=batch, C=channels, D=depth, H=height, W=width). After this, fused has channels = skip.channels + x.channels.

28return self.conv(x)

DoubleConv3D on the fused tensor. The 3×3×3 convs inside learn how to mix decoder semantics with encoder skip details.

32Sanity tensors

x: decoder map at the lower resolution (8³ voxels, 128 channels). skip: encoder map at the matching higher resolution (16³ voxels, 64 channels). The expected output is (1, 64, 16, 16, 16) — same spatial size as skip, halved channel count of in_channels.

34Stage forward

Up3D applies upsample (8³→16³) → channel concat (128 + 64 = 192 → wait, this shows the wiring detail) → DoubleConv3D back down to 64 channels.

25 lines without explanation
1import torch.nn.functional as F
2
3class Up3D(nn.Module):
4    """Decoder stage: 2x2x2 up-conv (or trilinear upsample), concat with skip,
5    then DoubleConv3D.
6    """
7
8    def __init__(self, in_channels: int, out_channels: int, trilinear: bool = True):
9        super().__init__()
10        if trilinear:
11            self.up = nn.Upsample(scale_factor=2, mode="trilinear", align_corners=True)
12            self.conv = DoubleConv3D(in_channels, out_channels)
13        else:
14            # Çiçek 2016: 2×2×2 transposed conv with stride 2 on each axis
15            self.up = nn.ConvTranspose3d(in_channels // 2, in_channels // 2,
16                                         kernel_size=2, stride=2)
17            self.conv = DoubleConv3D(in_channels, out_channels)
18
19    def forward(self, x: torch.Tensor, skip: torch.Tensor) -> torch.Tensor:
20        x = self.up(x)
21        # Pad in case any spatial axis was odd before the encoder pool
22        diff_d = skip.size(2) - x.size(2)
23        diff_h = skip.size(3) - x.size(3)
24        diff_w = skip.size(4) - x.size(4)
25        x = F.pad(x, [diff_w // 2, diff_w - diff_w // 2,
26                      diff_h // 2, diff_h - diff_h // 2,
27                      diff_d // 2, diff_d - diff_d // 2])
28        x = torch.cat([skip, x], dim=1)   # concat along channel axis
29        return self.conv(x)
30
31# Sanity
32x  = torch.randn(1, 128, 8, 8, 8)
33sk = torch.randn(1, 64, 16, 16, 16)
34stage = Up3D(128, 64, trilinear=True)
35y = stage(x, sk)
36print("Up3D: x=", tuple(x.shape), " skip=", tuple(sk.shape), " → out=", tuple(y.shape))
37# expected: (1, 64, 16, 16, 16)

4. UNet3D — full module

UNet3D: end-to-end volumetric network
🐍unet3d/unet3d.py
1class UNet3D(nn.Module)

Full network. Same topology as the 2-D UNet from earlier in this section: 5 encoder-side blocks (inc + 4 Down3Ds), 4 decoder-side blocks (Up3Ds), final 1×1×1 Conv3d to n_classes.

18def __init__(self, n_channels=1, n_classes=3, trilinear=True)

n_channels: 1 for single-modality (e.g. CT or one MRI sequence). 4 for multi-modal BraTS (T1, T1ce, T2, FLAIR stacked along the channel axis). n_classes: 3 for BraTS tumour subregions (whole tumour / core / enhancing). The Çiçek 2016 paper used n_classes=3 (inside-tubule / tubule / background).

EXECUTION STATE
n_channels = 1 default (single-modality CT or MRI)
n_classes = 3 default (per the Çiçek 2016 Xenopus kidney setup)
21self.inc = DoubleConv3D(n_channels, 32)

First block at full input resolution; produces 32 channels. The first encoder channel count of 32 (not 64 as in 2-D U-Net) is the key memory-saving trade-off in the original 3-D paper — 3-D activations are dimensionally larger so we start with fewer channels. Total params 19,069,955 ≈ 19M (Çiçek 2016 §2).

22self.down1 = Down3D(32, 64)

First encoder stage. Output: (B, 64, D/2, H/2, W/2). Same channel-doubling rule as 2-D, applied to D/H/W simultaneously (8× voxel reduction).

25factor = 2 if trilinear else 1

Same trick as in the 2-D version. Bilinear/trilinear upsample doesn't reduce channels, so we halve the bottleneck channel count to keep the up-conv input shape consistent with the learned ConvTranspose3d path.

EXECUTION STATE
factor = 2 (with trilinear=True)
26self.down4 = Down3D(256, 512 // factor)

Bottleneck. With trilinear=True: (B, 256, D/16, H/16, W/16). With trilinear=False: (B, 512, D/16, H/16, W/16).

28self.up1 = Up3D(512, 256 // factor, trilinear)

First decoder stage. Up3D's first arg is the SUM of decoder + skip channel counts after concat. With trilinear=True: x5 has 256 ch, x4 has 256 ch → after upsample-then-concat we have 512 ch, conv produces 128 ch.

33self.outc = nn.Conv3d(32, n_classes, kernel_size=1)

Final 1×1×1 convolution = per-voxel linear projection to n_classes logit channels. No activation — caller applies softmax (multi-class) or sigmoid (binary) inside the loss. The Çiçek 2016 paper uses weighted softmax loss; see the next section.

35def forward(self, x)

Save every encoder activation (x1..x5) so the decoder Up3D calls can use them as skips. Identical shape book-keeping to the 2-D UNet, with one extra spatial axis everywhere.

EXECUTION STATE
x1 =
(B, 32, D, H, W) — full-resolution skip → up4
x2 =
(B, 64, D/2, H/2, W/2) → up3
x3 =
(B, 128, D/4, H/4, W/4) → up2
x4 =
(B, 256, D/8, H/8, W/8) → up1
x5 =
(B, 256 or 512, D/16, H/16, W/16) — bottleneck
47Smoke test on (1, 1, 128, 128, 128)

128³ is a typical 3-D U-Net training patch (e.g. nnU-Net default for many BraTS-style problems). Activation memory at the input layer alone: 1·32·128³·4 bytes ≈ 256 MB — already a substantial fraction of GPU memory. This is why batch=1 patch-based training is the norm in 3-D segmentation.

50Trainable params print

With trilinear=True, n_channels=1, n_classes=3 the count is roughly 19 million — matching the 19,069,955 figure quoted in Çiçek 2016 §2 (their exact figure assumes ConvTranspose3d everywhere). The trilinear+factor-halving variant lands a bit below that count.

42 lines without explanation
1class UNet3D(nn.Module):
2    """Full 3-D U-Net for volumetric semantic segmentation.
3
4    Channel pattern follows the original Çiçek 2016 paper:
5    32 → 64 → 128 → 256 → 512  (encoder, with doubling-before-pool).
6
7    Args
8    ----
9    n_channels : input channels (1 for single-modality MRI/CT,
10                 4 for multi-modal BraTS = T1, T1ce, T2, FLAIR)
11    n_classes  : output channels (number of segmentation classes;
12                 BraTS uses 4 = background + 3 tumour subregions)
13    trilinear  : True → trilinear upsample (cheaper, no checkerboard);
14                 False → learned ConvTranspose3d (Çiçek 2016 default)
15    """
16
17    def __init__(self, n_channels: int = 1, n_classes: int = 3, trilinear: bool = True):
18        super().__init__()
19        # encoder
20        self.inc   = DoubleConv3D(n_channels, 32)
21        self.down1 = Down3D(32, 64)
22        self.down2 = Down3D(64, 128)
23        self.down3 = Down3D(128, 256)
24        factor = 2 if trilinear else 1
25        self.down4 = Down3D(256, 512 // factor)
26        # decoder
27        self.up1 = Up3D(512, 256 // factor, trilinear)
28        self.up2 = Up3D(256, 128 // factor, trilinear)
29        self.up3 = Up3D(128, 64 // factor, trilinear)
30        self.up4 = Up3D(64, 32, trilinear)
31        # final 1x1x1 conv to n_classes channels (Çiçek 2016 §2, last paragraph)
32        self.outc = nn.Conv3d(32, n_classes, kernel_size=1)
33
34    def forward(self, x: torch.Tensor) -> torch.Tensor:
35        x1 = self.inc(x)        # (B,  32, D,    H,    W)
36        x2 = self.down1(x1)     # (B,  64, D/2,  H/2,  W/2)
37        x3 = self.down2(x2)     # (B, 128, D/4,  H/4,  W/4)
38        x4 = self.down3(x3)     # (B, 256, D/8,  H/8,  W/8)
39        x5 = self.down4(x4)     # (B, 256 or 512, D/16, H/16, W/16)  ← bottleneck
40
41        x = self.up1(x5, x4)    # (B, 128 or 256, D/8,  H/8,  W/8)
42        x = self.up2(x,  x3)    # (B,  64 or 128, D/4,  H/4,  W/4)
43        x = self.up3(x,  x2)    # (B,  32 or  64, D/2,  H/2,  W/2)
44        x = self.up4(x,  x1)    # (B,  32,        D,    H,    W)
45        return self.outc(x)     # (B, n_classes,  D,    H,    W)
46
47# Smoke test on a (1, 1, 128, 128, 128) volume — typical patch size
48model = UNet3D(n_channels=1, n_classes=3, trilinear=True)
49x = torch.randn(1, 1, 128, 128, 128)
50y = model(x)
51print("UNet3D: in=", tuple(x.shape), "→ logits=", tuple(y.shape))
52n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
53print(f"Trainable params: {n_params / 1e6:.2f}M")
GPU memory budget. A single forward pass of the model above on a (1, 1, 128, 128, 128) input keeps roughly 1.4 GB of activations alive for backprop. On a 16 GB card that leaves room for batch=1 and not much else. Two routine fixes: (1) patch-based training — sample 96³ or 128³ patches from the full volume rather than feeding the whole scan; (2) gradient checkpointing— trade compute for memory by not storing intermediate activations. Both are standard in modern 3D segmentation pipelines (nnU-Net does both automatically).

Sparse-annotation loss & 3D Dice

The Çiçek 2016 paper's key training contribution is not the 3D extension — that is mechanical — but the loss it uses to train from sparse 2D annotations on 3D volumes. Annotating every voxel of a 132×132×116 microscopy volume is impractical; annotating a few orthogonal xy/xz/yz slices is tractable. The trick: a per-voxel weighted softmax with the weight set to zero on unlabelled voxels:

L=xΩw(x)logp(x)(x),w(x)=0 when x is unlabelled\mathcal{L} = -\sum_{\mathbf{x} \in \Omega} w(\mathbf{x})\,\log p_{\ell(\mathbf{x})}(\mathbf{x}), \qquad w(\mathbf{x}) = 0 \text{ when } \mathbf{x} \text{ is unlabelled}

where Ω\Omega is the voxel grid, pc(x)p_c(\mathbf{x}) is the softmax probability for class cc at voxel x\mathbf{x}, (x)\ell(\mathbf{x}) is the ground-truth label, and w(x)0w(\mathbf{x}) \geq 0 is the per-voxel weight. Setting w=0w = 0 on unlabelled voxels makes them invisible to the gradient, so the network learns only from the annotated slices but predicts on the whole volume at inference. Setting w>1w > 1 on rare classes (small tumours) up-weights them — the same idea as class re-balancing in 2D, lifted to 3D.

For fully-annotated volumetric tasks (BraTS, KiTS) the dominant loss is the 3D soft Dice loss from Milletari et al. 2016 (V-Net):

LDice(3D)=12xp(x)g(x)+εxp(x)+xg(x)+ε\mathcal{L}_{\text{Dice}}^{(3D)} = 1 - \frac{2 \sum_{\mathbf{x}} p(\mathbf{x})\,g(\mathbf{x}) + \varepsilon}{\sum_{\mathbf{x}} p(\mathbf{x}) + \sum_{\mathbf{x}} g(\mathbf{x}) + \varepsilon}

identical to the 2D Dice from loss-functions above except the sums run over voxels (3D) rather than pixels (2D). Production pipelines almost always use Dice + cross-entropy as a sum: cross-entropy gives a clean per-voxel signal early in training, Dice gives the shape-level signal that handles the severe foreground/background imbalance typical of 3D tumour volumes (a glioma is often < 1% of brain voxels).


3D U-Net in production (BraTS, KiTS, nnU-Net)

3D U-Net is the workhorse of clinical-grade volumetric segmentation. The two patterns below reappear across essentially every public 3D segmentation benchmark since 2018: (1) a U-Net-shaped 3D backbone with concat skips, (2) Dice + CE loss with patch-based training. Variants of nnU-Net (which auto-configures both) currently lead most public leaderboards.

Benchmark / datasetTaskWhy 3D U-Net (and what wins)
BraTS (Brain Tumour Segmentation, 2012–present)Multi-modal MRI segmentation of glioma into 3 sub-regions: whole tumour, tumour core, enhancing tumour4-channel input (T1, T1ce, T2, FLAIR). 3D context is critical because lesions span tens of slices. Most recent winners are nnU-Net variants. (Menze et al. 2015, IEEE TMI; Bakas et al. 2018)
KiTS19 / KiTS21 (Kidney Tumor)Kidney + tumour + cyst segmentation in contrast-enhanced abdominal CT3D shape prior of the kidney is strong; 2D slice-by-slice loses both ends of the organ. Top KiTS19 entry was an ensembled 3D U-Net (Heller et al. 2021, Med. Image Anal.)
LiTS (Liver Tumor Segmentation)Liver + lesion segmentation in abdominal CTTumours often span 5–20 slices and have weak 2D contrast; 3D conv recovers them. (Bilic et al. 2023, Med. Image Anal.)
Medical Segmentation Decathlon (MSD)10 different 3D tasks: brain, liver, hippocampus, lung, prostate, pancreas, hepatic vessels, spleen, colon, cardiacThe MSD was won by nnU-Net out-of-the-box on 7/10 tasks, demonstrating that a single self-configuring 3D U-Net pipeline beats task-specific custom networks. (Antonelli et al. 2022, Nat. Comm.)
Industrial CT (defect inspection)Cracks, voids, inclusions in cast metal parts scanned with industrial CTSame algorithm; medical pre-trained checkpoints transfer well. 3D context disambiguates spheres-of-revolution (real defects) from ring artefacts (acquisition flaws).
Cryo-electron tomographyParticle picking / membrane segmentation in cellular tomogramsResolution is ~10× lower than confocal; 3D U-Net + sparse-annotation loss (Çiçek-style) is well-suited because dense annotation is infeasible.
Seismic interpretationSalt-body / fault segmentation in 3D seismic volumes (energy industry)The 3D structure of geological faults is fundamentally volumetric; same 3D U-Net architecture is used unchanged on seismic amplitudes.

Two practical pointers for anyone shipping a 3D U-Net:

  1. Patch sampling and overlap. Train on 96–128³ random patches; at inference use a sliding window with ~50% overlap and Gaussian-weighted averaging at patch borders. This is what nnU-Net does by default.
  2. Resampling to a canonical voxel size. Volumetric medical data has wildly varying voxel spacing across scanners. Resample every volume to a fixed mm-per-voxel (median of the training set) before training; reverse at inference. This single step often outweighs architecture tweaks.

Other variants (Attention U-Net, U-Net++, …)

Beyond the 3D extension above, a decade of 2D U-Net variants has sharpened the network for specific failure modes — cluttered scenes, label scarcity, transformer backbones, and full clinical pipelines. None of them change the "encoder-decoder with concat skips" idea; each tweaks one piece:

VariantYearWhat it changesWhen to reach for it
Attention U-Net2018Adds attention gates on each skip connection to suppress irrelevant skip featuresCluttered backgrounds where the encoder skip carries noise the decoder must ignore
U-Net++2018Replaces single skips with a nested grid of intermediate convs (the ++ in the name)Deep supervision wanted; better gradient flow at intermediate decoder depths
TransUNet2021Replaces the bottleneck with a Vision Transformer (ViT)Long-range context is critical (large organs, long roads). See Ch 18 for ViT.
nnU-Net2021Not a network change — a fully automated pipeline that picks U-Net hyperparams from dataReal clinical work. State-of-the-art on most medical-segmentation benchmarks out of the box.

References

Primary paper. Ronneberger, O., Fischer, P., & Brox, T. (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation. In MICCAI 2015. arXiv:1505.04597.

Variants.

  • Çiçek, Ö., Abdulkadir, A., Lienkamp, S. S., Brox, T., & Ronneberger, O. (2016). 3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation. MICCAI 2016. arXiv:1606.06650. — architecture used in the 3D-U-Net deep dive above (4 resolution steps, channel doubling before max-pool, 19,069,955 params).
  • Milletari, F., Navab, N., & Ahmadi, S.-A. (2016). V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation. 3DV 2016. arXiv:1606.04797. — soft Dice loss for 3D segmentation; close design contemporary to 3D U-Net.
  • Szegedy, C., Vanhoucke, V., Ioffe, S., Shlens, J., & Wojna, Z. (2015). Rethinking the Inception Architecture for Computer Vision. arXiv:1512.00567. — source of the "avoid representational bottlenecks" rule cited in Çiçek 2016 §2 (channel doubling before pooling).
  • Oktay, O. et al. (2018). Attention U-Net: Learning Where to Look for the Pancreas. arXiv:1804.03999.
  • Zhou, Z., Siddiquee, M. M. R., Tajbakhsh, N., & Liang, J. (2018). UNet++: A Nested U-Net Architecture for Medical Image Segmentation. DLMIA. arXiv:1807.10165.
  • Chen, J. et al. (2021). TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation. arXiv:2102.04306.
  • Isensee, F., Jaeger, P. F., Kohl, S. A. A., Petersen, J., & Maier-Hein, K. H. (2021). nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature Methods 18, 203–211. DOI: 10.1038/s41592-020-01008-z.

Loss functions cited above.

  • Salehi, S. S. M., Erdogmus, D., & Gholipour, A. (2017). Tversky loss function for image segmentation using 3D fully convolutional deep networks. MLMI. arXiv:1706.05721.

3D segmentation benchmarks cited above.

  • Menze, B. H. et al. (2015). The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS). IEEE Transactions on Medical Imaging 34(10), 1993–2024. DOI: 10.1109/TMI.2014.2377694.
  • Bakas, S. et al. (2018). Identifying the Best Machine Learning Algorithms for Brain Tumor Segmentation, Progression Assessment, and Overall Survival Prediction in the BRATS Challenge. arXiv:1811.02629.
  • Heller, N. et al. (2021). The state of the art in kidney and kidney tumor segmentation in contrast-enhanced CT imaging: Results of the KiTS19 challenge. Medical Image Analysis 67, 101821. DOI: 10.1016/j.media.2020.101821.
  • Bilic, P. et al. (2023). The Liver Tumor Segmentation Benchmark (LiTS). Medical Image Analysis 84, 102680. DOI: 10.1016/j.media.2022.102680.
  • Antonelli, M. et al. (2022). The Medical Segmentation Decathlon. Nature Communications 13, 4128. DOI: 10.1038/s41467-022-30695-9.

Cross-references inside this book. Chapter 5 §4 (loss functions including Dice). Chapter 5 §5 (BatchNorm and GroupNorm). Chapter 10 §6 (transposed convolutions and the checkerboard artifact). Chapter 11 §5 (ResNet and add-style skip connections). Chapter 23 (Diffusion models — uses U-Net as the denoising backbone).

Loading comments...