from typing import List
import torch.nn as nn
from torch_geometric.nn import JumpingKnowledge
from greatx.functional import spmm
from greatx.nn.layers import GCNConv, Sequential, activations
from greatx.nn.layers.gcn_conv import make_gcn_norm
from greatx.utils import wrapper
[docs]class JKNet(nn.Module):
r"""Implementation of Graph Convolution Network with
Jumping knowledge (JKNet) from
the `"Representation Learning on Graphs with
Jumping Knowledge Networks"
<https://arxiv.org/abs/1806.03536>`_ paper (ICML'18)
Parameters
----------
in_channels : int,
the input dimensions of model
out_channels : int,
the output dimensions of model
hids : List[int], optional
the number of hidden units for each hidden layer,
by default [16, 16, 16]
acts : List[str], optional
the activation function for each hidden layer,
by default ['relu', 'relu', 'relu']
dropout : float, optional
the dropout ratio of model, by default 0.5
mode : str, optional
the mode of jumping knowledge,
including 'cat', 'lstm', and 'max',
bias : bool, optional
whether to use bias in the layers, by default True
bn: bool, optional
whether to use :class:`BatchNorm1d` after the convolution layer,
by default False
Note
----
To accept a different graph as inputs, please call :meth:`cache_clear`
first to clear cached results.
Examples
--------
>>> # JKNet with five hidden layers
>>> model = JKNet(100, 10, hids=[16]*5)
"""
@wrapper
def __init__(self, in_channels: int, out_channels: int,
hids: List[int] = [16] * 3, acts: List[str] = ['relu'] * 3,
dropout: float = 0.5, mode: str = 'cat', bn: bool = False,
bias: bool = True):
super().__init__()
self.mode = mode
num_JK_layers = len(list(hids)) - 1 # number of JK layers
if num_JK_layers <= 1 or len(set(hids)) != 1:
raise ValueError("The number of hidden layers "
"should be greater than 2 and the "
"hidden units must be equal.")
conv = []
assert len(hids) == len(acts)
for hid, act in zip(hids, acts):
block = []
block.append(nn.Dropout(dropout))
block.append(GCNConv(in_channels, hid, bias=bias))
if bn:
block.append(nn.BatchNorm1d(hid))
block.append(activations.get(act))
conv.append(Sequential(*block))
in_channels = hid
# `loc=1` specifies the location of features.
self.conv = Sequential(*conv)
assert len(conv) == num_JK_layers + 1
if self.mode == 'lstm':
self.jump = JumpingKnowledge(mode, hid, num_JK_layers)
else:
self.jump = JumpingKnowledge(mode)
if self.mode == 'cat':
hid = hid * (num_JK_layers + 1)
self.mlp = nn.Linear(hid, out_channels, bias=bias)
[docs] def reset_parameters(self):
self.conv.reset_parameters()
if self.mode == 'lstm':
self.lstm.reset_parameters()
self.attn.reset_parameters()
self.mlp.reset_parameters()
[docs] def forward(self, x, edge_index, edge_weight=None):
xs = []
for conv in self.conv:
x = conv(x, edge_index, edge_weight)
xs.append(x)
x = self.jump(xs)
edge_index, edge_weight = make_gcn_norm(edge_index, edge_weight,
num_nodes=x.size(0),
dtype=x.dtype,
add_self_loops=True)
out = spmm(x, edge_index, edge_weight)
return self.mlp(out)