"""Wrappers for PyTorch models."""
import torch
from secmlt.models.base_model import BaseModel
from secmlt.models.data_processing.data_processing import DataProcessing
from secmlt.models.pytorch.base_pytorch_trainer import BasePyTorchTrainer
from torch.utils.data import DataLoader
[docs]
class BasePytorchClassifier(BaseModel):
"""Wrapper for PyTorch classifier."""
[docs]
def __init__(
self,
model: torch.nn.Module,
preprocessing: DataProcessing = None,
postprocessing: DataProcessing = None,
trainer: BasePyTorchTrainer = None,
) -> None:
"""
Create wrapped PyTorch classifier.
Parameters
----------
model : torch.nn.Module
PyTorch model.
preprocessing : DataProcessing, optional
Preprocessing to apply before the forward., by default None.
postprocessing : DataProcessing, optional
Postprocessing to apply after the forward, by default None.
trainer : BasePyTorchTrainer, optional
Trainer object to train the model, by default None.
"""
super().__init__(preprocessing=preprocessing, postprocessing=postprocessing)
self._model: torch.nn.Module = model
self._trainer = trainer
@property
def model(self) -> torch.nn.Module:
"""
Get the wrapped instance of PyTorch model.
Returns
-------
torch.nn.Module
Wrapped PyTorch model.
"""
return self._model
def _get_device(self) -> torch.device:
return next(self._model.parameters()).device
[docs]
def predict(self, x: torch.Tensor) -> torch.Tensor:
"""
Return the predicted class for the given samples.
Parameters
----------
x : torch.Tensor
Input samples.
Returns
-------
torch.Tensor
Predicted class for the samples.
"""
scores = self.decision_function(x)
return torch.argmax(scores, dim=-1)
[docs]
def _decision_function(self, x: torch.Tensor) -> torch.Tensor:
"""
Compute decision function of the model.
Parameters
----------
x : torch.Tensor
Input samples.
Returns
-------
torch.Tensor
Output scores from the model.
"""
x = x.to(device=self._get_device())
return self._model(x)
[docs]
def gradient(self, x: torch.Tensor, y: int) -> torch.Tensor:
"""
Compute batch gradients of class y w.r.t. x.
Parameters
----------
x : torch.Tensor
Input samples.
y : int
Class label.
Returns
-------
torch.Tensor
Gradient of class y w.r.t. input x.
"""
x = x.clone().requires_grad_()
if x.grad is not None:
x.grad.zero_()
output = self.decision_function(x)
output = output[:, y].sum()
output.backward()
return x.grad
[docs]
def train(self, dataloader: DataLoader) -> torch.nn.Module:
"""
Train the model with given dataloader, if trainer is set.
Parameters
----------
dataloader : DataLoader
Training PyTorch dataloader to use for training.
Returns
-------
torch.nn.Module
Trained PyTorch model.
Raises
------
ValueError
Raises ValueError if the trainer is not set.
"""
if self._trainer is None:
msg = "Cannot train without a trainer."
raise ValueError(msg)
return self._trainer.train(self._model, dataloader)