from typing import Optional, Tuple
import torch
from torch import Tensor
try:
import torch_cluster # noqa
random_walk = torch.ops.torch_cluster.random_walk
except ImportError:
random_walk = None
from torch_geometric.utils import degree, sort_edge_index, subgraph
from torch_geometric.utils.num_nodes import maybe_num_nodes
[docs]def drop_edge(edge_index: Tensor, edge_weight: Optional[Tensor] = None,
p: float = 0.5,
training: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
r"""DropEdge: Sampling edge using a uniform distribution
from the `"DropEdge: Towards Deep Graph Convolutional
Networks on Node Classification" <https://arxiv.org/abs/1907.10903>`_
paper (ICLR'20)
Parameters
----------
edge_index : torch.Tensor
the input edge index
edge_weight : Optional[Tensor], optional
the input edge weight, by default None
p : float, optional
the probability of dropping out on each edge, by default 0.5
training : bool, optional
whether the model is during training,
do nothing if :obj:`training=True`, by default True
Returns
-------
Tuple[Tensor, Optional[Tensor]]
the output edge index and edge weight
Raises
------
ValueError
p is out of range [0,1]
Example
-------
.. code-block:: python
from greatx.functional import drop_edge
edge_index = torch.LongTensor([[1, 2], [3,4]])
drop_edge(edge_index, p=0.5)
See also
--------
:class:`greatx.nn.layers.DropEdge`
"""
if p < 0. or p > 1.:
raise ValueError(f'Dropout probability has to be between 0 and 1 '
f'(got {p}')
if not training or not p:
return edge_index, edge_weight
num_edges = edge_index.size(1)
e_ids = torch.arange(num_edges, dtype=torch.long, device=edge_index.device)
mask = torch.full_like(e_ids, p, dtype=torch.float32)
mask = torch.bernoulli(mask).to(torch.bool)
edge_index = edge_index[:, ~mask]
if edge_weight is not None:
edge_weight = edge_weight[~mask]
return edge_index, edge_weight
[docs]def drop_node(
edge_index: Tensor, edge_weight: Optional[Tensor] = None,
p: float = 0.5, training: bool = True,
num_nodes: Optional[int] = None) -> Tuple[Tensor, Optional[Tensor]]:
"""DropNode: Sampling node using a uniform distribution
from the `"Graph Contrastive Learning
with Augmentations" <https://arxiv.org/abs/2010.139023>`_
paper (NeurIPS'20)
Parameters
----------
edge_index : torch.Tensor
the input edge index
edge_weight : Optional[Tensor], optional
the input edge weight, by default None
p : float, optional
the probability of dropping out on each node, by default 0.5
training : bool, optional
whether the model is during training,
do nothing if :obj:`training=True`, by default True
Returns
-------
Tuple[Tensor, Optional[Tensor]]
the output edge index and edge weight
Raises
------
ValueError
p is out of range [0,1]
Example
-------
.. code-block:: python
from greatx.functional import drop_node
edge_index = torch.LongTensor([[1, 2], [3,4]])
drop_node(edge_index, p=0.5)
See also
--------
:class:`greatx.nn.layers.DropNode`
"""
if p < 0. or p > 1.:
raise ValueError(f'Dropout probability has to be between 0 and 1 '
f'(got {p}')
if not training or not p:
return edge_index, edge_weight
num_nodes = maybe_num_nodes(edge_index, num_nodes)
nodes = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device)
mask = torch.full_like(nodes, 1 - p, dtype=torch.float32)
mask = torch.bernoulli(mask).to(torch.bool)
subset = nodes[mask]
return subgraph(subset, edge_index, edge_weight)
[docs]def drop_path(edge_index: Tensor, edge_weight: Optional[Tensor] = None,
p: float = 0.5, walks_per_node: int = 1, walk_length: int = 3,
num_nodes: Optional[int] = None, start: str = 'node',
is_sorted: bool = False,
training: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
"""DropPath: a structured form of :class:`greatx.functional.drop_edge`
from the `"MaskGAE: Masked Graph Modeling Meets
Graph Autoencoders" <https://arxiv.org/abs/2205.10053>`_
paper (arXiv'22)
Parameters
----------
edge_index : torch.Tensor
the input edge index
edge_weight : Optional[Tensor], optional
the input edge weight, by default None
p : Optional[Union[float, Tensor]], optional
If :obj:`p` is a float value - the percentage of
nodes in the graph that chosen as root nodes to
perform random walks.
If :obj:`p` is :class:`torch.Tensor` - a set of
custom root nodes.
By default, :obj:`p=0.5`.
walks_per_node : int, optional
number of walks per node, by default 1
walk_length : int, optional
number of walk length per node, by default 3
num_nodes : int, optional
number of total nodes in the graph, by default None
start : string, optional
the type of starting node chosen from node of edge,
by default 'node'
is_sorted : bool, optional
whether the input :obj:`edge_index` is sorted
training : bool, optional
whether the model is during training,
do nothing if :obj:`training=True`, by default True
Returns
-------
Tuple[Tensor, Optional[Tensor]]
the output edge index and edge weight
Raises
------
ImportError
if :class:`torch_cluster` is not installed.
ValueError
:obj:`p` is out of scope [0,1]
ValueError
:obj:`p` is not integer value or a Tensor
Example
-------
.. code-block:: python
from greatx.functional import drop_path
edge_index = torch.LongTensor([[1, 2], [3,4]])
drop_path(edge_index, p=0.5)
drop_path(edge_index, p=torch.tensor([1,2])) # specify root nodes
See also
--------
:class:`greatx.nn.layers.DropPath`
"""
if torch_cluster is None:
raise ImportError("`torch_cluster` is not installed.")
if not training:
return edge_index, edge_weight
if p < 0. or p > 1.:
raise ValueError(f'Sample probability has to be between 0 and 1 '
f'(got {p}')
assert start in ['node', 'edge']
num_edges = edge_index.size(1)
edge_mask = edge_index.new_ones(num_edges, dtype=torch.bool)
if not training or p == 0.0:
return edge_index, edge_mask
if random_walk is None:
raise ImportError('`dropout_path` requires `torch-cluster`.')
num_nodes = maybe_num_nodes(edge_index, num_nodes)
if not is_sorted:
edge_index = sort_edge_index(edge_index, edge_weight,
num_nodes=num_nodes)
if edge_weight is not None:
edge_index, edge_weight = edge_index
row, col = edge_index
if start == 'edge':
sample_mask = torch.rand(row.size(0), device=edge_index.device) <= p
start = row[sample_mask].repeat(walks_per_node)
else:
start = torch.randperm(
num_nodes,
device=edge_index.device)[:round(num_nodes *
p)].repeat(walks_per_node)
deg = degree(row, num_nodes=num_nodes)
rowptr = row.new_zeros(num_nodes + 1)
torch.cumsum(deg, 0, out=rowptr[1:])
n_id, e_id = random_walk(rowptr, col, start, walk_length, 1.0, 1.0)
e_id = e_id[e_id != -1].view(-1) # filter illegal edges
edge_mask[e_id] = False
if edge_weight is not None:
edge_weight = edge_weight[edge_mask]
return edge_index[:, edge_mask], edge_weight