Source code for greatx.attack.flip_attacker

import warnings
from copy import copy
from functools import lru_cache
from typing import Optional

import numpy as np
import torch
from torch import Tensor
from torch_geometric.data import Data

from greatx.attack.attacker import Attacker
from greatx.utils import BunchDict, add_edges, remove_edges


[docs]class FlipAttacker(Attacker): """Adversarial attacker for graph data by flipping edges. Parameters ---------- data : Data PyG-like data denoting the input graph device : str, optional the device of the attack running on, by default "cpu" seed : Optional[int], optional the random seed for reproducing the attack, by default None name : Optional[str], optional name of the attacker, if None, it would be :obj:`__class__.__name__`, by default None kwargs : additional arguments of :class:`greatx.attack.Attacker`, Raises ------ TypeError unexpected keyword argument in :obj:`kwargs` Note ---- :class:`greatx.attack.FlipAttacker` is a base class for graph modification attacks (GMA). """
[docs] def reset(self) -> "FlipAttacker": """Reset attacker. This method must be called before attack.""" super().reset() self.data.cache_clear() self._removed_edges = {} self._added_edges = {} self._removed_feats = {} self._added_feats = {} self.degree = self._degree.clone() return self
[docs] def remove_edge(self, u: int, v: int, it: Optional[int] = None): """Remove an edge from the graph. Parameters ---------- u : int The source node of the edge v : int The destination node of the edge it : Optional[int], optional The iteration that indicates the order of the edge being removed, by default None """ if not self._allow_singleton: is_singleton_u = self.degree[u] <= 1 is_singleton_v = self.degree[v] <= 1 if is_singleton_u or is_singleton_v: warnings.warn( f"You are trying to remove an edge ({u}-{v}) " "that would result in singleton nodes. " "If the behavior is not intended, " "please make sure you have set " "`attacker.set_allow_singleton(False)` " "or check your algorithm.", UserWarning) self._removed_edges[(u, v)] = it self.degree[u] -= 1 self.degree[v] -= 1
[docs] def add_edge(self, u: int, v: int, it: Optional[int] = None): """Add one edge to the graph. Parameters ---------- u : int The source node of the edge v : int The destination node of the edge it : Optional[int], optional The iteration that indicates the order of the edge being added, by default None """ self._added_edges[(u, v)] = it self.degree[u] += 1 self.degree[v] += 1
[docs] def removed_edges(self) -> Optional[Tensor]: """Get all the edges to be removed. """ edges = self._removed_edges if edges is None or len(edges) == 0: return None if torch.is_tensor(edges): return edges.to(self.device) if isinstance(edges, dict): edges = list(edges.keys()) removed = torch.tensor( np.asarray(edges, dtype="int64").T, device=self.device) return removed
[docs] def added_edges(self) -> Optional[Tensor]: """Get all the edges to be added.""" edges = self._added_edges if edges is None or len(edges) == 0: return None if torch.is_tensor(edges): return edges.to(self.device) if isinstance(edges, dict): edges = list(edges.keys()) return torch.tensor( np.asarray(edges, dtype="int64").T, device=self.device)
[docs] def edge_flips(self, frac: float = 1.0) -> BunchDict: """Get all the edges to be flipped, including edges to be added and removed. Parameters ---------- frac : float, optional the fraction of edge perturbations, i.e., how many perturbed features are used to construct the perturbed graph. by default 1.0 Example ------- >>> # Get the edge flips >>> attacker.edge_flips() >>> # Get the edge flips, with >>> # specifying feat_ratio >>> attacker.edge_flips(frac=0.5) """ assert 0 <= frac <= 1 added = self.added_edges() if added is not None: added = added[:, :round(added.size(1) * frac)] removed = self.removed_edges() if removed is not None: removed = removed[:, :round(removed.size(1) * frac)] _all = cat(added, removed, dim=1) return BunchDict(added=added, removed=removed, all=_all)
[docs] def remove_feat(self, u: int, v: int, it: Optional[int] = None): """Remove the feature in a dimension `v` form a node `u`. That is, set a dimension of the specific node to zero. Parameters ---------- u : int the node whose features are to be removed v : int the dimension of the feature to be removed it : Optional[int], optional The iteration that indicates the order of the features being removed, by default None """ self._removed_feats[(u, v)] = it
[docs] def add_feat(self, u: int, v: int, it: Optional[int] = None): """Remove the feature in a dimension `v` form a node `u`. That is, set a dimension of the specific node to one. Parameters ---------- u : int the node whose features are to be added v : int the dimension of the feature to be added it : Optional[int], optional The iteration that indicates the order of the features being added, by default None """ self._added_feats[(u, v)] = it
[docs] def removed_feats(self) -> Optional[Tensor]: """Get all the features to be removed.""" feats = self._removed_feats if feats is None or len(feats) == 0: return None if isinstance(feats, dict): feats = list(feats.keys()) if torch.is_tensor(feats): return feats.to(self.device) return torch.tensor( np.asarray(feats, dtype="int64").T, device=self.device)
[docs] def added_feats(self) -> Optional[Tensor]: """Get all the features to be added.""" feats = self._added_feats if feats is None or len(feats) == 0: return None if isinstance(feats, dict): feats = list(feats.keys()) if torch.is_tensor(feats): return feats.to(self.device) return torch.tensor( np.asarray(feats, dtype="int64").T, device=self.device)
[docs] def feat_flips(self, frac: float = 1.0) -> BunchDict: """Get all the features to be flipped, including features to be added and removed. Parameters ---------- frac : float, optional the fraction of feature perturbations, i.e., how many perturbed features are used to construct the perturbed graph. by default 1.0 Example ------- >>> # Get the feature flips >>> attacker.feat_flips() >>> # Get the feature flips, with >>> # specifying feat_ratio >>> attacker.feat_flips(frac=0.5) """ assert 0 <= frac <= 1 added = self.added_feats() if added is not None: added = added[:, :round(added.size(1) * frac)] removed = self.removed_feats() if removed is not None: removed = removed[:, :round(removed.size(1) * frac)] _all = cat(added, removed, dim=1) return BunchDict(added=added, removed=removed, all=_all)
[docs] @lru_cache(maxsize=1) def data( self, edge_ratio: float = 1.0, feat_ratio: float = 1.0, coalesce: bool = True, symmetric: bool = True, ) -> Data: """Get the attacked graph denoted by PyG-like data instance. Note that this method uses LRU cache for efficiency, the computation is only excuted at the first call if the input parameters were the same. Parameters ---------- edge_ratio : float, optional the fraction of edge perturbations, i.e., how many perturbed edges are used to construct the perturbed graph. by default 1.0 feat_ratio : float, optional the fraction of feature perturbations, i.e., how many perturbed features are used to construct the perturbed graph. by default 1.0 coalesce : bool, optional whether to coalesce the output edges. symmetric : bool, optional whether the output graph is symmetric, by default True Example ------- >>> # Get the perturbed graph, including >>> # edge flips and feature flips >>> attacker.data() >>> # Get the perturbed graph, with >>> # specifying edge_ratio >>> attacker.data(edge_ratio=0.5) >>> # Get the perturbed graph, with >>> # specifying feat_ratio >>> attacker.data(feat_ratio=0.5) Returns ------- Data the attacked graph denoted by PyG-like data instance """ data = copy(self.ori_data) edge_index = data.edge_index edge_weight = data.edge_weight assert edge_weight is None, 'weighted graph is not supported now.' edge_flips = self.edge_flips(frac=edge_ratio) removed = edge_flips['removed'] if removed is not None: edge_index = remove_edges(edge_index, removed, symmetric=symmetric) added = edge_flips['added'] if added is not None: edge_index = add_edges(edge_index, added, symmetric=symmetric, coalesce=coalesce) data.edge_index = edge_index if edge_weight is not None: data.edge_weight = edge_weight if self.feature_attack: feat = self.feat.detach().clone() feat_flips = self.feat_flips(frac=feat_ratio) removed = feat_flips['removed'] if removed is not None: feat[removed[0], removed[1]] = 0. added = feat_flips['added'] if added is not None: feat[added[0], added[1]] = 1. data.x = feat return data
[docs] def set_allow_singleton(self, state: bool): """Set whether the attacked graph allow singleton node, i.e., zero degree nodes. Parameters ---------- state : bool the flag to set Example ------- >>> attacker.set_allow_singleton(True) """ self._allow_singleton = state
[docs] def is_singleton_edge(self, u: int, v: int) -> bool: """Check if the edge is an singleton edge that, if removed, would result in a singleton node in the graph. Parameters ----------- u : int The source node of the edge v : int The destination node of the edge Return ------ bool: `True` if the edge is an singleton edge, otherwise `False`. Note ---- Please make sure the edge is the one being removed. """ threshold = 1 # threshold = 2 if the graph has selfloop before # otherwise threshold = 1 if not self._allow_singleton and (self.degree[u] <= threshold or self.degree[v] <= threshold): return True return False
def cat(a, b, dim=1): if a is None: return b if b is None: return a return torch.cat([a, b], dim=dim)