Source code for secmlt.models.pytorch.early_stopping_pytorch_trainer

"""PyTorch model trainers with early stopping."""

import torch.nn
from secmlt.models.pytorch.base_pytorch_trainer import BasePyTorchTrainer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader


[docs] class EarlyStoppingPyTorchTrainer(BasePyTorchTrainer): """Trainer for PyTorch models with early stopping."""
[docs] def __init__(self, optimizer: torch.optim.Optimizer, epochs: int = 5, loss: torch.nn.Module = None, scheduler: _LRScheduler = None) -> None: """ Create PyTorch trainer. Parameters ---------- optimizer : torch.optim.Optimizer Optimizer to use for training the model. epochs : int, optional Number of epochs, by default 5. loss : torch.nn.Module, optional Loss to minimize, by default None. scheduler : _LRScheduler, optional Scheduler for the optimizer, by default None. """ super().__init__(optimizer, epochs, loss, scheduler) self._epochs = epochs self._optimizer = optimizer self._loss = loss if loss is not None else torch.nn.CrossEntropyLoss() self._scheduler = scheduler
[docs] def fit(self, model: torch.nn.Module, train_loader: DataLoader, val_loader: DataLoader, patience: int) -> torch.nn.Module: """ Train model with given loaders and early stopping. Parameters ---------- model : torch.nn.Module Pytorch model to be trained. train_loader : DataLoader Train data loader. val_loader : DataLoader Validation data loader. patience : int Number of epochs to wait before early stopping. Returns ------- torch.nn.Module Trained model. """ best_loss = float("inf") best_model = None patience_counter = 0 for _ in range(self._epochs): model = self.train(model, train_loader) val_loss = self.validate(model, val_loader) if val_loss < best_loss: best_loss = val_loss best_model = model patience_counter = 0 else: patience_counter += 1 if patience_counter >= patience: break return best_model
[docs] def train(self, model: torch.nn.Module, dataloader: DataLoader) -> torch.nn.Module: """ Train model for one epoch with given loader. Parameters ---------- model : torch.nn.Module Pytorch model to be trained. dataloader : DataLoader Train data loader. Returns ------- torch.nn.Module Trained model. """ device = next(model.parameters()).device model = model.train() for _, (x, y) in enumerate(dataloader): x, y = x.to(device), y.to(device) self._optimizer.zero_grad() outputs = model(x) loss = self._loss(outputs, y) loss.backward() self._optimizer.step() if self._scheduler is not None: self._scheduler.step() return model
[docs] def validate(self, model: torch.nn.Module, dataloader: DataLoader) -> float: """ Validate model with given loader. Parameters ---------- model : torch.nn.Module Pytorch model to be balidated. dataloader : DataLoader Validation data loader. Returns ------- float Validation loss of the model. """ running_loss = 0 device = next(model.parameters()).device model = model.eval() with torch.no_grad(): for _, (x, y) in enumerate(dataloader): x, y = x.to(device), y.to(device) outputs = model(x) loss = self._loss(outputs, y) running_loss += loss.item() return loss