import math
import torch
from torch import Tensor, nn
[docs]class TensorGCNConv(nn.Module):
r"""The rotbust tensor graph convolutional operator from
the `"Robust Tensor Graph Convolutional Networks
via T-SVD based Graph Augmentation"
<https://dl.acm.org/doi/abs/10.1145/3534678.3539436>`_ paper (KDD'22)
Parameters
----------
in_channels : int
dimensions of int samples
out_channels : int
dimensions of output samples
num_nodes : int
number of input nodes
num_channels : int
number of input channels (adjacency matrixs)
bias : bool, optional
whether to use bias in the layers, by default True
See also
--------
:class:`greatx.nn.models.supervised.RTGCN`
"""
def __init__(self, in_channels: int, out_channels: int, num_nodes: int,
num_channels: int, bias: bool = True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.num_channels = num_channels
self.weight = nn.Parameter(
torch.Tensor(in_channels, out_channels, num_channels))
if bias:
self.bias = nn.Parameter(
torch.Tensor(num_nodes, out_channels, num_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
[docs] def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
[docs] def forward(self, x: Tensor, adjs: Tensor) -> Tensor:
""""""
if x.ndim == 2:
x = x.repeat(adjs.size(-1), 1, 1).permute(1, 2, 0)
x = self.fft_product(x, self.weight)
out = self.fft_product(adjs, x)
if self.bias is not None:
out += self.bias
return out
[docs] @staticmethod
def fft_product(X, Y):
X = torch.fft.fft(X)
Y = torch.fft.fft(Y)
Z = torch.fft.ifft(torch.einsum('ijk,jrk->irk', X, Y))
return Z.real
def __repr__(self) -> str:
return (f'{self.__class__.__name__}(({self.in_channels}, '
f'{self.num_channels}), '
f'({self.out_channels}, {self.num_channels}))')
[docs]class TensorLinear(nn.Module):
r"""The tensor linear operator from
the `"Robust Tensor Graph Convolutional Networks
via T-SVD based Graph Augmentation"
<https://dl.acm.org/doi/abs/10.1145/3534678.3539436>`_ paper (KDD'22)
Parameters
----------
in_channels : int
dimensions of int samples
out_channels : int
dimensions of output samples
bias : bool, optional
whether to use bias in the layers, by default True
See also
--------
:class:`greatx.nn.models.supervised.RTGCN`
"""
def __init__(
self,
in_channels: int,
num_nodes: int,
num_channels: int,
bias: bool = True,
):
super().__init__()
self.in_channels = in_channels
self.num_channels = num_channels
self.weight = nn.Parameter(torch.Tensor(num_channels, 1))
if bias:
self.bias = nn.Parameter(torch.Tensor(num_nodes, in_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
[docs] def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
[docs] def forward(self, x):
""""""
out = torch.einsum('ijk,kr->ijr', x, self.weight).squeeze()
if self.bias is not None:
out += self.bias
return out