Source code for greatx.attack.backdoor.lgc_backdoor

import warnings
from typing import Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.utils import add_self_loops, degree, remove_self_loops
from tqdm.auto import tqdm

from greatx.attack.backdoor.backdoor_attacker import BackdoorAttacker
from greatx.functional import spmm
from greatx.nn.models.surrogate import Surrogate


[docs]class LGCBackdoor(BackdoorAttacker): r"""Implementation of `LGCB` attack from the: `"Neighboring Backdoor Attacks on Graph Convolutional Network" <https://arxiv.org/abs/2201.06202>`_ paper (arXiv'22) Example ------- .. code-block:: python from greatx.dataset import GraphDataset import torch_geometric.transforms as T dataset = GraphDataset(root='.', name='Cora', transform=T.LargestConnectedComponents()) data = dataset[0] surrogate_model = ... # train your surrogate model from greatx.attack.backdoor import LGCBackdoor attacker.setup_surrogate(surrogate_model) attacker = LGCBackdoor(data) attacker.reset() attacker.attack(num_budgets=50, target_class=0) attacker.data() # get attacked graph attacker.trigger() # get trigger node Note ---- * Please remember to call :meth:`reset` before each attack. """
[docs] @torch.no_grad() def setup_surrogate(self, surrogate: nn.Module) -> "LGCBackdoor": W = None for para in surrogate.parameters(): if para.ndim == 1: warnings.warn("The surrogate model has `bias` term, " "which is ignored and the model itself " f"may not be a perfect choice for {self.name}.") continue if W is None: W = para.detach() else: W = para.detach() @ W assert W is not None self.W = W.t() self.num_classes = self.W.size(-1) return self
[docs] def attack(self, num_budgets: Union[int, float], target_class: int, disable: bool = False) -> "LGCBackdoor": super().attack(num_budgets, target_class) assert target_class < self.num_classes feat_perturbations = self.get_feat_perturbations( self.W, target_class, self.num_budgets) trigger = self.feat.new_zeros(self.num_feats) trigger[feat_perturbations] = 1. self._trigger = trigger return self
[docs] @staticmethod def get_feat_perturbations(W: Tensor, target_class: int, num_budgets: int) -> Tensor: D = W - W[:, target_class].view(-1, 1) D = D.sum(1) _, indices = torch.topk(D, k=num_budgets, largest=False) return indices
[docs]class FGBackdoor(BackdoorAttacker, Surrogate): r"""Implementation of `GB-FGSM` attack from the: `"Neighboring Backdoor Attacks on Graph Convolutional Network" <https://arxiv.org/abs/2201.06202>`_ paper (arXiv'22) Example ------- .. code-block:: python from greatx.dataset import GraphDataset import torch_geometric.transforms as T dataset = GraphDataset(root='.', name='Cora', transform=T.LargestConnectedComponents()) data = dataset[0] surrogate_model = ... # train your surrogate model from greatx.attack.backdoor import FGBackdoor attacker.setup_surrogate(surrogate_model) attacker = FGBackdoor(data) attacker.reset() attacker.attack(num_budgets=50, target_class=0) attacker.data() # get attacked graph attacker.trigger() # get trigger node Note ---- * Please remember to call :meth:`reset` before each attack. """
[docs] def setup_surrogate(self, surrogate: nn.Module, *, tau: float = 1.0) -> "FGBackdoor": Surrogate.setup_surrogate(self, surrogate=surrogate, tau=tau, freeze=True) W = [] for para in self.surrogate.parameters(): if para.ndim == 1: warnings.warn("The surrogate model has `bias` term, " "which is ignored and the model itself " f"may not be a perfect choice for {self.name}.") else: W.append(para.detach().t()) assert len(W) == 2 self.w1, self.w2 = W self.num_classes = W[-1].size(-1) return self
[docs] def attack(self, num_budgets: Union[int, float], target_class: int, disable: bool = False) -> "FGBackdoor": super().attack(num_budgets, target_class) assert target_class < self.num_classes N = self.num_nodes feat = self.feat trigger = feat.new_zeros(self.num_feats).requires_grad_() target_labels = torch.LongTensor([target_class ]).to(self.device).repeat(N) (edge_index, edge_weight_with_trigger, edge_index_with_self_loop, edge_weight, trigger_edge_index, trigger_edge_weight, augmented_edge_index, augmented_edge_weight) = get_backdoor_edges(self.edge_index, N) for _ in tqdm(range(self.num_budgets), desc="Updating trigger using gradients...", disable=disable): aug_feat = torch.cat([feat, trigger.repeat(N, 1)], dim=0) feat1 = aug_feat @ self.w1 h1 = spmm(feat1, edge_index_with_self_loop, edge_weight) h1_aug = spmm(feat1, augmented_edge_index, augmented_edge_weight).relu() h = spmm(h1_aug @ self.w2, trigger_edge_index, trigger_edge_weight) h += spmm(h1 @ self.w2, edge_index, edge_weight_with_trigger) h = h[:N] / self.tau loss = F.cross_entropy(h, target_labels) gradients = torch.autograd.grad(-loss, trigger)[0] * (1. - trigger) trigger.data[gradients.argmax()].fill_(1.0) self._trigger = trigger.detach() return self
def get_backdoor_edges(edge_index: Tensor, N: int) -> Tuple: device = edge_index.device influence_nodes = torch.arange(N, device=device) N_all = N + influence_nodes.size(0) trigger_nodes = torch.arange(N, N_all, device=device) # 1. edge index of original graph (without selfloops) edge_index, _ = remove_self_loops(edge_index) # 2. edge index of original graph (with selfloops) edge_index_with_self_loop, _ = add_self_loops(edge_index) # 3. edge index of trigger nodes conneted to victim nodes with selfloops trigger_edge_index = torch.stack([trigger_nodes, influence_nodes], dim=0) diag_index = torch.arange(N_all, device=device).repeat(2, 1) trigger_edge_index = torch.cat( [trigger_edge_index, trigger_edge_index[[1, 0]], diag_index], dim=1) # 4. all edge index with trigger nodes augmented_edge_index = torch.cat([edge_index, trigger_edge_index], dim=1) d = degree(edge_index[0], num_nodes=N, dtype=torch.float) d_augmented = d.clone() d_augmented[influence_nodes] += 1. d_augmented = torch.cat( [d_augmented, torch.full(trigger_nodes.size(), 2, device=device)]) d_pow = d.pow(-0.5) d_augmented_pow = d_augmented.pow(-0.5) edge_weight = d_pow[edge_index_with_self_loop[0]] * \ d_pow[edge_index_with_self_loop[1]] edge_weight_with_trigger = d_augmented_pow[edge_index[0]] * d_pow[ edge_index[1]] trigger_edge_weight = d_augmented_pow[ trigger_edge_index[0]] * d_augmented_pow[trigger_edge_index[1]] augmented_edge_weight = torch.cat( [edge_weight_with_trigger, trigger_edge_weight], dim=0) return (edge_index, edge_weight_with_trigger, edge_index_with_self_loop, edge_weight, trigger_edge_index, trigger_edge_weight, augmented_edge_index, augmented_edge_weight)