Chapter 14
18 min read
Section 68 of 104

Training Loop Architecture

Complete Training Script

Learning Objectives

By the end of this section, you will:

  1. Understand the training loop architecture for dual-task learning
  2. Structure epochs and batches for efficient training
  3. Integrate all optimization components (optimizer, scheduler, EMA)
  4. Handle dual-task forward and backward passes
  5. Implement the complete training epoch function
Why This Matters: The training loop is the heart of deep learning. A well-structured loop integrates all the components we have builtβ€”model, loss functions, optimizer, schedulers, and training enhancementsβ€”into a cohesive system that produces state-of-the-art results.

Training Loop Overview

The training loop orchestrates all components of the learning process.

High-Level Structure

πŸ“text
1AMNL Training Loop Architecture:
2
3β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
4β”‚                    INITIALIZATION                            β”‚
5β”‚  β€’ Set random seeds for reproducibility                      β”‚
6β”‚  β€’ Load datasets (train, test)                               β”‚
7β”‚  β€’ Create model, optimizer, scheduler                        β”‚
8β”‚  β€’ Initialize EMA, early stopping                            β”‚
9β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
10                          β”‚
11                          β–Ό
12β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
13β”‚                    EPOCH LOOP                                β”‚
14β”‚  for epoch in range(max_epochs):                             β”‚
15β”‚      β”‚                                                       β”‚
16β”‚      β”œβ”€β”€ Apply learning rate warmup (if epoch < warmup)      β”‚
17β”‚      β”œβ”€β”€ Update adaptive weight decay                        β”‚
18β”‚      β”‚                                                       β”‚
19β”‚      β”œβ”€β”€ TRAINING PHASE ─────────────────────────────────┐   β”‚
20β”‚      β”‚   for batch in train_loader:                      β”‚   β”‚
21β”‚      β”‚       β€’ Forward pass (dual-task)                  β”‚   β”‚
22β”‚      β”‚       β€’ Compute AMNL loss                         β”‚   β”‚
23β”‚      β”‚       β€’ Backward pass (with scaling)              β”‚   β”‚
24β”‚      β”‚       β€’ Gradient clipping                         β”‚   β”‚
25β”‚      β”‚       β€’ Optimizer step (with accumulation)        β”‚   β”‚
26β”‚      β”‚       β€’ EMA update                                β”‚   β”‚
27β”‚      β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜   β”‚
28β”‚      β”‚                                                       β”‚
29β”‚      β”œβ”€β”€ EVALUATION PHASE ───────────────────────────────┐   β”‚
30β”‚      β”‚   β€’ Apply EMA weights                             β”‚   β”‚
31β”‚      β”‚   β€’ Evaluate on test set                          β”‚   β”‚
32β”‚      β”‚   β€’ Restore training weights                      β”‚   β”‚
33β”‚      β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜   β”‚
34β”‚      β”‚                                                       β”‚
35β”‚      β”œβ”€β”€ Update scheduler (based on validation metric)       β”‚
36β”‚      β”œβ”€β”€ Check early stopping                                β”‚
37β”‚      β”œβ”€β”€ Save best model checkpoint                          β”‚
38β”‚      └── Log progress                                        β”‚
39β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
40                          β”‚
41                          β–Ό
42β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
43β”‚                    FINALIZATION                              β”‚
44β”‚  β€’ Restore best model weights                                β”‚
45β”‚  β€’ Final evaluation                                          β”‚
46β”‚  β€’ Save model and training history                           β”‚
47β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Component Integration

ComponentRoleWhen Applied
AdamWWeight updatesAfter gradient accumulation
LR WarmupGradual LR increaseFirst 10 epochs
Cosine AnnealingLR decayAfter warmup
Weight DecayRegularizationEvery step (adaptive)
Gradient ClippingStabilityBefore optimizer step
EMAWeight smoothingAfter optimizer step
Early StoppingPrevent overfittingEnd of each epoch

Epoch Structure

Each training epoch consists of distinct phases with specific responsibilities.

Phase 1: Pre-Epoch Setup

🐍python
1# Pre-epoch setup
2def prepare_epoch(
3    epoch: int,
4    optimizer: torch.optim.Optimizer,
5    base_lr: float,
6    warmup_epochs: int = 10,
7    initial_wd: float = 1e-4
8):
9    """
10    Prepare optimizer for the current epoch.
11
12    Applies learning rate warmup and adaptive weight decay.
13    """
14    # Learning rate warmup
15    if epoch < warmup_epochs:
16        warmup_factor = 0.1 + 0.9 * (epoch / warmup_epochs)
17        for param_group in optimizer.param_groups:
18            param_group['lr'] = base_lr * warmup_factor
19
20    # Adaptive weight decay
21    if epoch < 100:
22        current_wd = initial_wd
23    elif epoch < 200:
24        current_wd = initial_wd * 0.5
25    else:
26        current_wd = initial_wd * 0.1
27
28    for param_group in optimizer.param_groups:
29        param_group['weight_decay'] = current_wd
30
31    return current_wd

Phase 2: Training Phase

The training phase processes all batches with gradient accumulation:

🐍python
1# Training phase structure
2model.train()
3train_loss = 0.0
4rul_loss_epoch = 0.0
5health_loss_epoch = 0.0
6
7for batch_idx, (sequences, targets) in enumerate(train_loader):
8    # Move to device
9    sequences = sequences.to(device)
10    rul_targets = targets.to(device).view(-1, 1)
11
12    # Generate health state labels
13    health_targets = rul_to_health_state(targets.numpy())
14    health_targets = torch.tensor(health_targets, dtype=torch.long).to(device)
15
16    # Zero gradients at accumulation boundary
17    if batch_idx % accumulation_steps == 0:
18        optimizer.zero_grad()
19
20    # Forward pass, loss computation, backward pass
21    # (detailed in next section)
22
23    # Optimizer step at accumulation boundary
24    if (batch_idx + 1) % accumulation_steps == 0:
25        # Gradient clipping, optimizer step, EMA update
26        pass
27
28    # Accumulate losses for logging
29    train_loss += total_loss.item()

Phase 3: Evaluation Phase

🐍python
1# Evaluation phase structure
2model.eval()
3
4# Apply EMA weights for evaluation
5if ema is not None:
6    ema.apply_shadow(model)
7
8# Comprehensive evaluation
9eval_results = evaluate_model_comprehensive(model, test_dataset, device)
10
11# Restore training weights
12if ema is not None:
13    ema.restore(model)
14
15# Extract key metrics
16rmse_last = eval_results['RMSE_last_cycle']
17nasa_score = eval_results['nasa_score_paper']

Phase 4: Post-Epoch Updates

🐍python
1# Post-epoch updates
2# Update learning rate scheduler (after warmup)
3if epoch >= warmup_epochs:
4    scheduler.step(rmse_last)  # ReduceLROnPlateau
5
6# Check early stopping
7if early_stopping(rmse_last, model):
8    print(f"Early stopping at epoch {epoch + 1}")
9    break
10
11# Save best model
12if rmse_last < best_rmse:
13    best_rmse = rmse_last
14    best_model_state = copy.deepcopy(model.state_dict())
15
16# Log progress
17print(f"Epoch {epoch + 1}: RMSE={rmse_last:.2f}, NASA={nasa_score:.1f}")

Batch Processing

Each batch undergoes a complete forward-backward cycle for dual-task learning.

Forward Pass: Dual-Task Predictions

🐍python
1# Forward pass with mixed precision
2with torch.cuda.amp.autocast():
3    # Get both predictions from dual-task model
4    rul_pred, health_pred = model(sequences)
5
6    # Compute individual task losses
7    rul_loss = weighted_mse_loss(rul_pred, rul_targets)
8    health_loss = cross_entropy_loss(health_pred, health_targets)
9
10    # EMA-based loss normalization for AMNL
11    if rul_loss_ema is None:
12        rul_loss_ema = rul_loss.item()
13        health_loss_ema = health_loss.item()
14    else:
15        rul_loss_ema = 0.9 * rul_loss_ema + 0.1 * rul_loss.item()
16        health_loss_ema = 0.9 * health_loss_ema + 0.1 * health_loss.item()
17
18    # Normalize losses by their EMA values
19    rul_scale = max(rul_loss_ema, 1e-6)
20    health_scale = max(health_loss_ema, 1e-6)
21
22    normalized_rul = rul_loss / rul_scale
23    normalized_health = health_loss / health_scale
24
25    # AMNL combined loss (equal weighting)
26    total_loss = 0.5 * normalized_rul + 0.5 * normalized_health

Backward Pass: Gradient Computation

🐍python
1# Backward pass with gradient scaling
2if scaler is not None:
3    # Scale loss for accumulation and mixed precision
4    scaled_loss = total_loss / accumulation_steps
5    scaler.scale(scaled_loss).backward()
6
7    # At accumulation boundary: clip, step, update
8    if (batch_idx + 1) % accumulation_steps == 0:
9        scaler.unscale_(optimizer)
10        grad_norm = torch.nn.utils.clip_grad_norm_(
11            model.parameters(),
12            max_norm=1.0
13        )
14        scaler.step(optimizer)
15        scaler.update()
16
17        # Update EMA
18        if ema is not None:
19            ema.update(model)
20else:
21    # Standard backward pass (CPU or MPS)
22    scaled_loss = total_loss / accumulation_steps
23    scaled_loss.backward()
24
25    if (batch_idx + 1) % accumulation_steps == 0:
26        grad_norm = torch.nn.utils.clip_grad_norm_(
27            model.parameters(),
28            max_norm=1.0
29        )
30        optimizer.step()
31
32        if ema is not None:
33            ema.update(model)

Implementation

Complete training epoch function.

Full Training Epoch

🐍python
1def train_epoch(
2    model: nn.Module,
3    train_loader: DataLoader,
4    optimizer: torch.optim.Optimizer,
5    rul_criterion,
6    health_criterion: nn.Module,
7    scaler: Optional[torch.cuda.amp.GradScaler],
8    ema: Optional[ExponentialMovingAverage],
9    rul_loss_ema: Optional[float],
10    health_loss_ema: Optional[float],
11    accumulation_steps: int = 2,
12    max_grad_norm: float = 1.0,
13    device: torch.device = torch.device('cuda')
14) -> dict:
15    """
16    Execute one training epoch.
17
18    Args:
19        model: Dual-task model
20        train_loader: Training data loader
21        optimizer: AdamW optimizer
22        rul_criterion: Weighted MSE loss
23        health_criterion: Cross-entropy loss
24        scaler: GradScaler for mixed precision
25        ema: EMA tracker
26        rul_loss_ema: Running EMA of RUL loss
27        health_loss_ema: Running EMA of health loss
28        accumulation_steps: Gradient accumulation steps
29        max_grad_norm: Maximum gradient norm
30        device: Training device
31
32    Returns:
33        Dictionary with epoch statistics
34    """
35    model.train()
36
37    total_loss = 0.0
38    total_rul_loss = 0.0
39    total_health_loss = 0.0
40    total_grad_norm = 0.0
41    num_updates = 0
42
43    for batch_idx, (sequences, targets) in enumerate(train_loader):
44        # Move data to device
45        sequences = sequences.to(device)
46        rul_targets = targets.to(device).view(-1, 1)
47
48        # Generate health state labels
49        health_targets = rul_to_health_state(targets.numpy())
50        health_targets = torch.tensor(health_targets, dtype=torch.long).to(device)
51
52        # Zero gradients at accumulation boundary
53        if batch_idx % accumulation_steps == 0:
54            optimizer.zero_grad()
55
56        # Forward pass with mixed precision
57        if scaler is not None:
58            with torch.cuda.amp.autocast():
59                rul_pred, health_pred = model(sequences)
60                rul_loss = rul_criterion(rul_pred, rul_targets)
61                health_loss = health_criterion(health_pred, health_targets)
62
63                # EMA-based normalization
64                if rul_loss_ema is None:
65                    rul_loss_ema = rul_loss.item()
66                    health_loss_ema = health_loss.item()
67                else:
68                    rul_loss_ema = 0.9 * rul_loss_ema + 0.1 * rul_loss.item()
69                    health_loss_ema = 0.9 * health_loss_ema + 0.1 * health_loss.item()
70
71                rul_scale = max(rul_loss_ema, 1e-6)
72                health_scale = max(health_loss_ema, 1e-6)
73
74                normalized_rul = rul_loss / rul_scale
75                normalized_health = health_loss / health_scale
76
77                # AMNL loss
78                combined_loss = 0.5 * normalized_rul + 0.5 * normalized_health
79
80            # Backward with scaling
81            scaled_loss = combined_loss / accumulation_steps
82            scaler.scale(scaled_loss).backward()
83
84            # Optimizer step at accumulation boundary
85            if (batch_idx + 1) % accumulation_steps == 0:
86                scaler.unscale_(optimizer)
87                grad_norm = torch.nn.utils.clip_grad_norm_(
88                    model.parameters(), max_norm=max_grad_norm
89                )
90                scaler.step(optimizer)
91                scaler.update()
92
93                if ema is not None:
94                    ema.update(model)
95
96                total_grad_norm += grad_norm.item()
97                num_updates += 1
98        else:
99            # Non-mixed precision path
100            rul_pred, health_pred = model(sequences)
101            rul_loss = rul_criterion(rul_pred, rul_targets)
102            health_loss = health_criterion(health_pred, health_targets)
103
104            # EMA normalization and AMNL
105            # (same as above)
106            combined_loss = 0.5 * (rul_loss / max(rul_loss_ema, 1e-6)) +                            0.5 * (health_loss / max(health_loss_ema, 1e-6))
107
108            scaled_loss = combined_loss / accumulation_steps
109            scaled_loss.backward()
110
111            if (batch_idx + 1) % accumulation_steps == 0:
112                grad_norm = torch.nn.utils.clip_grad_norm_(
113                    model.parameters(), max_norm=max_grad_norm
114                )
115                optimizer.step()
116
117                if ema is not None:
118                    ema.update(model)
119
120                total_grad_norm += grad_norm.item()
121                num_updates += 1
122
123        # Accumulate losses
124        total_loss += combined_loss.item()
125        total_rul_loss += rul_loss.item()
126        total_health_loss += health_loss.item()
127
128    return {
129        'loss': total_loss / len(train_loader),
130        'rul_loss': total_rul_loss / len(train_loader),
131        'health_loss': total_health_loss / len(train_loader),
132        'avg_grad_norm': total_grad_norm / max(num_updates, 1),
133        'rul_loss_ema': rul_loss_ema,
134        'health_loss_ema': health_loss_ema
135    }

Summary

In this section, we covered the training loop architecture:

  1. Three main phases: Initialization, epoch loop, finalization
  2. Epoch structure: Pre-epoch setup, training, evaluation, post-epoch
  3. Batch processing: Dual-task forward, AMNL loss, backward with accumulation
  4. Component integration: All enhancements work together
  5. Mixed precision: Optional GradScaler for CUDA GPUs
PhaseKey Operations
Pre-epochLR warmup, adaptive weight decay
TrainingForward, loss, backward, clip, step, EMA
EvaluationApply EMA, evaluate, restore
Post-epochScheduler step, early stopping, checkpoint
Looking Ahead: With the training loop structure defined, we next examine validation and model selectionβ€”how to evaluate models fairly and select the best checkpoint.

With the training loop architecture understood, we explore validation strategies.