import inspect
import torch.nn as nn
[docs]class Sequential(nn.Sequential):
"""A modified :class:`torch.nn.Sequential` which can accept multiple inputs.
Parameters
----------
loc : int, optional
the location of feature input :obj:`x`, by default 0
Example
-------
>>> import torch
>>> from greatx.nn.layers import Sequential, GCNConv
>>> edge_index = torch.LongTensor([[1, 2], [3,4]]) # size [2, M]
>>> x = torch.randn(5, 20)
>>> conv1 = GCNConv(20, 50)
>>> conv2 = GCNConv(50, 5)
>>> dropout1 = torch.nn.Dropout(0.5)
>>> dropout2 = torch.nn.Dropout(0.6)
>>> # Case 1: standard usage
>>> sequential = Sequential(dropout1, conv1, dropout2, conv2)
>>> sequential(x, edge_index)
tensor([[ 0.6738, -0.9032, -0.9628, 0.0670, 0.0252],
[ 0.4909, -1.2430, -0.6029, 0.0510, 0.2107],
[ 0.6338, -0.2760, -0.9112, -0.3197, 0.2689],
[ 0.4909, -1.2430, -0.6029, 0.0510, 0.2107],
[ 0.3876, -0.6385, -0.5521, -0.2753, 0.6713]], grad_fn=<AddBackward0>)
>>> # which is equivalent to:
>>> h1 = dropout1(x)
>>> h2 = conv1(h1, edge_index)
>>> h3 = dropout2(h2)
>>> h4 = conv2(h3, edge_index)
>>> # Case 2: with keyword argument
>>> sequential(x, edge_index, edge_weight=torch.ones(20))
tensor([[ 0.6738, -0.9032, -0.9628, 0.0670, 0.0252],
[ 0.4909, -1.2430, -0.6029, 0.0510, 0.2107],
[ 0.6338, -0.2760, -0.9112, -0.3197, 0.2689],
[ 0.4909, -1.2430, -0.6029, 0.0510, 0.2107],
[ 0.3876, -0.6385, -0.5521, -0.2753, 0.6713]], grad_fn=<AddBackward0>)
>>> # which is equivalent to:
>>> h1 = dropout1(x)
>>> h2 = conv1(x, edge_index, edge_weight=torch.ones(20))
>>> h3 = dropout2(h2)
>>> h4 = conv2(x, edge_index, edge_weight=torch.ones(20))
Note
----
* The argument :obj:`loc` must be specified as the location of feature input :obj:`x`, which would walk through the whole layers.
* The usage of keyword argument must be matched with that of the layers with more than one arguments required.
"""
def __init__(self, *args, loc: int = 0):
super().__init__(*args)
self.loc = loc
para_required = []
for module in self:
assert hasattr(module, "forward"), module
para_required.append(inspect.signature(module.forward).parameters)
self._para_required = para_required
[docs] def forward(self, *inputs, **kwargs):
""""""
loc = self.loc
assert loc <= len(inputs)
output = inputs[loc]
for module, para_required in zip(self, self._para_required):
if len(para_required) == 1:
input = inputs[loc]
if isinstance(input, tuple):
output = tuple(module(_input) for _input in input)
else:
output = module(input)
else:
output = module(*inputs, **kwargs)
inputs = inputs[:loc] + (output,) + inputs[loc + 1:]
return output
[docs] def reset_parameters(self):
for layer in self:
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()