Source code for greatx.training.sat_trainer

from typing import Optional

import torch
import torch.nn.functional as F
from torch import Tensor
from torch_geometric.data import Data

from greatx.training import Trainer


[docs]class SATTrainer(Trainer): """Custom trainer for :class:`greatx.nn.models.supervised.SAT` Parameters ---------- model : nn.Module the model used for training device : Union[str, torch.device], optional the device used for training, by default 'cpu' cfg : other keyword arguments, such as `lr` and `weight_decay`. Note ---- :class:`greatx.training.SATTrainer` accepts the following additional arguments: * :obj:`eps_U`: scale of perturbation on eigenvectors * :obj:`eps_V`: scale of perturbation on eigenvalues * :obj:`lambda_U`: trade-off parameters for eigenvectors-specific loss * :obj:`lambda_V`: trade-off parameters for eigenvalues-specific loss """
[docs] def train_step(self, data: Data, mask: Optional[Tensor] = None) -> dict: """One-step training on the inputs. Parameters ---------- data : Data the training data. mask : Optional[Tensor] the mask of training nodes. Returns ------- dict the output logs, including `loss` and `acc`, etc. """ model = self.model self.callbacks.on_train_batch_begin(0) model.train() # =============================================================== eps_U = self.cfg.get("eps_U", 0.1) eps_V = self.cfg.get("eps_V", 0.1) lamb_U = self.cfg.get("lamb_U", 0.5) lamb_V = self.cfg.get("lamb_V", 0.5) data = data.to(self.device) y = data.y.squeeze() U, V = data.U, data.V U.requires_grad_() V.requires_grad_() out = model(data.x, U, V) if mask is not None: out = out[mask] y = y[mask] loss = F.cross_entropy(out, y) U_grad, V_grad = torch.autograd.grad(loss, [U, V], retain_graph=True) U.requires_grad_(False) V.requires_grad_(False) U_grad = eps_U * U_grad / torch.norm(U_grad, 2) V_grad = eps_V * V_grad / torch.norm(V_grad, 2) out_U = model(data.x, U + U_grad, V) out_V = model(data.x, U, V + V_grad) if mask is not None: out_U = out_U[mask] out_V = out_V[mask] loss += lamb_U * \ F.cross_entropy(out_U, y) + lamb_V * F.cross_entropy(out_V, y) # =============================================================== loss.backward() self.callbacks.on_train_batch_end(0) return dict(loss=loss.item(), acc=out.argmax(1).eq(y).float().mean().item())