Source code for secmlt.models.pytorch.base_pytorch_trainer

"""PyTorch model trainers."""

import torch.nn
from secmlt.models.base_trainer import BaseTrainer
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader


[docs] class BasePyTorchTrainer(BaseTrainer): """Trainer for PyTorch models."""
[docs] def __init__( self, optimizer: torch.optim.Optimizer, epochs: int = 5, loss: torch.nn.Module = None, scheduler: _LRScheduler = None, ) -> None: """ Create PyTorch trainer. Parameters ---------- optimizer : torch.optim.Optimizer Optimizer to use for training the model. epochs : int, optional Number of epochs, by default 5. loss : torch.nn.Module, optional Loss to minimize, by default None. scheduler : _LRScheduler, optional Scheduler for the optimizer, by default None. """ self._epochs = epochs self._optimizer = optimizer self._loss = loss if loss is not None else torch.nn.CrossEntropyLoss() self._scheduler = scheduler
[docs] def train(self, model: torch.nn.Module, dataloader: DataLoader) -> torch.nn.Module: """ Train model with given loader. Parameters ---------- model : torch.nn.Module Pytorch model to be trained. dataloader : DataLoader Train data loader. Returns ------- torch.nn.Module Trained model. """ device = next(model.parameters()).device model = model.train() for _ in range(self._epochs): for _, (x, y) in enumerate(dataloader): x, y = x.to(device), y.to(device) self._optimizer.zero_grad() outputs = model(x) loss = self._loss(outputs, y) loss.backward() self._optimizer.step() if self._scheduler is not None: self._scheduler.step() return model