import warnings
from copy import copy
from functools import lru_cache
from typing import Optional
import numpy as np
import torch
from torch import Tensor
from torch_geometric.data import Data
from greatx.attack.attacker import Attacker
from greatx.utils import BunchDict, add_edges, remove_edges
[docs]class FlipAttacker(Attacker):
"""Adversarial attacker for graph data by flipping edges.
Parameters
----------
data : Data
PyG-like data denoting the input graph
device : str, optional
the device of the attack running on, by default "cpu"
seed : Optional[int], optional
the random seed for reproducing the attack, by default None
name : Optional[str], optional
name of the attacker, if None, it would be
:obj:`__class__.__name__`, by default None
kwargs : additional arguments of :class:`greatx.attack.Attacker`,
Raises
------
TypeError
unexpected keyword argument in :obj:`kwargs`
Note
----
:class:`greatx.attack.FlipAttacker` is a base class
for graph modification attacks (GMA).
"""
[docs] def reset(self) -> "FlipAttacker":
"""Reset attacker. This method must be called before attack."""
super().reset()
self.data.cache_clear()
self._removed_edges = {}
self._added_edges = {}
self._removed_feats = {}
self._added_feats = {}
self.degree = self._degree.clone()
return self
[docs] def remove_edge(self, u: int, v: int, it: Optional[int] = None):
"""Remove an edge from the graph.
Parameters
----------
u : int
The source node of the edge
v : int
The destination node of the edge
it : Optional[int], optional
The iteration that indicates the order of
the edge being removed, by default None
"""
if not self._allow_singleton:
is_singleton_u = self.degree[u] <= 1
is_singleton_v = self.degree[v] <= 1
if is_singleton_u or is_singleton_v:
warnings.warn(
f"You are trying to remove an edge ({u}-{v}) "
"that would result in singleton nodes. "
"If the behavior is not intended, "
"please make sure you have set "
"`attacker.set_allow_singleton(False)` "
"or check your algorithm.", UserWarning)
self._removed_edges[(u, v)] = it
self.degree[u] -= 1
self.degree[v] -= 1
[docs] def add_edge(self, u: int, v: int, it: Optional[int] = None):
"""Add one edge to the graph.
Parameters
----------
u : int
The source node of the edge
v : int
The destination node of the edge
it : Optional[int], optional
The iteration that indicates the order of
the edge being added, by default None
"""
self._added_edges[(u, v)] = it
self.degree[u] += 1
self.degree[v] += 1
[docs] def removed_edges(self) -> Optional[Tensor]:
"""Get all the edges to be removed.
"""
edges = self._removed_edges
if edges is None or len(edges) == 0:
return None
if torch.is_tensor(edges):
return edges.to(self.device)
if isinstance(edges, dict):
edges = list(edges.keys())
removed = torch.tensor(
np.asarray(edges, dtype="int64").T, device=self.device)
return removed
[docs] def added_edges(self) -> Optional[Tensor]:
"""Get all the edges to be added."""
edges = self._added_edges
if edges is None or len(edges) == 0:
return None
if torch.is_tensor(edges):
return edges.to(self.device)
if isinstance(edges, dict):
edges = list(edges.keys())
return torch.tensor(
np.asarray(edges, dtype="int64").T, device=self.device)
[docs] def edge_flips(self, frac: float = 1.0) -> BunchDict:
"""Get all the edges to be flipped, including edges
to be added and removed.
Parameters
----------
frac : float, optional
the fraction of edge perturbations, i.e.,
how many perturbed features are used to
construct the perturbed graph.
by default 1.0
Example
-------
>>> # Get the edge flips
>>> attacker.edge_flips()
>>> # Get the edge flips, with
>>> # specifying feat_ratio
>>> attacker.edge_flips(frac=0.5)
"""
assert 0 <= frac <= 1
added = self.added_edges()
if added is not None:
added = added[:, :round(added.size(1) * frac)]
removed = self.removed_edges()
if removed is not None:
removed = removed[:, :round(removed.size(1) * frac)]
_all = cat(added, removed, dim=1)
return BunchDict(added=added, removed=removed, all=_all)
[docs] def remove_feat(self, u: int, v: int, it: Optional[int] = None):
"""Remove the feature in a dimension `v` form a node `u`.
That is, set a dimension of the specific node to zero.
Parameters
----------
u : int
the node whose features are to be removed
v : int
the dimension of the feature to be removed
it : Optional[int], optional
The iteration that indicates the order
of the features being removed, by default None
"""
self._removed_feats[(u, v)] = it
[docs] def add_feat(self, u: int, v: int, it: Optional[int] = None):
"""Remove the feature in a dimension `v` form a node `u`.
That is, set a dimension of the specific node to one.
Parameters
----------
u : int
the node whose features are to be added
v : int
the dimension of the feature to be added
it : Optional[int], optional
The iteration that indicates the order
of the features being added, by default None
"""
self._added_feats[(u, v)] = it
[docs] def removed_feats(self) -> Optional[Tensor]:
"""Get all the features to be removed."""
feats = self._removed_feats
if feats is None or len(feats) == 0:
return None
if isinstance(feats, dict):
feats = list(feats.keys())
if torch.is_tensor(feats):
return feats.to(self.device)
return torch.tensor(
np.asarray(feats, dtype="int64").T, device=self.device)
[docs] def added_feats(self) -> Optional[Tensor]:
"""Get all the features to be added."""
feats = self._added_feats
if feats is None or len(feats) == 0:
return None
if isinstance(feats, dict):
feats = list(feats.keys())
if torch.is_tensor(feats):
return feats.to(self.device)
return torch.tensor(
np.asarray(feats, dtype="int64").T, device=self.device)
[docs] def feat_flips(self, frac: float = 1.0) -> BunchDict:
"""Get all the features to be flipped, including features
to be added and removed.
Parameters
----------
frac : float, optional
the fraction of feature perturbations, i.e.,
how many perturbed features are used to
construct the perturbed graph.
by default 1.0
Example
-------
>>> # Get the feature flips
>>> attacker.feat_flips()
>>> # Get the feature flips, with
>>> # specifying feat_ratio
>>> attacker.feat_flips(frac=0.5)
"""
assert 0 <= frac <= 1
added = self.added_feats()
if added is not None:
added = added[:, :round(added.size(1) * frac)]
removed = self.removed_feats()
if removed is not None:
removed = removed[:, :round(removed.size(1) * frac)]
_all = cat(added, removed, dim=1)
return BunchDict(added=added, removed=removed, all=_all)
[docs] @lru_cache(maxsize=1)
def data(
self,
edge_ratio: float = 1.0,
feat_ratio: float = 1.0,
coalesce: bool = True,
symmetric: bool = True,
) -> Data:
"""Get the attacked graph denoted by
PyG-like data instance. Note that this method
uses LRU cache for efficiency, the computation is
only excuted at the first call if the input parameters
were the same.
Parameters
----------
edge_ratio : float, optional
the fraction of edge perturbations, i.e.,
how many perturbed edges are used to
construct the perturbed graph.
by default 1.0
feat_ratio : float, optional
the fraction of feature perturbations, i.e.,
how many perturbed features are used to
construct the perturbed graph.
by default 1.0
coalesce : bool, optional
whether to coalesce the output edges.
symmetric : bool, optional
whether the output graph is symmetric, by default True
Example
-------
>>> # Get the perturbed graph, including
>>> # edge flips and feature flips
>>> attacker.data()
>>> # Get the perturbed graph, with
>>> # specifying edge_ratio
>>> attacker.data(edge_ratio=0.5)
>>> # Get the perturbed graph, with
>>> # specifying feat_ratio
>>> attacker.data(feat_ratio=0.5)
Returns
-------
Data
the attacked graph denoted by PyG-like data instance
"""
data = copy(self.ori_data)
edge_index = data.edge_index
edge_weight = data.edge_weight
assert edge_weight is None, 'weighted graph is not supported now.'
edge_flips = self.edge_flips(frac=edge_ratio)
removed = edge_flips['removed']
if removed is not None:
edge_index = remove_edges(edge_index, removed, symmetric=symmetric)
added = edge_flips['added']
if added is not None:
edge_index = add_edges(edge_index, added, symmetric=symmetric,
coalesce=coalesce)
data.edge_index = edge_index
if edge_weight is not None:
data.edge_weight = edge_weight
if self.feature_attack:
feat = self.feat.detach().clone()
feat_flips = self.feat_flips(frac=feat_ratio)
removed = feat_flips['removed']
if removed is not None:
feat[removed[0], removed[1]] = 0.
added = feat_flips['added']
if added is not None:
feat[added[0], added[1]] = 1.
data.x = feat
return data
[docs] def set_allow_singleton(self, state: bool):
"""Set whether the attacked graph allow singleton node, i.e.,
zero degree nodes.
Parameters
----------
state : bool
the flag to set
Example
-------
>>> attacker.set_allow_singleton(True)
"""
self._allow_singleton = state
[docs] def is_singleton_edge(self, u: int, v: int) -> bool:
"""Check if the edge is an singleton edge that, if removed,
would result in a singleton node in the graph.
Parameters
-----------
u : int
The source node of the edge
v : int
The destination node of the edge
Return
------
bool: `True` if the edge is an singleton edge, otherwise `False`.
Note
----
Please make sure the edge is the one being removed.
"""
threshold = 1
# threshold = 2 if the graph has selfloop before
# otherwise threshold = 1
if not self._allow_singleton and (self.degree[u] <= threshold
or self.degree[v] <= threshold):
return True
return False
[docs] def is_legal_edge(self, u: int, v: int) -> bool:
"""Check whether the edge (u,v) is legal.
An edge (u,v) is legal if u!=v and edge (u,v) is
not selected before.
Parameters
-----------
u : int
The source node of the edge
v : int
The destination node of the edge
Returns
-------
bool: :obj:`True` if the u!=v and edge (u,v), (v,u) is not selected,
otherwise :obj:`False`.
"""
_removed_edges = self._removed_edges
_added_edges = self._added_edges
return all((u != v, (u, v) not in _removed_edges, (v, u)
not in _removed_edges, (u, v) not in _added_edges, (v, u)
not in _added_edges))
def cat(a, b, dim=1):
if a is None:
return b
if b is None:
return a
return torch.cat([a, b], dim=dim)