Learning Objectives
By the end of this section, you will be able to:
- Identify class imbalance: Recognize when a dataset has imbalanced class distributions and measure the severity
- Apply data-level solutions: Use oversampling, undersampling, and weighted sampling strategies in PyTorch
- Use algorithm-level solutions: Implement class-weighted loss functions and focal loss
- Choose appropriate metrics: Select evaluation metrics that properly assess performance on imbalanced data
What is Class Imbalance?
Class imbalance occurs when one or more classes in a classification dataset have significantly fewer samples than others. This is extremely common in real-world applications:
| Domain | Minority Class | Typical Imbalance Ratio |
|---|---|---|
| Fraud Detection | Fraudulent transactions | 1:1000 to 1:10000 |
| Medical Diagnosis | Rare diseases | 1:100 to 1:1000 |
| Spam Detection | Spam emails (or ham) | Varies, often 1:10 |
| Defect Detection | Defective items | 1:100 to 1:10000 |
| Churn Prediction | Churning customers | 1:5 to 1:20 |
Real-World Example: In credit card fraud detection, less than 0.1% of transactions are fraudulent. If you have 1 million transactions, only ~1,000 might be fraud. Training a model on this data without handling the imbalance will lead to poor fraud detection.
Why Imbalance Matters
Standard training on imbalanced data leads to models that ignore minority classes. Here's why:
The Accuracy Trap
Consider a fraud detection dataset with 99% legitimate and 1% fraudulent transactions:
1# A model that ALWAYS predicts "legitimate"
2# achieves 99% accuracy but catches ZERO fraud!
3
4y_true = [0] * 990 + [1] * 10 # 99% class 0, 1% class 1
5y_pred = [0] * 1000 # Always predict class 0
6
7accuracy = sum(p == t for p, t in zip(y_pred, y_true)) / len(y_true)
8print(f"Accuracy: {accuracy:.1%}") # 99.0% - looks great!
9
10fraud_caught = sum(1 for p, t in zip(y_pred, y_true) if p == 1 and t == 1)
11print(f"Fraud caught: {fraud_caught}") # 0 - completely useless!Never Trust Accuracy Alone
Gradient Domination
During training, the loss is averaged across all samples. With imbalanced data:
- The majority class contributes most of the gradient signal
- The model learns to minimize majority class errors
- Minority class patterns get "drowned out"
- The model becomes biased toward predicting the majority class
Measuring Imbalance
Before addressing imbalance, you need to quantify it:
Imbalance Ratio
1import numpy as np
2from collections import Counter
3
4def analyze_class_distribution(labels):
5 """Analyze and report class distribution statistics."""
6 counts = Counter(labels)
7 total = len(labels)
8
9 print("Class Distribution:")
10 for class_id, count in sorted(counts.items()):
11 pct = count / total * 100
12 print(f" Class {class_id}: {count:,} samples ({pct:.2f}%)")
13
14 # Calculate imbalance ratio
15 max_count = max(counts.values())
16 min_count = min(counts.values())
17 imbalance_ratio = max_count / min_count
18
19 print(f"\nImbalance Ratio: {imbalance_ratio:.1f}:1")
20
21 # Severity classification
22 if imbalance_ratio < 3:
23 severity = "Mild"
24 elif imbalance_ratio < 10:
25 severity = "Moderate"
26 elif imbalance_ratio < 100:
27 severity = "Significant"
28 else:
29 severity = "Extreme"
30
31 print(f"Severity: {severity}")
32 return counts
33
34# Example usage
35labels = [0] * 9000 + [1] * 900 + [2] * 100
36analyze_class_distribution(labels)| Imbalance Ratio | Severity | Recommended Actions |
|---|---|---|
| < 3:1 | Mild | Standard training often works; monitor metrics |
| 3:1 - 10:1 | Moderate | Class weights usually sufficient |
| 10:1 - 100:1 | Significant | Combine sampling + weighted loss |
| > 100:1 | Extreme | Aggressive techniques needed; consider problem reformulation |
Data-Level Solutions
Data-level solutions modify the training dataset to achieve better balance.
Oversampling
Oversampling increases minority class representation by duplicating or generating new samples.
Random Oversampling
The simplest approach: randomly duplicate minority class samples.
1from torch.utils.data import WeightedRandomSampler, DataLoader
2import torch
3
4def create_balanced_sampler(labels):
5 """Create a sampler that balances class frequencies."""
6 # Count samples per class
7 class_counts = torch.bincount(torch.tensor(labels))
8
9 # Compute weight for each sample (inverse of class frequency)
10 weights = 1.0 / class_counts[labels]
11
12 # Create sampler
13 sampler = WeightedRandomSampler(
14 weights=weights,
15 num_samples=len(labels),
16 replacement=True # Allow sampling same item multiple times
17 )
18 return sampler
19
20# Example usage
21labels = [0] * 900 + [1] * 100 # 9:1 imbalance
22sampler = create_balanced_sampler(labels)
23
24# This DataLoader will sample each class equally often
25dataloader = DataLoader(
26 dataset,
27 batch_size=32,
28 sampler=sampler # Replaces shuffle=True
29)Oversampling Trade-off
SMOTE (Synthetic Minority Over-sampling Technique)
SMOTE creates synthetic minority samples by interpolating between existing ones:
Where is a minority sample and is one of its k-nearest neighbors.
1# Note: SMOTE works on feature vectors, not raw images
2# Use imbalanced-learn library for production SMOTE
3
4from imblearn.over_sampling import SMOTE
5
6# Assume X is feature matrix, y is labels
7smote = SMOTE(random_state=42)
8X_resampled, y_resampled = smote.fit_resample(X, y)
9
10print(f"Original: {len(X)} samples")
11print(f"After SMOTE: {len(X_resampled)} samples")SMOTE Limitations
Undersampling
Undersampling reduces majority class samples to match minority class size.
1import numpy as np
2from torch.utils.data import Subset
3
4def create_undersampled_dataset(dataset, labels):
5 """Undersample majority class to match minority class."""
6 labels = np.array(labels)
7 classes, counts = np.unique(labels, return_counts=True)
8 min_count = counts.min()
9
10 selected_indices = []
11 for cls in classes:
12 cls_indices = np.where(labels == cls)[0]
13 # Randomly select min_count samples from this class
14 selected = np.random.choice(cls_indices, min_count, replace=False)
15 selected_indices.extend(selected)
16
17 return Subset(dataset, selected_indices)
18
19# Example: 9000 majority, 100 minority -> 100 each
20undersampled = create_undersampled_dataset(dataset, labels)
21print(f"Original: {len(dataset)}, Undersampled: {len(undersampled)}")Information Loss
Choosing Between Oversampling and Undersampling
| Situation | Recommendation | Reasoning |
|---|---|---|
| Large dataset, moderate imbalance | Undersampling | Plenty of data; training faster |
| Small dataset | Oversampling + augmentation | Can't afford to lose data |
| Extreme imbalance | Combination + weighted loss | Multiple techniques needed |
| Complex minority patterns | Oversampling | Need multiple views of minority class |
Algorithm-Level Solutions
Algorithm-level solutions modify the training process rather than the data.
Weighted Loss Functions
Assign higher weight to minority class errors, making the model pay more attention to them:
Where is the weight for class . Common weighting schemes:
| Scheme | Formula | Effect |
|---|---|---|
| Inverse frequency | wc = N / (K * nc) | Proportional to class rarity |
| Inverse square root | wc = 1 / sqrt(nc) | Less aggressive than inverse |
| Effective number | wc = (1-beta) / (1-beta^nc) | Accounts for data overlap |
Focal Loss
Focal loss (from RetinaNet) down-weights easy examples, focusing training on hard examples:
Where is the model's probability for the true class, is the focusing parameter (typically 2), and is the class weight.
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class FocalLoss(nn.Module):
6 """Focal Loss for imbalanced classification."""
7
8 def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
9 super().__init__()
10 self.alpha = alpha # Class weights (optional)
11 self.gamma = gamma # Focusing parameter
12 self.reduction = reduction
13
14 def forward(self, inputs, targets):
15 # Compute cross entropy
16 ce_loss = F.cross_entropy(inputs, targets, reduction='none')
17
18 # Get probabilities for true class
19 pt = torch.exp(-ce_loss)
20
21 # Compute focal weight
22 focal_weight = (1 - pt) ** self.gamma
23
24 # Apply class weights if provided
25 if self.alpha is not None:
26 alpha_t = self.alpha[targets]
27 focal_loss = alpha_t * focal_weight * ce_loss
28 else:
29 focal_loss = focal_weight * ce_loss
30
31 if self.reduction == 'mean':
32 return focal_loss.mean()
33 elif self.reduction == 'sum':
34 return focal_loss.sum()
35 return focal_loss
36
37# Usage
38alpha = torch.tensor([0.25, 0.75]) # Weight for each class
39criterion = FocalLoss(alpha=alpha, gamma=2.0)Focal Loss Intuition
Evaluation Metrics
Choose metrics that reveal true performance on imbalanced data:
Confusion Matrix Metrics
| Metric | Formula | Best For |
|---|---|---|
| Precision | TP / (TP + FP) | When false positives are costly |
| Recall (Sensitivity) | TP / (TP + FN) | When missing positives is costly |
| F1 Score | 2 * (Precision * Recall) / (P + R) | Balance of precision and recall |
| Specificity | TN / (TN + FP) | When negatives matter too |
1from sklearn.metrics import (
2 precision_score, recall_score, f1_score,
3 confusion_matrix, classification_report,
4 balanced_accuracy_score, roc_auc_score
5)
6
7def evaluate_imbalanced(y_true, y_pred, y_prob=None):
8 """Comprehensive evaluation for imbalanced classification."""
9
10 print("Confusion Matrix:")
11 print(confusion_matrix(y_true, y_pred))
12 print()
13
14 print("Classification Report:")
15 print(classification_report(y_true, y_pred))
16
17 # Balanced accuracy (average of recall per class)
18 bal_acc = balanced_accuracy_score(y_true, y_pred)
19 print(f"Balanced Accuracy: {bal_acc:.4f}")
20
21 # ROC-AUC if probabilities available
22 if y_prob is not None:
23 roc_auc = roc_auc_score(y_true, y_prob)
24 print(f"ROC-AUC: {roc_auc:.4f}")Which Metric to Prioritize?
| Scenario | Primary Metric | Reasoning |
|---|---|---|
| Fraud detection | Recall + precision | Catch fraud (recall) but don't block valid transactions (precision) |
| Medical screening | Recall | Missing a disease is worse than extra tests |
| Spam filtering | Precision | Blocking good email is worse than letting spam through |
| General imbalanced | F1 or balanced accuracy | Good balance for most cases |
| Threshold tuning | ROC-AUC or PR-AUC | Threshold-independent evaluation |
PR-AUC vs ROC-AUC
PyTorch Implementation
Here's how to implement the key techniques in PyTorch:
Complete Pipeline Example
Here's a complete training pipeline for imbalanced classification:
1import torch
2import torch.nn as nn
3import torch.optim as optim
4from torch.utils.data import DataLoader, WeightedRandomSampler, TensorDataset
5from sklearn.metrics import classification_report, balanced_accuracy_score
6import numpy as np
7
8def train_imbalanced_classifier(
9 X_train, y_train, X_val, y_val,
10 model,
11 epochs=50,
12 batch_size=32,
13 use_weighted_sampling=True,
14 use_weighted_loss=True
15):
16 """Train a classifier on imbalanced data with best practices."""
17
18 # Analyze class distribution
19 class_counts = torch.bincount(torch.tensor(y_train))
20 print(f"Class distribution: {class_counts.tolist()}")
21 imbalance_ratio = class_counts.max().item() / class_counts.min().item()
22 print(f"Imbalance ratio: {imbalance_ratio:.1f}:1")
23
24 # Setup loss function
25 if use_weighted_loss:
26 weights = len(y_train) / (len(class_counts) * class_counts.float())
27 criterion = nn.CrossEntropyLoss(weight=weights)
28 print(f"Class weights: {weights.tolist()}")
29 else:
30 criterion = nn.CrossEntropyLoss()
31
32 # Setup data loader with optional weighted sampling
33 train_dataset = TensorDataset(torch.FloatTensor(X_train), torch.LongTensor(y_train))
34
35 if use_weighted_sampling:
36 sample_weights = 1.0 / class_counts[y_train].float()
37 sampler = WeightedRandomSampler(sample_weights, len(y_train), replacement=True)
38 train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
39 else:
40 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
41
42 val_dataset = TensorDataset(torch.FloatTensor(X_val), torch.LongTensor(y_val))
43 val_loader = DataLoader(val_dataset, batch_size=batch_size)
44
45 # Optimizer
46 optimizer = optim.Adam(model.parameters(), lr=1e-3)
47
48 # Training loop
49 best_bal_acc = 0
50 for epoch in range(epochs):
51 model.train()
52 for X_batch, y_batch in train_loader:
53 optimizer.zero_grad()
54 outputs = model(X_batch)
55 loss = criterion(outputs, y_batch)
56 loss.backward()
57 optimizer.step()
58
59 # Validation
60 model.eval()
61 all_preds, all_labels = [], []
62 with torch.no_grad():
63 for X_batch, y_batch in val_loader:
64 outputs = model(X_batch)
65 preds = outputs.argmax(dim=1)
66 all_preds.extend(preds.tolist())
67 all_labels.extend(y_batch.tolist())
68
69 bal_acc = balanced_accuracy_score(all_labels, all_preds)
70 if bal_acc > best_bal_acc:
71 best_bal_acc = bal_acc
72 best_model_state = model.state_dict().copy()
73
74 if (epoch + 1) % 10 == 0:
75 print(f"Epoch {epoch+1}: Balanced Accuracy = {bal_acc:.4f}")
76
77 # Restore best model and print final report
78 model.load_state_dict(best_model_state)
79 print(f"\nBest Balanced Accuracy: {best_bal_acc:.4f}")
80 print("\nFinal Classification Report:")
81 print(classification_report(all_labels, all_preds))
82
83 return modelBest Practices
Do's
- Always analyze class distribution first: Know your imbalance ratio before choosing solutions
- Use balanced accuracy or F1 for model selection: Not regular accuracy
- Stratify train/val/test splits: Ensure all splits have similar class proportions
- Combine techniques: Weighted sampling + weighted loss often works best
- Use appropriate augmentation: Especially for oversampled minority class
Don'ts
- Don't rely on accuracy alone: Misleading for imbalanced data
- Don't oversample test/validation data: Only balance training data
- Don't use SMOTE on test data: Test must reflect real distribution
- Don't assume one technique works for all: Experiment with combinations
Quick Check
You have a fraud detection dataset with 1% fraud and 99% legitimate transactions. What's wrong with using accuracy as your primary metric?
Summary
Handling imbalanced data requires a multi-pronged approach:
| Technique | When to Use | PyTorch Implementation |
|---|---|---|
| Weighted Sampling | Most cases, especially moderate imbalance | WeightedRandomSampler |
| Class-Weighted Loss | Almost always beneficial | CrossEntropyLoss(weight=...) |
| Focal Loss | Extreme imbalance, many easy examples | Custom FocalLoss class |
| Oversampling + Augmentation | Small minority class | WeightedRandomSampler + transforms |
| Undersampling | Large dataset, computational constraints | Subset with random selection |
Key Takeaway
Exercises
Diagnostic Exercises
- You have a dataset with class distribution [8000, 1500, 500]. Calculate the imbalance ratio and recommend appropriate techniques.
- A medical diagnosis model has 95% accuracy but only 40% recall for the disease class. Is this acceptable? What should you do?
- Your colleague trained with weighted sampling but is evaluating on a weighted test set too. What's wrong with this approach?
Solution Hints
- E1: Ratio is 8000:500 = 16:1 (significant). Use weighted loss + sampling. Consider focal loss if results are poor.
- E2: Not acceptable - missing 60% of disease cases is dangerous. Increase class weight for disease, lower decision threshold, or use recall-optimized threshold.
- E3: Test data should reflect real distribution, never balanced. Weighted evaluation on weighted test gives misleading metrics. Only balance training data.
Coding Exercises
- Implement class weighting comparison: Train the same model on MNIST with only 100 samples of digit "9" vs 1000 samples of other digits. Compare accuracy, balanced accuracy, and per-class recall with and without class weights.
- Threshold tuning: Train a binary classifier on imbalanced data. Plot precision and recall vs. decision threshold. Find the threshold that maximizes F1 score.
- Sampling strategies: Implement and compare three approaches: (a) no balancing, (b) weighted sampling only, (c) weighted sampling + weighted loss. Use balanced accuracy as the evaluation metric.
Exercise Tips
- Exercise 1: Create imbalanced MNIST by subsampling digit 9. Use sklearn's classification_report to see per-class metrics.
- Exercise 2: Use model outputs with sigmoid, then test thresholds from 0.1 to 0.9. sklearn.metrics.precision_recall_curve is helpful.
- Exercise 3: Keep validation set unchanged across all experiments for fair comparison. Only modify training setup.
In the next chapter, we'll explore Backpropagation and Automatic Differentiation—the mathematical foundation that makes neural network training possible.