import copy
import torch
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.utils import coalesce as coalesce_edges
from torch_geometric.utils import (
from_scipy_sparse_matrix,
sort_edge_index,
to_scipy_sparse_matrix,
)
[docs]def add_edges(edge_index: Tensor, edges_to_add: Tensor, symmetric: bool = True,
coalesce: bool = True, sort_edges: bool = True) -> Tensor:
"""Add edges to the graph denoted as :obj:`edge_index`.
Parameters
----------
edge_index : torch.Tensor
the graph instance where edges will be removed from.
edges_to_add : torch.Tensor
shape [2, M], the edges to be added into the graph.
symmetric : bool
whether the output graph is symmetric, if True,
it will also append the reversed edges into the graph.
coalesce : bool
whether to coalesce the output edges.
sort_edges : bool
whether to sort the output edges.
Returns
-------
Tensor
the graph instance :obj:`edge_index` with edges added.
"""
if edges_to_add.size(1) == 0:
return edge_index
if symmetric:
edges_to_add = torch.cat([edges_to_add, edges_to_add.flip(0)], dim=1)
edges_to_add = edges_to_add.to(edge_index)
edge_index = torch.cat([edge_index, edges_to_add], dim=1)
if coalesce:
edge_index = coalesce_edges(edge_index)
if sort_edges:
edge_index = sort_edge_index(edge_index)
return edge_index
[docs]def remove_edges(edge_index: Tensor, edges_to_remove: Tensor,
symmetric: bool = True) -> Tensor:
"""Remove edges from the graph denoted as :obj:`edge_index`.
Parameters
----------
edge_index : torch.Tensor
the graph instance where edges will be removed from.
edges_to_remove : torch.Tensor
shape [2, M], the edges to be removed in the graph.
symmetric : bool
whether the output graph is symmetric, if True,
it will also remove the reversed edges from the graph.
Returns
-------
Tensor
the graph instance :obj:`edge_index` with edges removed.
"""
if edges_to_remove.size(1) == 0:
return edge_index
device = edge_index.device
if symmetric:
edges_to_remove = torch.cat(
[edges_to_remove, edges_to_remove.flip(0)], dim=1)
edges_to_remove = edges_to_remove.to(edge_index)
num_nodes = max(edge_index.max().item(), edges_to_remove.max().item()) + 1
adj_matrix = to_scipy_sparse_matrix(edge_index,
num_nodes=num_nodes).tocsr(copy=False)
row, col = edges_to_remove.cpu().numpy()
adj_matrix = adj_matrix.tolil(copy=True)
adj_matrix[(row, col)] = 0
adj_matrix = adj_matrix.tocsr(copy=False)
adj_matrix.eliminate_zeros()
edge_index, _ = from_scipy_sparse_matrix(adj_matrix)
edge_index = sort_edge_index(edge_index)
return edge_index.to(device)
[docs]def flip_edges(edge_index: Tensor, edges_to_flip: Tensor,
symmetric: bool = True) -> Tensor:
"""Flip edges from the graph denoted as :obj:`edge_index`.
Parameters
----------
edge_index : torch.Tensor
the graph instance where edges will be flipped from.
edges_to_flip : torch.Tensor
shape [2, M], the edges to be flipped in the graph.
symmetric : bool
whether the output graph is symmetric, if True,
it will also flip the reversed edges from the graph.
Returns
-------
Tensor
the graph instance :obj:`edge_index` with edges flipped.
"""
if edges_to_flip.size(1) == 0:
return edge_index
device = edge_index.device
if symmetric:
edges_to_flip = torch.cat(
[edges_to_flip, edges_to_flip.flip(0)], dim=1)
edges_to_flip = edges_to_flip.to(edge_index)
num_nodes = max(edge_index.max().item(), edges_to_flip.max().item()) + 1
adj_matrix = to_scipy_sparse_matrix(edge_index,
num_nodes=num_nodes).tocsr(copy=False)
row, col = edges_to_flip.cpu().numpy()
data = adj_matrix[(row, col)].A
data[data > 0.] = 1.
data[data < 0.] = 0.
adj_matrix = adj_matrix.tolil(copy=True)
adj_matrix[(row, col)] = 1. - data
adj_matrix = adj_matrix.tocsr(copy=False)
adj_matrix.eliminate_zeros()
edge_index, _ = from_scipy_sparse_matrix(adj_matrix)
edge_index = sort_edge_index(edge_index)
return edge_index.to(device)
[docs]def flip_graph(data: Data, edges_to_flip: Tensor,
symmetric: bool = True) -> Data:
"""Flip edges from the graph denoted as :obj:`data`.
Parameters
----------
edge_index : Data
the graph instance where edges will be flipped from.
edges_to_flip : torch.Tensor
shape [2, M], the edges to be flipped in the graph.
symmetric : bool
whether the output graph is symmetric, if True,
it will also flip the reversed edges from the graph.
Returns
-------
Data
the graph instance :obj:`data` with edges flipped.
NOTE
----
We currently don't support a weigher graph and this function will
automatically set :attr:`edge_weight` and :attr:`adj_t` as :obj:`None`.
"""
data = copy.copy(data)
data.edge_index = flip_edges(data.edge_index, edges_to_flip,
symmetric=symmetric)
data.edge_weight = None
data.adj_t = None
return data