from typing import Optional
import torch
from torch import Tensor, nn
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.utils import add_self_loops
from greatx.functional import spmm
from greatx.nn.layers.gcn_conv import dense_add_self_loops, dense_gcn_norm
[docs]class SATConv(nn.Module):
r"""The spectral adversarial training operator
from the `"Spectral Adversarial Training for Robust Graph Neural Network"
<https://arxiv.org/>`_ paper (arXiv'22)
Parameters
----------
in_channels : int
dimensions of int samples
out_channels : int
dimensions of output samples
add_self_loops : bool, optional
whether to add self-loops to the input graph, by default True
normalize : bool, optional
whether to compute symmetric normalization
coefficients on the fly, by default True
bias : bool, optional
whether to use bias in the layers, by default True
Note
----
For the inputs :obj:`x`, :obj:`U`, and :obj:`V`,
our implementation supports:
(1) :obj:`U` is :class:`torch.LongTensor`,
denoting edge indices with shape :obj:`[2, M]`;
(2) :obj:`U` is :class:`torch.FloatTensor` and :obj:`V` is :obj:`None`,
denoting dense matrix with shape :obj:`[N, N]`;
(3) :obj:`U` and :obj:`V` are :class:`torch.FloatTensor`,
denoting eigenvector and corresponding eigenvalues.
See also
--------
:class:`greatx.nn.models.supervised.SAT`
"""
def __init__(self, in_channels: int, out_channels: int,
add_self_loops: bool = True, normalize: bool = True,
bias: bool = False):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.add_self_loops = add_self_loops
self.normalize = normalize
self.lin = Linear(in_channels, out_channels, bias=False,
weight_initializer='glorot')
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
[docs] def reset_parameters(self):
self.lin.reset_parameters()
[docs] def forward(self, x: Tensor, U: Tensor, V: Optional[Tensor] = None):
""""""
# NOTE: torch_sparse.SparseTensor is not supported
x = self.lin(x)
if isinstance(U, Tensor) and U.dtype == torch.long:
edge_index, edge_weight = U, V
if self.add_self_loops:
edge_index, edge_weight = add_self_loops(
edge_index, edge_weight, num_nodes=x.size(0))
if self.normalize:
edge_index, edge_weight = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(0), False,
add_self_loops=False, dtype=x.dtype)
x = spmm(x, edge_index, edge_weight)
elif V is None:
adj = U
if self.add_self_loops:
adj = dense_add_self_loops(adj)
if self.normalize:
adj = dense_gcn_norm(adj, add_self_loops=False)
x = adj @ x
else:
x = (U * V) @ (U.t() @ x)
if self.bias is not None:
x += self.bias
return x
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels})')
# class SATConv(nn.Module):
# def __init__(self, in_channels: int, out_channels: int,
# K: int = 5,
# alpha: float = 0.1,
# add_self_loops: bool = True,
# normalize: bool = True,
# bias: bool = False):
# super().__init__()
# self.in_channels = in_channels
# self.out_channels = out_channels
# self.K = K
# self.alpha = alpha
# self.add_self_loops = add_self_loops
# self.normalize = normalize
# self.lin = Linear(in_channels, out_channels, bias=False,
# weight_initializer='glorot')
# if bias:
# self.bias = nn.Parameter(torch.Tensor(out_channels))
# else:
# self.register_parameter('bias', None)
# self.reset_parameters()
# def reset_parameters(self):
# self.lin.reset_parameters()
# def forward(self, x: Tensor, U: Tensor, V: Optional[Tensor] = None):
# is_edge_like = is_edge_index(U)
# x = self.lin(x)
# if is_edge_like:
# edge_index, edge_weight = U, V
# if self.normalize:
# edge_index, edge_weight = gcn_norm( # yapf: disable
# edge_index, edge_weight, x.size(0), False,
# self.add_self_loops, dtype=x.dtype)
# x_out = self.alpha*x
# for _ in range(self.K):
# x = spmm(x, edge_index, edge_weight)
# x_out = x_out + (1 - self.alpha)/self.K * x
# elif V is None:
# adj = U
# if self.normalize:
# adj = dense_gcn_norm(adj, add_self_loops=self.add_self_loops)
# x_in = x
# x_out = torch.zeros_like(x)
# for _ in range(self.K):
# x = adj @ x
# x_out += (1 - self.alpha) * x
# x_out /= self.K
# x_out += self.alpha * x_in
# else:
# V_out = 0.
# V_pow = 1.
# for _ in range(self.K):
# V_pow = V_pow * V
# V_out = V_out + (1 - self.alpha) / self.K * V_pow
# x_out = (U * V_out) @ (U.t() @ x) + self.alpha * x
# if self.bias is not None:
# x_out += self.bias
# return x_out
# def __repr__(self) -> str:
# return (f'{self.__class__.__name__}({self.in_channels}, '
# f'{self.out_channels}, K={self.K}, alpha={self.alpha})')