Source code for secmlt.optimization.random_perturb

"""Random pertubations in Lp balls."""

from abc import ABC, abstractmethod

import torch
from secmlt.adv.evasion.perturbation_models import LpPerturbationModels
from secmlt.data.lp_uniform_sampling import LpUniformSampling
from secmlt.optimization.constraints import (
    L0Constraint,
    L1Constraint,
    L2Constraint,
    LInfConstraint,
    LpConstraint,
)


[docs] class RandomPerturbBase(ABC): """Class implementing the random perturbations in Lp balls."""
[docs] def __init__(self, epsilon: float) -> None: """ Create random perturbation object. Parameters ---------- epsilon : float Constraint radius. """ self.epsilon = epsilon
[docs] def __call__(self, x: torch.Tensor) -> torch.Tensor: """ Get the perturbations for the given samples. Parameters ---------- x : torch.Tensor Input samples to perturb. Returns ------- torch.Tensor Perturbations (to apply) to the given samples. """ perturbations = self.get_perturb(x) return self._constraint( radius=self.epsilon, center=torch.zeros_like(perturbations), ).project(perturbations)
[docs] @abstractmethod def get_perturb(self, x: torch.Tensor) -> torch.Tensor: """ Generate random perturbation for the Lp norm. Parameters ---------- x : torch.Tensor Input samples to perturb. """ ...
@abstractmethod def _constraint(self) -> LpConstraint: ...
[docs] class RandomPerturbLinf(RandomPerturbBase): """Random Perturbations for Linf norm."""
[docs] def get_perturb(self, x: torch.Tensor) -> torch.Tensor: """ Generate random perturbation for the Linf norm. Parameters ---------- x : torch.Tensor Input samples to perturb. Returns ------- torch.Tensor Perturbed samples. """ x_random = LpUniformSampling(p=LpPerturbationModels.LINF).sample_like(x) return x + (x_random * self.epsilon)
@property def _constraint(self) -> type[LInfConstraint]: return LInfConstraint
[docs] class RandomPerturbL1(RandomPerturbBase): """Random Perturbations for L1 norm."""
[docs] def get_perturb(self, x: torch.Tensor) -> torch.Tensor: """ Generate random perturbation for the L1 norm. Parameters ---------- x : torch.Tensor Input samples to perturb. Returns ------- torch.Tensor Perturbed samples. """ x_random = LpUniformSampling(p=LpPerturbationModels.L1).sample_like(x) return x + (x_random * self.epsilon)
@property def _constraint(self) -> type[L1Constraint]: return L1Constraint
[docs] class RandomPerturbL2(RandomPerturbBase): """Random Perturbations for L2 norm."""
[docs] def get_perturb(self, x: torch.Tensor) -> torch.Tensor: """ Generate random perturbation for the L2 norm. Parameters ---------- x : torch.Tensor Input samples to perturb. Returns ------- torch.Tensor Perturbed samples. """ x_random = LpUniformSampling(p=LpPerturbationModels.L2).sample_like(x) return x + (x_random * self.epsilon)
@property def _constraint(self) -> type[L2Constraint]: return L2Constraint
[docs] class RandomPerturbL0(RandomPerturbBase): """Random Perturbations for L0 norm."""
[docs] def get_perturb(self, x: torch.Tensor) -> torch.Tensor: """ Generate random perturbation for the L0 norm. Parameters ---------- x : torch.Tensor Input samples to perturb. Returns ------- torch.Tensor Perturbed samples. """ x_random = LpUniformSampling(p=LpPerturbationModels.L0).sample_like(x) return x + (x_random * self.epsilon)
@property def _constraint(self) -> type[L0Constraint]: return L0Constraint
[docs] class RandomPerturb: """Random perturbation creator."""
[docs] def __new__(cls, p: str, epsilon: float) -> RandomPerturbBase: """ Creator for random perturbation in Lp norms. Parameters ---------- p : str p-norm used for the random perturbation shape. epsilon : float Radius of the random perturbation constraint. Returns ------- RandomPerturbBase Random perturbation object. Raises ------ ValueError Raises ValueError if the norm is not in 0, 1, 2, inf. """ random_inits = { LpPerturbationModels.L0: RandomPerturbL0, LpPerturbationModels.L1: RandomPerturbL1, LpPerturbationModels.L2: RandomPerturbL2, LpPerturbationModels.LINF: RandomPerturbLinf, } selected = random_inits.get(p) if selected is not None: return selected(epsilon=epsilon) msg = "Perturbation model not available." raise ValueError(msg)