import warnings
from typing import Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.utils import add_self_loops, degree, remove_self_loops
from tqdm.auto import tqdm
from greatx.attack.backdoor.backdoor_attacker import BackdoorAttacker
from greatx.functional import spmm
from greatx.nn.models.surrogate import Surrogate
[docs]class LGCBackdoor(BackdoorAttacker):
r"""Implementation of `LGCB` attack from the:
`"Neighboring Backdoor Attacks on Graph Convolutional Network"
<https://arxiv.org/abs/2201.06202>`_ paper (arXiv'22)
Example
-------
.. code-block:: python
from greatx.dataset import GraphDataset
import torch_geometric.transforms as T
dataset = GraphDataset(root='.', name='Cora',
transform=T.LargestConnectedComponents())
data = dataset[0]
surrogate_model = ... # train your surrogate model
from greatx.attack.backdoor import LGCBackdoor
attacker.setup_surrogate(surrogate_model)
attacker = LGCBackdoor(data)
attacker.reset()
attacker.attack(num_budgets=50, target_class=0)
attacker.data() # get attacked graph
attacker.trigger() # get trigger node
Note
----
* Please remember to call :meth:`reset` before each attack.
"""
[docs] @torch.no_grad()
def setup_surrogate(self, surrogate: nn.Module) -> "LGCBackdoor":
W = None
for para in surrogate.parameters():
if para.ndim == 1:
warnings.warn("The surrogate model has `bias` term, "
"which is ignored and the model itself "
f"may not be a perfect choice for {self.name}.")
continue
if W is None:
W = para.detach()
else:
W = para.detach() @ W
assert W is not None
self.W = W.t()
self.num_classes = self.W.size(-1)
return self
[docs] def attack(self, num_budgets: Union[int, float], target_class: int,
disable: bool = False) -> "LGCBackdoor":
super().attack(num_budgets, target_class)
assert target_class < self.num_classes
feat_perturbations = self.get_feat_perturbations(
self.W, target_class, self.num_budgets)
trigger = self.feat.new_zeros(self.num_feats)
trigger[feat_perturbations] = 1.
self._trigger = trigger
return self
[docs] @staticmethod
def get_feat_perturbations(W: Tensor, target_class: int,
num_budgets: int) -> Tensor:
D = W - W[:, target_class].view(-1, 1)
D = D.sum(1)
_, indices = torch.topk(D, k=num_budgets, largest=False)
return indices
[docs]class FGBackdoor(BackdoorAttacker, Surrogate):
r"""Implementation of `GB-FGSM` attack from the:
`"Neighboring Backdoor Attacks on Graph Convolutional Network"
<https://arxiv.org/abs/2201.06202>`_ paper (arXiv'22)
Example
-------
.. code-block:: python
from greatx.dataset import GraphDataset
import torch_geometric.transforms as T
dataset = GraphDataset(root='.', name='Cora',
transform=T.LargestConnectedComponents())
data = dataset[0]
surrogate_model = ... # train your surrogate model
from greatx.attack.backdoor import FGBackdoor
attacker.setup_surrogate(surrogate_model)
attacker = FGBackdoor(data)
attacker.reset()
attacker.attack(num_budgets=50, target_class=0)
attacker.data() # get attacked graph
attacker.trigger() # get trigger node
Note
----
* Please remember to call :meth:`reset` before each attack.
"""
[docs] def setup_surrogate(self, surrogate: nn.Module, *,
tau: float = 1.0) -> "FGBackdoor":
Surrogate.setup_surrogate(self, surrogate=surrogate, tau=tau,
freeze=True)
W = []
for para in self.surrogate.parameters():
if para.ndim == 1:
warnings.warn("The surrogate model has `bias` term, "
"which is ignored and the model itself "
f"may not be a perfect choice for {self.name}.")
else:
W.append(para.detach().t())
assert len(W) == 2
self.w1, self.w2 = W
self.num_classes = W[-1].size(-1)
return self
[docs] def attack(self, num_budgets: Union[int, float], target_class: int,
disable: bool = False) -> "FGBackdoor":
super().attack(num_budgets, target_class)
assert target_class < self.num_classes
N = self.num_nodes
feat = self.feat
trigger = feat.new_zeros(self.num_feats).requires_grad_()
target_labels = torch.LongTensor([target_class
]).to(self.device).repeat(N)
(edge_index, edge_weight_with_trigger, edge_index_with_self_loop,
edge_weight, trigger_edge_index, trigger_edge_weight,
augmented_edge_index,
augmented_edge_weight) = get_backdoor_edges(self.edge_index, N)
for _ in tqdm(range(self.num_budgets),
desc="Updating trigger using gradients...",
disable=disable):
aug_feat = torch.cat([feat, trigger.repeat(N, 1)], dim=0)
feat1 = aug_feat @ self.w1
h1 = spmm(feat1, edge_index_with_self_loop, edge_weight)
h1_aug = spmm(feat1, augmented_edge_index,
augmented_edge_weight).relu()
h = spmm(h1_aug @ self.w2, trigger_edge_index, trigger_edge_weight)
h += spmm(h1 @ self.w2, edge_index, edge_weight_with_trigger)
h = h[:N] / self.tau
loss = F.cross_entropy(h, target_labels)
gradients = torch.autograd.grad(-loss, trigger)[0] * (1. - trigger)
trigger.data[gradients.argmax()].fill_(1.0)
self._trigger = trigger.detach()
return self
def get_backdoor_edges(edge_index: Tensor, N: int) -> Tuple:
device = edge_index.device
influence_nodes = torch.arange(N, device=device)
N_all = N + influence_nodes.size(0)
trigger_nodes = torch.arange(N, N_all, device=device)
# 1. edge index of original graph (without selfloops)
edge_index, _ = remove_self_loops(edge_index)
# 2. edge index of original graph (with selfloops)
edge_index_with_self_loop, _ = add_self_loops(edge_index)
# 3. edge index of trigger nodes conneted to victim nodes with selfloops
trigger_edge_index = torch.stack([trigger_nodes, influence_nodes], dim=0)
diag_index = torch.arange(N_all, device=device).repeat(2, 1)
trigger_edge_index = torch.cat(
[trigger_edge_index, trigger_edge_index[[1, 0]], diag_index], dim=1)
# 4. all edge index with trigger nodes
augmented_edge_index = torch.cat([edge_index, trigger_edge_index], dim=1)
d = degree(edge_index[0], num_nodes=N, dtype=torch.float)
d_augmented = d.clone()
d_augmented[influence_nodes] += 1.
d_augmented = torch.cat(
[d_augmented,
torch.full(trigger_nodes.size(), 2, device=device)])
d_pow = d.pow(-0.5)
d_augmented_pow = d_augmented.pow(-0.5)
edge_weight = d_pow[edge_index_with_self_loop[0]] * \
d_pow[edge_index_with_self_loop[1]]
edge_weight_with_trigger = d_augmented_pow[edge_index[0]] * d_pow[
edge_index[1]]
trigger_edge_weight = d_augmented_pow[
trigger_edge_index[0]] * d_augmented_pow[trigger_edge_index[1]]
augmented_edge_weight = torch.cat(
[edge_weight_with_trigger, trigger_edge_weight], dim=0)
return (edge_index, edge_weight_with_trigger, edge_index_with_self_loop,
edge_weight, trigger_edge_index, trigger_edge_weight,
augmented_edge_index, augmented_edge_weight)