Source code for greatx.attack.untargeted.pgd_attack

import math
from typing import Optional, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.autograd import grad
from tqdm.auto import tqdm

from greatx.attack.untargeted.untargeted_attacker import UntargetedAttacker
from greatx.nn.models.surrogate import Surrogate


class PGD:
    """Base class for :class:`PGDAttack`."""
    # PGDAttack cannot ensure that there is not singleton node after attacks.
    _allow_singleton: bool = True

    def attack(
        self,
        num_budgets: int,
        victim_nodes: Tensor,
        victim_labels: Tensor,
        base_lr: float = 0.1,
        grad_clip: Optional[float] = None,
        epochs: int = 200,
        ce_loss: bool = False,
        sample_epochs: int = 20,
        disable: bool = False,
    ) -> "PGD":

        if ce_loss:
            self.loss_fn = cross_entropy_loss
        else:
            self.loss_fn = margin_loss

        perturbations = self.perturbations

        for epoch in tqdm(range(epochs), desc='PGD training...',
                          disable=disable):
            lr = base_lr * num_budgets / math.sqrt(epoch + 1)
            gradients = self.compute_gradients(perturbations, victim_nodes,
                                               victim_labels)

            gradients = self.clip_grad(gradients, grad_clip)

            with torch.no_grad():
                perturbations.data.add_(lr * gradients)
                if perturbations.clamp(0, 1).sum() <= self.num_budgets:
                    perturbations.clamp_(0, 1)
                else:
                    top = perturbations.max().item()
                    bot = (perturbations.min() - 1).clamp_min(0).item()
                    mu = (top + bot) / 2
                    while (top - bot) / 2 > 1e-5:
                        used_budget = (perturbations - mu).clamp(0, 1).sum()
                        if used_budget == self.num_budgets:
                            break
                        elif used_budget > self.num_budgets:
                            bot = mu
                        else:
                            top = mu
                        mu = (top + bot) / 2
                    perturbations.sub_(mu).clamp_(0, 1)

        best_loss = -np.inf
        best_pert = None

        perturbations.detach_()
        for it in tqdm(range(sample_epochs), desc='Bernoulli sampling...',
                       disable=disable):
            sampled = perturbations.bernoulli()
            if sampled.count_nonzero() <= self.num_budgets:
                loss = self.compute_loss(symmetric(sampled), victim_nodes,
                                         victim_labels)
                if best_loss < loss:
                    best_loss = loss
                    best_pert = sampled

        row, col = torch.where(best_pert > 0.)
        for it, (u, v) in enumerate(zip(row.tolist(), col.tolist())):
            if self.adj[u, v] > 0:
                self.remove_edge(u, v, it)
            else:
                self.add_edge(u, v, it)

        return self

    def compute_loss(
        self,
        perturbations: Tensor,
        victim_nodes: Tensor,
        victim_labels: Tensor,
    ) -> Tensor:
        adj = self.adj + perturbations * (1 - 2 * self.adj)
        logit = self.surrogate(self.feat, adj)[victim_nodes]
        if self.tau != 1:
            logit /= self.tau
        loss = self.loss_fn(logit, victim_labels)
        return loss

    def compute_gradients(
        self,
        perturbations: Tensor,
        victim_nodes: Tensor,
        victim_labels: Tensor,
    ) -> Tensor:
        pert_sym = symmetric(perturbations)
        grad_outputs = grad(
            self.compute_loss(pert_sym, victim_nodes, victim_labels), pert_sym)
        return grad(pert_sym, perturbations, grad_outputs=grad_outputs[0])[0]


[docs]class PGDAttack(UntargetedAttacker, PGD, Surrogate): r"""Implementation of `PGD` attack from the: `"Topology Attack and Defense for Graph Neural Networks: An Optimization Perspective" <https://arxiv.org/abs/1906.04214>`_ paper (IJCAI'19) Parameters ---------- data : Data PyG-like data denoting the input graph device : str, optional the device of the attack running on, by default "cpu" seed : Optional[int], optional the random seed for reproducing the attack, by default None name : Optional[str], optional name of the attacker, if None, it would be :obj:`__class__.__name__`, by default None kwargs : additional arguments of :class:`greatx.attack.Attacker`, Raises ------ TypeError unexpected keyword argument in :obj:`kwargs` Example ------- .. code-block:: python from greatx.dataset import GraphDataset import torch_geometric.transforms as T dataset = GraphDataset(root='.', name='Cora', transform=T.LargestConnectedComponents()) data = dataset[0] surrogate_model = ... # train your surrogate model from greatx.attack.untargeted import PGDAttack attacker = PGDAttack(data) attacker.setup_surrogate(surrogate_model, victim_nodes=test_nodes) attacker.reset() attacker.attack(0.05) # attack with 0.05% of edge perturbations attacker.data() # get attacked graph attacker.edge_flips() # get edge flips after attack attacker.added_edges() # get added edges after attack attacker.removed_edges() # get removed edges after attack Note ---- * Please remember to call :meth:`reset` before each attack. """ # PGDAttack cannot ensure that there is not singleton node after attacks. _allow_singleton: bool = True
[docs] def setup_surrogate( self, surrogate: torch.nn.Module, victim_nodes: Tensor, ground_truth: bool = False, *, tau: float = 1.0, freeze: bool = True, ) -> "PGDAttack": """Setup the surrogate model for adversarial attack. Parameters ---------- surrogate : torch.nn.Module the surrogate model victim_nodes : Tensor the victim nodes_set ground_truth : bool, optional whether to use ground-truth label for victim nodes, if False, the node labels are estimated by the surrogate model, by default False tau : float, optional the temperature of softmax activation, by default 1.0 freeze : bool, optional whether to free the surrogate model to avoid the gradient accumulation, by default True Returns ------- PGDAttack the attacker itself """ Surrogate.setup_surrogate(self, surrogate=surrogate, tau=tau, freeze=freeze) if victim_nodes.dtype == torch.bool: victim_nodes = victim_nodes.nonzero().view(-1) self.victim_nodes = victim_nodes.to(self.device) if ground_truth: self.victim_labels = self.label[victim_nodes] else: self.victim_labels = self.estimate_self_training_labels( victim_nodes) self.adj = self.get_dense_adj() return self
[docs] def reset(self) -> "PGDAttack": super().reset() self.perturbations = torch.zeros_like(self.adj).requires_grad_() return self
[docs] def attack( self, num_budgets: Union[int, float] = 0.05, *, base_lr: float = 0.1, grad_clip: Optional[float] = None, epochs: int = 200, ce_loss: bool = False, sample_epochs: int = 20, structure_attack: bool = True, feature_attack: bool = False, disable: bool = False, ) -> "PGDAttack": """Adversarial attack method for "Project gradient descent attack (PGD)" Parameters ---------- num_budgets : Union[int, float], optional the number of attack budgets, coubd be float (ratio) or int (number), by default 0.05 base_lr : float, optional the base learning rate for PGD training, by default 0.1 grad_clip : float, optional gradient clipping for the computed gradients, by default None epochs : int, optional the number of epochs for PGD training, by default 200 ce_loss : bool, optional whether to use cross-entropy loss (True) or margin loss (False), by default False sample_epochs : int, optional the number of sampling epochs for learned perturbations, by default 20 structure_attack : bool, optional whether to conduct structure attack, i.e., modify the graph structure (edges), by default True feature_attack : bool, optional whether to conduct feature attack, i.e., modify the node features, N/A for this method. by default False disable : bool, optional whether to disable the tqdm progress bar, by default False Returns ------- PGDAttack the attacker itself """ super().attack(num_budgets=num_budgets, structure_attack=structure_attack, feature_attack=feature_attack) return PGD.attack( self, self.num_budgets, victim_nodes=self.victim_nodes, victim_labels=self.victim_labels, base_lr=base_lr, grad_clip=grad_clip, epochs=epochs, ce_loss=ce_loss, sample_epochs=sample_epochs, disable=disable, )
def symmetric(x: Tensor) -> Tensor: x = x.triu(diagonal=1) return x + x.T def margin_loss(logit: Tensor, y_true: Tensor) -> Tensor: all_nodes = torch.arange(y_true.size(0)) # Get the scores of the true classes. scores_true = logit[all_nodes, y_true] # Get the highest scores when not considering the true classes. scores_mod = logit.clone() scores_mod[all_nodes, y_true] = -np.inf scores_pred_excl_true = scores_mod.amax(dim=-1) return -(scores_true - scores_pred_excl_true).tanh().mean() def cross_entropy_loss(logit: Tensor, y_true: Tensor) -> Tensor: return F.cross_entropy(logit, y_true)