What the Normaliser Must Do
Take a (B, T, F) sensor batch, take a (B, T) condition-label sequence, and return a (B, T, F) tensor where every cycle has been Z-scored against its own condition's pre-computed mean and std. Same shape in, same shape out. Differentiable. GPU-aware. State serialisable.
The Two-Line Math
For every cell with condition :
That is the entire operation. Three sources of complexity in production code: (1) efficiently looking up the per-(b, t) statistics, (2) handling float32 / GPU placement, (3) shipping the train-time statistics through to inference.
Python: A Reference Implementation
Five-line core; advanced indexing handles the per-(b,t) lookup
🐍per_condition_zscore.py
Advanced indexing is the trick.
means[cond_seq] performs a gather: for every (b, t) position it picks the F-vector at means[cond_seq[b, t]]. The result has shape (B, T, F) - exactly what we need for broadcasting against X.PyTorch: As an nn.Module With Buffers
Same algorithm; buffers + .to(device) + state_dict
🐍per_condition_normaliser.py
End-to-End Test Against Manual Computation
| Test | What it checks | How |
|---|---|---|
| Shape preservation | Output shape == input shape | assert x_norm.shape == x.shape |
| Per-condition zero-mean | Each condition's slice has mean ~ 0 | x_norm[cond_seq == c].mean() |
| Per-condition unit-std | Each condition's slice has std ~ 1 | x_norm[cond_seq == c].std() |
| Device transparency | Output lives on same device as input | x.device == x_norm.device |
| state_dict round-trip | Save + load produces identical normaliser | torch.save / torch.load |
Run all five tests once. The normaliser is the kind of thin wrapper that gets passed through code review without scrutiny - and then silently breaks gradient flow if it is wrong. Five tests, ten minutes, peace of mind.
Per-Cluster Normalisation Elsewhere
| Domain | Cluster | Normalisation step | Library |
|---|---|---|---|
| RUL (this book) | Operating regime | Per-condition Z-score | Custom (this section) |
| Speech recognition | Speaker | Cepstral mean / variance normalisation | Kaldi |
| Multi-site neuroimaging | Scanner site | ComBat | neuroCombat |
| Single-cell genomics | Cell type / batch | scTransform / SCTransform | Seurat |
| Federated learning | Client | Local BatchNorm or FedBN | FedML |
| Recommender systems | User | Per-user mean centring | Custom in every shop |
Three Implementation Pitfalls
Pitfall 1: Stats and labels mismatched. The means / stds bundle is fit-time-specific. If you re-fit k-means later (with a different seed) the cluster IDs change but the bundle does not - silently mis-normalising. Always load both from the SAME joblib bundle.
Pitfall 2: Forgetting epsilon. Some conditions x sensors have near-zero variance. Without
+ 1e-8 you get NaNs that propagate through the entire backward pass.Pitfall 3: Wrong dtype on cond_seq. PyTorch advanced indexing requires int64. If cond_seq is int32 or float you will get a runtime error.
.long() at the dataset boundary fixes it.The point. Per-condition normalisation is a five-line Module that erases 99% of the regime variance and unlocks the rest of the framework. Trivial to implement, easy to get subtly wrong.
Takeaway
- The math is two lines. per (b, t) cell.
- Implement as nn.Module with two buffers. register_buffer for state that travels with the model but is not optimised.
- Advanced indexing handles the per-cycle lookup.
means[cond_seq]is the one-line gather that produces the right (B, T, F) tensor. - Test five things. Shape, per-cond zero-mean, per- cond unit-std, device transparency, state_dict round-trip.