Source code for secmlt.manipulations.manipulation

"""Manipulations for perturbing input samples."""

from abc import ABC, abstractmethod

import torch
from secmlt.optimization.constraints import Constraint


[docs] class Manipulation(ABC): """Abstract class for manipulations."""
[docs] def __init__( self, domain_constraints: list[Constraint], perturbation_constraints: list[Constraint], ) -> None: """ Create manipulation object. Parameters ---------- domain_constraints : list[Constraint] Constraints for the domain bounds (x_adv). perturbation_constraints : list[Constraint] Constraints for the perturbation (delta). """ self._domain_constraints = domain_constraints self._perturbation_constraints = perturbation_constraints
@property def domain_constraints(self) -> list[Constraint]: """ Get the domain constraints for the manipulation. Returns ------- list[Constraint] List of domain constraints for the manipulation. """ return self._domain_constraints @domain_constraints.setter def domain_constraints(self, domain_constraints: list[Constraint]) -> None: self._domain_constraints = domain_constraints @property def perturbation_constraints(self) -> list[Constraint]: """ Get the perturbation constraints for the manipulation. Returns ------- list[Constraint] List of perturbation constraints for the manipulation. """ return self._perturbation_constraints @perturbation_constraints.setter def perturbation_constraints( self, perturbation_constraints: list[Constraint] ) -> None: self._perturbation_constraints = perturbation_constraints def _apply_domain_constraints(self, x: torch.Tensor) -> torch.Tensor: for constraint in self.domain_constraints: x = constraint(x) return x def _apply_perturbation_constraints(self, delta: torch.Tensor) -> torch.Tensor: for constraint in self.perturbation_constraints: delta = constraint(delta) return delta
[docs] @abstractmethod def _apply_manipulation( self, x: torch.Tensor, delta: torch.Tensor, ) -> torch.Tensor: """ Apply the manipulation. Parameters ---------- x : torch.Tensor Input samples. delta : torch.Tensor Manipulation to apply. Returns ------- torch.Tensor Perturbed samples. """ ...
[docs] def __call__( self, x: torch.Tensor, delta: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """ Apply the manipulation to the input data. Parameters ---------- x : torch.Tensor Input data. delta : torch.Tensor Perturbation to apply. Returns ------- tuple[torch.Tensor, torch.Tensor] Perturbed data and perturbation after the application of constraints. """ delta.data = self._apply_perturbation_constraints(delta.data) x_adv, delta = self._apply_manipulation(x, delta) x_adv.data = self._apply_domain_constraints(x_adv.data) return x_adv, delta
[docs] class AdditiveManipulation(Manipulation): """Additive manipulation for input data.""" def _apply_manipulation( self, x: torch.Tensor, delta: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: return x + delta, delta