Source code for greatx.attack.untargeted.ig_attack

from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.autograd import grad
from torch_geometric.data import Data
from tqdm.auto import tqdm

from greatx.attack.untargeted.untargeted_attacker import UntargetedAttacker
from greatx.nn.models.surrogate import Surrogate
from greatx.utils import singleton_mask


[docs]class IGAttack(UntargetedAttacker, Surrogate): r"""Implementation of `IG-FGSM` attack from the: `"Adversarial Examples on Graph Data: Deep Insights into Attack and Defense" <https://arxiv.org/abs/1903.01610>`_ paper (IJCAI'19) Parameters ---------- data : Data PyG-like data denoting the input graph device : str, optional the device of the attack running on, by default "cpu" seed : Optional[int], optional the random seed for reproducing the attack, by default None name : Optional[str], optional name of the attacker, if None, it would be :obj:`__class__.__name__`, by default None kwargs : additional arguments of :class:`greatx.attack.Attacker`, Raises ------ TypeError unexpected keyword argument in :obj:`kwargs` Example ------- .. code-block:: python from greatx.dataset import GraphDataset import torch_geometric.transforms as T dataset = GraphDataset(root='.', name='Cora', transform=T.LargestConnectedComponents()) data = dataset[0] surrogate_model = ... # train your surrogate model from greatx.attack.untargeted import IGAttack attacker = IGAttack(data) attacker.setup_surrogate(surrogate_model) attacker.reset() attacker.attack(0.05) # attack with 0.05% of edge perturbations attacker.data() # get attacked graph attacker.edge_flips() # get edge flips after attack attacker.added_edges() # get added edges after attack attacker.removed_edges() # get removed edges after attack Note ---- * In the paper, `IG-FGSM` attack was implemented for targeted attack, we adapt the codes for the non-targeted attack here. # noqa * Please remember to call :meth:`reset` before each attack. """ # IGAttack can conduct feature attack _allow_feature_attack: bool = True def __init__(self, data: Data, device: str = "cpu", seed: Optional[int] = None, name: Optional[str] = None, **kwargs): super().__init__(data=data, device=device, seed=seed, name=name, **kwargs) num_nodes, num_feats = self.num_nodes, self.num_feats self.nodes_set = set(range(num_nodes)) self.feats_list = list(range(num_feats)) self.adj = self.get_dense_adj()
[docs] def setup_surrogate(self, surrogate: torch.nn.Module, victim_nodes: Tensor, victim_labels: Optional[Tensor] = None, *, tau: float = 1.0): Surrogate.setup_surrogate(self, surrogate=surrogate, tau=tau, freeze=True) if victim_nodes.dtype == torch.bool: victim_nodes = victim_nodes.nonzero().view(-1) self.victim_nodes = victim_nodes.to(self.device) if victim_labels is None: victim_labels = self.label[victim_nodes] self.victim_labels = victim_labels.to(self.device) return self
[docs] def attack(self, num_budgets=0.05, *, steps=20, structure_attack=True, feature_attack=False, disable=False): super().attack(num_budgets=num_budgets, structure_attack=structure_attack, feature_attack=feature_attack) if structure_attack: link_importance = self.get_link_importance(steps, self.victim_nodes, self.victim_labels, disable=disable) adj_score = self.structure_score(self.adj, link_importance) if feature_attack: self._check_feature_matrix_binary() feature_importance = self.get_feature_importance( steps, self.victim_nodes, self.victim_labels, disable=disable) feat_score = self.feature_score(self.feat, feature_importance) if structure_attack and not feature_attack: indices = torch.topk(adj_score, k=self.num_budgets).indices for it, index in enumerate(indices.tolist()): u, v = divmod(index, self.num_nodes) edge_weight = self.adj[u, v].data.item() if edge_weight > 0: self.remove_edge(u, v, it) else: self.add_edge(u, v, it) elif feature_attack and not structure_attack: indices = torch.topk(feat_score, k=self.num_budgets).indices for it, index in enumerate(indices.tolist()): u, v = divmod(index, self.num_feats) feat_weight = self.feat[u, v].data.item() if feat_weight > 0: self.remove_feat(u, v, it) else: self.add_feat(u, v, it) else: # both attacks are conducted score = torch.cat([adj_score, feat_score]) indices = torch.topk(score, k=self.num_budgets).indices boundary = adj_score.size(0) for it, index in enumerate(indices.tolist()): if index < boundary: u, v = divmod(index, self.num_nodes) edge_weight = self.adj[u, v].data.item() if edge_weight > 0: self.remove_edge(u, v, it) else: self.add_edge(u, v, it) else: u, v = divmod(index - boundary, self.num_feats) feat_weight = self.feat[u, v].data.item() if feat_weight > 0: self.remove_feat(u, v, it) else: self.add_feat(u, v, it) return self
[docs] def get_feature_importance(self, steps, victim_nodes, victim_labels, disable=False): adj = self.adj feat = self.feat baseline_add = torch.ones_like(feat) baseline_remove = torch.zeros_like(feat) gradients = torch.zeros_like(feat) for alpha in tqdm(torch.linspace(0., 1.0, steps + 1), desc='Computing feature importance...', disable=disable): # Compute integrated gradients for removing features feat_diff = feat - baseline_remove feat_step = baseline_remove + alpha * feat_diff feat_step.requires_grad_() gradients += self.compute_feature_gradients( adj, feat_step, victim_nodes, victim_labels) # Compute integrated gradients for adding features feat_diff = baseline_add - feat feat_step = baseline_add - alpha * feat_diff feat_step.requires_grad_() gradients += self.compute_feature_gradients( adj, feat_step, victim_nodes, victim_labels) return gradients
[docs] def structure_score(self, adj, adj_grad): adj_grad = adj_grad + adj_grad.t() score = adj_grad * (1 - 2 * adj) score -= score.min() score = torch.triu(score, diagonal=1) if not self._allow_singleton: # Set entries to 0 that could lead to singleton nodes. score *= singleton_mask(adj) return score.view(-1)
[docs] def feature_score(self, feat, feat_grad): score = feat_grad * (1 - 2 * feat) score -= score.min() return score.view(-1)
[docs] def compute_structure_gradients(self, adj_step, feat, victim_nodes, victim_labels): logit = self.surrogate(feat, adj_step)[victim_nodes] / self.tau loss = F.cross_entropy(logit, victim_labels) return grad(loss, adj_step, create_graph=False)[0]
[docs] def compute_feature_gradients(self, adj, feat_step, victim_nodes, victim_labels): logit = self.surrogate(feat_step, adj)[victim_nodes] / self.tau loss = F.cross_entropy(logit, victim_labels) return grad(loss, feat_step, create_graph=False)[0]