from collections import namedtuple
import numpy as np
import torch
import torch.nn.functional as F
from torch.autograd import grad
from tqdm.auto import tqdm
from greatx.attack.targeted.targeted_attacker import TargetedAttacker
from greatx.nn.models.surrogate import Surrogate
from greatx.utils import ego_graph
SubGraph = namedtuple('SubGraph', [
'edge_index', 'sub_edges', 'non_edges', 'edge_weight', 'non_edge_weight',
'selfloop_weight'
])
[docs]class SGAttack(TargetedAttacker, Surrogate):
r"""Implementation of `SGA` attack from the:
`"Adversarial Attack on Large Scale Graph"
<https://arxiv.org/abs/2009.03488>`_ paper (TKDE'21)
Parameters
----------
data : Data
PyG-like data denoting the input graph
device : str, optional
the device of the attack running on, by default "cpu"
seed : Optional[int], optional
the random seed for reproducing the attack, by default None
name : Optional[str], optional
name of the attacker, if None, it would be :obj:`__class__.__name__`,
by default None
kwargs : additional arguments of :class:`greatx.attack.Attacker`,
Raises
------
TypeError
unexpected keyword argument in :obj:`kwargs`
Example
-------
.. code-block:: python
from greatx.dataset import GraphDataset
import torch_geometric.transforms as T
dataset = GraphDataset(root='.', name='Cora',
transform=T.LargestConnectedComponents())
data = dataset[0]
surrogate_model = ... # train your surrogate model
from greatx.attack.targeted import SGAttack
attacker = SGAttack(data)
attacker.setup_surrogate(surrogate_model)
attacker.reset()
# attacking target node `1` with default budget set as node degree
attacker.attack(target=1)
# attacking target node `1` with budget set as 1
attacker.attack(target=1, num_budgets=1)
attacker.data() # get attacked graph
attacker.edge_flips() # get edge flips after attack
attacker.added_edges() # get added edges after attack
attacker.removed_edges() # get removed edges after attack
Note
----
* `SGAttack` is a scalable attack that can be applied to large scale graph
* Please remember to call :meth:`reset` before each attack.
"""
# SGAttack cannot ensure that there is not singleton node after attacks.
_allow_singleton = True
[docs] @torch.no_grad()
def setup_surrogate(
self,
surrogate: torch.nn.Module,
*,
tau: float = 5.0,
freeze: bool = True,
):
Surrogate.setup_surrogate(self, surrogate=surrogate, tau=tau,
freeze=freeze)
self.logits = self.surrogate(self.feat, self.edge_index,
self.edge_weight).cpu()
return self
[docs] def set_normalize(self, state):
# TODO: this is incorrect for models
# with `normalize=False` by default
for layer in self.surrogate.modules():
if hasattr(layer, 'normalize'):
layer.normalize = state
if hasattr(layer, 'add_self_loops'):
layer.add_self_loops = state
[docs] def strongest_wrong_class(self, target, target_label):
logit = self.logits[target].clone()
logit[target_label] = -1e4
return logit.argmax()
[docs] def get_subgraph(self, target, target_label, best_wrong_label):
sub_nodes, sub_edges = ego_graph(self.adjacency_matrix, int(target),
self.K)
if sub_edges.size == 0:
raise RuntimeError(
f"The target node {int(target)} is a singleton node.")
sub_nodes = torch.as_tensor(sub_nodes, dtype=torch.long,
device=self.device)
sub_edges = torch.as_tensor(sub_edges, dtype=torch.long,
device=self.device)
attacker_nodes = torch.where(
self.label == best_wrong_label)[0].cpu().numpy()
neighbors = self.adjacency_matrix[target].indices
influencers = [target]
attacker_nodes = np.setdiff1d(attacker_nodes, neighbors)
subgraph = self.subgraph_processing(sub_nodes, sub_edges, influencers,
attacker_nodes)
if self.direct_attack:
influencers = [target]
num_attackers = self.num_budgets + 1
else:
influencers = neighbors
num_attackers = 3
attacker_nodes = self.get_top_attackers(subgraph, target, target_label,
best_wrong_label,
num_attackers=num_attackers)
subgraph = self.subgraph_processing(sub_nodes, sub_edges, influencers,
attacker_nodes)
return subgraph
[docs] def get_top_attackers(self, subgraph, target, target_label,
best_wrong_label, num_attackers):
non_edge_grad, _ = self.compute_gradients(subgraph, target,
target_label,
best_wrong_label)
_, index = torch.topk(non_edge_grad, k=min(num_attackers,
non_edge_grad.size(0)),
sorted=False)
attacker_nodes = subgraph.non_edges[1][index]
return attacker_nodes.tolist()
[docs] def subgraph_processing(self, sub_nodes, sub_edges, influencers,
attacker_nodes):
row = np.repeat(influencers, len(attacker_nodes))
col = np.tile(attacker_nodes, len(influencers))
non_edges = np.row_stack([row, col])
if not self.direct_attack: # indirect attack
mask = self.adjacency_matrix[non_edges[0], non_edges[1]].A1 == 0
non_edges = non_edges[:, mask]
non_edges = torch.as_tensor(non_edges, dtype=torch.long,
device=self.device)
attacker_nodes = torch.as_tensor(attacker_nodes, dtype=torch.long,
device=self.device)
selfloop = torch.unique(torch.cat([sub_nodes, attacker_nodes]))
edge_index = torch.cat([
non_edges, sub_edges,
non_edges.flip(0),
sub_edges.flip(0),
selfloop.repeat((2, 1))
], dim=1)
edge_weight = torch.ones(sub_edges.size(1),
device=self.device).requires_grad_()
non_edge_weight = torch.zeros(non_edges.size(1),
device=self.device).requires_grad_()
selfloop_weight = torch.ones(selfloop.size(0), device=self.device)
subgraph = SubGraph(
edge_index=edge_index,
sub_edges=sub_edges,
non_edges=non_edges,
edge_weight=edge_weight,
non_edge_weight=non_edge_weight,
selfloop_weight=selfloop_weight,
)
return subgraph
[docs] def attack(self, target, *, K: int = 2, target_label=None,
num_budgets=None, direct_attack=True, structure_attack=True,
feature_attack=False, disable=False):
super().attack(target, target_label, num_budgets=num_budgets,
direct_attack=direct_attack,
structure_attack=structure_attack,
feature_attack=feature_attack)
self.set_normalize(False)
self.K = K
target_label = self.target_label.view(-1)
best_wrong_label = self.strongest_wrong_class(target,
target_label).view(-1)
best_wrong_label = best_wrong_label.to(self.device)
subgraph = self.get_subgraph(target, target_label, best_wrong_label)
if not direct_attack:
condition1 = subgraph.sub_edges[0] != target
condition2 = subgraph.sub_edges[1] != target
mask = torch.logical_and(condition1, condition2).float()
for it in tqdm(range(self.num_budgets), desc='Peturbing graph...',
disable=disable):
non_edge_grad, edge_grad = self.compute_gradients(
subgraph, target, target_label, best_wrong_label)
with torch.no_grad():
edge_grad *= -2 * subgraph.edge_weight + 1
if not direct_attack:
edge_grad *= mask
non_edge_grad *= -2 * subgraph.non_edge_weight + 1
max_edge_grad, max_edge_idx = torch.max(edge_grad, dim=0)
max_non_edge_grad, max_non_edge_idx = torch.max(
non_edge_grad, dim=0)
if max_edge_grad > max_non_edge_grad:
# remove one edge
subgraph.edge_weight.data[max_edge_idx].fill_(0.)
u, v = subgraph.sub_edges[:, max_edge_idx].tolist()
self.remove_edge(u, v, it)
else:
# add one edge
subgraph.non_edge_weight.data[max_non_edge_idx].fill_(1.)
u, v = subgraph.non_edges[:, max_non_edge_idx].tolist()
self.add_edge(u, v, it)
self.set_normalize(True)
return self
[docs] def compute_gradients(self, subgraph, target, target_label,
best_wrong_label):
edge_weight = torch.cat([
subgraph.non_edge_weight, subgraph.edge_weight,
subgraph.non_edge_weight, subgraph.edge_weight,
subgraph.selfloop_weight
], dim=0)
row, col = subgraph.edge_index
norm = (self.degree + 1.).pow(-0.5)
edge_weight = norm[row] * edge_weight * norm[col]
logit = self.surrogate(self.feat, subgraph.edge_index, edge_weight)
logit = logit[target].view(1, -1) / self.tau
logit = F.log_softmax(logit, dim=1)
loss = F.nll_loss(logit, target_label) - \
F.nll_loss(logit, best_wrong_label)
return grad(loss, [subgraph.non_edge_weight, subgraph.edge_weight],
create_graph=False)