"""Manipulations for perturbing input samples."""
from abc import ABC, abstractmethod
import torch
from secmlt.optimization.constraints import Constraint
[docs]
class Manipulation(ABC):
"""Abstract class for manipulations."""
[docs]
def __init__(
self,
domain_constraints: list[Constraint],
perturbation_constraints: list[Constraint],
) -> None:
"""
Create manipulation object.
Parameters
----------
domain_constraints : list[Constraint]
Constraints for the domain bounds (x_adv).
perturbation_constraints : list[Constraint]
Constraints for the perturbation (delta).
"""
self._domain_constraints = domain_constraints
self._perturbation_constraints = perturbation_constraints
@property
def domain_constraints(self) -> list[Constraint]:
"""
Get the domain constraints for the manipulation.
Returns
-------
list[Constraint]
List of domain constraints for the manipulation.
"""
return self._domain_constraints
@domain_constraints.setter
def domain_constraints(self, domain_constraints: list[Constraint]) -> None:
self._domain_constraints = domain_constraints
@property
def perturbation_constraints(self) -> list[Constraint]:
"""
Get the perturbation constraints for the manipulation.
Returns
-------
list[Constraint]
List of perturbation constraints for the manipulation.
"""
return self._perturbation_constraints
@perturbation_constraints.setter
def perturbation_constraints(
self, perturbation_constraints: list[Constraint]
) -> None:
self._perturbation_constraints = perturbation_constraints
def _apply_domain_constraints(self, x: torch.Tensor) -> torch.Tensor:
for constraint in self.domain_constraints:
x = constraint(x)
return x
def _apply_perturbation_constraints(self, delta: torch.Tensor) -> torch.Tensor:
for constraint in self.perturbation_constraints:
delta = constraint(delta)
return delta
[docs]
@abstractmethod
def _apply_manipulation(
self,
x: torch.Tensor,
delta: torch.Tensor,
) -> torch.Tensor:
"""
Apply the manipulation.
Parameters
----------
x : torch.Tensor
Input samples.
delta : torch.Tensor
Manipulation to apply.
Returns
-------
torch.Tensor
Perturbed samples.
"""
...
[docs]
def __call__(
self,
x: torch.Tensor,
delta: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Apply the manipulation to the input data.
Parameters
----------
x : torch.Tensor
Input data.
delta : torch.Tensor
Perturbation to apply.
Returns
-------
tuple[torch.Tensor, torch.Tensor]
Perturbed data and perturbation after the
application of constraints.
"""
delta.data = self._apply_perturbation_constraints(delta.data)
x_adv, delta = self._apply_manipulation(x, delta)
x_adv.data = self._apply_domain_constraints(x_adv.data)
return x_adv, delta
[docs]
class AdditiveManipulation(Manipulation):
"""Additive manipulation for input data."""
def _apply_manipulation(
self,
x: torch.Tensor,
delta: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
return x + delta, delta