from typing import Optional
import torch
from torch import Tensor, nn
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import degree
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_sparse import SparseTensor, mul
from greatx.functional import spmm
from greatx.nn.layers.gcn_conv import make_gcn_norm, make_self_loops
def get_inc(edge_index: Adj, num_nodes: Optional[int] = None) -> SparseTensor:
"""Compute the incident matrix
"""
device = edge_index.device
if torch.is_tensor(edge_index):
row_index, col_index = edge_index
num_nodes = maybe_num_nodes(edge_index, num_nodes)
else:
row_index = edge_index.storage.row()
col_index = edge_index.storage.col()
num_nodes = edge_index.sizes()[1]
mask = row_index > col_index # remove duplicate edge and self loop
row_index = row_index[mask]
col_index = col_index[mask]
num_edges = row_index.numel()
row = torch.cat([
torch.arange(num_edges, device=device),
torch.arange(num_edges, device=device)
])
col = torch.cat([row_index, col_index])
value = torch.cat([
torch.ones(num_edges, device=device),
-torch.ones(num_edges, device=device)
])
inc_mat = SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(num_edges, num_nodes))
return inc_mat
def inc_norm(inc: SparseTensor, edge_index: Adj,
num_nodes: Optional[int] = None) -> SparseTensor:
"""Normalize the incident matrix
"""
if torch.is_tensor(edge_index):
deg = degree(edge_index[0], num_nodes=num_nodes,
dtype=torch.float).clamp(min=1)
else:
deg = edge_index.sum(1).clamp(min=1)
deg_inv_sqrt = deg.pow(-0.5)
inc = mul(inc, deg_inv_sqrt.view(1, -1)) # col-wise
return inc
[docs]class ElasticConv(nn.Module):
r"""
The ElasticGNN operator from the `"Elastic Graph Neural
Networks" <https://arxiv.org/abs/2107.06996>`_
paper (ICML'21)
Parameters
----------
K : int, optional
the number of propagation steps, by default 3
lambda_amp : float, optional
trade-off of adaptive message passing, by default 0.1
normalize : bool, optional
Whether to add self-loops and compute
symmetric normalization coefficients on the fly, by default True
add_self_loops : bool, optional
whether to add self-loops to the input graph, by default True
lambda1 : float, optional
trade-off hyperparameter, by default 3
lambda2 : float, optional
trade-off hyperparameter, by default 3
L21 : bool, optional
whether to use row-wise projection
on the l2 ball of radius λ1., by default True
cached : bool, optional
whether to cache the incident matrix, by default True
See also
--------
:class:`greatx.nn.models.supervised.ElasticGNN`
"""
_cached: Optional[SparseTensor] = None # incident matrix
def __init__(self, K: int = 3, lambda_amp: float = 0.1,
normalize: bool = True, add_self_loops: bool = True,
lambda1: float = 3., lambda2: float = 3., L21: bool = True,
cached: bool = True):
super().__init__()
self.K = K
self.lambda_amp = lambda_amp
self.add_self_loops = add_self_loops
self.normalize = normalize
self.lambda1 = lambda1
self.lambda2 = lambda2
self.L21 = L21
self.cached = cached
[docs] def reset_parameters(self):
self.cache_clear()
[docs] def cache_clear(self):
"""Clear cached inputs or intermediate results."""
self._cached = None
return self
[docs] def forward(self, x: Tensor, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:
""""""
cache = self._cached
if cache is None:
if self.add_self_loops:
# NOTE: we do not support Dense adjacency matrix here
edge_index, edge_weight = make_self_loops(
edge_index, edge_weight, num_nodes=x.size(0))
if self.normalize:
# NOTE: we do not support Dense adjacency matrix here
edge_index, edge_weight = make_gcn_norm(
edge_index, edge_weight, num_nodes=x.size(0),
dtype=x.dtype, add_self_loops=False)
# compute incident matrix before normalizing edge_index
inc_mat = get_inc(edge_index, num_nodes=x.size(0))
# normalize incident matrix
inc_mat = inc_norm(inc_mat, edge_index, num_nodes=x.size(0))
if self.cached:
self._cached = (inc_mat, edge_index, edge_weight)
self.init_z = x.new_zeros((inc_mat.sizes()[0], x.size()[-1]))
else:
inc_mat, edge_index, edge_weight = self._cached
return self.emp_forward(x, inc_mat, edge_index, edge_weight)
[docs] def emp_forward(self, x: Tensor, inc_mat: SparseTensor, edge_index: Adj,
edge_weight: OptTensor = None) -> Tensor:
lambda1 = self.lambda1
lambda2 = self.lambda2
gamma = 1 / (1 + lambda2)
beta = 1 / (2 * gamma)
hh = x
if lambda1:
z = self.init_z
for k in range(self.K):
if lambda2:
out = spmm(x, edge_index, edge_weight)
y = gamma * hh + (1 - gamma) * out
else:
y = gamma * hh + (1 - gamma) * x # y = x - gamma * (x - hh)
if lambda1:
x_bar = y - gamma * (inc_mat.t() @ z)
z_bar = z + beta * (inc_mat @ x_bar)
if self.L21:
z = self.L21_projection(z_bar, lambda_=lambda1)
else:
z = self.L1_projection(z_bar, lambda_=lambda1)
x = y - gamma * (inc_mat.t() @ z)
else:
x = y # z=0
return x
[docs] def L1_projection(self, x: Tensor, lambda_: float) -> Tensor:
"""component-wise projection onto the l∞ ball of radius λ1."""
return torch.clamp(x, min=-lambda_, max=lambda_)
[docs] def L21_projection(self, x: Tensor, lambda_: float) -> Tensor:
# row-wise projection on the l2 ball of radius λ1.
row_norm = torch.norm(x, p=2, dim=1)
scale = torch.clamp(row_norm, max=lambda_)
index = row_norm > 0
scale[index] = scale[index] / \
row_norm[index] # avoid to be devided by 0
return scale.unsqueeze(1) * x
def __repr__(self) -> str:
return f"{self.__class__.__name__}(K={self.K})"