Source code for secmlt.models.base_trainer

"""Model trainers."""

from abc import ABCMeta, abstractmethod

from secmlt.models.base_model import BaseModel
from torch.utils.data import DataLoader


[docs] class BaseTrainer(metaclass=ABCMeta): """Abstract class for model trainers."""
[docs] @abstractmethod def train(self, model: BaseModel, dataloader: DataLoader) -> BaseModel: """ Train a model with the given dataloader. Parameters ---------- model : BaseModel Model to train. dataloader : DataLoader Training dataloader. Returns ------- BaseModel The trained model. """ ...