Source code for greatx.utils.ego_graph

from collections import namedtuple
from typing import Union

import numpy as np
import scipy.sparse as sp

ego_graph_nodes_edges = namedtuple('ego_graph', ['nodes', 'edges'])

__all__ = ['ego_graph']

[docs]def ego_graph(adj_matrix: sp.csr_matrix, targets: Union[int, list], hops: int = 1) -> ego_graph_nodes_edges: """Returns induced subgraph of neighbors centered at node n within a given radius. Parameters ---------- adj_matrix : sp.csr_matrix, a Scipy CSR sparse adjacency matrix representing a graph targets : Union[int, list] center nodes, a single node or a list of nodes hops : int number, optional Include all neighbors of distance<=hops from nodes. Returns ------- NamedTuple(nodes, edges): nodes: shape [N], the nodes of the subgraph edges: shape [2, M], the edges of the subgraph Note ---- This is a faster implementation of :class:`networkx.ego_graph` based on scipy sparse matrix and numba See Also -------- :class:`networkx.ego_graph` :class:`torch_geometric.utils.k_hop_subgraph` """ fn = get_numbafn() assert sp.issparse(adj_matrix) adj_matrix = adj_matrix.tocsr(copy=False) if np.ndim(targets) == 0: targets = [targets] elif isinstance(targets, np.ndarray): targets = targets.tolist() else: targets = list(targets) indices = adj_matrix.indices indptr = adj_matrix.indptr edges = {} start = 0 N = adj_matrix.shape[0] seen = np.zeros(N) - 1 seen[targets] = 0 for level in range(hops): end = len(targets) while start < end: head = targets[start] nbrs = indices[indptr[head]:indptr[head + 1]] for u in nbrs: if seen[u] < 0: targets.append(u) seen[u] = level + 1 if (u, head) not in edges: edges[(head, u)] = level + 1 start += 1 if len(targets[start:]): e = fn(indices, indptr, np.array(targets[start:]), seen, hops) else: e = [] return ego_graph_nodes_edges(nodes=np.asarray(targets), edges=np.asarray(list(edges.keys()) + e).T)
def get_numbafn(): from numba import njit, types from numba.typed import Dict @njit def _get_remaining_edges(indices: np.ndarray, indptr: np.ndarray, last_level: np.ndarray, seen: np.ndarray, hops: int) -> list: edges = [] mapping = Dict.empty( key_type=types.int64, value_type=types.int64, ) for u in last_level: nbrs = indices[indptr[u]:indptr[u + 1]] nbrs = nbrs[seen[nbrs] == hops] mapping[u] = 1 for v in nbrs: if v not in mapping: edges.append((u, v)) return edges return _get_remaining_edges