Source code for greatx.attack.injection.adv_injection

from typing import Optional, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.autograd import grad
from tqdm.auto import tqdm

from greatx.attack.injection.injection_attacker import InjectionAttacker
from greatx.nn.models.surrogate import Surrogate


[docs]class AdvInjection(InjectionAttacker, Surrogate): r"""2nd place solution of KDD CUP 2020 "Adversarial attack and defense" challenge. Example ------- .. code-block:: python from greatx.dataset import GraphDataset import torch_geometric.transforms as T dataset = GraphDataset(root='.', name='Cora', transform=T.LargestConnectedComponents()) data = dataset[0] surrogate_model = ... # train your surrogate model from greatx.attack.injection import AdvInjection attacker.setup_surrogate(surrogate_model) attacker = AdvInjection(data) attacker.reset() # injecting 10 nodes for continuous features attacker.attack(10, feat_limits=(0, 1)) attacker.reset() # injecting 10 nodes for binary features attacker.attack(10, feat_budgets=10) attacker.data() # get attacked graph attacker.injected_nodes() # get injected nodes after attack attacker.injected_edges() # get injected edges after attack attacker.injected_feats() # get injected features after attack Note ---- * Please remember to call :meth:`reset` before each attack. """
[docs] def attack(self, num_budgets: Union[int, float], *, targets: Optional[Tensor] = None, interconnection: bool = False, lr: float = 0.1, num_edges_global: Optional[int] = None, num_edges_local: Optional[int] = None, feat_limits: Optional[Union[tuple, dict]] = None, feat_budgets: Optional[int] = None, disable: bool = False) -> "AdvInjection": super().attack(num_budgets, targets=targets, num_edges_global=num_edges_global, num_edges_local=num_edges_local, feat_limits=feat_limits, feat_budgets=feat_budgets) candidate_nodes = self.targets.tolist() (edge_index, edge_weight, feat) = self.edge_index, self.edge_weight, self.feat if edge_weight is None: edge_weight = feat.new_ones(edge_index.size(1)) feat_min, feat_max = self.feat_limits feat_limits = max(abs(feat_min), feat_max) feat_budgets = self.feat_budgets injected_feats = None for injected_node in tqdm( range(self.num_nodes, self.num_nodes + self.num_budgets), desc="Injecting nodes...", disable=disable): injected_edge_index = np.stack([ np.tile(injected_node, len(candidate_nodes)), candidate_nodes ], axis=0) injected_edge_index = torch.as_tensor(injected_edge_index).to( edge_index) injected_edge_weight = edge_weight.new_zeros( injected_edge_index.size(1)).requires_grad_() injected_feat = feat.new_zeros(1, self.num_feats) if injected_feats is None: injected_feats = injected_feat.requires_grad_() else: injected_feats = torch.cat([injected_feats, injected_feat], dim=0).requires_grad_() edge_grad, feat_grad = self.compute_gradients( feat, edge_index, edge_weight, injected_feats, injected_edge_index, injected_edge_weight, targets=self.targets, target_labels=self.target_labels) topk_edges = torch.topk(edge_grad, k=self.num_edges_local).indices injected_edge_index = injected_edge_index[:, topk_edges] self.inject_node(injected_node) self.inject_edges(injected_edge_index) with torch.no_grad(): edge_index = torch.cat([ edge_index, injected_edge_index, injected_edge_index.flip(0) ], dim=1) edge_weight = torch.cat([ edge_weight, edge_weight.new_ones(injected_edge_index.size(1) * 2) ], dim=0) if feat_budgets is not None: topk = torch.topk(feat_grad, k=feat_budgets, dim=1) injected_feats.data.fill_(0.) injected_feats.data.scatter_(1, topk.indices, 1.0) else: injected_feats.data = feat_limits * lr * feat_grad.sign() injected_feats.data.clamp_(min=feat_min, max=feat_max) if interconnection: candidate_nodes.append(injected_node) self._injected_feats = injected_feats.data return self
[docs] def compute_gradients(self, x, edge_index, edge_weight, injected_feats, injected_edge_index, injected_edge_weight, targets, target_labels): x = torch.cat([x, injected_feats], dim=0) edge_index = torch.cat( [edge_index, injected_edge_index, injected_edge_index.flip(0)], dim=1) edge_weight = torch.cat( [edge_weight, injected_edge_weight.repeat(2)], dim=0) logit = self.surrogate(x, edge_index, edge_weight)[targets] / self.tau loss = F.cross_entropy(logit, target_labels) return grad(loss, [injected_edge_weight, injected_feats], create_graph=False)