Source code for greatx.nn.layers.sat_conv

from typing import Optional

import torch
from torch import Tensor, nn
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.utils import add_self_loops

from greatx.functional import spmm
from greatx.nn.layers.gcn_conv import dense_add_self_loops, dense_gcn_norm


[docs]class SATConv(nn.Module): r"""The spectral adversarial training operator from the `"Spectral Adversarial Training for Robust Graph Neural Network" <https://arxiv.org/>`_ paper (arXiv'22) Parameters ---------- in_channels : int dimensions of int samples out_channels : int dimensions of output samples add_self_loops : bool, optional whether to add self-loops to the input graph, by default True normalize : bool, optional whether to compute symmetric normalization coefficients on the fly, by default True bias : bool, optional whether to use bias in the layers, by default True Note ---- For the inputs :obj:`x`, :obj:`U`, and :obj:`V`, our implementation supports: (1) :obj:`U` is :class:`torch.LongTensor`, denoting edge indices with shape :obj:`[2, M]`; (2) :obj:`U` is :class:`torch.FloatTensor` and :obj:`V` is :obj:`None`, denoting dense matrix with shape :obj:`[N, N]`; (3) :obj:`U` and :obj:`V` are :class:`torch.FloatTensor`, denoting eigenvector and corresponding eigenvalues. See also -------- :class:`greatx.nn.models.supervised.SAT` """ def __init__(self, in_channels: int, out_channels: int, add_self_loops: bool = True, normalize: bool = True, bias: bool = False): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.add_self_loops = add_self_loops self.normalize = normalize self.lin = Linear(in_channels, out_channels, bias=False, weight_initializer='glorot') if bias: self.bias = nn.Parameter(torch.Tensor(out_channels)) else: self.register_parameter('bias', None) self.reset_parameters()
[docs] def reset_parameters(self): self.lin.reset_parameters()
[docs] def forward(self, x: Tensor, U: Tensor, V: Optional[Tensor] = None): """""" # NOTE: torch_sparse.SparseTensor is not supported x = self.lin(x) if isinstance(U, Tensor) and U.dtype == torch.long: edge_index, edge_weight = U, V if self.add_self_loops: edge_index, edge_weight = add_self_loops( edge_index, edge_weight, num_nodes=x.size(0)) if self.normalize: edge_index, edge_weight = gcn_norm( # yapf: disable edge_index, edge_weight, x.size(0), False, add_self_loops=False, dtype=x.dtype) x = spmm(x, edge_index, edge_weight) elif V is None: adj = U if self.add_self_loops: adj = dense_add_self_loops(adj) if self.normalize: adj = dense_gcn_norm(adj, add_self_loops=False) x = adj @ x else: x = (U * V) @ (U.t() @ x) if self.bias is not None: x += self.bias return x
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels})')
# class SATConv(nn.Module): # def __init__(self, in_channels: int, out_channels: int, # K: int = 5, # alpha: float = 0.1, # add_self_loops: bool = True, # normalize: bool = True, # bias: bool = False): # super().__init__() # self.in_channels = in_channels # self.out_channels = out_channels # self.K = K # self.alpha = alpha # self.add_self_loops = add_self_loops # self.normalize = normalize # self.lin = Linear(in_channels, out_channels, bias=False, # weight_initializer='glorot') # if bias: # self.bias = nn.Parameter(torch.Tensor(out_channels)) # else: # self.register_parameter('bias', None) # self.reset_parameters() # def reset_parameters(self): # self.lin.reset_parameters() # def forward(self, x: Tensor, U: Tensor, V: Optional[Tensor] = None): # is_edge_like = is_edge_index(U) # x = self.lin(x) # if is_edge_like: # edge_index, edge_weight = U, V # if self.normalize: # edge_index, edge_weight = gcn_norm( # yapf: disable # edge_index, edge_weight, x.size(0), False, # self.add_self_loops, dtype=x.dtype) # x_out = self.alpha*x # for _ in range(self.K): # x = spmm(x, edge_index, edge_weight) # x_out = x_out + (1 - self.alpha)/self.K * x # elif V is None: # adj = U # if self.normalize: # adj = dense_gcn_norm(adj, add_self_loops=self.add_self_loops) # x_in = x # x_out = torch.zeros_like(x) # for _ in range(self.K): # x = adj @ x # x_out += (1 - self.alpha) * x # x_out /= self.K # x_out += self.alpha * x_in # else: # V_out = 0. # V_pow = 1. # for _ in range(self.K): # V_pow = V_pow * V # V_out = V_out + (1 - self.alpha) / self.K * V_pow # x_out = (U * V_out) @ (U.t() @ x) + self.alpha * x # if self.bias is not None: # x_out += self.bias # return x_out # def __repr__(self) -> str: # return (f'{self.__class__.__name__}({self.in_channels}, ' # f'{self.out_channels}, K={self.K}, alpha={self.alpha})')