graphwar.training

Trainer

A simple trainer to train graph neural network models conveniently.

get_trainer

Get the default trainer using str or a model in graphwar.nn.models

RobustGCNTrainer

Custom trainer for graphwar.nn.models.RobustGCN

SimPGCNTrainer

Custom trainer for graphwar.nn.models.SimPGCN

SATTrainer

Custom trainer for graphwar.nn.models.SAT

class Trainer(model: Module, device: Union[str, device] = 'cpu', **cfg)[source]

A simple trainer to train graph neural network models conveniently.

Parameters
  • model (nn.Module) – the model used for training

  • device (Union[str, torch.device], optional) – the device used for training, by default ‘cpu’

  • cfg (other keyword arguments, such as lr and weight_decay.) –

Example

>>> from graphwar.training import Trainer
>>> model = ... # your model
>>> trainer = Trainer(model, device='cuda')
>>> data # PyG-like data, e.g., Cora
Data(x=[2485, 1433], edge_index=[2, 10138], y=[2485])
>>> # simple training
>>> trainer.fit({'data': data, 'mask': train_mask})
>>> # train with model picking
>>> from graphwar.training import ModelCheckpoint
>>> cb = ModelCheckpoint('my_ckpt', monitor='val_acc')
>>> trainer.fit({'data': data, 'mask': train_mask},
{'data': data, 'mask': val_mask}, callbacks=[cb])
>>> # get training logs
>>> history = trainer.model.history
>>> trainer.evaluate({'data': data, 'mask': data.test_mask}) # evaluation
>>> predict = trainer.predict({'data': data, 'mask': your_mask}) # prediction
fit(train_inputs, val_inputs: Optional[dict] = None, callbacks: Optional[Callback] = None, verbose: Optional[int] = 1, epochs: int = 100) Trainer[source]

Simple training method designed for :attr:model

Parameters
  • train_inputs (dict like or custom inputs) – training data. It is used for train_step.

  • val_inputs (Optional[dict]) – used for validation.

  • callbacks (Optional[Callback], optional) – callbacks used for training, see graphwar.training.callbacks, by default None

  • verbose (Optional[int], optional) – verbosity during training, can be: None, 1, 2, 3, 4, by default 1

  • epochs (int, optional) – training epochs, by default 100

Example

>>> # simple training
>>> trainer.fit({'data': data, 'mask': train_mask})
>>> # train with model picking
>>> from graphwar.training import ModelCheckpoint
>>> cb = ModelCheckpoint('my_ckpt', monitor='val_acc')
>>> trainer.fit({'data': data, 'mask': train_mask},
{'data': data, 'mask': val_mask}, callbacks=[cb])
train_step(inputs: dict) dict[source]

One-step training on the inputs.

Parameters

inputs (dict like or custom inputs) – the training data.

Returns

the output logs, including loss and acc, etc.

Return type

dict

evaluate(inputs: dict, verbose: Optional[int] = 1) BunchDict[source]

Simple evaluation step for :attr:model

Parameters
  • inputs (dict like or custom inputs) – test data, it is used for test_step.

  • verbose (Optional[int], optional) – verbosity during evaluation, by default 1

Returns

the dict-like output logs

Return type

BunchDict

Example

>>> trainer.evaluate({'data': data, 'mask': data.test_mask}) # evaluation
test_step(inputs: dict) dict[source]

One-step evaluation on the inputs.

Parameters

inputs (dict like or custom inputs) – the testing data.

Returns

the output logs, including loss and acc, etc.

Return type

dict

predict_step(inputs: dict) Tensor[source]

One-step prediction on the inputs.

Parameters

inputs (dict like or custom inputs) – the prediction data.

Returns

the output prediction.

Return type

Tensor

predict(inputs: dict, transform: Callable = Softmax(dim=- 1)) Tensor[source]
Parameters
  • inputs (dict like or custom inputs) – predict data, it is used for predict_step

  • transform (Callable) – Callable function applied on output predictions.

Example

>>> predict = trainer.predict({'data': data, 'mask': mask_or_not_given}) # prediction
config_optimizer() Optimizer[source]
config_scheduler(optimizer: Optimizer)[source]
config_callbacks(verbose, epochs, callbacks=None) CallbackList[source]
property model
cache_clear() Trainer[source]

Clear cached inputs or intermediate results of the model.

extra_repr() str[source]
get_trainer(model: Union[str, Module]) Trainer[source]

Get the default trainer using str or a model in graphwar.nn.models

Parameters

model (Union[str, torch.nn.Module]) – the model to be trained in

Return type

Custom trainer or default trainer graphwar.training.Trainer for the model.

Examples

>>> import graphwar
>>> graphwar.training.get_trainer('GCN')
graphwar.training.trainer.Trainer
>>> from graphwar.nn.models import GCN
>>> graphwar.training.get_trainer(GCN)
graphwar.training.trainer.Trainer
>>> # by default, it returns `graphwar.training.Trainer`
>>> graphwar.training.get_trainer('unimplemeted_model')
graphwar.training.trainer.Trainer
>>> graphwar.training.get_trainer('RobustGCN')
graphwar.training.robustgcn_trainer.RobustGCNTrainer
>>> # it is case-sensitive
>>> graphwar.training.get_trainer('robustGCN')
graphwar.training.trainer.Trainer
class RobustGCNTrainer(model: Module, device: Union[str, device] = 'cpu', **cfg)[source]

Custom trainer for graphwar.nn.models.RobustGCN

Parameters
  • model (nn.Module) – the model used for training

  • device (Union[str, torch.device], optional) – the device used for training, by default ‘cpu’

  • cfg (other keyword arguments, such as lr and weight_decay.) –

Note

graphwar.training.RobustGCNTrainer accepts the following additional arguments:

  • kl: trade-off parameter for kl loss

train_step(inputs: dict) dict[source]

One-step training on the input dataloader.

Parameters

inputs (dict) – the training data.

Returns

the output logs, including loss and val_acc, etc.

Return type

dict

class SimPGCNTrainer(model: Module, device: Union[str, device] = 'cpu', **cfg)[source]

Custom trainer for graphwar.nn.models.SimPGCN

Parameters
  • model (nn.Module) – the model used for training

  • device (Union[str, torch.device], optional) – the device used for training, by default ‘cpu’

  • cfg (other keyword arguments, such as lr and weight_decay.) –

Note

graphwar.training.SimPGCNTrainer accepts the following additional arguments:

  • lambda_: trade-off parameter for regression loss

train_step(inputs: dict) dict[source]

One-step training on the input dataloader.

Parameters

inputs (dict) – the training data.

Returns

the output logs, including loss and val_acc, etc.

Return type

dict

class SATTrainer(model: Module, device: Union[str, device] = 'cpu', **cfg)[source]

Custom trainer for graphwar.nn.models.SAT

Parameters
  • model (nn.Module) – the model used for training

  • device (Union[str, torch.device], optional) – the device used for training, by default ‘cpu’

  • cfg (other keyword arguments, such as lr and weight_decay.) –

Note

graphwar.training.SATTrainer accepts the following additional arguments:

  • eps_U: scale of perturbation on eigenvectors

  • eps_V: scale of perturbation on eigenvalues

  • lambda_U: trade-off parameters for eigenvectors-specific loss

  • lambda_V: trade-off parameters for eigenvalues-specific loss

train_step(inputs: dict) dict[source]

One-step training on the input dataloader.

Parameters

inputs (dict) – the training data.

Returns

the output logs, including loss and val_acc, etc.

Return type

dict