Source code for greatx.attack.untargeted.metattack

import torch
import torch.nn.functional as F
from torch import Tensor
from torch.autograd import grad
from torch.nn import init
from tqdm.auto import tqdm

from greatx.attack.untargeted.untargeted_attacker import UntargetedAttacker
from greatx.nn.layers.gcn_conv import dense_gcn_norm
from greatx.nn.models import GCN
from greatx.nn.models.surrogate import Surrogate
from greatx.utils import singleton_mask


[docs]class Metattack(UntargetedAttacker, Surrogate): r"""Implementation of `Metattack` attack from the: `"Adversarial Attacks on Graph Neural Networks via Meta Learning" <https://arxiv.org/abs/1902.08412>`_ paper (ICLR'19) Parameters ---------- data : Data PyG-like data denoting the input graph device : str, optional the device of the attack running on, by default "cpu" seed : Optional[int], optional the random seed for reproducing the attack, by default None name : Optional[str], optional name of the attacker, if None, it would be :obj:`__class__.__name__`, by default None kwargs : additional arguments of :class:`greatx.attack.Attacker`, Raises ------ TypeError unexpected keyword argument in :obj:`kwargs` Example ------- .. code-block:: python from greatx.dataset import GraphDataset import torch_geometric.transforms as T dataset = GraphDataset(root='.', name='Cora', transform=T.LargestConnectedComponents()) data = dataset[0] surrogate_model = ... # train your surrogate model from greatx.attack.untargeted import Metattack attacker = Metattack(data) attacker.setup_surrogate(surrogate_model) attacker.reset() attacker.attack(0.05) # attack with 0.05% of edge perturbations attacker.data() # get attacked graph attacker.edge_flips() # get edge flips after attack attacker.added_edges() # get added edges after attack attacker.removed_edges() # get removed edges after attack Note ---- * Please remember to call :meth:`reset` before each attack. """ # Metattack can also conduct feature attack _allow_feature_attack: bool = True
[docs] def setup_surrogate(self, surrogate: torch.nn.Module, labeled_nodes: Tensor, unlabeled_nodes: Tensor, lr: float = 0.1, epochs: int = 100, momentum: float = 0.9, lambda_: float = 0., *, tau: float = 1.0): if lambda_ not in (0., 0.5, 1.): raise ValueError( "Invalid argument `lambda_`, allowed values " "[0: (meta-self), 1: (meta-train), 0.5: (meta-both)].") Surrogate.setup_surrogate(self, surrogate=surrogate, tau=tau, freeze=False, required=GCN) if labeled_nodes.dtype == torch.bool: labeled_nodes = labeled_nodes.nonzero().view(-1) labeled_nodes = labeled_nodes.to(self.device) if unlabeled_nodes.dtype == torch.bool: unlabeled_nodes = unlabeled_nodes.nonzero().view(-1) unlabeled_nodes = unlabeled_nodes.to(self.device) self.labeled_nodes = labeled_nodes self.unlabeled_nodes = unlabeled_nodes self.y_train = self.label[labeled_nodes] self.y_self_train = self.estimate_self_training_labels(unlabeled_nodes) self.adj = self.get_dense_adj() weights = [] w_velocities = [] for para in self.surrogate.parameters(): if para.ndim == 2: para = para.t() weights.append(torch.zeros_like(para, requires_grad=True)) w_velocities.append(torch.zeros_like(para)) else: # we do not consider bias terms for simplicity pass self.weights, self.w_velocities = weights, w_velocities self.epochs = epochs self.lr = lr self.momentum = momentum self.lambda_ = lambda_
[docs] def reset(self): super().reset() self.adj_changes = torch.zeros_like(self.adj) self.feat_changes = torch.zeros_like(self.feat) return self
[docs] def get_perturbed_adj(self, adj_changes=None): if adj_changes is None: adj_changes = self.adj_changes adj_changes_triu = torch.triu(adj_changes, diagonal=1) adj_changes_symm = self.clip(adj_changes_triu + adj_changes_triu.t()) modified_adj = adj_changes_symm + self.adj return modified_adj
[docs] def get_perturbed_feat(self, feat_changes=None): if feat_changes is None: feat_changes = self.feat_changes return self.feat + self.clip(feat_changes)
[docs] def clip(self, matrix): clipped_matrix = torch.clamp(matrix, -1., 1.) return clipped_matrix
[docs] def reset_parameters(self): for w, wv in zip(self.weights, self.w_velocities): init.xavier_uniform_(w) init.zeros_(wv) for i in range(len(self.weights)): self.weights[i] = self.weights[i].detach().requires_grad_() self.w_velocities[i] = self.w_velocities[i].detach()
[docs] def forward(self, adj, x): """""" h = x for w in self.weights[:-1]: h = adj @ (h @ w) h = h.relu() return adj @ (h @ self.weights[-1])
[docs] def inner_train(self, adj, feat): self.reset_parameters() for _ in range(self.epochs): out = self(adj, feat) loss = F.cross_entropy(out[self.labeled_nodes], self.y_train) grads = torch.autograd.grad(loss, self.weights, create_graph=True) self.w_velocities = [ self.momentum * v + g for v, g in zip(self.w_velocities, grads) ] self.weights = [ w - self.lr * v for w, v in zip(self.weights, self.w_velocities) ]
[docs] def attack(self, num_budgets=0.05, *, structure_attack=True, feature_attack=False, disable=False): super().attack(num_budgets=num_budgets, structure_attack=structure_attack, feature_attack=feature_attack) if feature_attack: self._check_feature_matrix_binary() adj_changes = self.adj_changes feat_changes = self.feat_changes modified_adj = self.adj modified_feat = self.feat adj_changes.requires_grad_(bool(structure_attack)) feat_changes.requires_grad_(bool(feature_attack)) num_nodes, num_feats = self.num_nodes, self.num_feats for it in tqdm(range(self.num_budgets), desc='Peturbing graph...', disable=disable): if structure_attack: modified_adj = self.get_perturbed_adj(adj_changes) if feature_attack: modified_feat = self.get_perturbed_feat(feat_changes) adj_norm = dense_gcn_norm(modified_adj) self.inner_train(adj_norm, modified_feat) adj_grad, feat_grad = self.compute_gradients( adj_norm, modified_feat) adj_grad_score = modified_adj.new_zeros(1) feat_grad_score = modified_feat.new_zeros(1) with torch.no_grad(): if structure_attack: adj_grad_score = self.structure_score( modified_adj, adj_grad) if feature_attack: feat_grad_score = self.feature_score( modified_feat, feat_grad) adj_max, adj_argmax = torch.max(adj_grad_score, dim=0) feat_max, feat_argmax = torch.max(feat_grad_score, dim=0) if adj_max >= feat_max: u, v = divmod(adj_argmax.item(), num_nodes) edge_weight = modified_adj[u, v].data.item() adj_changes[u, v].data.fill_(1 - 2 * edge_weight) adj_changes[v, u].data.fill_(1 - 2 * edge_weight) if edge_weight > 0: self.remove_edge(u, v, it) else: self.add_edge(u, v, it) else: u, v = divmod(feat_argmax.item(), num_feats) feat_weight = modified_feat[u, v].data.item() feat_changes[u, v].data.fill_(1 - 2 * feat_weight) if feat_weight > 0: self.remove_feat(u, v, it) else: self.add_feat(u, v, it) return self
[docs] def structure_score(self, modified_adj, adj_grad): score = adj_grad * (1 - 2 * modified_adj) score -= score.min() score = torch.triu(score, diagonal=1) if not self._allow_singleton: # Set entries to 0 that could lead to singleton nodes. score *= singleton_mask(modified_adj) return score.view(-1)
[docs] def feature_score(self, modified_feat, feat_grad): score = feat_grad * (1 - 2 * modified_feat) score -= score.min() return score.view(-1)
[docs] def compute_gradients(self, modified_adj, modified_feat): logit = self(modified_adj, modified_feat) / self.tau if self.lambda_ == 1: loss = F.cross_entropy(logit[self.labeled_nodes], self.y_train) elif self.lambda_ == 0.: loss = F.cross_entropy(logit[self.unlabeled_nodes], self.y_self_train) else: loss_labeled = F.cross_entropy(logit[self.labeled_nodes], self.y_train) loss_unlabeled = F.cross_entropy(logit[self.unlabeled_nodes], self.y_self_train) loss = self.lambda_ * loss_labeled + \ (1 - self.lambda_) * loss_unlabeled if self.structure_attack and self.feature_attack: return grad(loss, [self.adj_changes, self.feat_changes]) if self.structure_attack: return grad(loss, self.adj_changes)[0], None if self.feature_attack: return None, grad(loss, self.feat_changes)[0]