Source code for secmlt.adv.evasion.modular_attack

"""Implementation of modular iterative attacks with customizable components."""

from functools import partial
from typing import Literal, Union

import torch.nn
from secmlt.adv.evasion.base_evasion_attack import BaseEvasionAttack
from secmlt.adv.evasion.perturbation_models import LpPerturbationModels
from secmlt.manipulations.manipulation import Manipulation
from secmlt.models.base_model import BaseModel
from secmlt.optimization.constraints import Constraint
from secmlt.optimization.gradient_processing import GradientProcessing
from secmlt.optimization.initializer import Initializer
from secmlt.optimization.optimizer_factory import OptimizerFactory
from secmlt.trackers.trackers import Tracker
from secmlt.utils.tensor_utils import atleast_kd
from torch.nn import CrossEntropyLoss
from torch.optim import Optimizer

CE_LOSS = "ce_loss"
LOGIT_LOSS = "logit_loss"

LOSS_FUNCTIONS = {
    CE_LOSS: CrossEntropyLoss,
}


[docs] class ModularEvasionAttackFixedEps(BaseEvasionAttack): """Modular evasion attack for fixed-epsilon attacks."""
[docs] def __init__( self, y_target: int | None, num_steps: int, step_size: float, loss_function: Union[str, torch.nn.Module], optimizer_cls: str | partial[Optimizer], manipulation_function: Manipulation, initializer: Initializer, gradient_processing: GradientProcessing, trackers: list[Tracker] | Tracker | None = None, ) -> None: """ Create modular evasion attack. Parameters ---------- y_target : int | None Target label for the attack, None for untargeted. num_steps : int Number of iterations for the attack. step_size : float Attack step size. loss_function : str | torch.nn.Module Loss function to minimize. optimizer_cls : str | partial[Optimizer] Algorithm for solving the attack optimization problem. manipulation_function : Manipulation Manipulation function to perturb the inputs. initializer : Initializer Initialization for the perturbation delta. gradient_processing : GradientProcessing Gradient transformation function. trackers : list[Tracker] | Tracker | None, optional Trackers for logging, by default None. Raises ------ ValueError Raises ValueError if the loss is not in allowed list of loss functions. """ self.y_target = y_target self.num_steps = num_steps self.step_size = step_size self.trackers = trackers if isinstance(loss_function, str): if loss_function in LOSS_FUNCTIONS: self.loss_function = LOSS_FUNCTIONS[loss_function](reduction="none") else: msg = ( f"Loss function not found. Use one among {LOSS_FUNCTIONS.values()}" ) raise ValueError(msg) else: self.loss_function = loss_function if isinstance(optimizer_cls, str): optimizer_cls = OptimizerFactory.create_from_name( optimizer_cls, lr=step_size, ) self.optimizer_cls = optimizer_cls self._manipulation_function = manipulation_function self.initializer = initializer self.gradient_processing = gradient_processing super().__init__()
@property def manipulation_function(self) -> Manipulation: """ Get the manipulation function for the attack. Returns ------- Manipulation The manipulation function used in the attack. """ return self._manipulation_function @manipulation_function.setter def manipulation_function(self, manipulation_function: Manipulation) -> None: """ Set the manipulation function for the attack. Parameters ---------- manipulation_function : Manipulation The manipulation function to be used in the attack. """ self._manipulation_function = manipulation_function
[docs] @classmethod def get_perturbation_models(cls) -> set[str]: """ Check if a given perturbation model is implemented. Returns ------- set[str] Set of perturbation models available for this attack. """ return { LpPerturbationModels.L1, LpPerturbationModels.L2, LpPerturbationModels.LINF, }
@classmethod def _trackers_allowed(cls) -> Literal[True]: return True def _init_perturbation_constraints(self) -> list[Constraint]: msg = "Must be implemented accordingly" raise NotImplementedError(msg) def _create_optimizer(self, delta: torch.Tensor, **kwargs) -> Optimizer: return self.optimizer_cls([delta], lr=self.step_size, **kwargs)
[docs] def forward_loss( self, model: BaseModel, x: torch.Tensor, target: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: """ Compute the forward for the loss function. Parameters ---------- model : BaseModel Model used by the attack run. x : torch.Tensor Input sample. target : torch.Tensor Target for computing the loss. Returns ------- tuple[torch.Tensor, torch.Tensor] Output scores and loss. """ scores = model.decision_function(x) target = target.to(scores.device) losses = self.loss_function(scores, target) return scores, losses
def _run( self, model: BaseModel, samples: torch.Tensor, labels: torch.Tensor, init_deltas: torch.Tensor = None, optim_kwargs: dict | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: if optim_kwargs is None: optim_kwargs = {} multiplier = 1 if self.y_target is not None else -1 target = ( torch.zeros_like(labels) + self.y_target if self.y_target is not None else labels ).type(labels.dtype) if init_deltas is not None: delta = init_deltas.data elif isinstance(self.initializer, BaseEvasionAttack): _, delta = self.initializer._run(model, samples, target) else: delta = self.initializer(samples.data) delta.requires_grad = True optimizer = self._create_optimizer(delta, **optim_kwargs) x_adv, delta = self.manipulation_function(samples, delta) x_adv.data, delta.data = self.manipulation_function(samples.data, delta.data) best_losses = torch.zeros(samples.shape[0]).fill_(torch.inf) best_delta = torch.zeros_like(samples) for i in range(self.num_steps): scores, losses = self.forward_loss(model=model, x=x_adv, target=target) losses *= multiplier loss = losses.sum() optimizer.zero_grad() loss.backward() grad_before_processing = delta.grad.data delta.grad.data = self.gradient_processing(delta.grad.data) optimizer.step() x_adv.data, delta.data = self.manipulation_function( samples.data, delta.data, ) if self.trackers is not None: for tracker in self.trackers: tracker.track( i, losses.detach().cpu().data, scores.detach().cpu().data, x_adv.detach().cpu().data, delta.detach().cpu().data, grad_before_processing.detach().cpu().data, ) # keep perturbation with highest loss best_delta.data = torch.where( atleast_kd(losses.detach().cpu() < best_losses, len(samples.shape)), delta.data, best_delta.data, ) best_losses.data = torch.where( losses.detach().cpu() < best_losses, losses.detach().cpu(), best_losses.data, ) x_adv, _ = self.manipulation_function(samples.data, best_delta.data) return x_adv, best_delta