Source code for greatx.nn.layers.spiking_gcn_conv

from typing import Optional

from torch import Tensor, nn
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptTensor

from greatx.functional import spmm
from greatx.nn.layers import LIF, PoissonEncoder
from greatx.nn.layers.gcn_conv import make_gcn_norm, make_self_loops


[docs]class SpikingGCNonv(nn.Module): r"""The spiking graph convolutional operator from the `"Spiking Graph Convolutional Networks" <https://arxiv.org/abs/2205.02767>`_ paper (IJCAI'22) Parameters ---------- in_channels : int dimensions of int samples out_channels : int dimensions of output samples K : int the number of propagation steps, by default 1 T : int the number of time steps, by default 20 tau : float the :math:`\tau` in LIF neuron, by default 1.0 v_threshold : float the threshold :math:`V_{th}` in LIF neuron, by default 1.0 v_reset : float the reset level :math:`V_{reset}` in LIF neuron, by default 0 cached : bool, optional whether the layer will cache the computation of :math:`(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2})^K` on first execution, and will use the cached version for further executions, by default False add_self_loops : bool, optional whether to add self-loops to the input graph, by default True normalize : bool, optional whether to compute symmetric normalization coefficients on the fly, by default True bias : bool, optional whether to use bias in the layers, by default True Note ---- Different from that in :class:`torch_geometric`, for the input :obj:`edge_index`, our implementation supports :obj:`torch.FloatTensor`, :obj:`torch.LongTensor` and obj:`torch_sparse.SparseTensor`. See also -------- :class:`greatx.nn.models.supervised.SpikingGCN` """ _cached_x: Optional[Tensor] def __init__( self, in_channels: int, out_channels: int, K: int = 1, T: int = 20, tau: float = 1.0, v_threshold: float = 1.0, v_reset: float = 0., cached: bool = False, add_self_loops: bool = True, normalize: bool = True, bias: bool = True, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.K = K self.cached = cached self.add_self_loops = add_self_loops self.normalize = normalize self._cached_x = None self.lin = Linear(in_channels, out_channels, bias=bias, weight_initializer='glorot') assert T > 0 self.T = T self.encoder = PoissonEncoder() self.neuron = LIF(tau=tau, v_threshold=v_threshold, v_reset=v_reset) self.reset_parameters()
[docs] def reset_parameters(self): self.lin.reset_parameters() self.cache_clear() self.neuron.reset()
[docs] def cache_clear(self): """Clear cached inputs or intermediate results.""" self._cached_x = None return self
[docs] def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor: """""" cache = self._cached_x if cache is None: if self.add_self_loops: edge_index, edge_weight = make_self_loops( edge_index, edge_weight, num_nodes=x.size(0)) if self.normalize: edge_index, edge_weight = make_gcn_norm( edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype, add_self_loops=False) for k in range(self.K): x = spmm(x, edge_index, edge_weight) if self.cached: self._cached_x = x else: x = cache.detach() spikes = 0. for t in range(self.T): out = self.lin(self.encoder(x)) spikes += self.neuron(out) self.neuron.reset() return spikes / self.T
def __repr__(self) -> str: return (f'{self.__class__.__name__}({self.in_channels}, ' f'{self.out_channels}, K={self.K})')