Source code for secmlt.adv.evasion.foolbox_attacks.foolbox_base

"""Generic wrapper for Foolbox evasion attacks."""

from typing import Literal

import torch
from foolbox.attacks.base import Attack
from foolbox.criteria import Misclassification, TargetedMisclassification
from foolbox.models.pytorch import PyTorchModel
from secmlt.adv.evasion.base_evasion_attack import TRACKER_TYPE, BaseEvasionAttack
from secmlt.models.base_model import BaseModel
from secmlt.models.pytorch.base_pytorch_nn import BasePytorchClassifier


[docs] class BaseFoolboxEvasionAttack(BaseEvasionAttack): """Generic wrapper for Foolbox Evasion attacks."""
[docs] def __init__( self, foolbox_attack: type[Attack], epsilon: float = torch.inf, y_target: int | None = None, lb: float = 0.0, ub: float = 1.0, trackers: type[TRACKER_TYPE] | None = None, ) -> None: """ Wrap Foolbox attacks. Parameters ---------- foolbox_attack : Type[Attack] Foolbox attack class to wrap. epsilon : float, optional Perturbation constraint, by default torch.inf. y_target : int | None, optional Target label for the attack, None if untargeted, by default None. lb : float, optional Lower bound of the input space, by default 0.0. ub : float, optional Upper bound of the input space, by default 1.0. trackers : type[TRACKER_TYPE] | None, optional Trackers for the attack (unallowed in Foolbox), by default None. """ self.foolbox_attack = foolbox_attack self.lb = lb self.ub = ub self.epsilon = epsilon self.y_target = y_target self.trackers = trackers super().__init__()
@classmethod def _trackers_allowed(cls) -> Literal[False]: return False def _run( self, model: BaseModel, samples: torch.Tensor, labels: torch.Tensor, ) -> torch.Tensor: if not isinstance(model, BasePytorchClassifier): msg = "Model type not supported." raise NotImplementedError(msg) device = model._get_device() samples = samples.to(device) labels = labels.to(device) foolbox_model = PyTorchModel(model.model, (self.lb, self.ub), device=device) if self.y_target is None: criterion = Misclassification(labels) else: target = ( torch.zeros_like(labels) + self.y_target if self.y_target is not None else labels ).type(labels.dtype) target = target.to(device) criterion = TargetedMisclassification(target) _, advx, _ = self.foolbox_attack( model=foolbox_model, inputs=samples, criterion=criterion, epsilons=self.epsilon, ) # foolbox deals only with additive perturbations delta = advx - samples return advx, delta