Source code for greatx.functional.dropouts

from typing import Optional, Tuple

import torch
from torch import Tensor

try:
    import torch_cluster  # noqa
    random_walk = torch.ops.torch_cluster.random_walk
except ImportError:
    random_walk = None

from torch_geometric.utils import degree, sort_edge_index, subgraph
from torch_geometric.utils.num_nodes import maybe_num_nodes


[docs]def drop_edge(edge_index: Tensor, edge_weight: Optional[Tensor] = None, p: float = 0.5, training: bool = True) -> Tuple[Tensor, Optional[Tensor]]: r"""DropEdge: Sampling edge using a uniform distribution from the `"DropEdge: Towards Deep Graph Convolutional Networks on Node Classification" <https://arxiv.org/abs/1907.10903>`_ paper (ICLR'20) Parameters ---------- edge_index : torch.Tensor the input edge index edge_weight : Optional[Tensor], optional the input edge weight, by default None p : float, optional the probability of dropping out on each edge, by default 0.5 training : bool, optional whether the model is during training, do nothing if :obj:`training=True`, by default True Returns ------- Tuple[Tensor, Optional[Tensor]] the output edge index and edge weight Raises ------ ValueError p is out of range [0,1] Example ------- .. code-block:: python from greatx.functional import drop_edge edge_index = torch.LongTensor([[1, 2], [3,4]]) drop_edge(edge_index, p=0.5) See also -------- :class:`greatx.nn.layers.DropEdge` """ if p < 0. or p > 1.: raise ValueError(f'Dropout probability has to be between 0 and 1 ' f'(got {p}') if not training or not p: return edge_index, edge_weight num_edges = edge_index.size(1) e_ids = torch.arange(num_edges, dtype=torch.long, device=edge_index.device) mask = torch.full_like(e_ids, p, dtype=torch.float32) mask = torch.bernoulli(mask).to(torch.bool) edge_index = edge_index[:, ~mask] if edge_weight is not None: edge_weight = edge_weight[~mask] return edge_index, edge_weight
[docs]def drop_node( edge_index: Tensor, edge_weight: Optional[Tensor] = None, p: float = 0.5, training: bool = True, num_nodes: Optional[int] = None) -> Tuple[Tensor, Optional[Tensor]]: """DropNode: Sampling node using a uniform distribution from the `"Graph Contrastive Learning with Augmentations" <https://arxiv.org/abs/2010.139023>`_ paper (NeurIPS'20) Parameters ---------- edge_index : torch.Tensor the input edge index edge_weight : Optional[Tensor], optional the input edge weight, by default None p : float, optional the probability of dropping out on each node, by default 0.5 training : bool, optional whether the model is during training, do nothing if :obj:`training=True`, by default True Returns ------- Tuple[Tensor, Optional[Tensor]] the output edge index and edge weight Raises ------ ValueError p is out of range [0,1] Example ------- .. code-block:: python from greatx.functional import drop_node edge_index = torch.LongTensor([[1, 2], [3,4]]) drop_node(edge_index, p=0.5) See also -------- :class:`greatx.nn.layers.DropNode` """ if p < 0. or p > 1.: raise ValueError(f'Dropout probability has to be between 0 and 1 ' f'(got {p}') if not training or not p: return edge_index, edge_weight num_nodes = maybe_num_nodes(edge_index, num_nodes) nodes = torch.arange(num_nodes, dtype=torch.long, device=edge_index.device) mask = torch.full_like(nodes, 1 - p, dtype=torch.float32) mask = torch.bernoulli(mask).to(torch.bool) subset = nodes[mask] return subgraph(subset, edge_index, edge_weight)
[docs]def drop_path(edge_index: Tensor, edge_weight: Optional[Tensor] = None, p: float = 0.5, walks_per_node: int = 1, walk_length: int = 3, num_nodes: Optional[int] = None, start: str = 'node', is_sorted: bool = False, training: bool = True) -> Tuple[Tensor, Optional[Tensor]]: """DropPath: a structured form of :class:`greatx.functional.drop_edge` from the `"MaskGAE: Masked Graph Modeling Meets Graph Autoencoders" <https://arxiv.org/abs/2205.10053>`_ paper (arXiv'22) Parameters ---------- edge_index : torch.Tensor the input edge index edge_weight : Optional[Tensor], optional the input edge weight, by default None p : Optional[Union[float, Tensor]], optional If :obj:`p` is a float value - the percentage of nodes in the graph that chosen as root nodes to perform random walks. If :obj:`p` is :class:`torch.Tensor` - a set of custom root nodes. By default, :obj:`p=0.5`. walks_per_node : int, optional number of walks per node, by default 1 walk_length : int, optional number of walk length per node, by default 3 num_nodes : int, optional number of total nodes in the graph, by default None start : string, optional the type of starting node chosen from node of edge, by default 'node' is_sorted : bool, optional whether the input :obj:`edge_index` is sorted training : bool, optional whether the model is during training, do nothing if :obj:`training=True`, by default True Returns ------- Tuple[Tensor, Optional[Tensor]] the output edge index and edge weight Raises ------ ImportError if :class:`torch_cluster` is not installed. ValueError :obj:`p` is out of scope [0,1] ValueError :obj:`p` is not integer value or a Tensor Example ------- .. code-block:: python from greatx.functional import drop_path edge_index = torch.LongTensor([[1, 2], [3,4]]) drop_path(edge_index, p=0.5) drop_path(edge_index, p=torch.tensor([1,2])) # specify root nodes See also -------- :class:`greatx.nn.layers.DropPath` """ if torch_cluster is None: raise ImportError("`torch_cluster` is not installed.") if not training: return edge_index, edge_weight if p < 0. or p > 1.: raise ValueError(f'Sample probability has to be between 0 and 1 ' f'(got {p}') assert start in ['node', 'edge'] num_edges = edge_index.size(1) edge_mask = edge_index.new_ones(num_edges, dtype=torch.bool) if not training or p == 0.0: return edge_index, edge_mask if random_walk is None: raise ImportError('`dropout_path` requires `torch-cluster`.') num_nodes = maybe_num_nodes(edge_index, num_nodes) if not is_sorted: edge_index = sort_edge_index(edge_index, edge_weight, num_nodes=num_nodes) if edge_weight is not None: edge_index, edge_weight = edge_index row, col = edge_index if start == 'edge': sample_mask = torch.rand(row.size(0), device=edge_index.device) <= p start = row[sample_mask].repeat(walks_per_node) else: start = torch.randperm( num_nodes, device=edge_index.device)[:round(num_nodes * p)].repeat(walks_per_node) deg = degree(row, num_nodes=num_nodes) rowptr = row.new_zeros(num_nodes + 1) torch.cumsum(deg, 0, out=rowptr[1:]) n_id, e_id = random_walk(rowptr, col, start, walk_length, 1.0, 1.0) e_id = e_id[e_id != -1].view(-1) # filter illegal edges edge_mask[e_id] = False if edge_weight is not None: edge_weight = edge_weight[edge_mask] return edge_index[:, edge_mask], edge_weight