Source code for greatx.datasets.graph_dataset

import os.path as osp
from typing import Callable, List, Optional

import numpy as np
import torch
from sklearn.preprocessing import LabelEncoder
from torch_geometric.data import Data, InMemoryDataset, download_url
from torch_geometric.utils import remove_self_loops, to_undirected


def load_npz(file_name: str) -> Data:
    with np.load(file_name, allow_pickle=True) as loader:
        loader = dict(loader)
        adj_matrix = loader['adj_matrix'].item()
        adj_matrix = adj_matrix.maximum(adj_matrix.T)
        attr_matrix = loader['node_attr']

        if attr_matrix.dtype.kind == 'O':
            # scipy sparse matrix
            attr_matrix = attr_matrix.item().A

        labels = loader['node_label']
        if labels.shape[0] != adj_matrix.shape[0]:
            _labels = np.full((adj_matrix.shape[0] - labels.shape[0], ), -1)
            labels = np.hstack([labels, _labels])

        if np.unique(labels).shape[0] != labels.max() + 1:
            labels = LabelEncoder().fit_transform(labels)

        x = torch.from_numpy(attr_matrix).to(torch.float)
        adj_matrix = adj_matrix.tocoo()
        row = torch.from_numpy(adj_matrix.row).to(torch.long)
        col = torch.from_numpy(adj_matrix.col).to(torch.long)
        edge_index = torch.stack([row, col], dim=0)
        edge_index, _ = remove_self_loops(edge_index)
        edge_index = to_undirected(edge_index, num_nodes=x.size(0))

        y = torch.from_numpy(labels).to(torch.long)

    return Data(x=x, edge_index=edge_index, y=y)


DATASETS = {
    'citeseer',
    'citeseer_full',
    'cora',
    'cora_ml',
    'cora_full',
    'amazon_cs',
    'amazon_photo',
    'coauthor_cs',
    'coauthor_phy',
    'polblogs',
    'karate_club',
    'pubmed',
    'flickr',
    'blogcatalog',
    'dblp',
    'acm',
    'uai',
    'pdn',
}


[docs]class GraphDataset(InMemoryDataset): r"""A series of datasets used in GreatX. These datasets are stored in :obj:`.npz` format, consisting of a single graph. Parameters ---------- root : str Root directory where the dataset should be saved. name : str The name of the dataset. See :meth:`available_datasets` for all available datasets. transform : Optional[Callable], optional A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before every access, by default None pre_transform : Optional[Callable], optional A function/transform that takes in an :obj:`torch_geometric.data.Data` object and returns a transformed version. The data object will be transformed before being saved to disk, by default None Example ------- >>> from greatx.dataset import GraphDataset >>> import torch_geometric.transforms as T >>> GraphDataset.available_datasets() # see all available datasets. ['cora', 'citeseer', 'pubmed', ...] >>> dataset = GraphDataset(root='.', name='Cora') >>> data = dataset[0] # there is only one graph Note ---- We follow the setting in :obj:`Nettack` from the: `"Adversarial Attacks on Neural Networks for Graph Data" <https://arxiv.org/abs/1805.07984>`_ paper, which considers the largest connected component for each graph. For more details of these datasets, see https://github.com/EdisonLeeeee/GraphData """ url = 'https://github.com/EdisonLeeeee/GraphData/raw/master/' + \ 'datasets/{}.npz' def __init__(self, root: str, name: str, transform: Optional[Callable] = None, pre_transform: Optional[Callable] = None): self.name = name.lower() if self.name not in DATASETS: raise ValueError( f'Unknown dataset {name}. Please take a look at ' '`GraphDataset.available_datasets()` for more information.') super().__init__(root, transform, pre_transform) self.data, self.slices = torch.load(self.processed_paths[0]) @property def raw_dir(self) -> str: return osp.join(self.root, f"GreatX-{self.name}", 'raw') @property def processed_dir(self) -> str: return osp.join(self.root, f"GreatX-{self.name}", 'processed') @property def raw_file_names(self) -> str: return f'{self.name}.npz' @property def processed_file_names(self) -> str: return 'data.pt' def download(self): download_url(self.url.format(self.name), self.raw_dir) def process(self): data = load_npz(self.raw_paths[0]) data = data if self.pre_transform is None else self.pre_transform(data) data, slices = self.collate([data]) torch.save((data, slices), self.processed_paths[0])
[docs] @staticmethod def available_datasets() -> List[str]: """Return all available datasets. """ return list(DATASETS)
def __repr__(self) -> str: return f'GreatX-{self.name.capitalize()}'