Source code for greatx.nn.models.surrogate

from typing import Optional, Tuple, Union

import torch
from torch import Tensor
from torch.nn import Module

classes = __all__ = ['Surrogate']


[docs]class Surrogate(Module): """Base class for attacker or defenders that require a surrogate model for estimating labels or computing gradient information. Parameters ---------- device : str, optional the device of a model to use for, by default "cpu" """ _is_setup = False # flags to denote the surrogate model is properly set def __init__(self, device: str = "cpu"): super().__init__() self.device = torch.device(device)
[docs] def setup_surrogate( self, surrogate: Module, *, tau: float = 1.0, freeze: bool = True, required: Union[Module, Tuple[Module]] = None) -> "Surrogate": """Method used to initialize the (trained) surrogate model. Parameters ---------- surrogate : Module the input surrogate module tau : float, optional temperature used for softmax activation, by default 1.0 freeze : bool, optional whether to freeze the model's parameters to save time, by default True required : Union[Module, Tuple[Module]], optional which class(es) of the surrogate model are required, by default None Returns ------- Surrogate the class itself Raises ------ RuntimeError if the surrogate model is not an instance of :class:`torch.nn.Module` RuntimeError if the surrogate model is not an instance of :obj:`required` """ if not isinstance(surrogate, Module): raise RuntimeError( "The surrogate model must be an instance of `torch.nn.Module`." ) if required is not None and not isinstance(surrogate, required): raise RuntimeError( f"The surrogate model is required to be `{required}`, " f"but got `{surrogate.__class__.__name__}`.") surrogate.eval() if hasattr(surrogate, 'cache_clear'): surrogate.cache_clear() for layer in surrogate.modules(): if hasattr(layer, 'cached'): layer.cached = False self.surrogate = surrogate.to(self.device) self.tau = tau if freeze: self.freeze_surrogate() self._is_setup = True return self
[docs] def clip_grad( self, grad: Tensor, grad_clip: Optional[float], ) -> Tensor: """Gradient clipping function Parameters ---------- grad : Tensor the input gradients to clip grad_clip : Optional[float] the clipping number of the gradients Returns ------- Tensor the clipped gradients """ if grad_clip is not None: grad_len_sq = grad.square().sum() if grad_len_sq > grad_clip * grad_clip: grad *= grad_clip / grad_len_sq.sqrt() return grad
[docs] def estimate_self_training_labels( self, nodes: Optional[Tensor] = None) -> Tensor: """Estimate the labels of nodes using the trained surrogate model. Parameters ---------- nodes : Optional[Tensor], optional the input nodes, if None, it would be all nodes in the graph, by default None Returns ------- Tensor the labels of the input nodes. """ self_training_labels = self.surrogate(self.feat, self.edge_index, self.edge_weight) if nodes is not None: self_training_labels = self_training_labels[nodes] return self_training_labels.argmax(-1)
[docs] def freeze_surrogate(self) -> "Surrogate": """Freezie the parameters of the surrogate model. Returns ------- Surrogate the class itself """ for para in self.surrogate.parameters(): para.requires_grad_(False) return self
[docs] def defrozen_surrogate(self) -> "Surrogate": """Defrozen the parameters of the surrogate model Returns ------- Surrogate the class itself """ for para in self.surrogate.parameters(): para.requires_grad_(True) return self