Source code for greatx.attack.targeted.sg_attack

from collections import namedtuple

import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import grad
from tqdm.auto import tqdm

from greatx.attack.targeted.targeted_attacker import TargetedAttacker
from greatx.nn.models.surrogate import Surrogate
from greatx.utils import ego_graph

SubGraph = namedtuple('SubGraph', [
    'edge_index', 'sub_edges', 'non_edges', 'edge_weight', 'non_edge_weight',
    'selfloop_weight'
])


[docs]class SGAttack(TargetedAttacker, Surrogate): r"""Implementation of `SGA` attack from the: `"Adversarial Attack on Large Scale Graph" <https://arxiv.org/abs/2009.03488>`_ paper (TKDE'21) Parameters ---------- data : Data PyG-like data denoting the input graph device : str, optional the device of the attack running on, by default "cpu" seed : Optional[int], optional the random seed for reproducing the attack, by default None name : Optional[str], optional name of the attacker, if None, it would be :obj:`__class__.__name__`, by default None kwargs : additional arguments of :class:`greatx.attack.Attacker`, Raises ------ TypeError unexpected keyword argument in :obj:`kwargs` Example ------- .. code-block:: python from greatx.dataset import GraphDataset import torch_geometric.transforms as T dataset = GraphDataset(root='.', name='Cora', transform=T.LargestConnectedComponents()) data = dataset[0] surrogate_model = ... # train your surrogate model from greatx.attack.targeted import SGAttack attacker = SGAttack(data) attacker.setup_surrogate(surrogate_model) attacker.reset() # attacking target node `1` with default budget set as node degree attacker.attack(target=1) # attacking target node `1` with budget set as 1 attacker.attack(target=1, num_budgets=1) attacker.data() # get attacked graph attacker.edge_flips() # get edge flips after attack attacker.added_edges() # get added edges after attack attacker.removed_edges() # get removed edges after attack Note ---- * `SGAttack` is a scalable attack that can be applied to large scale graph * Please remember to call :meth:`reset` before each attack. """ # SGAttack cannot ensure that there is not singleton node after attacks. _allow_singleton = True
[docs] @torch.no_grad() def setup_surrogate( self, surrogate: torch.nn.Module, *, tau: float = 5.0, freeze: bool = True, ): Surrogate.setup_surrogate(self, surrogate=surrogate, tau=tau, freeze=freeze) self.logits = self.surrogate(self.feat, self.edge_index, self.edge_weight).cpu() return self
[docs] def set_normalize(self, state): # TODO: this is incorrect for models # with `normalize=False` by default for layer in self.surrogate.modules(): if hasattr(layer, 'normalize'): layer.normalize = state if hasattr(layer, 'add_self_loops'): layer.add_self_loops = state
[docs] def strongest_wrong_class(self, target, target_label): logit = self.logits[target].clone() logit[target_label] = -1e4 return logit.argmax()
[docs] def get_subgraph(self, target, target_label, best_wrong_label): sub_nodes, sub_edges = ego_graph(self.adjacency_matrix, int(target), self.K) if sub_edges.size == 0: raise RuntimeError( f"The target node {int(target)} is a singleton node.") sub_nodes = torch.as_tensor(sub_nodes, dtype=torch.long, device=self.device) sub_edges = torch.as_tensor(sub_edges, dtype=torch.long, device=self.device) attacker_nodes = torch.where( self.label == best_wrong_label)[0].cpu().numpy() neighbors = self.adjacency_matrix[target].indices influencers = [target] attacker_nodes = np.setdiff1d(attacker_nodes, neighbors) subgraph = self.subgraph_processing(sub_nodes, sub_edges, influencers, attacker_nodes) if self.direct_attack: influencers = [target] num_attackers = self.num_budgets + 1 else: influencers = neighbors num_attackers = 3 attacker_nodes = self.get_top_attackers(subgraph, target, target_label, best_wrong_label, num_attackers=num_attackers) subgraph = self.subgraph_processing(sub_nodes, sub_edges, influencers, attacker_nodes) return subgraph
[docs] def get_top_attackers(self, subgraph, target, target_label, best_wrong_label, num_attackers): non_edge_grad, _ = self.compute_gradients(subgraph, target, target_label, best_wrong_label) _, index = torch.topk(non_edge_grad, k=min(num_attackers, non_edge_grad.size(0)), sorted=False) attacker_nodes = subgraph.non_edges[1][index] return attacker_nodes.tolist()
[docs] def subgraph_processing(self, sub_nodes, sub_edges, influencers, attacker_nodes): row = np.repeat(influencers, len(attacker_nodes)) col = np.tile(attacker_nodes, len(influencers)) non_edges = np.row_stack([row, col]) if not self.direct_attack: # indirect attack mask = self.adjacency_matrix[non_edges[0], non_edges[1]].A1 == 0 non_edges = non_edges[:, mask] non_edges = torch.as_tensor(non_edges, dtype=torch.long, device=self.device) attacker_nodes = torch.as_tensor(attacker_nodes, dtype=torch.long, device=self.device) selfloop = torch.unique(torch.cat([sub_nodes, attacker_nodes])) edge_index = torch.cat([ non_edges, sub_edges, non_edges.flip(0), sub_edges.flip(0), selfloop.repeat((2, 1)) ], dim=1) edge_weight = torch.ones(sub_edges.size(1), device=self.device).requires_grad_() non_edge_weight = torch.zeros(non_edges.size(1), device=self.device).requires_grad_() selfloop_weight = torch.ones(selfloop.size(0), device=self.device) subgraph = SubGraph( edge_index=edge_index, sub_edges=sub_edges, non_edges=non_edges, edge_weight=edge_weight, non_edge_weight=non_edge_weight, selfloop_weight=selfloop_weight, ) return subgraph
[docs] def attack(self, target, *, K: int = 2, target_label=None, num_budgets=None, direct_attack=True, structure_attack=True, feature_attack=False, disable=False): super().attack(target, target_label, num_budgets=num_budgets, direct_attack=direct_attack, structure_attack=structure_attack, feature_attack=feature_attack) self.set_normalize(False) self.K = K target_label = self.target_label.view(-1) best_wrong_label = self.strongest_wrong_class(target, target_label).view(-1) best_wrong_label = best_wrong_label.to(self.device) subgraph = self.get_subgraph(target, target_label, best_wrong_label) if not direct_attack: condition1 = subgraph.sub_edges[0] != target condition2 = subgraph.sub_edges[1] != target mask = torch.logical_and(condition1, condition2).float() for it in tqdm(range(self.num_budgets), desc='Peturbing graph...', disable=disable): non_edge_grad, edge_grad = self.compute_gradients( subgraph, target, target_label, best_wrong_label) with torch.no_grad(): edge_grad *= -2 * subgraph.edge_weight + 1 if not direct_attack: edge_grad *= mask non_edge_grad *= -2 * subgraph.non_edge_weight + 1 max_edge_grad, max_edge_idx = torch.max(edge_grad, dim=0) max_non_edge_grad, max_non_edge_idx = torch.max( non_edge_grad, dim=0) if max_edge_grad > max_non_edge_grad: # remove one edge subgraph.edge_weight.data[max_edge_idx].fill_(0.) u, v = subgraph.sub_edges[:, max_edge_idx].tolist() self.remove_edge(u, v, it) else: # add one edge subgraph.non_edge_weight.data[max_non_edge_idx].fill_(1.) u, v = subgraph.non_edges[:, max_non_edge_idx].tolist() self.add_edge(u, v, it) self.set_normalize(True) return self
[docs] def compute_gradients(self, subgraph, target, target_label, best_wrong_label): edge_weight = torch.cat([ subgraph.non_edge_weight, subgraph.edge_weight, subgraph.non_edge_weight, subgraph.edge_weight, subgraph.selfloop_weight ], dim=0) row, col = subgraph.edge_index norm = (self.degree + 1.).pow(-0.5) edge_weight = norm[row] * edge_weight * norm[col] logit = self.surrogate(self.feat, subgraph.edge_index, edge_weight) logit = logit[target].view(1, -1) / self.tau logit = F.log_softmax(logit, dim=1) loss = F.nll_loss(logit, target_label) - \ F.nll_loss(logit, best_wrong_label) return grad(loss, [subgraph.non_edge_weight, subgraph.edge_weight], create_graph=False)