Learning Objectives
By the end of this section, you will:
- Implement the complete attention module using PyTorch's MultiheadAttention
- Integrate attention with the BiLSTM encoder
- Trace the forward pass through all components
- Verify parameter counts match our calculations
- Understand the complete feature extraction pipeline
Why This Matters: This section synthesizes all attention concepts into working code. PyTorch's MultiheadAttention encapsulates the Q, K, V projections, scaled dot-product attention, and output projection in a single optimized module. Understanding how to use it correctly is essential for building Transformer-based architectures.
Complete Attention Module
We wrap PyTorch's MultiheadAttention with layer normalization, residual connection, and mean pooling for RUL prediction.
Module Definition
1class AttentionModule(nn.Module):
2 """
3 Multi-head self-attention with residual connection and layer norm.
4
5 Architecture:
6 Input β MHA β Dropout β Add(residual) β LayerNorm β MeanPool β Output
7
8 Args:
9 embed_dim: Input/output dimension (256 from BiLSTM)
10 num_heads: Number of attention heads (8)
11 dropout: Dropout rate for attention and residual (0.1)
12 """
13
14 def __init__(
15 self,
16 embed_dim: int = 256,
17 num_heads: int = 8,
18 dropout: float = 0.1
19 ):
20 super().__init__()
21
22 # Multi-head attention
23 # embed_dim must be divisible by num_heads
24 # Each head dimension: 256 / 8 = 32
25 self.attention = nn.MultiheadAttention(
26 embed_dim=embed_dim,
27 num_heads=num_heads,
28 dropout=dropout,
29 batch_first=True # Input: (batch, seq, embed)
30 )
31
32 # Layer normalization (post-norm)
33 self.layer_norm = nn.LayerNorm(embed_dim)
34
35 # Dropout for residual path
36 self.dropout = nn.Dropout(dropout)Key Parameters Explained
| Parameter | Value | Purpose |
|---|---|---|
| embed_dim | 256 | Matches BiLSTM output dimension |
| num_heads | 8 | Parallel attention mechanisms |
| dropout | 0.1 | Regularization during training |
| batch_first | True | Input shape (B, T, D) not (T, B, D) |
batch_first Parameter
PyTorch's MultiheadAttention defaults to (seq, batch, embed) ordering. Setting batch_first=True aligns with our CNN and LSTM outputs which use (batch, seq, embed) format.
Forward Method
Dimension Flow
1Input x: (B, 30, 256)
2 β
3MultiheadAttention(Q=x, K=x, V=x)
4 β
5attn_output: (B, 30, 256)
6 β
7Dropout(attn_output)
8 β
9Add: x + Dropout(attn_output)
10 β
11x: (B, 30, 256)
12 β
13LayerNorm(x)
14 β
15x: (B, 30, 256)
16 β
17Mean over dim=1
18 β
19Output: (B, 256)Integration with BiLSTM
The attention module receives the BiLSTM output and produces a single 256-dimensional vector for each sample.
Complete Encoder Architecture
1class AMNL_Encoder(nn.Module):
2 """
3 Complete encoder: CNN β BiLSTM β Attention β Output
4
5 This combines all feature extraction components:
6 1. CNN: Extract local patterns (17 sensors β 64 features)
7 2. BiLSTM: Encode temporal dependencies (64 β 256)
8 3. Attention: Focus on informative timesteps (256 β 256)
9 """
10
11 def __init__(
12 self,
13 input_dim: int = 17,
14 cnn_channels: list = [64, 128, 64],
15 lstm_hidden: int = 128,
16 lstm_layers: int = 2,
17 attention_heads: int = 8,
18 dropout: float = 0.1
19 ):
20 super().__init__()
21
22 # CNN Feature Extractor (from Chapter 5)
23 self.cnn = CNN_FeatureExtractor(
24 in_channels=input_dim,
25 channels=cnn_channels,
26 dropout=0.2
27 )
28
29 # BiLSTM Encoder (from Chapter 6)
30 self.bilstm = BiLSTM_Encoder(
31 input_size=cnn_channels[-1], # 64
32 hidden_size=lstm_hidden, # 128
33 num_layers=lstm_layers, # 2
34 dropout=0.3
35 )
36
37 # Attention Module (this chapter)
38 self.attention = AttentionModule(
39 embed_dim=lstm_hidden * 2, # 256 (bidirectional)
40 num_heads=attention_heads, # 8
41 dropout=dropout # 0.1
42 )
43
44 def forward(self, x: torch.Tensor) -> torch.Tensor:
45 """
46 Forward pass through complete encoder.
47
48 Args:
49 x: Raw sensor data (batch, seq_len, sensors)
50 Shape: (B, 30, 17)
51
52 Returns:
53 Encoded representation (batch, 256)
54 """
55 # CNN: (B, 30, 17) β (B, 30, 64)
56 x = self.cnn(x)
57
58 # BiLSTM: (B, 30, 64) β (B, 30, 256)
59 x = self.bilstm(x)
60
61 # Attention: (B, 30, 256) β (B, 256)
62 x = self.attention(x)
63
64 return xModule Connections
| Stage | Input Shape | Output Shape | Key Operation |
|---|---|---|---|
| CNN | (B, 30, 17) | (B, 30, 64) | Local pattern extraction |
| BiLSTM | (B, 30, 64) | (B, 30, 256) | Temporal encoding |
| Attention | (B, 30, 256) | (B, 256) | Weighted aggregation |
Dimension Compatibility
Each module's output dimension must match the next module's input. CNN outputs 64 (matching BiLSTM input_size), BiLSTM outputs 256 (matching attention embed_dim), and attention outputs 256 (ready for the prediction head).
Forward Pass Flow
Let us trace a concrete example through the complete encoder.
Attention Weight Visualization
To understand what the model focuses on, we can extract attention weights:
1def forward_with_attention(self, x: torch.Tensor):
2 """
3 Forward pass returning attention weights for visualization.
4 """
5 # Self-attention with weights
6 attn_output, attn_weights = self.attention.attention(
7 query=x, key=x, value=x,
8 need_weights=True,
9 average_attn_weights=True # Average across heads
10 )
11 # attn_weights: (B, 30, 30)
12 # Each row i shows how position i attends to all positions
13
14 # Complete forward pass
15 x = x + self.attention.dropout(attn_output)
16 x = self.attention.layer_norm(x)
17 x = x.mean(dim=1)
18
19 return x, attn_weightsInterpreting Attention Weights
Row i of the attention matrix shows which timesteps position i attends to. For RUL prediction, we often see higher attention on later timesteps (degradation signals) and transition points.
Parameter Verification
Let us verify our parameter calculations match the implementation.
Attention Module Parameters
1def count_attention_params(model):
2 """Count parameters in attention module."""
3 params = {}
4
5 for name, p in model.attention.named_parameters():
6 params[name] = p.numel()
7
8 for name, p in model.layer_norm.named_parameters():
9 params[name] = p.numel()
10
11 return params
12
13# Create module
14attn = AttentionModule(embed_dim=256, num_heads=8, dropout=0.1)
15params = count_attention_params(attn)
16
17# Expected output:
18# in_proj_weight: 196608 (3 Γ 256 Γ 256 for Q, K, V)
19# in_proj_bias: 768 (3 Γ 256)
20# out_proj.weight: 65536 (256 Γ 256)
21# out_proj.bias: 256
22# layer_norm.weight: 256
23# layer_norm.bias: 256Parameter Breakdown
| Component | Shape | Parameters |
|---|---|---|
| in_proj_weight (Q,K,V) | (768, 256) | 196,608 |
| in_proj_bias | (768,) | 768 |
| out_proj.weight | (256, 256) | 65,536 |
| out_proj.bias | (256,) | 256 |
| layer_norm.weight | (256,) | 256 |
| layer_norm.bias | (256,) | 256 |
| Total | 263,680 |
in_proj Packing
PyTorch packs the Q, K, V projections into a single in_proj matrix for efficiency. The shape (768, 256) = (3 Γ 256, 256) represents all three projections combined.
Complete Encoder Parameters
Summary
In this section, we implemented the complete attention module:
- AttentionModule: Wraps MultiheadAttention with residual and layer norm
- Integration: CNN β BiLSTM β Attention pipeline
- Forward pass: (B, 30, 17) β (B, 256) encoding
- Parameters: ~264K for attention, ~909K total encoder
| Component | Input | Output | Parameters |
|---|---|---|---|
| CNN | (B, 30, 17) | (B, 30, 64) | ~53K |
| BiLSTM | (B, 30, 64) | (B, 30, 256) | ~592K |
| Attention | (B, 30, 256) | (B, 256) | ~264K |
| Total Encoder | (B, 30, 17) | (B, 256) | ~909K |
Chapter Complete: We have designed and implemented the complete multi-head self-attention mechanism. The encoder now transforms raw sensor sequences into rich 256-dimensional representations, with attention focusing on the most predictive timesteps. The next chapter introduces the prediction head and loss function for RUL regression.
With the attention-enhanced encoder complete, we now design the output layer for RUL prediction.