Source code for secmlt.adv.poisoning.backdoor

"""Simple backdoor attack in PyTorch."""

from typing import Union

import torch
from secmlt.adv.poisoning.base_data_poisoning import PoisoningDatasetPyTorch
from torch.utils.data import Dataset


[docs] class BackdoorDatasetPyTorch(PoisoningDatasetPyTorch): """Dataset class for adding triggers for backdoor attacks."""
[docs] def __init__( self, dataset: Dataset, data_manipulation_func: callable, trigger_label: int = 0, portion: float | None = None, poisoned_indexes: Union[list[int], torch.Tensor] = None, ) -> None: """ Create the backdoored dataset. Parameters ---------- dataset : torch.utils.data.Dataset PyTorch dataset. data_manipulation_func: callable Function to manipulate the data and add the backdoor. trigger_label : int, optional Label to associate with the backdoored data (default 0). portion : float, optional Percentage of samples on which the backdoor will be injected (default 0.1). poisoned_indexes: list[int] | torch.Tensor Specific indexes of samples to perturb. Alternative to portion. """ super().__init__( dataset=dataset, data_manipulation_func=data_manipulation_func, label_manipulation_func=lambda _: trigger_label, portion=portion, poisoned_indexes=poisoned_indexes, )