Source code for greatx.nn.layers.soft_median_conv
from typing import Optional, Tuple
import torch
from torch import Tensor, nn
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import OptTensor
from torch_geometric.utils import add_self_loops, coalesce
from torch_sparse import SparseTensor
try:
from glcore import dimmedian_idx # noqa
except (ModuleNotFoundError, ImportError):
dimmedian_idx = None
[docs]class SoftMedianConv(nn.Module):
r"""The graph convolutional operator with soft
median aggregation from the `"Robustness of Graph Neural Networks
at Scale" <https://arxiv.org/abs/2110.14038>`_ paper (NeurIPS'21)
Parameters
----------
in_channels : int
dimensions of int samples
out_channels : int
dimensions of output samples
cached : bool, optional
whether the layer will cache
the computation of :math:`(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}}
\mathbf{\hat{D}}^{-1/2})` and sorted edges on first execution,
and will use the cached version for further executions,
by default False
add_self_loops : bool, optional
whether to add self-loops to the input graph, by default True
normalize : bool, optional
whether to compute symmetric normalization
coefficients on the fly, by default False
row_normalize : bool, optional
whether to perform row-normalization on the fly, by default True
bias : bool, optional
whether to use bias in the layers, by default True
Raises
------
RuntimeWarning
if module `"glcore" <https://github.com/EdisonLeeeee/glcore>`_
is not properly installed.
Note
----
The input edges must be sorted for :meth:`dimmedian_idx`
from :class:`glcore`
See also
--------
:class:`greatx.nn.models.supervised.SoftMedianGCN`
"""
_cached_edges: Optional[Tuple[Tensor, Tensor]] = None
def __init__(self, in_channels: int, out_channels: int,
cached: bool = False, add_self_loops: bool = True,
normalize: bool = False, row_normalize: bool = True,
bias: bool = True):
super().__init__()
if dimmedian_idx is None:
raise RuntimeWarning(
"Module 'glcore' is not properly installed, "
"please refer 'https://github.com/EdisonLeeeee/glcore' "
"for more information.")
self.in_channels = in_channels
self.out_channels = out_channels
self.cached = cached
self.add_self_loops = add_self_loops
self.normalize = normalize
self.row_normalize = row_normalize
self.lin = Linear(in_channels, out_channels, bias=False,
weight_initializer='glorot')
if bias:
self.bias = nn.Parameter(torch.Tensor(out_channels))
else:
self.register_parameter('bias', None)
self.reset_parameters()
[docs] def cache_clear(self):
"""Clear cached inputs or intermediate results."""
self._cached_edges = None
return self
[docs] def forward(self, x: Tensor, edge_index: Tensor,
edge_weight: OptTensor = None) -> Tensor:
""""""
x = self.lin(x)
if self._cached_edges is not None:
edge_index, edge_weight = self._cached_edges
else:
# NOTE: we do not support Dense adjacency matrix here
if isinstance(edge_index, SparseTensor):
row, col, edge_weight = edge_index.coo()
edge_index = torch.stack([row, col], dim=0)
if self.add_self_loops:
edge_index, edge_weight = add_self_loops(
edge_index, edge_weight, num_nodes=x.size(0))
if self.normalize:
edge_index, edge_weight = gcn_norm(edge_index, edge_weight,
x.size(0), improved=False,
add_self_loops=False,
dtype=x.dtype)
if edge_weight is None:
edge_weight = x.new_ones(edge_index.size(1))
edge_index, edge_weight = coalesce(edge_index, edge_weight)
# cache edges
if self.cached:
self._cached_edges = edge_index, edge_weight
x = soft_median_reduce(x, edge_index, edge_weight)
# Normalization and calculation of new embeddings
if self.row_normalize:
row_sum = edge_weight.new_zeros(x.size(0))
row_sum.scatter_add_(0, edge_index[0], edge_weight)
x = row_sum.view(-1, 1) * x
if self.bias is not None:
x = x + self.bias
return x
def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels})')
def soft_median_reduce(x: Tensor, edge_index: Tensor,
edge_weight: Tensor) -> Tensor:
"""weighted dimension-wise Median aggregation"""
assert edge_weight is not None
row, col = edge_index
N, D = x.size()
median_idx = dimmedian_idx(x, row, col, edge_weight, N)
col_idx = torch.arange(D, device=row.device).view(1, -1).expand(N, D)
x = x[median_idx, col_idx]
return x