from typing import Optional, Tuple, Union
import torch
from torch import Tensor
from torch.nn import Module
classes = __all__ = ['Surrogate']
[docs]class Surrogate(Module):
"""Base class for attacker or defenders that require
a surrogate model for estimating labels or computing
gradient information.
Parameters
----------
device : str, optional
the device of a model to use for, by default "cpu"
"""
_is_setup = False # flags to denote the surrogate model is properly set
def __init__(self, device: str = "cpu"):
super().__init__()
self.device = torch.device(device)
[docs] def setup_surrogate(
self, surrogate: Module, *, tau: float = 1.0, freeze: bool = True,
required: Union[Module, Tuple[Module]] = None) -> "Surrogate":
"""Method used to initialize the (trained) surrogate model.
Parameters
----------
surrogate : Module
the input surrogate module
tau : float, optional
temperature used for softmax activation, by default 1.0
freeze : bool, optional
whether to freeze the model's parameters to save time,
by default True
required : Union[Module, Tuple[Module]], optional
which class(es) of the surrogate model are required,
by default None
Returns
-------
Surrogate
the class itself
Raises
------
RuntimeError
if the surrogate model is not an instance of
:class:`torch.nn.Module`
RuntimeError
if the surrogate model is not an instance of :obj:`required`
"""
if not isinstance(surrogate, Module):
raise RuntimeError(
"The surrogate model must be an instance of `torch.nn.Module`."
)
if required is not None and not isinstance(surrogate, required):
raise RuntimeError(
f"The surrogate model is required to be `{required}`, "
f"but got `{surrogate.__class__.__name__}`.")
surrogate.eval()
if hasattr(surrogate, 'cache_clear'):
surrogate.cache_clear()
for layer in surrogate.modules():
if hasattr(layer, 'cached'):
layer.cached = False
self.surrogate = surrogate.to(self.device)
self.tau = tau
if freeze:
self.freeze_surrogate()
self._is_setup = True
return self
[docs] def clip_grad(
self,
grad: Tensor,
grad_clip: Optional[float],
) -> Tensor:
"""Gradient clipping function
Parameters
----------
grad : Tensor
the input gradients to clip
grad_clip : Optional[float]
the clipping number of the gradients
Returns
-------
Tensor
the clipped gradients
"""
if grad_clip is not None:
grad_len_sq = grad.square().sum()
if grad_len_sq > grad_clip * grad_clip:
grad *= grad_clip / grad_len_sq.sqrt()
return grad
[docs] def estimate_self_training_labels(
self, nodes: Optional[Tensor] = None) -> Tensor:
"""Estimate the labels of nodes using the trained surrogate model.
Parameters
----------
nodes : Optional[Tensor], optional
the input nodes, if None, it would be all nodes in the graph,
by default None
Returns
-------
Tensor
the labels of the input nodes.
"""
self_training_labels = self.surrogate(self.feat, self.edge_index,
self.edge_weight)
if nodes is not None:
self_training_labels = self_training_labels[nodes]
return self_training_labels.argmax(-1)
[docs] def freeze_surrogate(self) -> "Surrogate":
"""Freezie the parameters of the surrogate model.
Returns
-------
Surrogate
the class itself
"""
for para in self.surrogate.parameters():
para.requires_grad_(False)
return self
[docs] def defrozen_surrogate(self) -> "Surrogate":
"""Defrozen the parameters of the surrogate model
Returns
-------
Surrogate
the class itself
"""
for para in self.surrogate.parameters():
para.requires_grad_(True)
return self