graphwar.training¶
A simple trainer to train graph neural network models conveniently. |
|
Get the default trainer using str or a model in |
|
Custom trainer for |
|
Custom trainer for |
|
Custom trainer for |
- class Trainer(model: Module, device: Union[str, device] = 'cpu', **cfg)[source]¶
A simple trainer to train graph neural network models conveniently.
- Parameters
model (nn.Module) – the model used for training
device (Union[str, torch.device], optional) – the device used for training, by default ‘cpu’
cfg (other keyword arguments, such as lr and weight_decay.) –
Example
>>> from graphwar.training import Trainer >>> model = ... # your model >>> trainer = Trainer(model, device='cuda')
>>> data # PyG-like data, e.g., Cora Data(x=[2485, 1433], edge_index=[2, 10138], y=[2485])
>>> # simple training >>> trainer.fit({'data': data, 'mask': train_mask})
>>> # train with model picking >>> from graphwar.training import ModelCheckpoint >>> cb = ModelCheckpoint('my_ckpt', monitor='val_acc') >>> trainer.fit({'data': data, 'mask': train_mask}, {'data': data, 'mask': val_mask}, callbacks=[cb])
>>> # get training logs >>> history = trainer.model.history
>>> trainer.evaluate({'data': data, 'mask': data.test_mask}) # evaluation
>>> predict = trainer.predict({'data': data, 'mask': your_mask}) # prediction
- fit(train_inputs, val_inputs: Optional[dict] = None, callbacks: Optional[Callback] = None, verbose: Optional[int] = 1, epochs: int = 100) Trainer [source]¶
Simple training method designed for :attr:model
- Parameters
train_inputs (dict like or custom inputs) – training data. It is used for train_step.
val_inputs (Optional[dict]) – used for validation.
callbacks (Optional[Callback], optional) – callbacks used for training, see graphwar.training.callbacks, by default None
verbose (Optional[int], optional) – verbosity during training, can be:
None, 1, 2, 3, 4
, by default 1epochs (int, optional) – training epochs, by default 100
Example
>>> # simple training >>> trainer.fit({'data': data, 'mask': train_mask})
>>> # train with model picking >>> from graphwar.training import ModelCheckpoint >>> cb = ModelCheckpoint('my_ckpt', monitor='val_acc') >>> trainer.fit({'data': data, 'mask': train_mask}, {'data': data, 'mask': val_mask}, callbacks=[cb])
- train_step(inputs: dict) dict [source]¶
One-step training on the inputs.
- Parameters
inputs (dict like or custom inputs) – the training data.
- Returns
the output logs, including loss and acc, etc.
- Return type
- evaluate(inputs: dict, verbose: Optional[int] = 1) BunchDict [source]¶
Simple evaluation step for :attr:model
- Parameters
inputs (dict like or custom inputs) – test data, it is used for test_step.
verbose (Optional[int], optional) – verbosity during evaluation, by default 1
- Returns
the dict-like output logs
- Return type
Example
>>> trainer.evaluate({'data': data, 'mask': data.test_mask}) # evaluation
- test_step(inputs: dict) dict [source]¶
One-step evaluation on the inputs.
- Parameters
inputs (dict like or custom inputs) – the testing data.
- Returns
the output logs, including loss and acc, etc.
- Return type
- predict_step(inputs: dict) Tensor [source]¶
One-step prediction on the inputs.
- Parameters
inputs (dict like or custom inputs) – the prediction data.
- Returns
the output prediction.
- Return type
Tensor
- predict(inputs: dict, transform: Callable = Softmax(dim=- 1)) Tensor [source]¶
- Parameters
inputs (dict like or custom inputs) – predict data, it is used for predict_step
transform (Callable) – Callable function applied on output predictions.
Example
>>> predict = trainer.predict({'data': data, 'mask': mask_or_not_given}) # prediction
- property model¶
- get_trainer(model: Union[str, Module]) Trainer [source]¶
Get the default trainer using str or a model in
graphwar.nn.models
- Parameters
model (Union[str, torch.nn.Module]) – the model to be trained in
- Return type
Custom trainer or default trainer
graphwar.training.Trainer
for the model.
Examples
>>> import graphwar >>> graphwar.training.get_trainer('GCN') graphwar.training.trainer.Trainer
>>> from graphwar.nn.models import GCN >>> graphwar.training.get_trainer(GCN) graphwar.training.trainer.Trainer
>>> # by default, it returns `graphwar.training.Trainer` >>> graphwar.training.get_trainer('unimplemeted_model') graphwar.training.trainer.Trainer
>>> graphwar.training.get_trainer('RobustGCN') graphwar.training.robustgcn_trainer.RobustGCNTrainer
>>> # it is case-sensitive >>> graphwar.training.get_trainer('robustGCN') graphwar.training.trainer.Trainer
- class RobustGCNTrainer(model: Module, device: Union[str, device] = 'cpu', **cfg)[source]¶
Custom trainer for
graphwar.nn.models.RobustGCN
- Parameters
model (nn.Module) – the model used for training
device (Union[str, torch.device], optional) – the device used for training, by default ‘cpu’
cfg (other keyword arguments, such as lr and weight_decay.) –
Note
graphwar.training.RobustGCNTrainer
accepts the following additional arguments:kl
: trade-off parameter for kl loss
- class SimPGCNTrainer(model: Module, device: Union[str, device] = 'cpu', **cfg)[source]¶
Custom trainer for
graphwar.nn.models.SimPGCN
- Parameters
model (nn.Module) – the model used for training
device (Union[str, torch.device], optional) – the device used for training, by default ‘cpu’
cfg (other keyword arguments, such as lr and weight_decay.) –
Note
graphwar.training.SimPGCNTrainer
accepts the following additional arguments:lambda_
: trade-off parameter for regression loss
- class SATTrainer(model: Module, device: Union[str, device] = 'cpu', **cfg)[source]¶
Custom trainer for
graphwar.nn.models.SAT
- Parameters
model (nn.Module) – the model used for training
device (Union[str, torch.device], optional) – the device used for training, by default ‘cpu’
cfg (other keyword arguments, such as lr and weight_decay.) –
Note
graphwar.training.SATTrainer
accepts the following additional arguments:eps_U
: scale of perturbation on eigenvectorseps_V
: scale of perturbation on eigenvalueslambda_U
: trade-off parameters for eigenvectors-specific losslambda_V
: trade-off parameters for eigenvalues-specific loss