from copy import copy
from typing import Union
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.utils import degree
from greatx.nn.models.surrogate import Surrogate
from greatx.utils import remove_edges
[docs]class UniversalDefense(torch.nn.Module):
r"""Base class for graph universal defense
from the `"Graph Universal Adversarial Defense"
<https://arxiv.org/abs/2204.09803>`_ paper (arXiv'22)
"""
def __init__(self, device: str = "cpu"):
super().__init__()
self.device = torch.device(device)
self._anchors = None
[docs] def forward(self, data: Data, target_nodes: Union[int, Tensor],
k: int = 50, symmetric: bool = True) -> Data:
"""Return the defended graph with defensive perturbation
performed on.
Parameters
----------
data : a graph represented as PyG-like data instance
the graph where the defensive perturbation performed on
target_nodes : Union[int, Tensor]
the target nodes where the defensive perturbation performed on
k : int
the number of anchor nodes in the defensive perturbation,
by default 50
symmetric : bool
Determine whether the resulting graph is forcibly symmetric,
by default True
Returns
-------
Data: PyG-like data
the defended graph with defensive perturbation
performed on the target nodes
"""
data = copy(data)
data.edge_index = remove_edges(data.edge_index,
self.removed_edges(target_nodes, k),
symmetric=symmetric)
return data
[docs] def removed_edges(self, target_nodes: Union[int, Tensor],
k: int = 50) -> Tensor:
"""Return edges to remove with the defensive perturbation performed on
on the target nodes
Parameters
----------
target_nodes : Union[int, Tensor]
the target nodes where the defensive perturbation
performed on
k : int
the number of anchor nodes in the defensive perturbation,
by default 50
Returns
-------
Tensor, shape [2, k]
the edges to remove with the defensive perturbation performed on
on the target nodes
"""
row = torch.as_tensor(target_nodes, device=self.device).view(-1)
col = self.anchors(k)
row, col = row.repeat_interleave(k), col.repeat(row.size(0))
return torch.stack([row, col], dim=0)
[docs] def anchors(self, k: int = 50) -> Tensor:
"""Return the top-k anchor nodes
Parameters
----------
k : int, optional
the number of anchor nodes in the defensive perturbation,
by default 50
Returns
-------
Tensor
the top-k anchor nodes
"""
assert k > 0
return self._anchors[:k].to(self.device)
[docs] def patch(self, k=50) -> Tensor:
"""Return the universal patch of the defensive perturbation
Parameters
----------
k : int, optional
the number of anchor nodes in the defensive perturbation,
by default 50
Returns
-------
Tensor
the 0-1 (boolean) universal patch where 1
denotes the edges to be removed.
"""
_patch = torch.zeros(self.num_nodes, dtype=torch.bool,
device=self.device)
_patch[self.anchors(k=k)] = True
return _patch
[docs]class GUARD(UniversalDefense, Surrogate):
r"""Implementation of Graph Universal Adversarial
Defense (GUARD) from the `"Graph Universal
Adversarial Defense"
<https://arxiv.org/abs/2204.09803>`_ paper (arXiv'22)
Parameters
----------
data : Data
the PyG-like input data
alpha : float, optional
the scale factor for node degree, by default 2
batch_size : int, optional
the batch size for computing node influence, by default 512
device : str, optional
the device where the method running on, by default "cpu"
Example
-------
.. code-block:: python
surrogate = GCN(num_features, num_classes, bias=False, acts=None)
surrogate_trainer = Trainer(surrogate, device=device)
ckp = ModelCheckpoint('guard.pth', monitor='val_acc')
trainer.fit(data, mask=(splits.train_nodes,
splits.val_nodes), callbacks=[ckp])
trainer.evaluate(data, splits.test_nodes)
guard = GUARD(data, device=device)
guard.setup_surrogate(surrogate, data.y[splits.train_nodes])
target_node = 1
perturbed_data = ... # Other PyG-like Data
guard(perturbed_data, target_node, k=50)
"""
def __init__(self, data: Data, alpha: float = 2, batch_size: int = 512,
device: str = "cpu"):
super().__init__(device=device)
self.data = data
self.alpha = alpha
self.batch_size = batch_size
self.influence_score = None
self.deg = degree(data.edge_index[0], num_nodes=data.num_nodes,
dtype=torch.float)
[docs] @torch.no_grad()
def setup_surrogate(self, surrogate: torch.nn.Module,
victim_labels: Tensor) -> "GUARD":
from greatx.nn.models.supervised import GCN, SGC
Surrogate.setup_surrogate(self, surrogate=surrogate, freeze=True,
required=(SGC, GCN))
W = None
for para in self.surrogate.parameters():
if para.ndim == 1:
continue
if W is None:
W = para.detach()
else:
W = para.detach() @ W
W = self.data.x.to(self.device) @ W.t()
d = self.deg.clamp(min=1).to(self.device)
loader = DataLoader(victim_labels, pin_memory=False,
batch_size=self.batch_size, shuffle=False)
w_max = W.max(1).values
incluence = 0.
for y in loader:
incluence += W[:, y].sum(1)
incluence = (w_max - incluence / victim_labels.size(0)) / \
d.pow(self.alpha) # node importance
self._anchors = torch.argsort(incluence, descending=True)
self.influence_score = incluence
return self
[docs]class DegreeGUARD(UniversalDefense):
r"""Implementation of Graph Universal Defense
based on node degrees from the `"Graph Universal
Adversarial Defense"
<https://arxiv.org/abs/2204.09803>`_ paper (arXiv'22)
Parameters
----------
data : Data
the PyG-like input data
descending : bool, optional
whether the degree of chosen nodes are in descending order,
by default False
device : str, optional
the device where the method running on, by default "cpu"
Example
-------
.. code-block:: python
data = ... # PyG-like Data
guard = DegreeGUARD(data))
target_node = 1
perturbed_data = ... # Other PyG-like Data
guard(perturbed_data, target_node, k=50)
"""
def __init__(self, data: Data, descending: bool = False,
device: str = "cpu"):
super().__init__(device=device)
deg = degree(data.edge_index[0], num_nodes=data.num_nodes,
dtype=torch.float)
self._anchors = torch.argsort(deg, descending=descending)
[docs]class RandomGUARD(UniversalDefense):
r"""Implementation of Graph Universal Defense
based on random choices from the `"Graph Universal
Adversarial Defense"
<https://arxiv.org/abs/2204.09803>`_ paper (arXiv'22)
Parameters
----------
data : Data
the PyG-like input data
device : str, optional
the device where the method running on, by default "cpu"
Example
-------
.. code-block:: python
data = ... # PyG-like Data
guard = RandomGUARD(data)
target_node = 1
perturbed_data = ... # Other PyG-like Data
guard(perturbed_data, target_node, k=50)
"""
def __init__(self, data: Data, device: str = "cpu"):
super().__init__(device=device)
self._anchors = torch.randperm(data.num_nodes, device=self.device)