from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.data import Data
from greatx.training import Trainer
[docs]class SATTrainer(Trainer):
"""Custom trainer for :class:`greatx.nn.models.supervised.SAT`
Parameters
----------
model : nn.Module
the model used for training
device : Union[str, torch.device], optional
the device used for training, by default 'cpu'
cfg : other keyword arguments, such as `lr` and `weight_decay`.
Note
----
:class:`greatx.training.SATTrainer` accepts the following
additional arguments:
* :obj:`eps_U`: scale of perturbation on eigenvectors
* :obj:`eps_V`: scale of perturbation on eigenvalues
* :obj:`lambda_U`: trade-off parameters for eigenvectors-specific loss
* :obj:`lambda_V`: trade-off parameters for eigenvalues-specific loss
"""
[docs] def train_step(self, data: Data, mask: Optional[Tensor] = None) -> dict:
"""One-step training on the inputs.
Parameters
----------
data : Data
the training data.
mask : Optional[Tensor]
the mask of training nodes.
Returns
-------
dict
the output logs, including `loss` and `acc`, etc.
"""
model = self.model
self.callbacks.on_train_batch_begin(0)
model.train()
# ===============================================================
eps_U = self.cfg.get("eps_U", 0.1)
eps_V = self.cfg.get("eps_V", 0.1)
lamb_U = self.cfg.get("lamb_U", 0.5)
lamb_V = self.cfg.get("lamb_V", 0.5)
data = data.to(self.device)
y = data.y.squeeze()
U, V = data.U, data.V
U.requires_grad_()
V.requires_grad_()
out = model(data.x, U, V)
if mask is not None:
out = out[mask]
y = y[mask]
loss = F.cross_entropy(out, y)
U_grad, V_grad = torch.autograd.grad(loss, [U, V], retain_graph=True)
U.requires_grad_(False)
V.requires_grad_(False)
U_grad = eps_U * U_grad / torch.norm(U_grad, 2)
V_grad = eps_V * V_grad / torch.norm(V_grad, 2)
out_U = model(data.x, U + U_grad, V)
out_V = model(data.x, U, V + V_grad)
if mask is not None:
out_U = out_U[mask]
out_V = out_V[mask]
loss += lamb_U * \
F.cross_entropy(out_U, y) + lamb_V * F.cross_entropy(out_V, y)
# ===============================================================
loss.backward()
self.callbacks.on_train_batch_end(0)
return dict(loss=loss.item(),
acc=out.argmax(1).eq(y).float().mean().item())