greatx.training

Trainer

A simple trainer to train graph neural network models conveniently.

UnspuervisedTrainer

Custom trainer for Unspuervised models, similar to greatx.training.Trainer but only uses unsupervised loss defined in model.loss() method.

get_trainer

Get the default trainer using str or a model in greatx.nn.models

callbacks

SATTrainer

Custom trainer for greatx.nn.models.supervised.SAT

class Trainer(model: Module, device: Union[str, device] = 'cpu', **cfg)[source]

A simple trainer to train graph neural network models conveniently.

Parameters:
  • model (nn.Module) – the model used for training

  • device (Union[str, torch.device], optional) – the device used for training, by default ‘cpu’

  • cfg (other keyword arguments, such as lr and weight_decay.) –

Example

>>> from greatx.training import Trainer
>>> model = ... # your model
>>> trainer = Trainer(model, device='cuda')
>>> data # PyG-like data, e.g., Cora
Data(x=[2485, 1433], edge_index=[2, 10138], y=[2485])
>>> # simple training
>>> trainer.fit(data, data.train_mask)
>>> # train with model picking
>>> from greatx.training import ModelCheckpoint
>>> cb = ModelCheckpoint('my_ckpt', monitor='val_acc')
>>> trainer.fit(data, mask=(data.train_mask,
...                         data.val_mask), callbacks=[cb])
>>> # get training logs
>>> history = trainer.model.history
>>> trainer.evaluate(data, your_test_mask) # evaluation
>>> predict = trainer.predict(data, your_mask) # prediction
supervised = True
fit(data: Union[Data, Tuple[Data, Data]], mask: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, callbacks: Optional[Callback] = None, verbose: Optional[int] = 1, epochs: int = 100, prefix: str = 'val') Trainer[source]

Simple training method designed for :attr:model

Parameters:
  • data (Union[Data, Tuple[Data, Data]]) – An instance or a tuple of torch_geometric.data.Data denoting the graph. They are used for train_step and val_step, respectively.

  • mask (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]) – node masks used for training and validation.

  • callbacks (Optional[Callback], optional) – callbacks used for training, see greatx.training.callbacks, by default None

  • verbose (Optional[int], optional) – verbosity during training, can be: None, 1, 2, 3, 4, by default 1

  • epochs (int, optional) – training epochs, by default 100

  • prefix (str, optional) – prefix for validation metrics

Example

>>> # simple training
>>> trainer.fit(data, data.train_mask)
>>> # train with model picking
>>> from greatx.training import ModelCheckpoint
>>> cb = ModelCheckpoint('my_ckpt', monitor='val_acc')
>>> trainer.fit(data, mask=(data.train_mask,
...                         data.val_mask), callbacks=[cb])
train_step(data: Data, mask: Optional[Tensor] = None) dict[source]

One-step training on the inputs.

Parameters:
  • data (Data) – the training data.

  • mask (Optional[Tensor]) – the mask of training nodes.

Returns:

the output logs, including loss and acc, etc.

Return type:

dict

evaluate(data: Data, mask: Optional[Tensor] = None, verbose: Optional[int] = 1) BunchDict[source]

Simple evaluation step for :attr:model

Parameters:
  • data (Data) – the testing data used for test_step().

  • mask (Optional[Tensor]) – the mask of testing nodes used for test_step().

  • verbose (Optional[int], optional) – verbosity during evaluation, by default 1

Returns:

the dict-like output logs

Return type:

BunchDict

Example

>>> trainer.evaluate(data, data.test_mask) # evaluation
test_step(data: Data, mask: Optional[Tensor] = None) dict[source]

One-step evaluation on the inputs.

Parameters:
  • data (Data) – the testing data.

  • mask (Optional[Tensor]) – the mask of testing nodes.

Returns:

the output logs, including loss and acc, etc.

Return type:

dict

predict_step(data: Data, mask: Optional[Tensor] = None) Tensor[source]

One-step prediction on the inputs.

Parameters:
  • data (Data) – the prediction data.

  • mask (Optional[Tensor]) – the mask of prediction nodes.

Returns:

the output prediction.

Return type:

Tensor

predict(data: Data, mask: Optional[Tensor] = None, transform: Callable = Softmax(dim=-1)) Tensor[source]
Parameters:
  • data (Data) – the prediction data used for predict_step().

  • mask (Optional[Tensor]) – the mask of prediction nodes used for predict_step().

  • transform (Callable) – Callable function applied on output predictions.

Example

>>> predict = trainer.predict(data, mask) # prediction
config_optimizer() Optimizer[source]
reset_optimizer(lr: Optional[float] = None, weight_decay: Optional[float] = None) Trainer[source]

Reset the optimizer with given learning rate and weight decay parameters.

Parameters:
  • lr (Optional[float], optional) – the learning rate. If None, will use the default learning rate 0.01, by default None

  • weight_decay (Optional[float], optional) – the weight decay factor. If None, will use the default factor 5e-5, by default None

Returns:

the trainer itself

Return type:

Trainer

Example

>>> Reset optimizer and use default learning rate
>>> and weight decay
>>> trainer.reset_optimizer()
>>> Reset optimizer and use learning rate of 0.1
>>> trainer.reset_optimizer(lr=0.1)
>>> Reset optimizer and use weight decay of 0.001
>>> trainer.reset_optimizer(weight_decay=0.001)
config_scheduler(optimizer: Optimizer)[source]
config_callbacks(verbose, epochs, callbacks=None) CallbackList[source]
property model: Optional[Module]
cache_clear() Trainer[source]

Clear cached inputs or intermediate results of the model.

extra_repr() str[source]
class UnspuervisedTrainer(model: Module, device: Union[str, device] = 'cpu', **cfg)[source]

Custom trainer for Unspuervised models, similar to greatx.training.Trainer but only uses unsupervised loss defined in model.loss() method.

supervised = False
test(*args, **kwargs)[source]
predict(*args, **kwargs)[source]
Parameters:
  • data (Data) – the prediction data used for predict_step().

  • mask (Optional[Tensor]) – the mask of prediction nodes used for predict_step().

  • transform (Callable) – Callable function applied on output predictions.

Example

>>> predict = trainer.predict(data, mask) # prediction
get_trainer(model: Union[str, Module]) Trainer[source]

Get the default trainer using str or a model in greatx.nn.models

Parameters:

model (Union[str, torch.nn.Module]) – the model to be trained

Returns:

Examples

>>> import greatx
>>> greatx.training.get_trainer('GCN')
greatx.training.trainer.Trainer
>>> from greatx.nn.models import GCN
>>> greatx.training.get_trainer(GCN)
greatx.training.trainer.Trainer
>>> # by default, it returns `greatx.training.Trainer`
>>> greatx.training.get_trainer('unimplemeted_model')
greatx.training.trainer.Trainer
>>> greatx.training.get_trainer('RobustGCN')
greatx.training.robustgcn_trainer.RobustGCNTrainer
>>> # it is case-sensitive
>>> greatx.training.get_trainer('robustGCN')
greatx.training.trainer.Trainer

Note

Unsupervised models are not supported for thie method to get the greatx.training.UnsupervisedTrainer. It will also return the greatx.training.Trainer by default.

class SATTrainer(model: Module, device: Union[str, device] = 'cpu', **cfg)[source]

Custom trainer for greatx.nn.models.supervised.SAT

Parameters:
  • model (nn.Module) – the model used for training

  • device (Union[str, torch.device], optional) – the device used for training, by default ‘cpu’

  • cfg (other keyword arguments, such as lr and weight_decay.) –

Note

greatx.training.SATTrainer accepts the following additional arguments:

  • eps_U: scale of perturbation on eigenvectors

  • eps_V: scale of perturbation on eigenvalues

  • lambda_U: trade-off parameters for eigenvectors-specific loss

  • lambda_V: trade-off parameters for eigenvalues-specific loss

train_step(data: Data, mask: Optional[Tensor] = None) dict[source]

One-step training on the inputs.

Parameters:
  • data (Data) – the training data.

  • mask (Optional[Tensor]) – the mask of training nodes.

Returns:

the output logs, including loss and acc, etc.

Return type:

dict