Source code for greatx.attack.targeted.ig_attack

from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import grad
from torch_geometric.data import Data
from tqdm.auto import tqdm

from greatx.attack.targeted.targeted_attacker import TargetedAttacker
from greatx.nn.models.surrogate import Surrogate
from greatx.utils import singleton_filter


[docs]class IGAttack(TargetedAttacker, Surrogate): r"""Implementation of `IG-FGSM` attack from the: `"Adversarial Examples on Graph Data: Deep Insights into Attack and Defense" <https://arxiv.org/abs/1903.01610>`_ paper (IJCAI'19) Parameters ---------- data : Data PyG-like data denoting the input graph device : str, optional the device of the attack running on, by default "cpu" seed : Optional[int], optional the random seed for reproducing the attack, by default None name : Optional[str], optional name of the attacker, if None, it would be :obj:`__class__.__name__`, by default None kwargs : additional arguments of :class:`greatx.attack.Attacker`, Raises ------ TypeError unexpected keyword argument in :obj:`kwargs` Example ------- .. code-block:: python from greatx.dataset import GraphDataset import torch_geometric.transforms as T import os.path as osp dataset = GraphDataset(root='.', name='Cora', transform=T.LargestConnectedComponents()) data = dataset[0] surrogate_model = ... # train your surrogate model from greatx.attack.targeted import IGAttack attacker = IGAttack(data) attacker.setup_surrogate(surrogate_model) attacker.reset() # attacking target node `1` with default budget set as node degree attacker.attack(target=1) # attacking target node `1` with budget set as 1 attacker.attack(target=1, num_budgets=1) attacker.data() # get attacked graph attacker.edge_flips() # get edge flips after attack attacker.added_edges() # get added edges after attack attacker.removed_edges() # get removed edges after attack Note ---- * Please remember to call :meth:`reset` before each attack. """ # IGAttack can conduct feature attack _allow_feature_attack: bool = True def __init__(self, data: Data, device: str = "cpu", seed: Optional[int] = None, name: Optional[str] = None, **kwargs): super().__init__(data=data, device=device, seed=seed, name=name, **kwargs) num_nodes, num_feats = self.num_nodes, self.num_feats self.nodes_set = set(range(num_nodes)) self.feats_list = list(range(num_feats)) self.adj = self.get_dense_adj()
[docs] def attack(self, target, *, target_label=None, num_budgets=None, steps=20, direct_attack=True, structure_attack=True, feature_attack=False, disable=False): super().attack(target, target_label, num_budgets=num_budgets, direct_attack=direct_attack, structure_attack=structure_attack, feature_attack=feature_attack) if feature_attack: self._check_feature_matrix_binary() target_label = self.target_label.view(-1) if structure_attack: candidate_edges = self.get_candidate_edges() link_importance, edge_indicator = self.get_link_importance( candidate_edges, steps, target, target_label, disable=disable) if feature_attack: candidate_feats = self.get_candidate_features() feature_importance, feat_indicator = self.get_feature_importance( candidate_feats, steps, target, target_label, disable=disable) if structure_attack and not feature_attack: indices = torch.topk(link_importance, k=self.num_budgets).indices edge_indicator = edge_indicator[indices] link_selected = candidate_edges[indices] for u, v in link_selected[~edge_indicator].tolist(): self.add_edge(u, v) for u, v in link_selected[edge_indicator].tolist(): self.remove_edge(u, v) elif feature_attack and not structure_attack: indices = torch.topk(feature_importance, k=self.num_budgets).indices feat_indicator = feat_indicator[indices] feature_selected = candidate_feats[indices] for u, v in feature_selected[feat_indicator].tolist(): self.remove_feat(u, v) for u, v in feature_selected[~feat_indicator].tolist(): self.add_feat(u, v) else: # both attacks are conducted importance = torch.cat([link_importance, feature_importance]) indices = torch.topk(importance, k=self.num_budgets).indices boundary = link_importance.size(0) link_indices = indices[indices < boundary] edge_indicator = edge_indicator[link_indices] link_selected = candidate_edges[link_indices] feat_indices = indices[indices > boundary] - boundary feat_indicator = feat_indicator[feat_indices] feature_selected = candidate_feats[feat_indices] for u, v in link_selected[~edge_indicator].tolist(): self.add_edge(u, v) for u, v in link_selected[edge_indicator].tolist(): self.remove_edge(u, v) for u, v in feature_selected[feat_indicator].tolist(): self.remove_feat(u, v) for u, v in feature_selected[~feat_indicator].tolist(): self.add_feat(u, v) return self
[docs] def get_candidate_edges(self): target = self.target N = self.num_nodes nodes_set = set(range(N)) - set([target]) if self.direct_attack: influencers = [target] row = np.repeat(influencers, N - 1) col = list(nodes_set) else: influencers = self.adjacency_matrix[target].indices row = np.repeat(influencers, N - 2) col = np.hstack( [list(nodes_set - set([infl])) for infl in influencers]) candidate_edges = np.stack([row, col], axis=1) if not self._allow_singleton: candidate_edges = singleton_filter(candidate_edges, self.adjacency_matrix) candidate_edges = torch.as_tensor(candidate_edges, dtype=torch.long, device=self.device) return candidate_edges
[docs] def get_candidate_features(self): num_feats = self.num_feats target = self.target if self.direct_attack: influencers = [target] candidate_feats = np.column_stack( (np.tile(target, num_feats), self.feats_list)) else: influencers = self.adjacency_matrix[target].indices candidate_feats = np.row_stack([ np.column_stack((np.tile(infl, num_feats), self.feats_list)) for infl in influencers ]) candidate_feats = torch.as_tensor(candidate_feats, dtype=torch.long, device=self.device) return candidate_feats
[docs] def get_feature_importance(self, candidates, steps, target, target_label, disable=False): adj = self.adj feat = self.feat mask = (candidates[:, 0], candidates[:, 1]) baseline_add = feat.clone() baseline_add[mask] = 1.0 baseline_remove = feat.clone() baseline_remove[mask] = 0.0 feat_indicator = feat[mask] > 0 features = candidates[feat_indicator] non_features = candidates[~feat_indicator] feat_gradients = feat.new_zeros(features.size(0)) non_feat_gradients = feat.new_zeros(non_features.size(0)) for alpha in tqdm(torch.linspace(0., 1.0, steps + 1), desc='Computing feature importance...', disable=disable): # Compute integrated gradients for removing features feat_diff = feat - baseline_remove feat_step = baseline_remove + alpha * feat_diff feat_step.requires_grad_() gradients = self.compute_feature_gradients(feat_step, adj, target, target_label) feat_gradients += gradients[features[:, 0], features[:, 1]] # Compute integrated gradients for adding features feat_diff = baseline_add - feat feat_step = baseline_add - alpha * feat_diff feat_step.requires_grad_() gradients = self.compute_feature_gradients(feat_step, adj, target, target_label) non_feat_gradients += gradients[non_features[:, 0], non_features[:, 1]] integrated_grads = feat.new_zeros(feat_indicator.size(0)) integrated_grads[feat_indicator] = feat_gradients integrated_grads[~feat_indicator] = non_feat_gradients return integrated_grads, feat_indicator
[docs] def compute_structure_gradients(self, feat, adj_step, target, target_label): logit = self.surrogate(feat, adj_step)[target].view(1, -1) / self.tau loss = F.cross_entropy(logit, target_label) return grad(loss, adj_step, create_graph=False)[0]
[docs] def compute_feature_gradients(self, feat_step, adj, target, target_label): logit = self.surrogate(feat_step, adj)[target].view(1, -1) / self.tau loss = F.cross_entropy(logit, target_label) return grad(loss, feat_step, create_graph=False)[0]