greatx.training
A simple trainer to train graph neural network models conveniently. |
|
Custom trainer for Unspuervised models, similar to |
|
Get the default trainer using str or a model in |
|
|
|
Custom trainer for |
- class Trainer(model: Module, device: Union[str, device] = 'cpu', **cfg)[source]
A simple trainer to train graph neural network models conveniently.
- Parameters:
model (nn.Module) – the model used for training
device (Union[str, torch.device], optional) – the device used for training, by default ‘cpu’
cfg (other keyword arguments, such as lr and weight_decay.) –
Example
>>> from greatx.training import Trainer >>> model = ... # your model >>> trainer = Trainer(model, device='cuda')
>>> data # PyG-like data, e.g., Cora Data(x=[2485, 1433], edge_index=[2, 10138], y=[2485])
>>> # simple training >>> trainer.fit(data, data.train_mask)
>>> # train with model picking >>> from greatx.training import ModelCheckpoint >>> cb = ModelCheckpoint('my_ckpt', monitor='val_acc') >>> trainer.fit(data, mask=(data.train_mask, ... data.val_mask), callbacks=[cb])
>>> # get training logs >>> history = trainer.model.history
>>> trainer.evaluate(data, your_test_mask) # evaluation
>>> predict = trainer.predict(data, your_mask) # prediction
- supervised = True
- fit(data: Union[Data, Tuple[Data, Data]], mask: Optional[Union[Tensor, Tuple[Tensor, Tensor]]] = None, callbacks: Optional[Callback] = None, verbose: Optional[int] = 1, epochs: int = 100, prefix: str = 'val') Trainer [source]
Simple training method designed for :attr:model
- Parameters:
data (Union[Data, Tuple[Data, Data]]) – An instance or a tuple of
torch_geometric.data.Data
denoting the graph. They are used for train_step and val_step, respectively.mask (Optional[Union[Tensor, Tuple[Tensor, Tensor]]]) – node masks used for training and validation.
callbacks (Optional[Callback], optional) – callbacks used for training, see greatx.training.callbacks, by default None
verbose (Optional[int], optional) – verbosity during training, can be:
None, 1, 2, 3, 4
, by default 1epochs (int, optional) – training epochs, by default 100
prefix (str, optional) – prefix for validation metrics
Example
>>> # simple training >>> trainer.fit(data, data.train_mask)
>>> # train with model picking >>> from greatx.training import ModelCheckpoint >>> cb = ModelCheckpoint('my_ckpt', monitor='val_acc') >>> trainer.fit(data, mask=(data.train_mask, ... data.val_mask), callbacks=[cb])
- train_step(data: Data, mask: Optional[Tensor] = None) dict [source]
One-step training on the inputs.
- Parameters:
data (Data) – the training data.
mask (Optional[Tensor]) – the mask of training nodes.
- Returns:
the output logs, including loss and acc, etc.
- Return type:
- evaluate(data: Data, mask: Optional[Tensor] = None, verbose: Optional[int] = 1) BunchDict [source]
Simple evaluation step for :attr:model
- Parameters:
data (Data) – the testing data used for
test_step()
.mask (Optional[Tensor]) – the mask of testing nodes used for
test_step()
.verbose (Optional[int], optional) – verbosity during evaluation, by default 1
- Returns:
the dict-like output logs
- Return type:
Example
>>> trainer.evaluate(data, data.test_mask) # evaluation
- test_step(data: Data, mask: Optional[Tensor] = None) dict [source]
One-step evaluation on the inputs.
- Parameters:
data (Data) – the testing data.
mask (Optional[Tensor]) – the mask of testing nodes.
- Returns:
the output logs, including loss and acc, etc.
- Return type:
- predict_step(data: Data, mask: Optional[Tensor] = None) Tensor [source]
One-step prediction on the inputs.
- Parameters:
data (Data) – the prediction data.
mask (Optional[Tensor]) – the mask of prediction nodes.
- Returns:
the output prediction.
- Return type:
Tensor
- predict(data: Data, mask: Optional[Tensor] = None, transform: Callable = Softmax(dim=-1)) Tensor [source]
- Parameters:
data (Data) – the prediction data used for
predict_step()
.mask (Optional[Tensor]) – the mask of prediction nodes used for
predict_step()
.transform (Callable) – Callable function applied on output predictions.
Example
>>> predict = trainer.predict(data, mask) # prediction
- reset_optimizer(lr: Optional[float] = None, weight_decay: Optional[float] = None) Trainer [source]
Reset the optimizer with given learning rate and weight decay parameters.
- Parameters:
- Returns:
the trainer itself
- Return type:
Example
>>> Reset optimizer and use default learning rate >>> and weight decay >>> trainer.reset_optimizer()
>>> Reset optimizer and use learning rate of 0.1 >>> trainer.reset_optimizer(lr=0.1)
>>> Reset optimizer and use weight decay of 0.001 >>> trainer.reset_optimizer(weight_decay=0.001)
- class UnspuervisedTrainer(model: Module, device: Union[str, device] = 'cpu', **cfg)[source]
Custom trainer for Unspuervised models, similar to
greatx.training.Trainer
but only uses unsupervised loss defined inmodel.loss()
method.See also
- supervised = False
- predict(*args, **kwargs)[source]
- Parameters:
data (Data) – the prediction data used for
predict_step()
.mask (Optional[Tensor]) – the mask of prediction nodes used for
predict_step()
.transform (Callable) – Callable function applied on output predictions.
Example
>>> predict = trainer.predict(data, mask) # prediction
- get_trainer(model: Union[str, Module]) Trainer [source]
Get the default trainer using str or a model in
greatx.nn.models
- Parameters:
model (Union[str, torch.nn.Module]) – the model to be trained
- Returns:
Custom trainer or default trainer
greatx.training.Trainer
for the model.
Examples
>>> import greatx >>> greatx.training.get_trainer('GCN') greatx.training.trainer.Trainer
>>> from greatx.nn.models import GCN >>> greatx.training.get_trainer(GCN) greatx.training.trainer.Trainer
>>> # by default, it returns `greatx.training.Trainer` >>> greatx.training.get_trainer('unimplemeted_model') greatx.training.trainer.Trainer
>>> greatx.training.get_trainer('RobustGCN') greatx.training.robustgcn_trainer.RobustGCNTrainer
>>> # it is case-sensitive >>> greatx.training.get_trainer('robustGCN') greatx.training.trainer.Trainer
Note
Unsupervised models are not supported for thie method to get the
greatx.training.UnsupervisedTrainer
. It will also return thegreatx.training.Trainer
by default.
- class SATTrainer(model: Module, device: Union[str, device] = 'cpu', **cfg)[source]
Custom trainer for
greatx.nn.models.supervised.SAT
- Parameters:
model (nn.Module) – the model used for training
device (Union[str, torch.device], optional) – the device used for training, by default ‘cpu’
cfg (other keyword arguments, such as lr and weight_decay.) –
Note
greatx.training.SATTrainer
accepts the following additional arguments:eps_U
: scale of perturbation on eigenvectorseps_V
: scale of perturbation on eigenvalueslambda_U
: trade-off parameters for eigenvectors-specific losslambda_V
: trade-off parameters for eigenvalues-specific loss