Learning Objectives
By the end of this section, you will:
- Understand the training loop architecture for dual-task learning
- Structure epochs and batches for efficient training
- Integrate all optimization components (optimizer, scheduler, EMA)
- Handle dual-task forward and backward passes
- Implement the complete training epoch function
Why This Matters: The training loop is the heart of deep learning. A well-structured loop integrates all the components we have builtβmodel, loss functions, optimizer, schedulers, and training enhancementsβinto a cohesive system that produces state-of-the-art results.
Training Loop Overview
The training loop orchestrates all components of the learning process.
High-Level Structure
πtext
1AMNL Training Loop Architecture:
2
3βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
4β INITIALIZATION β
5β β’ Set random seeds for reproducibility β
6β β’ Load datasets (train, test) β
7β β’ Create model, optimizer, scheduler β
8β β’ Initialize EMA, early stopping β
9βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
10 β
11 βΌ
12βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
13β EPOCH LOOP β
14β for epoch in range(max_epochs): β
15β β β
16β βββ Apply learning rate warmup (if epoch < warmup) β
17β βββ Update adaptive weight decay β
18β β β
19β βββ TRAINING PHASE ββββββββββββββββββββββββββββββββββ β
20β β for batch in train_loader: β β
21β β β’ Forward pass (dual-task) β β
22β β β’ Compute AMNL loss β β
23β β β’ Backward pass (with scaling) β β
24β β β’ Gradient clipping β β
25β β β’ Optimizer step (with accumulation) β β
26β β β’ EMA update β β
27β β βββββββββββββββββββββββββββββββββββββββββββββββββ β
28β β β
29β βββ EVALUATION PHASE ββββββββββββββββββββββββββββββββ β
30β β β’ Apply EMA weights β β
31β β β’ Evaluate on test set β β
32β β β’ Restore training weights β β
33β β βββββββββββββββββββββββββββββββββββββββββββββββββ β
34β β β
35β βββ Update scheduler (based on validation metric) β
36β βββ Check early stopping β
37β βββ Save best model checkpoint β
38β βββ Log progress β
39βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
40 β
41 βΌ
42βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
43β FINALIZATION β
44β β’ Restore best model weights β
45β β’ Final evaluation β
46β β’ Save model and training history β
47βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββComponent Integration
| Component | Role | When Applied |
|---|---|---|
| AdamW | Weight updates | After gradient accumulation |
| LR Warmup | Gradual LR increase | First 10 epochs |
| Cosine Annealing | LR decay | After warmup |
| Weight Decay | Regularization | Every step (adaptive) |
| Gradient Clipping | Stability | Before optimizer step |
| EMA | Weight smoothing | After optimizer step |
| Early Stopping | Prevent overfitting | End of each epoch |
Epoch Structure
Each training epoch consists of distinct phases with specific responsibilities.
Phase 1: Pre-Epoch Setup
πpython
1# Pre-epoch setup
2def prepare_epoch(
3 epoch: int,
4 optimizer: torch.optim.Optimizer,
5 base_lr: float,
6 warmup_epochs: int = 10,
7 initial_wd: float = 1e-4
8):
9 """
10 Prepare optimizer for the current epoch.
11
12 Applies learning rate warmup and adaptive weight decay.
13 """
14 # Learning rate warmup
15 if epoch < warmup_epochs:
16 warmup_factor = 0.1 + 0.9 * (epoch / warmup_epochs)
17 for param_group in optimizer.param_groups:
18 param_group['lr'] = base_lr * warmup_factor
19
20 # Adaptive weight decay
21 if epoch < 100:
22 current_wd = initial_wd
23 elif epoch < 200:
24 current_wd = initial_wd * 0.5
25 else:
26 current_wd = initial_wd * 0.1
27
28 for param_group in optimizer.param_groups:
29 param_group['weight_decay'] = current_wd
30
31 return current_wdPhase 2: Training Phase
The training phase processes all batches with gradient accumulation:
πpython
1# Training phase structure
2model.train()
3train_loss = 0.0
4rul_loss_epoch = 0.0
5health_loss_epoch = 0.0
6
7for batch_idx, (sequences, targets) in enumerate(train_loader):
8 # Move to device
9 sequences = sequences.to(device)
10 rul_targets = targets.to(device).view(-1, 1)
11
12 # Generate health state labels
13 health_targets = rul_to_health_state(targets.numpy())
14 health_targets = torch.tensor(health_targets, dtype=torch.long).to(device)
15
16 # Zero gradients at accumulation boundary
17 if batch_idx % accumulation_steps == 0:
18 optimizer.zero_grad()
19
20 # Forward pass, loss computation, backward pass
21 # (detailed in next section)
22
23 # Optimizer step at accumulation boundary
24 if (batch_idx + 1) % accumulation_steps == 0:
25 # Gradient clipping, optimizer step, EMA update
26 pass
27
28 # Accumulate losses for logging
29 train_loss += total_loss.item()Phase 3: Evaluation Phase
πpython
1# Evaluation phase structure
2model.eval()
3
4# Apply EMA weights for evaluation
5if ema is not None:
6 ema.apply_shadow(model)
7
8# Comprehensive evaluation
9eval_results = evaluate_model_comprehensive(model, test_dataset, device)
10
11# Restore training weights
12if ema is not None:
13 ema.restore(model)
14
15# Extract key metrics
16rmse_last = eval_results['RMSE_last_cycle']
17nasa_score = eval_results['nasa_score_paper']Phase 4: Post-Epoch Updates
πpython
1# Post-epoch updates
2# Update learning rate scheduler (after warmup)
3if epoch >= warmup_epochs:
4 scheduler.step(rmse_last) # ReduceLROnPlateau
5
6# Check early stopping
7if early_stopping(rmse_last, model):
8 print(f"Early stopping at epoch {epoch + 1}")
9 break
10
11# Save best model
12if rmse_last < best_rmse:
13 best_rmse = rmse_last
14 best_model_state = copy.deepcopy(model.state_dict())
15
16# Log progress
17print(f"Epoch {epoch + 1}: RMSE={rmse_last:.2f}, NASA={nasa_score:.1f}")Batch Processing
Each batch undergoes a complete forward-backward cycle for dual-task learning.
Forward Pass: Dual-Task Predictions
πpython
1# Forward pass with mixed precision
2with torch.cuda.amp.autocast():
3 # Get both predictions from dual-task model
4 rul_pred, health_pred = model(sequences)
5
6 # Compute individual task losses
7 rul_loss = weighted_mse_loss(rul_pred, rul_targets)
8 health_loss = cross_entropy_loss(health_pred, health_targets)
9
10 # EMA-based loss normalization for AMNL
11 if rul_loss_ema is None:
12 rul_loss_ema = rul_loss.item()
13 health_loss_ema = health_loss.item()
14 else:
15 rul_loss_ema = 0.9 * rul_loss_ema + 0.1 * rul_loss.item()
16 health_loss_ema = 0.9 * health_loss_ema + 0.1 * health_loss.item()
17
18 # Normalize losses by their EMA values
19 rul_scale = max(rul_loss_ema, 1e-6)
20 health_scale = max(health_loss_ema, 1e-6)
21
22 normalized_rul = rul_loss / rul_scale
23 normalized_health = health_loss / health_scale
24
25 # AMNL combined loss (equal weighting)
26 total_loss = 0.5 * normalized_rul + 0.5 * normalized_healthBackward Pass: Gradient Computation
πpython
1# Backward pass with gradient scaling
2if scaler is not None:
3 # Scale loss for accumulation and mixed precision
4 scaled_loss = total_loss / accumulation_steps
5 scaler.scale(scaled_loss).backward()
6
7 # At accumulation boundary: clip, step, update
8 if (batch_idx + 1) % accumulation_steps == 0:
9 scaler.unscale_(optimizer)
10 grad_norm = torch.nn.utils.clip_grad_norm_(
11 model.parameters(),
12 max_norm=1.0
13 )
14 scaler.step(optimizer)
15 scaler.update()
16
17 # Update EMA
18 if ema is not None:
19 ema.update(model)
20else:
21 # Standard backward pass (CPU or MPS)
22 scaled_loss = total_loss / accumulation_steps
23 scaled_loss.backward()
24
25 if (batch_idx + 1) % accumulation_steps == 0:
26 grad_norm = torch.nn.utils.clip_grad_norm_(
27 model.parameters(),
28 max_norm=1.0
29 )
30 optimizer.step()
31
32 if ema is not None:
33 ema.update(model)Implementation
Complete training epoch function.
Full Training Epoch
πpython
1def train_epoch(
2 model: nn.Module,
3 train_loader: DataLoader,
4 optimizer: torch.optim.Optimizer,
5 rul_criterion,
6 health_criterion: nn.Module,
7 scaler: Optional[torch.cuda.amp.GradScaler],
8 ema: Optional[ExponentialMovingAverage],
9 rul_loss_ema: Optional[float],
10 health_loss_ema: Optional[float],
11 accumulation_steps: int = 2,
12 max_grad_norm: float = 1.0,
13 device: torch.device = torch.device('cuda')
14) -> dict:
15 """
16 Execute one training epoch.
17
18 Args:
19 model: Dual-task model
20 train_loader: Training data loader
21 optimizer: AdamW optimizer
22 rul_criterion: Weighted MSE loss
23 health_criterion: Cross-entropy loss
24 scaler: GradScaler for mixed precision
25 ema: EMA tracker
26 rul_loss_ema: Running EMA of RUL loss
27 health_loss_ema: Running EMA of health loss
28 accumulation_steps: Gradient accumulation steps
29 max_grad_norm: Maximum gradient norm
30 device: Training device
31
32 Returns:
33 Dictionary with epoch statistics
34 """
35 model.train()
36
37 total_loss = 0.0
38 total_rul_loss = 0.0
39 total_health_loss = 0.0
40 total_grad_norm = 0.0
41 num_updates = 0
42
43 for batch_idx, (sequences, targets) in enumerate(train_loader):
44 # Move data to device
45 sequences = sequences.to(device)
46 rul_targets = targets.to(device).view(-1, 1)
47
48 # Generate health state labels
49 health_targets = rul_to_health_state(targets.numpy())
50 health_targets = torch.tensor(health_targets, dtype=torch.long).to(device)
51
52 # Zero gradients at accumulation boundary
53 if batch_idx % accumulation_steps == 0:
54 optimizer.zero_grad()
55
56 # Forward pass with mixed precision
57 if scaler is not None:
58 with torch.cuda.amp.autocast():
59 rul_pred, health_pred = model(sequences)
60 rul_loss = rul_criterion(rul_pred, rul_targets)
61 health_loss = health_criterion(health_pred, health_targets)
62
63 # EMA-based normalization
64 if rul_loss_ema is None:
65 rul_loss_ema = rul_loss.item()
66 health_loss_ema = health_loss.item()
67 else:
68 rul_loss_ema = 0.9 * rul_loss_ema + 0.1 * rul_loss.item()
69 health_loss_ema = 0.9 * health_loss_ema + 0.1 * health_loss.item()
70
71 rul_scale = max(rul_loss_ema, 1e-6)
72 health_scale = max(health_loss_ema, 1e-6)
73
74 normalized_rul = rul_loss / rul_scale
75 normalized_health = health_loss / health_scale
76
77 # AMNL loss
78 combined_loss = 0.5 * normalized_rul + 0.5 * normalized_health
79
80 # Backward with scaling
81 scaled_loss = combined_loss / accumulation_steps
82 scaler.scale(scaled_loss).backward()
83
84 # Optimizer step at accumulation boundary
85 if (batch_idx + 1) % accumulation_steps == 0:
86 scaler.unscale_(optimizer)
87 grad_norm = torch.nn.utils.clip_grad_norm_(
88 model.parameters(), max_norm=max_grad_norm
89 )
90 scaler.step(optimizer)
91 scaler.update()
92
93 if ema is not None:
94 ema.update(model)
95
96 total_grad_norm += grad_norm.item()
97 num_updates += 1
98 else:
99 # Non-mixed precision path
100 rul_pred, health_pred = model(sequences)
101 rul_loss = rul_criterion(rul_pred, rul_targets)
102 health_loss = health_criterion(health_pred, health_targets)
103
104 # EMA normalization and AMNL
105 # (same as above)
106 combined_loss = 0.5 * (rul_loss / max(rul_loss_ema, 1e-6)) + 0.5 * (health_loss / max(health_loss_ema, 1e-6))
107
108 scaled_loss = combined_loss / accumulation_steps
109 scaled_loss.backward()
110
111 if (batch_idx + 1) % accumulation_steps == 0:
112 grad_norm = torch.nn.utils.clip_grad_norm_(
113 model.parameters(), max_norm=max_grad_norm
114 )
115 optimizer.step()
116
117 if ema is not None:
118 ema.update(model)
119
120 total_grad_norm += grad_norm.item()
121 num_updates += 1
122
123 # Accumulate losses
124 total_loss += combined_loss.item()
125 total_rul_loss += rul_loss.item()
126 total_health_loss += health_loss.item()
127
128 return {
129 'loss': total_loss / len(train_loader),
130 'rul_loss': total_rul_loss / len(train_loader),
131 'health_loss': total_health_loss / len(train_loader),
132 'avg_grad_norm': total_grad_norm / max(num_updates, 1),
133 'rul_loss_ema': rul_loss_ema,
134 'health_loss_ema': health_loss_ema
135 }Summary
In this section, we covered the training loop architecture:
- Three main phases: Initialization, epoch loop, finalization
- Epoch structure: Pre-epoch setup, training, evaluation, post-epoch
- Batch processing: Dual-task forward, AMNL loss, backward with accumulation
- Component integration: All enhancements work together
- Mixed precision: Optional GradScaler for CUDA GPUs
| Phase | Key Operations |
|---|---|
| Pre-epoch | LR warmup, adaptive weight decay |
| Training | Forward, loss, backward, clip, step, EMA |
| Evaluation | Apply EMA, evaluate, restore |
| Post-epoch | Scheduler step, early stopping, checkpoint |
Looking Ahead: With the training loop structure defined, we next examine validation and model selectionβhow to evaluate models fairly and select the best checkpoint.
With the training loop architecture understood, we explore validation strategies.