Source code for secmlt.data.lp_uniform_sampling
"""Implementation of Lp uniform sampling."""
import torch
from secmlt.adv.evasion.perturbation_models import LpPerturbationModels
from secmlt.data.distributions import GeneralizedNormal
from torch.distributions.exponential import Exponential
[docs]
class LpUniformSampling:
"""
Uniform sampling from the unit Lp ball.
This class provides a method for sampling uniformly from the
unit Lp ball, where Lp is a norm defined by a parameter `p`.
The class supports sampling from the L0, L2, and Linf norms.
The sampling method is based on the following reference:
https://arxiv.org/abs/math/0503650
Attributes
----------
p : str
The norm to use for sampling. Must be one of 'l0', 'l1', 'l2', 'linf'.
"""
[docs]
def __init__(self, p: str = LpPerturbationModels.L2) -> None:
"""
Initialize the LpUniformSampling object.
Parameters
----------
p : str, optional
The norm to use for sampling. Must be one
of 'L0', 'L2', or 'Linf'. Default is 'L2'.
"""
self.p = p
[docs]
def sample_like(self, x: torch.Tensor) -> torch.Tensor:
"""
Sample from the unit Lp ball with the same shape as a given tensor.
Parameters
----------
x : torch.Tensor
The input tensor whose shape is used to determine the shape of the samples.
Returns
-------
torch.Tensor
A tensor of samples from the unit Lp ball, with the
same shape as the input tensor `x`.
"""
num_samples, dim = x.flatten(1).shape
return self.sample(num_samples, dim).reshape(x.shape)
[docs]
def sample(self, num_samples: int = 1, dim: int = 2) -> torch.Tensor:
"""
Sample uniformly from the unit Lp ball.
This method generates a specified number of samples
from the unit Lp ball, where Lp is a norm defined by the `p` parameter.
The samples are generated using the algorithm
described in the class documentation.
Parameters
----------
num_samples : int
The number of samples to generate.
dim : int
The dimension of the samples.
Returns
-------
torch.Tensor
A tensor of samples from the unit Lp ball, with shape `(num_samples, dim)`.
"""
shape = torch.Size((num_samples, dim))
_p = LpPerturbationModels.get_p(self.p)
if self.p == LpPerturbationModels.LINF:
ball = 2 * torch.rand(size=shape) - 1
elif self.p == LpPerturbationModels.L0:
ball = torch.rand(size=shape).sign()
else:
g = GeneralizedNormal().sample(shape)
e = Exponential(rate=1).sample(sample_shape=(num_samples,))
d = ((torch.abs(g) ** _p).sum(-1) + e) ** (1 / _p)
ball = g / d.unsqueeze(-1)
return ball