Source code for greatx.training.get_trainer

from typing import Union

import torch

from greatx import training

Trainer = training.trainer.Trainer
UnspuervisedTrainer = training.unsup_trainer.UnspuervisedTrainer


[docs]def get_trainer(model: Union[str, torch.nn.Module]) -> Trainer: r"""Get the default trainer using str or a model in :class:`greatx.nn.models` Parameters ---------- model : Union[str, torch.nn.Module] the model to be trained Returns ------- Custom trainer or default trainer :class:`greatx.training.Trainer` for the model. Examples -------- >>> import greatx >>> greatx.training.get_trainer('GCN') greatx.training.trainer.Trainer >>> from greatx.nn.models import GCN >>> greatx.training.get_trainer(GCN) greatx.training.trainer.Trainer >>> # by default, it returns `greatx.training.Trainer` >>> greatx.training.get_trainer('unimplemeted_model') greatx.training.trainer.Trainer >>> greatx.training.get_trainer('RobustGCN') greatx.training.robustgcn_trainer.RobustGCNTrainer >>> # it is case-sensitive >>> greatx.training.get_trainer('robustGCN') greatx.training.trainer.Trainer Note ---- Unsupervised models are not supported for thie method to get the :class:`greatx.training.UnsupervisedTrainer`. It will also return the :class:`greatx.training.Trainer` by default. """ default = training.Trainer if isinstance(model, str): class_name = model else: class_name = model.__class__.__name__ trainer = getattr(training, class_name + "Trainer", default) return trainer