Chapter 7
20 min read
Section 47 of 178

Handling Imbalanced Data

Data Loading and Processing

Learning Objectives

By the end of this section, you will be able to:

  1. Identify class imbalance: Recognize when a dataset has imbalanced class distributions and measure the severity
  2. Apply data-level solutions: Use oversampling, undersampling, and weighted sampling strategies in PyTorch
  3. Use algorithm-level solutions: Implement class-weighted loss functions and focal loss
  4. Choose appropriate metrics: Select evaluation metrics that properly assess performance on imbalanced data

What is Class Imbalance?

Class imbalance occurs when one or more classes in a classification dataset have significantly fewer samples than others. This is extremely common in real-world applications:

DomainMinority ClassTypical Imbalance Ratio
Fraud DetectionFraudulent transactions1:1000 to 1:10000
Medical DiagnosisRare diseases1:100 to 1:1000
Spam DetectionSpam emails (or ham)Varies, often 1:10
Defect DetectionDefective items1:100 to 1:10000
Churn PredictionChurning customers1:5 to 1:20
Real-World Example: In credit card fraud detection, less than 0.1% of transactions are fraudulent. If you have 1 million transactions, only ~1,000 might be fraud. Training a model on this data without handling the imbalance will lead to poor fraud detection.

Why Imbalance Matters

Standard training on imbalanced data leads to models that ignore minority classes. Here's why:

The Accuracy Trap

Consider a fraud detection dataset with 99% legitimate and 1% fraudulent transactions:

🐍accuracy_trap.py
1# A model that ALWAYS predicts "legitimate"
2# achieves 99% accuracy but catches ZERO fraud!
3
4y_true = [0] * 990 + [1] * 10  # 99% class 0, 1% class 1
5y_pred = [0] * 1000            # Always predict class 0
6
7accuracy = sum(p == t for p, t in zip(y_pred, y_true)) / len(y_true)
8print(f"Accuracy: {accuracy:.1%}")  # 99.0% - looks great!
9
10fraud_caught = sum(1 for p, t in zip(y_pred, y_true) if p == 1 and t == 1)
11print(f"Fraud caught: {fraud_caught}")  # 0 - completely useless!

Never Trust Accuracy Alone

On imbalanced datasets, accuracy is misleading. A 99% accurate model might be completely useless if it fails to detect the minority class that you actually care about.

Gradient Domination

During training, the loss is averaged across all samples. With imbalanced data:

  1. The majority class contributes most of the gradient signal
  2. The model learns to minimize majority class errors
  3. Minority class patterns get "drowned out"
  4. The model becomes biased toward predicting the majority class

Measuring Imbalance

Before addressing imbalance, you need to quantify it:

Imbalance Ratio

Imbalance Ratio=Majority Class SamplesMinority Class Samples\text{Imbalance Ratio} = \frac{\text{Majority Class Samples}}{\text{Minority Class Samples}}
🐍measure_imbalance.py
1import numpy as np
2from collections import Counter
3
4def analyze_class_distribution(labels):
5    """Analyze and report class distribution statistics."""
6    counts = Counter(labels)
7    total = len(labels)
8
9    print("Class Distribution:")
10    for class_id, count in sorted(counts.items()):
11        pct = count / total * 100
12        print(f"  Class {class_id}: {count:,} samples ({pct:.2f}%)")
13
14    # Calculate imbalance ratio
15    max_count = max(counts.values())
16    min_count = min(counts.values())
17    imbalance_ratio = max_count / min_count
18
19    print(f"\nImbalance Ratio: {imbalance_ratio:.1f}:1")
20
21    # Severity classification
22    if imbalance_ratio < 3:
23        severity = "Mild"
24    elif imbalance_ratio < 10:
25        severity = "Moderate"
26    elif imbalance_ratio < 100:
27        severity = "Significant"
28    else:
29        severity = "Extreme"
30
31    print(f"Severity: {severity}")
32    return counts
33
34# Example usage
35labels = [0] * 9000 + [1] * 900 + [2] * 100
36analyze_class_distribution(labels)
Imbalance RatioSeverityRecommended Actions
< 3:1MildStandard training often works; monitor metrics
3:1 - 10:1ModerateClass weights usually sufficient
10:1 - 100:1SignificantCombine sampling + weighted loss
> 100:1ExtremeAggressive techniques needed; consider problem reformulation

Data-Level Solutions

Data-level solutions modify the training dataset to achieve better balance.

Oversampling

Oversampling increases minority class representation by duplicating or generating new samples.

Random Oversampling

The simplest approach: randomly duplicate minority class samples.

🐍random_oversampling.py
1from torch.utils.data import WeightedRandomSampler, DataLoader
2import torch
3
4def create_balanced_sampler(labels):
5    """Create a sampler that balances class frequencies."""
6    # Count samples per class
7    class_counts = torch.bincount(torch.tensor(labels))
8
9    # Compute weight for each sample (inverse of class frequency)
10    weights = 1.0 / class_counts[labels]
11
12    # Create sampler
13    sampler = WeightedRandomSampler(
14        weights=weights,
15        num_samples=len(labels),
16        replacement=True  # Allow sampling same item multiple times
17    )
18    return sampler
19
20# Example usage
21labels = [0] * 900 + [1] * 100  # 9:1 imbalance
22sampler = create_balanced_sampler(labels)
23
24# This DataLoader will sample each class equally often
25dataloader = DataLoader(
26    dataset,
27    batch_size=32,
28    sampler=sampler  # Replaces shuffle=True
29)

Oversampling Trade-off

Oversampling can lead to overfitting on minority class examples since the same samples are seen repeatedly. Combine with data augmentation to mitigate this.

SMOTE (Synthetic Minority Over-sampling Technique)

SMOTE creates synthetic minority samples by interpolating between existing ones:

xnew=xi+λ(xjxi),λUniform(0,1)x_{\text{new}} = x_i + \lambda \cdot (x_j - x_i), \quad \lambda \sim \text{Uniform}(0, 1)

Where xix_i is a minority sample and xjx_j is one of its k-nearest neighbors.

🐍smote_example.py
1# Note: SMOTE works on feature vectors, not raw images
2# Use imbalanced-learn library for production SMOTE
3
4from imblearn.over_sampling import SMOTE
5
6# Assume X is feature matrix, y is labels
7smote = SMOTE(random_state=42)
8X_resampled, y_resampled = smote.fit_resample(X, y)
9
10print(f"Original: {len(X)} samples")
11print(f"After SMOTE: {len(X_resampled)} samples")

SMOTE Limitations

SMOTE works best with tabular/feature data. For images, prefer random oversampling combined with data augmentation (rotation, flipping, color jitter). For text, consider paraphrase augmentation.

Undersampling

Undersampling reduces majority class samples to match minority class size.

🐍undersampling.py
1import numpy as np
2from torch.utils.data import Subset
3
4def create_undersampled_dataset(dataset, labels):
5    """Undersample majority class to match minority class."""
6    labels = np.array(labels)
7    classes, counts = np.unique(labels, return_counts=True)
8    min_count = counts.min()
9
10    selected_indices = []
11    for cls in classes:
12        cls_indices = np.where(labels == cls)[0]
13        # Randomly select min_count samples from this class
14        selected = np.random.choice(cls_indices, min_count, replace=False)
15        selected_indices.extend(selected)
16
17    return Subset(dataset, selected_indices)
18
19# Example: 9000 majority, 100 minority -> 100 each
20undersampled = create_undersampled_dataset(dataset, labels)
21print(f"Original: {len(dataset)}, Undersampled: {len(undersampled)}")

Information Loss

Undersampling discards data. With extreme imbalance (e.g., 10000:100), you lose 99% of majority class samples. Use this cautiously and consider ensemble approaches (multiple undersampled subsets).

Choosing Between Oversampling and Undersampling

SituationRecommendationReasoning
Large dataset, moderate imbalanceUndersamplingPlenty of data; training faster
Small datasetOversampling + augmentationCan't afford to lose data
Extreme imbalanceCombination + weighted lossMultiple techniques needed
Complex minority patternsOversamplingNeed multiple views of minority class

Algorithm-Level Solutions

Algorithm-level solutions modify the training process rather than the data.

Weighted Loss Functions

Assign higher weight to minority class errors, making the model pay more attention to them:

Lweighted=i=1Nwyiyilog(y^i)\mathcal{L}_{\text{weighted}} = -\sum_{i=1}^{N} w_{y_i} \cdot y_i \log(\hat{y}_i)

Where wcw_c is the weight for class cc. Common weighting schemes:

SchemeFormulaEffect
Inverse frequencywc = N / (K * nc)Proportional to class rarity
Inverse square rootwc = 1 / sqrt(nc)Less aggressive than inverse
Effective numberwc = (1-beta) / (1-beta^nc)Accounts for data overlap
Computing and Using Class Weights
🐍class_weights.py
14Inverse Frequency

Makes rare classes count more. A class with 10x fewer samples gets 10x higher weight.

18Square Root Inverse

Less aggressive than linear inverse. A class with 100x fewer samples gets only 10x higher weight.

22Effective Number

From 'Class-Balanced Loss' paper. Accounts for diminishing returns of additional samples due to data overlap.

30 lines without explanation
1import torch
2import torch.nn as nn
3
4def compute_class_weights(labels, method='inverse'):
5    """Compute class weights for imbalanced data."""
6    labels = torch.tensor(labels)
7    class_counts = torch.bincount(labels).float()
8    num_classes = len(class_counts)
9    total_samples = len(labels)
10
11    if method == 'inverse':
12        # Inverse frequency weighting
13        weights = total_samples / (num_classes * class_counts)
14    elif method == 'sqrt_inverse':
15        # Inverse square root (less aggressive)
16        weights = 1.0 / torch.sqrt(class_counts)
17        weights = weights / weights.sum() * num_classes  # Normalize
18    elif method == 'effective':
19        # Effective number weighting (beta=0.9999)
20        beta = 0.9999
21        effective_num = 1.0 - torch.pow(beta, class_counts)
22        weights = (1.0 - beta) / effective_num
23        weights = weights / weights.sum() * num_classes
24
25    return weights
26
27# Example usage
28labels = [0] * 900 + [1] * 100
29weights = compute_class_weights(labels, method='inverse')
30print(f"Class weights: {weights}")  # [0.55, 5.0] approx
31
32# Use with CrossEntropyLoss
33criterion = nn.CrossEntropyLoss(weight=weights)

Focal Loss

Focal loss (from RetinaNet) down-weights easy examples, focusing training on hard examples:

FL(pt)=αt(1pt)γlog(pt)\text{FL}(p_t) = -\alpha_t (1 - p_t)^\gamma \log(p_t)

Where ptp_t is the model's probability for the true class, γ\gamma is the focusing parameter (typically 2), and αt\alpha_t is the class weight.

🐍focal_loss.py
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4
5class FocalLoss(nn.Module):
6    """Focal Loss for imbalanced classification."""
7
8    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
9        super().__init__()
10        self.alpha = alpha  # Class weights (optional)
11        self.gamma = gamma  # Focusing parameter
12        self.reduction = reduction
13
14    def forward(self, inputs, targets):
15        # Compute cross entropy
16        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
17
18        # Get probabilities for true class
19        pt = torch.exp(-ce_loss)
20
21        # Compute focal weight
22        focal_weight = (1 - pt) ** self.gamma
23
24        # Apply class weights if provided
25        if self.alpha is not None:
26            alpha_t = self.alpha[targets]
27            focal_loss = alpha_t * focal_weight * ce_loss
28        else:
29            focal_loss = focal_weight * ce_loss
30
31        if self.reduction == 'mean':
32            return focal_loss.mean()
33        elif self.reduction == 'sum':
34            return focal_loss.sum()
35        return focal_loss
36
37# Usage
38alpha = torch.tensor([0.25, 0.75])  # Weight for each class
39criterion = FocalLoss(alpha=alpha, gamma=2.0)

Focal Loss Intuition

When γ=2\gamma=2 and a sample is correctly classified with 90% confidence (pt=0.9p_t=0.9), the loss is multiplied by (10.9)2=0.01(1-0.9)^2 = 0.01—100x reduction. But for a hard example with 50% confidence, the multiplier is (10.5)2=0.25(1-0.5)^2 = 0.25—only 4x reduction. This focuses learning on hard examples.

Evaluation Metrics

Choose metrics that reveal true performance on imbalanced data:

Confusion Matrix Metrics

MetricFormulaBest For
PrecisionTP / (TP + FP)When false positives are costly
Recall (Sensitivity)TP / (TP + FN)When missing positives is costly
F1 Score2 * (Precision * Recall) / (P + R)Balance of precision and recall
SpecificityTN / (TN + FP)When negatives matter too
🐍metrics.py
1from sklearn.metrics import (
2    precision_score, recall_score, f1_score,
3    confusion_matrix, classification_report,
4    balanced_accuracy_score, roc_auc_score
5)
6
7def evaluate_imbalanced(y_true, y_pred, y_prob=None):
8    """Comprehensive evaluation for imbalanced classification."""
9
10    print("Confusion Matrix:")
11    print(confusion_matrix(y_true, y_pred))
12    print()
13
14    print("Classification Report:")
15    print(classification_report(y_true, y_pred))
16
17    # Balanced accuracy (average of recall per class)
18    bal_acc = balanced_accuracy_score(y_true, y_pred)
19    print(f"Balanced Accuracy: {bal_acc:.4f}")
20
21    # ROC-AUC if probabilities available
22    if y_prob is not None:
23        roc_auc = roc_auc_score(y_true, y_prob)
24        print(f"ROC-AUC: {roc_auc:.4f}")

Which Metric to Prioritize?

ScenarioPrimary MetricReasoning
Fraud detectionRecall + precisionCatch fraud (recall) but don't block valid transactions (precision)
Medical screeningRecallMissing a disease is worse than extra tests
Spam filteringPrecisionBlocking good email is worse than letting spam through
General imbalancedF1 or balanced accuracyGood balance for most cases
Threshold tuningROC-AUC or PR-AUCThreshold-independent evaluation

PR-AUC vs ROC-AUC

For highly imbalanced data, Precision-Recall AUC is often more informative than ROC-AUC. ROC-AUC can be overly optimistic when negatives vastly outnumber positives because specificity (used in ROC) becomes easy to achieve.

PyTorch Implementation

Here's how to implement the key techniques in PyTorch:

PyTorch Imbalanced Data Handler
🐍imbalanced_handler.py
16Inverse Weighting

Standard formula: total / (num_classes * class_count). Makes loss contribution equal across classes.

25Sample Weights

Each sample gets weight = 1/class_frequency. Minority samples more likely to be selected.

42Combining Techniques

Using both weighted sampling AND weighted loss often gives best results, especially for extreme imbalance.

48 lines without explanation
1import torch
2from torch.utils.data import DataLoader, WeightedRandomSampler
3import torch.nn as nn
4
5class ImbalancedDataHandler:
6    """Helper class for handling imbalanced datasets in PyTorch."""
7
8    def __init__(self, labels):
9        self.labels = torch.tensor(labels)
10        self.class_counts = torch.bincount(self.labels)
11        self.num_classes = len(self.class_counts)
12        self.num_samples = len(labels)
13
14    def get_class_weights(self, scheme='inverse'):
15        """Get weights for CrossEntropyLoss."""
16        if scheme == 'inverse':
17            weights = self.num_samples / (self.num_classes * self.class_counts.float())
18        elif scheme == 'balanced':
19            weights = self.class_counts.max().float() / self.class_counts.float()
20        else:
21            weights = torch.ones(self.num_classes)
22        return weights
23
24    def get_sampler(self):
25        """Get WeightedRandomSampler for DataLoader."""
26        # Weight per sample = inverse of its class frequency
27        sample_weights = 1.0 / self.class_counts[self.labels].float()
28        sampler = WeightedRandomSampler(
29            weights=sample_weights,
30            num_samples=self.num_samples,
31            replacement=True
32        )
33        return sampler
34
35# Usage example
36labels = [0] * 900 + [1] * 100  # Imbalanced labels
37handler = ImbalancedDataHandler(labels)
38
39# Option 1: Weighted loss
40class_weights = handler.get_class_weights('inverse')
41criterion = nn.CrossEntropyLoss(weight=class_weights)
42
43# Option 2: Balanced sampling
44sampler = handler.get_sampler()
45dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
46
47# Option 3: Combine both (often most effective)
48class_weights = handler.get_class_weights('inverse')
49criterion = nn.CrossEntropyLoss(weight=class_weights)
50sampler = handler.get_sampler()
51dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)

Complete Pipeline Example

Here's a complete training pipeline for imbalanced classification:

🐍complete_pipeline.py
1import torch
2import torch.nn as nn
3import torch.optim as optim
4from torch.utils.data import DataLoader, WeightedRandomSampler, TensorDataset
5from sklearn.metrics import classification_report, balanced_accuracy_score
6import numpy as np
7
8def train_imbalanced_classifier(
9    X_train, y_train, X_val, y_val,
10    model,
11    epochs=50,
12    batch_size=32,
13    use_weighted_sampling=True,
14    use_weighted_loss=True
15):
16    """Train a classifier on imbalanced data with best practices."""
17
18    # Analyze class distribution
19    class_counts = torch.bincount(torch.tensor(y_train))
20    print(f"Class distribution: {class_counts.tolist()}")
21    imbalance_ratio = class_counts.max().item() / class_counts.min().item()
22    print(f"Imbalance ratio: {imbalance_ratio:.1f}:1")
23
24    # Setup loss function
25    if use_weighted_loss:
26        weights = len(y_train) / (len(class_counts) * class_counts.float())
27        criterion = nn.CrossEntropyLoss(weight=weights)
28        print(f"Class weights: {weights.tolist()}")
29    else:
30        criterion = nn.CrossEntropyLoss()
31
32    # Setup data loader with optional weighted sampling
33    train_dataset = TensorDataset(torch.FloatTensor(X_train), torch.LongTensor(y_train))
34
35    if use_weighted_sampling:
36        sample_weights = 1.0 / class_counts[y_train].float()
37        sampler = WeightedRandomSampler(sample_weights, len(y_train), replacement=True)
38        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
39    else:
40        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
41
42    val_dataset = TensorDataset(torch.FloatTensor(X_val), torch.LongTensor(y_val))
43    val_loader = DataLoader(val_dataset, batch_size=batch_size)
44
45    # Optimizer
46    optimizer = optim.Adam(model.parameters(), lr=1e-3)
47
48    # Training loop
49    best_bal_acc = 0
50    for epoch in range(epochs):
51        model.train()
52        for X_batch, y_batch in train_loader:
53            optimizer.zero_grad()
54            outputs = model(X_batch)
55            loss = criterion(outputs, y_batch)
56            loss.backward()
57            optimizer.step()
58
59        # Validation
60        model.eval()
61        all_preds, all_labels = [], []
62        with torch.no_grad():
63            for X_batch, y_batch in val_loader:
64                outputs = model(X_batch)
65                preds = outputs.argmax(dim=1)
66                all_preds.extend(preds.tolist())
67                all_labels.extend(y_batch.tolist())
68
69        bal_acc = balanced_accuracy_score(all_labels, all_preds)
70        if bal_acc > best_bal_acc:
71            best_bal_acc = bal_acc
72            best_model_state = model.state_dict().copy()
73
74        if (epoch + 1) % 10 == 0:
75            print(f"Epoch {epoch+1}: Balanced Accuracy = {bal_acc:.4f}")
76
77    # Restore best model and print final report
78    model.load_state_dict(best_model_state)
79    print(f"\nBest Balanced Accuracy: {best_bal_acc:.4f}")
80    print("\nFinal Classification Report:")
81    print(classification_report(all_labels, all_preds))
82
83    return model

Best Practices

Do's

  1. Always analyze class distribution first: Know your imbalance ratio before choosing solutions
  2. Use balanced accuracy or F1 for model selection: Not regular accuracy
  3. Stratify train/val/test splits: Ensure all splits have similar class proportions
  4. Combine techniques: Weighted sampling + weighted loss often works best
  5. Use appropriate augmentation: Especially for oversampled minority class

Don'ts

  • Don't rely on accuracy alone: Misleading for imbalanced data
  • Don't oversample test/validation data: Only balance training data
  • Don't use SMOTE on test data: Test must reflect real distribution
  • Don't assume one technique works for all: Experiment with combinations

Quick Check

You have a fraud detection dataset with 1% fraud and 99% legitimate transactions. What's wrong with using accuracy as your primary metric?


Summary

Handling imbalanced data requires a multi-pronged approach:

TechniqueWhen to UsePyTorch Implementation
Weighted SamplingMost cases, especially moderate imbalanceWeightedRandomSampler
Class-Weighted LossAlmost always beneficialCrossEntropyLoss(weight=...)
Focal LossExtreme imbalance, many easy examplesCustom FocalLoss class
Oversampling + AugmentationSmall minority classWeightedRandomSampler + transforms
UndersamplingLarge dataset, computational constraintsSubset with random selection

Key Takeaway

The combination of weighted random sampling in the DataLoader plus class-weighted loss is a robust baseline that works well for most imbalanced problems. Add focal loss or more aggressive techniques for extreme cases.

Exercises

Diagnostic Exercises

  1. You have a dataset with class distribution [8000, 1500, 500]. Calculate the imbalance ratio and recommend appropriate techniques.
  2. A medical diagnosis model has 95% accuracy but only 40% recall for the disease class. Is this acceptable? What should you do?
  3. Your colleague trained with weighted sampling but is evaluating on a weighted test set too. What's wrong with this approach?

Solution Hints

  1. E1: Ratio is 8000:500 = 16:1 (significant). Use weighted loss + sampling. Consider focal loss if results are poor.
  2. E2: Not acceptable - missing 60% of disease cases is dangerous. Increase class weight for disease, lower decision threshold, or use recall-optimized threshold.
  3. E3: Test data should reflect real distribution, never balanced. Weighted evaluation on weighted test gives misleading metrics. Only balance training data.

Coding Exercises

  1. Implement class weighting comparison: Train the same model on MNIST with only 100 samples of digit "9" vs 1000 samples of other digits. Compare accuracy, balanced accuracy, and per-class recall with and without class weights.
  2. Threshold tuning: Train a binary classifier on imbalanced data. Plot precision and recall vs. decision threshold. Find the threshold that maximizes F1 score.
  3. Sampling strategies: Implement and compare three approaches: (a) no balancing, (b) weighted sampling only, (c) weighted sampling + weighted loss. Use balanced accuracy as the evaluation metric.

Exercise Tips

  • Exercise 1: Create imbalanced MNIST by subsampling digit 9. Use sklearn's classification_report to see per-class metrics.
  • Exercise 2: Use model outputs with sigmoid, then test thresholds from 0.1 to 0.9. sklearn.metrics.precision_recall_curve is helpful.
  • Exercise 3: Keep validation set unchanged across all experiments for fair comparison. Only modify training setup.

In the next chapter, we'll explore Backpropagation and Automatic Differentiation—the mathematical foundation that makes neural network training possible.