Source code for greatx.nn.layers.container

import inspect

import torch.nn as nn

[docs]class Sequential(nn.Sequential): """A modified :class:`torch.nn.Sequential` which can accept multiple inputs. Parameters ---------- loc : int, optional the location of feature input :obj:`x`, by default 0 Example ------- >>> import torch >>> from greatx.nn.layers import Sequential, GCNConv >>> edge_index = torch.LongTensor([[1, 2], [3,4]]) # size [2, M] >>> x = torch.randn(5, 20) >>> conv1 = GCNConv(20, 50) >>> conv2 = GCNConv(50, 5) >>> dropout1 = torch.nn.Dropout(0.5) >>> dropout2 = torch.nn.Dropout(0.6) >>> # Case 1: standard usage >>> sequential = Sequential(dropout1, conv1, dropout2, conv2) >>> sequential(x, edge_index) tensor([[ 0.6738, -0.9032, -0.9628, 0.0670, 0.0252], [ 0.4909, -1.2430, -0.6029, 0.0510, 0.2107], [ 0.6338, -0.2760, -0.9112, -0.3197, 0.2689], [ 0.4909, -1.2430, -0.6029, 0.0510, 0.2107], [ 0.3876, -0.6385, -0.5521, -0.2753, 0.6713]], grad_fn=<AddBackward0>) >>> # which is equivalent to: >>> h1 = dropout1(x) >>> h2 = conv1(h1, edge_index) >>> h3 = dropout2(h2) >>> h4 = conv2(h3, edge_index) >>> # Case 2: with keyword argument >>> sequential(x, edge_index, edge_weight=torch.ones(20)) tensor([[ 0.6738, -0.9032, -0.9628, 0.0670, 0.0252], [ 0.4909, -1.2430, -0.6029, 0.0510, 0.2107], [ 0.6338, -0.2760, -0.9112, -0.3197, 0.2689], [ 0.4909, -1.2430, -0.6029, 0.0510, 0.2107], [ 0.3876, -0.6385, -0.5521, -0.2753, 0.6713]], grad_fn=<AddBackward0>) >>> # which is equivalent to: >>> h1 = dropout1(x) >>> h2 = conv1(x, edge_index, edge_weight=torch.ones(20)) >>> h3 = dropout2(h2) >>> h4 = conv2(x, edge_index, edge_weight=torch.ones(20)) Note ---- * The argument :obj:`loc` must be specified as the location of feature input :obj:`x`, which would walk through the whole layers. * The usage of keyword argument must be matched with that of the layers with more than one arguments required. """ def __init__(self, *args, loc: int = 0): super().__init__(*args) self.loc = loc para_required = [] for module in self: assert hasattr(module, "forward"), module para_required.append(inspect.signature(module.forward).parameters) self._para_required = para_required
[docs] def forward(self, *inputs, **kwargs): """""" loc = self.loc assert loc <= len(inputs) output = inputs[loc] for module, para_required in zip(self, self._para_required): if len(para_required) == 1: input = inputs[loc] if isinstance(input, tuple): output = tuple(module(_input) for _input in input) else: output = module(input) else: output = module(*inputs, **kwargs) inputs = inputs[:loc] + (output,) + inputs[loc + 1:] return output
[docs] def reset_parameters(self): for layer in self: if hasattr(layer, 'reset_parameters'): layer.reset_parameters()