Source code for greatx.utils.split_data

from typing import Optional, Tuple

import torch
from sklearn.model_selection import train_test_split
from torch import Tensor

from greatx.utils import BunchDict


[docs]def split_nodes(labels: Tensor, *, train: float = 0.1, test: float = 0.8, val: float = 0.1, random_state: Optional[int] = None) -> BunchDict: """Randomly split a set of nodes labeled with :obj:`labels`. Parameters ---------- labels : torch.Tensor the labels of the nodes. train : float, optional the percentage of the training set, by default 0.1 test : float, optional the percentage of the test set, by default 0.8 val : float, optional the percentage of the validation set, by default 0.1 random_state : Optional[int], optional random seed for the random number generator, by default None Returns ------- BunchDict with the following items: * train_nodes: torch.Tensor with Size [train * num_nodes] The indices of the training nodes * val_nodes: torch.Tensor with Size [val * num_nodes] The indices of the validation nodes * test_nodes torch.Tensor with Size [test * num_nodes] The indices of the test nodes """ val = 0. if val is None else val assert train + val + test <= 1.0 train_nodes, val_nodes, test_nodes = train_val_test_split_tabular( labels.shape[0], train=train, val=val, test=test, stratify=labels, random_state=random_state) return BunchDict( dict(train_nodes=train_nodes, val_nodes=val_nodes, test_nodes=test_nodes))
[docs]def split_nodes_by_classes(labels: torch.Tensor, n_per_class: int = 20, random_state: Optional[int] = None) -> BunchDict: """Randomly split the training data by the number of nodes per classes. Parameters ---------- labels: torch.Tensor [num_nodes] The class labels n_per_class : int Number of samples per class random_state: Optional[int] Random seed Returns ------- BunchDict with the following items: * train_nodes: torch.Tensor with Size [n_per_class * num_classes] The indices of the training nodes * val_nodes: torch.Tensor with Size [n_per_class * num_classes] The indices of the validation nodes * test_nodes torch.Tensor with Size [num_nodes - 2*n_per_class * num_classes] The indices of the test nodes """ if random_state is not None: torch.manual_seed(random_state) num_classes = labels.max() + 1 split_train, split_val = [], [] for c in range(num_classes): perm = (labels == c).nonzero().view(-1) perm = perm[torch.randperm(perm.size(0))] split_train.append(perm[:n_per_class]) split_val.append(perm[n_per_class:2 * n_per_class]) split_train = torch.cat(split_train) split_train = split_train[torch.randperm(split_train.size(0))] split_val = torch.cat(split_val) split_train = split_val[torch.randperm(split_val.size(0))] assert split_train.size(0) == split_val.size( 0) == n_per_class * num_classes mask = torch.ones_like(labels).bool() mask[split_train] = False mask[split_val] = False split_test = torch.arange(labels.size(0), device=labels.device)[mask] return BunchDict( dict(train_nodes=split_train, val_nodes=split_val, test_nodes=split_test))
def train_val_test_split_tabular(N: int, *, train: float = 0.1, test: float = 0.8, val: float = 0.1, stratify: Optional[bool] = None, random_state: Optional[int] = None) -> Tuple: idx = torch.arange(N) idx_train, idx_test = train_test_split(idx, random_state=random_state, train_size=train + val, test_size=test, stratify=stratify) if val: if stratify is not None: stratify = stratify[idx_train] idx_train, idx_val = train_test_split(idx_train, random_state=random_state, train_size=train / (train + val), stratify=stratify) else: idx_val = None return idx_train, idx_val, idx_test