"""Random pertubations in Lp balls."""
from abc import ABC, abstractmethod
import torch
from secmlt.adv.evasion.perturbation_models import LpPerturbationModels
from secmlt.data.lp_uniform_sampling import LpUniformSampling
from secmlt.optimization.constraints import (
L0Constraint,
L1Constraint,
L2Constraint,
LInfConstraint,
LpConstraint,
)
[docs]
class RandomPerturbBase(ABC):
"""Class implementing the random perturbations in Lp balls."""
[docs]
def __init__(self, epsilon: float) -> None:
"""
Create random perturbation object.
Parameters
----------
epsilon : float
Constraint radius.
"""
self.epsilon = epsilon
[docs]
def __call__(self, x: torch.Tensor) -> torch.Tensor:
"""
Get the perturbations for the given samples.
Parameters
----------
x : torch.Tensor
Input samples to perturb.
Returns
-------
torch.Tensor
Perturbations (to apply) to the given samples.
"""
perturbations = self.get_perturb(x)
return self._constraint(
radius=self.epsilon,
center=torch.zeros_like(perturbations),
).project(perturbations)
[docs]
@abstractmethod
def get_perturb(self, x: torch.Tensor) -> torch.Tensor:
"""
Generate random perturbation for the Lp norm.
Parameters
----------
x : torch.Tensor
Input samples to perturb.
"""
...
@abstractmethod
def _constraint(self) -> LpConstraint: ...
[docs]
class RandomPerturbLinf(RandomPerturbBase):
"""Random Perturbations for Linf norm."""
[docs]
def get_perturb(self, x: torch.Tensor) -> torch.Tensor:
"""
Generate random perturbation for the Linf norm.
Parameters
----------
x : torch.Tensor
Input samples to perturb.
Returns
-------
torch.Tensor
Perturbed samples.
"""
x_random = LpUniformSampling(p=LpPerturbationModels.LINF).sample_like(x)
return x + (x_random * self.epsilon)
@property
def _constraint(self) -> type[LInfConstraint]:
return LInfConstraint
[docs]
class RandomPerturbL1(RandomPerturbBase):
"""Random Perturbations for L1 norm."""
[docs]
def get_perturb(self, x: torch.Tensor) -> torch.Tensor:
"""
Generate random perturbation for the L1 norm.
Parameters
----------
x : torch.Tensor
Input samples to perturb.
Returns
-------
torch.Tensor
Perturbed samples.
"""
x_random = LpUniformSampling(p=LpPerturbationModels.L1).sample_like(x)
return x + (x_random * self.epsilon)
@property
def _constraint(self) -> type[L1Constraint]:
return L1Constraint
[docs]
class RandomPerturbL2(RandomPerturbBase):
"""Random Perturbations for L2 norm."""
[docs]
def get_perturb(self, x: torch.Tensor) -> torch.Tensor:
"""
Generate random perturbation for the L2 norm.
Parameters
----------
x : torch.Tensor
Input samples to perturb.
Returns
-------
torch.Tensor
Perturbed samples.
"""
x_random = LpUniformSampling(p=LpPerturbationModels.L2).sample_like(x)
return x + (x_random * self.epsilon)
@property
def _constraint(self) -> type[L2Constraint]:
return L2Constraint
[docs]
class RandomPerturbL0(RandomPerturbBase):
"""Random Perturbations for L0 norm."""
[docs]
def get_perturb(self, x: torch.Tensor) -> torch.Tensor:
"""
Generate random perturbation for the L0 norm.
Parameters
----------
x : torch.Tensor
Input samples to perturb.
Returns
-------
torch.Tensor
Perturbed samples.
"""
x_random = LpUniformSampling(p=LpPerturbationModels.L0).sample_like(x)
return x + (x_random * self.epsilon)
@property
def _constraint(self) -> type[L0Constraint]:
return L0Constraint
[docs]
class RandomPerturb:
"""Random perturbation creator."""
[docs]
def __new__(cls, p: str, epsilon: float) -> RandomPerturbBase:
"""
Creator for random perturbation in Lp norms.
Parameters
----------
p : str
p-norm used for the random perturbation shape.
epsilon : float
Radius of the random perturbation constraint.
Returns
-------
RandomPerturbBase
Random perturbation object.
Raises
------
ValueError
Raises ValueError if the norm is not in 0, 1, 2, inf.
"""
random_inits = {
LpPerturbationModels.L0: RandomPerturbL0,
LpPerturbationModels.L1: RandomPerturbL1,
LpPerturbationModels.L2: RandomPerturbL2,
LpPerturbationModels.LINF: RandomPerturbLinf,
}
selected = random_inits.get(p)
if selected is not None:
return selected(epsilon=epsilon)
msg = "Perturbation model not available."
raise ValueError(msg)