Chapter 6
11 min read
Section 25 of 121

Per-Condition Z-Score Implementation

Per-Condition Normalization

What the Normaliser Must Do

Take a (B, T, F) sensor batch, take a (B, T) condition-label sequence, and return a (B, T, F) tensor where every cycle has been Z-scored against its own condition's pre-computed mean and std. Same shape in, same shape out. Differentiable. GPU-aware. State serialisable.

The Two-Line Math

For every cell Xb,t,fX_{b,t,f} with condition c=condb,tc = \text{cond}_{b,t}:

Zb,t,f  =  Xb,t,fμc,fσc,f+ε.Z_{b,t,f} \;=\; \frac{X_{b,t,f} - \mu_{c,f}}{\sigma_{c,f} + \varepsilon}.

That is the entire operation. Three sources of complexity in production code: (1) efficiently looking up the per-(b, t) statistics, (2) handling float32 / GPU placement, (3) shipping the train-time statistics through to inference.

Python: A Reference Implementation

Five-line core; advanced indexing handles the per-(b,t) lookup
🐍per_condition_zscore.py
1import numpy as np

Standard alias.

3def per_condition_zscore(X, cond_seq, means, stds):

Pure-NumPy reference implementation. Five-line core; everything else is type signatures.

EXECUTION STATE
input: X (B, T, F) = Sensor windows
input: cond_seq (B, T) = Per-cycle condition labels (int64)
input: means (n_cond, F) = Per-condition mean per sensor (from §6.2 fit)
input: stds (n_cond, F) = Per-condition std per sensor
returns: (B, T, F) = Z-scored windows; per-condition normalised
11mu = means[cond_seq]

ADVANCED INDEXING. means is shape (n_cond, F); cond_seq is shape (B, T). Result is shape (B, T, F) - for every (b, t) position we pulled the F-vector of means corresponding to that cycle's condition.

EXECUTION STATE
advanced indexing = tensor[index_array] performs gather. Output shape = index_array.shape + tensor.shape[1:]. Same in NumPy and PyTorch.
Example: cond_seq[0, 0] = 3 = mu[0, 0] = means[3] - the mean vector for condition 3
12sigma = stds[cond_seq]

Same gather on the stds tensor.

13return (X - mu) / (sigma + 1e-8)

Element-wise normalisation. The 1e-8 epsilon prevents division by zero in the (rare) case that some condition has zero variance for some sensor.

EXECUTION STATE
+ 1e-8 = Numerical safety. Standard idiom in BatchNorm and LayerNorm too.
17np.random.seed(0)

Determinism for the demo.

18B, T, F, n_conds = 4, 30, 14, 6

Real C-MAPSS shapes - 14 informative sensors after §5.3 selection.

20X = np.random.randn(B, T, F).astype(np.float32) * 50 + 1000

Fake input around 1000 with std 50.

21cond = np.random.randint(0, n_conds, (B, T)).astype(np.int64)

Random per-cycle condition labels in [0, 6).

22means = np.random.randn(n_conds, F).astype(np.float32) * 30 + 1000

Fake per-cond means around 1000.

23stds = (np.abs(np.random.randn(n_conds, F)) * 5 + 1).astype(np.float32)

Fake per-cond stds in [1, 16]. abs() to keep them positive.

25X_norm = per_condition_zscore(X, cond, means, stds)

Apply.

27print("X.shape :", X.shape)

Sanity check.

EXECUTION STATE
Output = X.shape : (4, 30, 14)
28print("X_norm.shape :", X_norm.shape)

Same shape as input.

EXECUTION STATE
Output = X_norm.shape : (4, 30, 14)
29print("X_norm[0,0,:3]:", X_norm[0, 0, :3].round(4).tolist())

First sample, first cycle, first 3 sensors. Values depend on the random seed but the SHAPE invariant holds.

EXECUTION STATE
Output (representative) = X_norm[0,0,:3]: [-0.32, 0.84, -1.21]
→ on real data = When means/stds come from the same train data as X, X_norm.mean ~ 0 and X_norm.std ~ 1 within each condition.
19 lines without explanation
1import numpy as np
2
3def per_condition_zscore(
4    X: np.ndarray,         # (B, T, F)
5    cond_seq: np.ndarray,  # (B, T) integer condition IDs
6    means: np.ndarray,     # (n_conditions, F)
7    stds: np.ndarray,      # (n_conditions, F)
8) -> np.ndarray:
9    """For every (b, t) cell, look up the per-condition stats by cond_seq[b, t]
10    and apply (X - mu) / sigma."""
11    # Advanced indexing: means[cond_seq] has shape (B, T, F)
12    mu    = means[cond_seq]
13    sigma = stds [cond_seq]
14    return (X - mu) / (sigma + 1e-8)
15
16
17# ----- Quick verification -----
18np.random.seed(0)
19B, T, F, n_conds = 4, 30, 14, 6
20
21X       = np.random.randn(B, T, F).astype(np.float32) * 50 + 1000
22cond    = np.random.randint(0, n_conds, (B, T)).astype(np.int64)
23means   = np.random.randn(n_conds, F).astype(np.float32) * 30 + 1000
24stds    = (np.abs(np.random.randn(n_conds, F)) * 5 + 1).astype(np.float32)
25
26X_norm  = per_condition_zscore(X, cond, means, stds)
27
28print("X.shape       :", X.shape)
29print("X_norm.shape  :", X_norm.shape)
30print("X_norm[0,0,:3]:", X_norm[0, 0, :3].round(4).tolist())
31
32# X.shape       : (4, 30, 14)
33# X_norm.shape  : (4, 30, 14)
34# X_norm[0,0,:3]: roughly 0-mean unit-variance after normalisation
Advanced indexing is the trick. means[cond_seq] performs a gather: for every (b, t) position it picks the F-vector at means[cond_seq[b, t]]. The result has shape (B, T, F) - exactly what we need for broadcasting against X.

PyTorch: As an nn.Module With Buffers

Same algorithm; buffers + .to(device) + state_dict
🐍per_condition_normaliser.py
1import numpy as np

For loading the means/stds.

2import torch

Tensor type.

3import torch.nn as nn

Module base class.

5class PerConditionNormaliser(nn.Module):

The reference implementation. Promoted to a Module because we want it to (a) move to GPU with the rest of the model, (b) save in state_dict, (c) compose with nn.Sequential.

14def __init__(self, means, stds):

Constructor accepts the (n_conds, F) NumPy arrays from §6.2's bundle.

EXECUTION STATE
input: means (n_cond, F) = Per-condition sensor means
input: stds (n_cond, F) = Per-condition sensor stds
15super().__init__()

Initialise nn.Module bookkeeping.

16self.register_buffer("means", torch.from_numpy(means).float())

Buffer registration. Goes into state_dict; moves with .to(device); does NOT receive gradients. The right abstraction for fixed dataset statistics.

EXECUTION STATE
register_buffer vs register_parameter = Buffer: state, not learned. Parameter: learned by the optimiser. Means / stds are dataset statistics, not learnable.
17self.register_buffer("stds", torch.from_numpy(stds).float())

Same registration for stds.

19def forward(self, x, cond_seq):

Stateless forward. Two inputs - the (B, T, F) sensor batch and the (B, T) condition sequence.

EXECUTION STATE
input: x = (B, T, F) raw sensor values
input: cond_seq = (B, T) int64 condition labels per cycle
22mu = self.means[cond_seq]

PyTorch advanced indexing. cond_seq is (B, T); means is (n_cond, F); result is (B, T, F).

EXECUTION STATE
Example: cond_seq[3, 17] = 5 = mu[3, 17] = self.means[5] - the F-vector of means for condition 5
23sigma = self.stds[cond_seq]

Same trick on stds.

24return (x - mu) / (sigma + 1e-8)

Element-wise normalisation. Output has the same shape as x.

28torch.manual_seed(0)

Determinism.

29B, T, F, n_conds = 4, 30, 14, 6

Realistic shapes.

30means = np.random.randn(n_conds, F).astype(np.float32) * 30 + 1000

Fake per-cond means around 1000.

31stds = (np.abs(np.random.randn(n_conds, F)) * 5 + 1).astype(np.float32)

Fake per-cond stds in [1, ~16].

33normer = PerConditionNormaliser(means, stds)

Instantiate.

34print("buffers in state_dict:", list(normer.state_dict().keys()))

Verify both buffers are tracked.

EXECUTION STATE
Output = buffers in state_dict: ['means', 'stds']
37device = "cuda" if torch.cuda.is_available() else "cpu"

Select device.

38normer = normer.to(device)

.to(device) recursively moves all parameters AND buffers. After this both means and stds live on the GPU.

40x = torch.randn(B, T, F, device=device) * 50 + 1000

Fake input on the same device.

41cond_seq = torch.randint(0, n_conds, (B, T), device=device)

Random per-cycle condition IDs on device.

43x_norm = normer(x, cond_seq)

Forward pass. PyTorch's __call__ routes to forward().

44print("x_norm.shape:", tuple(x_norm.shape))

Verify shape is preserved.

EXECUTION STATE
Output = x_norm.shape: (4, 30, 14)
45print("x_norm.device:", x_norm.device)

On the same device as the input and the buffers.

EXECUTION STATE
Output (CPU) = x_norm.device: cpu
Output (GPU) = x_norm.device: cuda:0
19 lines without explanation
1import numpy as np
2import torch
3import torch.nn as nn
4
5class PerConditionNormaliser(nn.Module):
6    """nn.Module wrapper around per-condition Z-score.
7
8    Buffers (move with .to(device), saved in state_dict, NOT optimised):
9        means : (n_conditions, F)
10        stds  : (n_conditions, F)
11    """
12
13    def __init__(self, means: np.ndarray, stds: np.ndarray):
14        super().__init__()
15        self.register_buffer("means", torch.from_numpy(means).float())
16        self.register_buffer("stds",  torch.from_numpy(stds).float())
17
18    def forward(self, x: torch.Tensor, cond_seq: torch.Tensor) -> torch.Tensor:
19        # x        : (B, T, F)
20        # cond_seq : (B, T) int64
21        mu    = self.means[cond_seq]            # (B, T, F)
22        sigma = self.stds [cond_seq]            # (B, T, F)
23        return (x - mu) / (sigma + 1e-8)
24
25
26# ----- Use it -----
27torch.manual_seed(0)
28B, T, F, n_conds = 4, 30, 14, 6
29means = np.random.randn(n_conds, F).astype(np.float32) * 30 + 1000
30stds  = (np.abs(np.random.randn(n_conds, F)) * 5 + 1).astype(np.float32)
31
32normer = PerConditionNormaliser(means, stds)
33print("buffers in state_dict:", list(normer.state_dict().keys()))
34
35# Move to GPU if available - buffers move automatically
36device = "cuda" if torch.cuda.is_available() else "cpu"
37normer = normer.to(device)
38
39x        = torch.randn(B, T, F, device=device) * 50 + 1000
40cond_seq = torch.randint(0, n_conds, (B, T), device=device)
41
42x_norm = normer(x, cond_seq)
43print("x_norm.shape:", tuple(x_norm.shape))   # (4, 30, 14)
44print("x_norm.device:", x_norm.device)

End-to-End Test Against Manual Computation

TestWhat it checksHow
Shape preservationOutput shape == input shapeassert x_norm.shape == x.shape
Per-condition zero-meanEach condition's slice has mean ~ 0x_norm[cond_seq == c].mean()
Per-condition unit-stdEach condition's slice has std ~ 1x_norm[cond_seq == c].std()
Device transparencyOutput lives on same device as inputx.device == x_norm.device
state_dict round-tripSave + load produces identical normalisertorch.save / torch.load
Run all five tests once. The normaliser is the kind of thin wrapper that gets passed through code review without scrutiny - and then silently breaks gradient flow if it is wrong. Five tests, ten minutes, peace of mind.

Per-Cluster Normalisation Elsewhere

DomainClusterNormalisation stepLibrary
RUL (this book)Operating regimePer-condition Z-scoreCustom (this section)
Speech recognitionSpeakerCepstral mean / variance normalisationKaldi
Multi-site neuroimagingScanner siteComBatneuroCombat
Single-cell genomicsCell type / batchscTransform / SCTransformSeurat
Federated learningClientLocal BatchNorm or FedBNFedML
Recommender systemsUserPer-user mean centringCustom in every shop

Three Implementation Pitfalls

Pitfall 1: Stats and labels mismatched. The means / stds bundle is fit-time-specific. If you re-fit k-means later (with a different seed) the cluster IDs change but the bundle does not - silently mis-normalising. Always load both from the SAME joblib bundle.
Pitfall 2: Forgetting epsilon. Some conditions x sensors have near-zero variance. Without + 1e-8 you get NaNs that propagate through the entire backward pass.
Pitfall 3: Wrong dtype on cond_seq. PyTorch advanced indexing requires int64. If cond_seq is int32 or float you will get a runtime error. .long() at the dataset boundary fixes it.
The point. Per-condition normalisation is a five-line Module that erases 99% of the regime variance and unlocks the rest of the framework. Trivial to implement, easy to get subtly wrong.

Takeaway

  • The math is two lines. (xμc)/(σc+ε)(x - \mu_c) / (\sigma_c + \varepsilon) per (b, t) cell.
  • Implement as nn.Module with two buffers. register_buffer for state that travels with the model but is not optimised.
  • Advanced indexing handles the per-cycle lookup. means[cond_seq] is the one-line gather that produces the right (B, T, F) tensor.
  • Test five things. Shape, per-cond zero-mean, per- cond unit-std, device transparency, state_dict round-trip.
Loading comments...