Source code for greatx.nn.layers.elastic_conv

from typing import Optional

import torch
from torch import Tensor, nn
from torch_geometric.typing import Adj, OptTensor
from torch_geometric.utils import degree
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_sparse import SparseTensor, mul

from greatx.functional import spmm
from greatx.nn.layers.gcn_conv import make_gcn_norm, make_self_loops


def get_inc(edge_index: Adj, num_nodes: Optional[int] = None) -> SparseTensor:
    """Compute the incident matrix
    """
    device = edge_index.device
    if torch.is_tensor(edge_index):
        row_index, col_index = edge_index
        num_nodes = maybe_num_nodes(edge_index, num_nodes)
    else:
        row_index = edge_index.storage.row()
        col_index = edge_index.storage.col()
        num_nodes = edge_index.sizes()[1]

    mask = row_index > col_index  # remove duplicate edge and self loop

    row_index = row_index[mask]
    col_index = col_index[mask]
    num_edges = row_index.numel()

    row = torch.cat([
        torch.arange(num_edges, device=device),
        torch.arange(num_edges, device=device)
    ])
    col = torch.cat([row_index, col_index])
    value = torch.cat([
        torch.ones(num_edges, device=device),
        -torch.ones(num_edges, device=device)
    ])
    inc_mat = SparseTensor(row=row, rowptr=None, col=col, value=value,
                           sparse_sizes=(num_edges, num_nodes))
    return inc_mat


def inc_norm(inc: SparseTensor, edge_index: Adj,
             num_nodes: Optional[int] = None) -> SparseTensor:
    """Normalize the incident matrix
    """

    if torch.is_tensor(edge_index):
        deg = degree(edge_index[0], num_nodes=num_nodes,
                     dtype=torch.float).clamp(min=1)
    else:
        deg = edge_index.sum(1).clamp(min=1)

    deg_inv_sqrt = deg.pow(-0.5)
    inc = mul(inc, deg_inv_sqrt.view(1, -1))  # col-wise
    return inc


[docs]class ElasticConv(nn.Module): r""" The ElasticGNN operator from the `"Elastic Graph Neural Networks" <https://arxiv.org/abs/2107.06996>`_ paper (ICML'21) Parameters ---------- K : int, optional the number of propagation steps, by default 3 lambda_amp : float, optional trade-off of adaptive message passing, by default 0.1 normalize : bool, optional Whether to add self-loops and compute symmetric normalization coefficients on the fly, by default True add_self_loops : bool, optional whether to add self-loops to the input graph, by default True lambda1 : float, optional trade-off hyperparameter, by default 3 lambda2 : float, optional trade-off hyperparameter, by default 3 L21 : bool, optional whether to use row-wise projection on the l2 ball of radius λ1., by default True cached : bool, optional whether to cache the incident matrix, by default True See also -------- :class:`greatx.nn.models.supervised.ElasticGNN` """ _cached: Optional[SparseTensor] = None # incident matrix def __init__(self, K: int = 3, lambda_amp: float = 0.1, normalize: bool = True, add_self_loops: bool = True, lambda1: float = 3., lambda2: float = 3., L21: bool = True, cached: bool = True): super().__init__() self.K = K self.lambda_amp = lambda_amp self.add_self_loops = add_self_loops self.normalize = normalize self.lambda1 = lambda1 self.lambda2 = lambda2 self.L21 = L21 self.cached = cached
[docs] def reset_parameters(self): self.cache_clear()
[docs] def cache_clear(self): """Clear cached inputs or intermediate results.""" self._cached = None return self
[docs] def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" cache = self._cached if cache is None: if self.add_self_loops: # NOTE: we do not support Dense adjacency matrix here edge_index, edge_weight = make_self_loops( edge_index, edge_weight, num_nodes=x.size(0)) if self.normalize: # NOTE: we do not support Dense adjacency matrix here edge_index, edge_weight = make_gcn_norm( edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype, add_self_loops=False) # compute incident matrix before normalizing edge_index inc_mat = get_inc(edge_index, num_nodes=x.size(0)) # normalize incident matrix inc_mat = inc_norm(inc_mat, edge_index, num_nodes=x.size(0)) if self.cached: self._cached = (inc_mat, edge_index, edge_weight) self.init_z = x.new_zeros((inc_mat.sizes()[0], x.size()[-1])) else: inc_mat, edge_index, edge_weight = self._cached return self.emp_forward(x, inc_mat, edge_index, edge_weight)
[docs] def emp_forward(self, x: Tensor, inc_mat: SparseTensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: lambda1 = self.lambda1 lambda2 = self.lambda2 gamma = 1 / (1 + lambda2) beta = 1 / (2 * gamma) hh = x if lambda1: z = self.init_z for k in range(self.K): if lambda2: out = spmm(x, edge_index, edge_weight) y = gamma * hh + (1 - gamma) * out else: y = gamma * hh + (1 - gamma) * x # y = x - gamma * (x - hh) if lambda1: x_bar = y - gamma * (inc_mat.t() @ z) z_bar = z + beta * (inc_mat @ x_bar) if self.L21: z = self.L21_projection(z_bar, lambda_=lambda1) else: z = self.L1_projection(z_bar, lambda_=lambda1) x = y - gamma * (inc_mat.t() @ z) else: x = y # z=0 return x
[docs] def L1_projection(self, x: Tensor, lambda_: float) -> Tensor: """component-wise projection onto the l∞ ball of radius λ1.""" return torch.clamp(x, min=-lambda_, max=lambda_)
[docs] def L21_projection(self, x: Tensor, lambda_: float) -> Tensor: # row-wise projection on the l2 ball of radius λ1. row_norm = torch.norm(x, p=2, dim=1) scale = torch.clamp(row_norm, max=lambda_) index = row_norm > 0 scale[index] = scale[index] / \ row_norm[index] # avoid to be devided by 0 return scale.unsqueeze(1) * x
def __repr__(self) -> str: return f"{self.__class__.__name__}(K={self.K})"