Source code for secmlt.adv.evasion.base_evasion_attack

"""Base classes for implementing attacks and wrapping backends."""

import importlib.util
from abc import abstractmethod
from typing import Literal

import torch
from secmlt.adv.backends import Backends
from secmlt.models.base_model import BaseModel
from torch.utils.data import DataLoader, TensorDataset

# lazy evaluation to avoid circular imports
TRACKER_TYPE = "secmlt.trackers.tracker.Tracker"


[docs] class BaseEvasionAttackCreator: """Generic creator for attacks."""
[docs] @classmethod def get_implementation(cls, backend: str) -> "BaseEvasionAttack": """ Get the implementation of the attack with the given backend. Parameters ---------- backend : str The backend for the attack. See secmlt.adv.backends for available backends. Returns ------- BaseEvasionAttack Attack implementation. """ implementations = { Backends.FOOLBOX: cls.get_foolbox_implementation, Backends.ADVLIB: cls.get_advlib_implementation, Backends.NATIVE: cls._get_native_implementation, } cls.check_backend_available(backend) return implementations[backend]()
[docs] @classmethod def check_backend_available(cls, backend: str) -> bool: """ Check if a given backend is available for the attack. Parameters ---------- backend : str Backend string. Returns ------- bool True if the given backend is implemented. Raises ------ NotImplementedError Raises NotImplementedError if the requested backend is not in the list of the possible backends (check secmlt.adv.backends). """ if backend in cls.get_backends(): return True msg = "Unsupported or not-implemented backend." raise NotImplementedError(msg)
[docs] @classmethod def get_foolbox_implementation(cls) -> "BaseEvasionAttack": """ Get the Foolbox implementation of the attack. Returns ------- BaseEvasionAttack Foolbox implementation of the attack. Raises ------ ImportError Raises ImportError if Foolbox extra is not installed. """ if importlib.util.find_spec("foolbox", None) is not None: return cls._get_foolbox_implementation() msg = "Foolbox extra not installed." raise ImportError(msg)
@staticmethod def _get_foolbox_implementation() -> "BaseEvasionAttack": msg = "Foolbox implementation not available." raise NotImplementedError(msg)
[docs] @classmethod def get_advlib_implementation(cls) -> "BaseEvasionAttack": """ Get the Adversarial Library implementation of the attack. Returns ------- BaseEvasionAttack Adversarial Library implementation of the attack. Raises ------ ImportError Raises ImportError if Adversarial Library extra is not installed. """ if importlib.util.find_spec("adv_lib", None) is not None: return cls._get_advlib_implementation() msg = "Adversarial Library extra not installed." raise ImportError(msg)
@staticmethod def _get_advlib_implementation() -> "BaseEvasionAttack": msg = "Adversarial Library implementation not available." raise NotImplementedError(msg) @staticmethod def _get_native_implementation() -> "BaseEvasionAttack": msg = "Native implementation not available." raise NotImplementedError(msg)
[docs] @staticmethod @abstractmethod def get_backends() -> set[str]: """ Get the available backends for the given attack. Returns ------- set[str] Set of implemented backends available for the attack. Raises ------ NotImplementedError Raises NotImplementedError if not implemented in the inherited class. """ msg = "Backends should be specified in inherited class." raise NotImplementedError(msg)
[docs] class BaseEvasionAttack: """Base class for evasion attacks."""
[docs] def __call__(self, model: BaseModel, data_loader: DataLoader) -> DataLoader: """ Compute the attack against the model, using the input data. Parameters ---------- model : BaseModel Model to test. data_loader : DataLoader Test dataloader. Returns ------- DataLoader Dataloader with adversarial examples and original labels. """ adversarials = [] original_labels = [] for samples, labels in data_loader: x_adv, _ = self._run(model, samples, labels) adversarials.append(x_adv) original_labels.append(labels) adversarials = torch.vstack(adversarials) original_labels = torch.hstack(original_labels) adversarial_dataset = TensorDataset(adversarials, original_labels) return DataLoader( adversarial_dataset, batch_size=data_loader.batch_size, )
@property def trackers(self) -> list[TRACKER_TYPE] | None: """ Get the trackers set for this attack. Returns ------- list[TRACKER_TYPE] | None Trackers set for the attack, if any. """ return self._trackers @trackers.setter def trackers(self, trackers: list[TRACKER_TYPE] | None = None) -> None: if self._trackers_allowed(): if trackers is not None and not isinstance(trackers, list): trackers = [trackers] self._trackers = trackers elif trackers is not None: msg = "Trackers not implemented for this attack." raise NotImplementedError(msg) @classmethod @abstractmethod def _trackers_allowed(cls) -> Literal[False]: return False
[docs] @classmethod def check_perturbation_model_available(cls, perturbation_model: str) -> bool: """ Check whether the given perturbation model is available for the attack. Parameters ---------- perturbation_model : str A perturbation model. Returns ------- bool True if the attack implements the given perturbation model. Raises ------ NotImplementedError Raises NotImplementedError if not implemented in the inherited class. """ if perturbation_model in cls.get_perturbation_models(): return True msg = "Unsupported or not-implemented perturbation model." raise NotImplementedError(msg)
[docs] @staticmethod @abstractmethod def get_perturbation_models() -> set[str]: """ Check the perturbation models implemented for the given attack. Returns ------- set[str] The set of perturbation models for which the attack is implemented. Raises ------ NotImplementedError Raises NotImplementedError if not implemented in the inherited class. """ msg = "Perturbation models should be specified in inherited class." raise NotImplementedError(msg)
@abstractmethod def _run( self, model: BaseModel, samples: torch.Tensor, labels: torch.Tensor, ) -> torch.Tensor: ...