Source code for greatx.nn.models.supervised.jknet

from typing import List

import torch.nn as nn
from torch_geometric.nn import JumpingKnowledge

from greatx.functional import spmm
from greatx.nn.layers import GCNConv, Sequential, activations
from greatx.nn.layers.gcn_conv import make_gcn_norm
from greatx.utils import wrapper


[docs]class JKNet(nn.Module): r"""Implementation of Graph Convolution Network with Jumping knowledge (JKNet) from the `"Representation Learning on Graphs with Jumping Knowledge Networks" <https://arxiv.org/abs/1806.03536>`_ paper (ICML'18) Parameters ---------- in_channels : int, the input dimensions of model out_channels : int, the output dimensions of model hids : List[int], optional the number of hidden units for each hidden layer, by default [16, 16, 16] acts : List[str], optional the activation function for each hidden layer, by default ['relu', 'relu', 'relu'] dropout : float, optional the dropout ratio of model, by default 0.5 mode : str, optional the mode of jumping knowledge, including 'cat', 'lstm', and 'max', bias : bool, optional whether to use bias in the layers, by default True bn: bool, optional whether to use :class:`BatchNorm1d` after the convolution layer, by default False Note ---- To accept a different graph as inputs, please call :meth:`cache_clear` first to clear cached results. Examples -------- >>> # JKNet with five hidden layers >>> model = JKNet(100, 10, hids=[16]*5) """ @wrapper def __init__(self, in_channels: int, out_channels: int, hids: List[int] = [16] * 3, acts: List[str] = ['relu'] * 3, dropout: float = 0.5, mode: str = 'cat', bn: bool = False, bias: bool = True): super().__init__() self.mode = mode num_JK_layers = len(list(hids)) - 1 # number of JK layers if num_JK_layers <= 1 or len(set(hids)) != 1: raise ValueError("The number of hidden layers " "should be greater than 2 and the " "hidden units must be equal.") conv = [] assert len(hids) == len(acts) for hid, act in zip(hids, acts): block = [] block.append(nn.Dropout(dropout)) block.append(GCNConv(in_channels, hid, bias=bias)) if bn: block.append(nn.BatchNorm1d(hid)) block.append(activations.get(act)) conv.append(Sequential(*block)) in_channels = hid # `loc=1` specifies the location of features. self.conv = Sequential(*conv) assert len(conv) == num_JK_layers + 1 if self.mode == 'lstm': self.jump = JumpingKnowledge(mode, hid, num_JK_layers) else: self.jump = JumpingKnowledge(mode) if self.mode == 'cat': hid = hid * (num_JK_layers + 1) self.mlp = nn.Linear(hid, out_channels, bias=bias)
[docs] def reset_parameters(self): self.conv.reset_parameters() if self.mode == 'lstm': self.lstm.reset_parameters() self.attn.reset_parameters() self.mlp.reset_parameters()
[docs] def forward(self, x, edge_index, edge_weight=None): xs = [] for conv in self.conv: x = conv(x, edge_index, edge_weight) xs.append(x) x = self.jump(xs) edge_index, edge_weight = make_gcn_norm(edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype, add_self_loops=True) out = spmm(x, edge_index, edge_weight) return self.mlp(out)