from math import pi
from typing import Callable
import torch
import torch.nn as nn
from torch import Tensor
def heaviside(x: Tensor) -> Tensor:
return x.ge(0).float()
class BaseSpike(torch.autograd.Function):
"""Base spiking function.
"""
@staticmethod
def forward(ctx, x, alpha):
ctx.save_for_backward(x, alpha)
return heaviside(x)
@staticmethod
def backward(ctx, grad_output):
raise NotImplementedError
class SuperSpike(BaseSpike):
"""Spike function with SuperSpike surrogate gradient from
"SuperSpike: Supervised Learning in Multilayer
Spiking Neural Networks", Zenke et al. 2018.
Design choices: (1) Height of 1 ("The Remarkable Robustness of
Surrogate Gradient...", Zenke et al. 2021) (2) alpha scaled by 10
("Training Deep Spiking Neural Networks", Ledinauskas et al. 2020)
"""
@staticmethod
def backward(ctx, grad_output):
x, alpha = ctx.saved_tensors
grad_input = grad_output # .clone()
sg = 1 / (1 + alpha * x.abs())**2
return grad_input * sg, None
class MultiGaussSpike(BaseSpike):
"""Spike function with multi-Gaussian surrogate gradient from
"Accurate and efficient time-domain classification", Yin et al. 2021.
Design choices:
- Hyperparameters determined through grid search (Yin et al. 2021)
"""
@staticmethod
def backward(ctx, grad_output):
x, alpha = ctx.saved_tensors
grad_input = grad_output # .clone()
zero = torch.tensor(0.0) # no need to specify device for 0-d tensors
sg = (1.15 * gaussian(x, zero, alpha) -
0.15 * gaussian(x, alpha, 6 * alpha) -
0.15 * gaussian(x, -alpha, 6 * alpha))
return grad_input * sg, None
def gaussian(x: Tensor, mu: Tensor, sigma: Tensor) -> Tensor:
"""Gaussian PDF with broadcasting.
"""
return torch.exp(-((x - mu) * (x - mu)) / (2 * sigma * sigma)) / (
sigma * torch.sqrt(2 * torch.tensor(pi))) # noqa
class TriangleSpike(BaseSpike):
"""Spike function with triangular surrogate gradient
as in Bellec et al. 2020.
"""
@staticmethod
def backward(ctx, grad_output):
x, alpha = ctx.saved_tensors
grad_input = grad_output # .clone()
sg = torch.nn.functional.relu(1 - alpha * x.abs())
return grad_input * sg, None
class ArctanSpike(BaseSpike):
"""Spike function with derivative of arctan surrogate gradient.
Featured in Fang et al. 2020/2021.
"""
@staticmethod
def backward(ctx, grad_output):
x, alpha = ctx.saved_tensors
grad_input = grad_output # .clone()
sg = 1 / (1 + alpha * x * x)
return grad_input * sg, None
class SigmoidSpike(BaseSpike):
@staticmethod
def backward(ctx, grad_output):
x, alpha = ctx.saved_tensors
grad_input = grad_output # .clone()
sgax = (x * alpha).sigmoid_()
sg = (1. - sgax) * sgax * alpha
return grad_input * sg, None
def superspike(x, thresh=torch.tensor(1.0), alpha=torch.tensor(10.0)):
return SuperSpike.apply(x - thresh, alpha)
def mgspike(x, thresh=torch.tensor(1.0), alpha=torch.tensor(0.5)):
return MultiGaussSpike.apply(x - thresh, alpha)
def sigmoidspike(x, thresh=torch.tensor(1.0), alpha=torch.tensor(1.0)):
return SigmoidSpike.apply(x - thresh, alpha)
def trianglespike(x, thresh=torch.tensor(1.0), alpha=torch.tensor(1.0)):
return TriangleSpike.apply(x - thresh, alpha)
def arctanspike(x, thresh=torch.tensor(1.0), alpha=torch.tensor(10.0)):
return ArctanSpike.apply(x - thresh, alpha)
SURROGATE = {
'sigmoid': sigmoidspike,
'triangle': trianglespike,
'arctan': arctanspike,
'mg': mgspike,
'super': superspike,
}
[docs]class PoissonEncoder(nn.Module):
[docs] def forward(self, x: Tensor) -> Tensor:
""""""
out_spike = torch.rand_like(x).le(x).to(x)
return out_spike
def get_surrogate(name: str) -> Callable:
assert name in SURROGATE
surrogate = SURROGATE.get(name, None)
if surrogate is None:
raise ValueError(f"Unsupported surrogate function {name}")
return surrogate
[docs]class IF(nn.Module):
r"""The Integrate-and-Fire (IF) neuron for
spiking neural networks.
Parameters
----------
v_threshold : float, optional
the threshold for emitting a spike, by default 1.0
v_reset : float, optional
the reset level for neuron, by default 0.
alpha : float, optional
the smooth factor for surrogate function, by default 1.0
gamma : float, optional
the threshold decay factor :math:`\gamma`, by default 0.
thresh_decay : float, optional
the threshold decay factor, by default 1.0
surrogate : str, optional
the surrogate function for training spiking neurons,
could one of (:obj:'sigmoid', :obj:'triangle', :obj:'arctan'
:obj:'mg', and :obj:'super'), by default 'sigmoid'
"""
def __init__(
self,
v_threshold: float = 1.0,
v_reset: float = 0.,
alpha: float = 1.0,
gamma: float = 0.,
thresh_decay: float = 1.0,
surrogate: str = 'sigmoid',
):
super().__init__()
self.v_threshold = v_threshold
self.v_reset = v_reset
self.gamma = gamma
self.thresh_decay = thresh_decay
self.surrogate = get_surrogate(surrogate)
self.register_buffer("alpha", torch.as_tensor(alpha,
dtype=torch.float))
self.reset()
[docs] def reset(self):
"""Reset neuron states."""
self.v = 0.
self.v_th = self.v_threshold
[docs] def forward(self, dv: Tensor) -> Tensor:
""""""
# 1. charge
self.v += dv
# 2. fire
spike = self.surrogate(self.v, self.v_threshold, self.alpha)
# 3. reset
self.v = (1 - spike) * self.v + spike * self.v_reset
# 4. threhold updates
self.v_th = self.gamma * spike + self.v_th * self.thresh_decay
return spike
[docs]class LIF(nn.Module):
r"""The Leaky Integrate-and-Fire (LIF) neuron for
spiking neural networks.
Parameters
----------
v_threshold : float, optional
the threshold for emitting a spike, by default 1.0
v_reset : float, optional
the reset level for neuron, by default 0.
tau : float, optional
the leaky factor :math:`\tau` for LIF-based neuron, by default 1.0
alpha : float, optional
the smooth factor for surrogate function, by default 1.0
gamma : float, optional
the threshold decay factor :math:`\gamma`, by default 0.
thresh_decay : float, optional
the threshold decay factor, by default 1.0
surrogate : str, optional
the surrogate function for training spiking neurons,
could one of (:obj:'sigmoid', :obj:'triangle', :obj:'arctan'
:obj:'mg', and :obj:'super'), by default 'sigmoid'
"""
def __init__(
self,
v_threshold: float = 1.0,
v_reset: float = 0.,
tau: float = 1.0,
alpha: float = 1.0,
gamma: float = 0.,
thresh_decay: float = 1.0,
surrogate: str = 'sigmoid',
):
super().__init__()
self.v_threshold = v_threshold
self.v_reset = v_reset
self.gamma = gamma
self.thresh_decay = thresh_decay
self.surrogate = get_surrogate(surrogate)
self.register_buffer("tau", torch.as_tensor(tau, dtype=torch.float))
self.register_buffer("alpha", torch.as_tensor(alpha,
dtype=torch.float))
self.reset()
[docs] def reset(self):
"""Reset neuron states."""
self.v = 0.
self.v_th = self.v_threshold
[docs] def forward(self, dv: Tensor) -> Tensor:
""""""
# 1. charge
self.v = self.v + (dv - (self.v - self.v_reset)) / self.tau
# 2. fire
spike = self.surrogate(self.v, self.v_th, self.alpha)
# 3. reset
self.v = (1 - spike) * self.v + spike * self.v_reset
# 4. threhold updates
self.v_th = self.gamma * spike + self.v_th * self.thresh_decay
return spike
[docs]class PLIF(nn.Module):
r"""The Parametric Leaky Integrate-and-Fire (PLIF) neuron for
spiking neural networks. It differs from :class:`LIF` with a
trainable :math:`\tau`.
Parameters
----------
v_threshold : float, optional
the threshold for emitting a spike, by default 1.0
v_reset : float, optional
the reset level for neuron, by default 0.
tau : float, optional
the leaky factor :math:`\tau` for LIF-based neuron, by default 1.0
alpha : float, optional
the smooth factor for surrogate function, by default 1.0
gamma : float, optional
the threshold decay factor :math:`\gamma`, by default 0.
thresh_decay : float, optional
the threshold decay factor, by default 1.0
surrogate : str, optional
the surrogate function for training spiking neurons,
could one of (:obj:'sigmoid', :obj:'triangle', :obj:'arctan'
:obj:'mg', and :obj:'super'), by default 'sigmoid'
"""
def __init__(
self,
v_threshold: float = 1.0,
v_reset: float = 0.,
tau: float = 1.0,
alpha: float = 1.0,
gamma: float = 0.,
thresh_decay: float = 1.0,
surrogate: str = 'sigmoid',
):
super().__init__()
self.v_threshold = v_threshold
self.v_reset = v_reset
self.gamma = gamma
self.thresh_decay = thresh_decay
self.surrogate = get_surrogate(surrogate)
self.register_parameter(
"tau", nn.Parameter(torch.as_tensor(tau, dtype=torch.float)))
self.register_buffer("alpha", torch.as_tensor(alpha,
dtype=torch.float))
self.reset()
[docs] def reset(self):
"""Reset neuron states."""
self.v = 0.
self.v_th = self.v_threshold
[docs] def forward(self, dv: Tensor) -> Tensor:
""""""
# 1. charge
self.v = self.v + (dv - (self.v - self.v_reset)) / self.tau
# 2. fire
spike = self.surrogate(self.v, self.v_th, self.alpha)
# 3. reset
self.v = (1 - spike) * self.v + spike * self.v_reset
# 4. threhold updates
self.v_th = self.gamma * spike + self.v_th * self.thresh_decay
return spike