Source code for greatx.utils.modification

import copy

import torch
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.utils import coalesce as coalesce_edges
from torch_geometric.utils import (
    from_scipy_sparse_matrix,
    sort_edge_index,
    to_scipy_sparse_matrix,
)


[docs]def add_edges(edge_index: Tensor, edges_to_add: Tensor, symmetric: bool = True, coalesce: bool = True, sort_edges: bool = True) -> Tensor: """Add edges to the graph denoted as :obj:`edge_index`. Parameters ---------- edge_index : torch.Tensor the graph instance where edges will be removed from. edges_to_add : torch.Tensor shape [2, M], the edges to be added into the graph. symmetric : bool whether the output graph is symmetric, if True, it will also append the reversed edges into the graph. coalesce : bool whether to coalesce the output edges. sort_edges : bool whether to sort the output edges. Returns ------- Tensor the graph instance :obj:`edge_index` with edges added. """ if edges_to_add.size(1) == 0: return edge_index if symmetric: edges_to_add = torch.cat([edges_to_add, edges_to_add.flip(0)], dim=1) edges_to_add = edges_to_add.to(edge_index) edge_index = torch.cat([edge_index, edges_to_add], dim=1) if coalesce: edge_index = coalesce_edges(edge_index) if sort_edges: edge_index = sort_edge_index(edge_index) return edge_index
[docs]def remove_edges(edge_index: Tensor, edges_to_remove: Tensor, symmetric: bool = True) -> Tensor: """Remove edges from the graph denoted as :obj:`edge_index`. Parameters ---------- edge_index : torch.Tensor the graph instance where edges will be removed from. edges_to_remove : torch.Tensor shape [2, M], the edges to be removed in the graph. symmetric : bool whether the output graph is symmetric, if True, it will also remove the reversed edges from the graph. Returns ------- Tensor the graph instance :obj:`edge_index` with edges removed. """ if edges_to_remove.size(1) == 0: return edge_index device = edge_index.device if symmetric: edges_to_remove = torch.cat( [edges_to_remove, edges_to_remove.flip(0)], dim=1) edges_to_remove = edges_to_remove.to(edge_index) num_nodes = max(edge_index.max().item(), edges_to_remove.max().item()) + 1 adj_matrix = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr(copy=False) row, col = edges_to_remove.cpu().numpy() adj_matrix = adj_matrix.tolil(copy=True) adj_matrix[(row, col)] = 0 adj_matrix = adj_matrix.tocsr(copy=False) adj_matrix.eliminate_zeros() edge_index, _ = from_scipy_sparse_matrix(adj_matrix) edge_index = sort_edge_index(edge_index) return edge_index.to(device)
[docs]def flip_edges(edge_index: Tensor, edges_to_flip: Tensor, symmetric: bool = True) -> Tensor: """Flip edges from the graph denoted as :obj:`edge_index`. Parameters ---------- edge_index : torch.Tensor the graph instance where edges will be flipped from. edges_to_flip : torch.Tensor shape [2, M], the edges to be flipped in the graph. symmetric : bool whether the output graph is symmetric, if True, it will also flip the reversed edges from the graph. Returns ------- Tensor the graph instance :obj:`edge_index` with edges flipped. """ if edges_to_flip.size(1) == 0: return edge_index device = edge_index.device if symmetric: edges_to_flip = torch.cat( [edges_to_flip, edges_to_flip.flip(0)], dim=1) edges_to_flip = edges_to_flip.to(edge_index) num_nodes = max(edge_index.max().item(), edges_to_flip.max().item()) + 1 adj_matrix = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr(copy=False) row, col = edges_to_flip.cpu().numpy() data = adj_matrix[(row, col)].A data[data > 0.] = 1. data[data < 0.] = 0. adj_matrix = adj_matrix.tolil(copy=True) adj_matrix[(row, col)] = 1. - data adj_matrix = adj_matrix.tocsr(copy=False) adj_matrix.eliminate_zeros() edge_index, _ = from_scipy_sparse_matrix(adj_matrix) edge_index = sort_edge_index(edge_index) return edge_index.to(device)
[docs]def flip_graph(data: Data, edges_to_flip: Tensor, symmetric: bool = True) -> Data: """Flip edges from the graph denoted as :obj:`data`. Parameters ---------- edge_index : Data the graph instance where edges will be flipped from. edges_to_flip : torch.Tensor shape [2, M], the edges to be flipped in the graph. symmetric : bool whether the output graph is symmetric, if True, it will also flip the reversed edges from the graph. Returns ------- Data the graph instance :obj:`data` with edges flipped. NOTE ---- We currently don't support a weigher graph and this function will automatically set :attr:`edge_weight` and :attr:`adj_t` as :obj:`None`. """ data = copy.copy(data) data.edge_index = flip_edges(data.edge_index, edges_to_flip, symmetric=symmetric) data.edge_weight = None data.adj_t = None return data