import torch
import torch.nn as nn
from torch.nn import init
from DynapSEtorch.surrogate import fast_sigmoid, triangular, step
from collections import namedtuple
amp = 1
mA = 1e-3
uA = 1e-6
nA = 1e-9
pA = 1e-12
volt = 1
mV = 1e-3
uV = 1e-6
nV = 1e-9
farad = 1
mF = 1e-3
uF = 1e-6
nF = 1e-9
pF = 1e-12
second = 1
ms = 1e-3
us = 1e-6
ns = 1e-9
ps = 1e-12
kappa_n = 0.75 # Subthreshold slope factor (n-type transistor)
kappa_p = 0.66 # Subthreshold slope factor (p-type transistor)
Ut = 25.0 * mV # Thermal voltage
I0 = 1 * pA # Dark current
[docs]class ADM(nn.Module):
"""Adaptive Delta Modulation (ADM) module
Converts an analog signal into UP and DOWN spikes using the Adaptive Delta Modulation scheme.
"""
def __init__(
self, N: int, threshold_up: float, threshold_down: float, refractory: int
):
super(ADM, self).__init__()
self.refractory = nn.Parameter(
torch.tensor(refractory).float(), requires_grad=True
)
self.threshold_up = nn.Parameter(torch.tensor(threshold_up), requires_grad=True)
self.threshold_down = nn.Parameter(
torch.tensor(threshold_down), requires_grad=True
)
self.N = N
self.reset()
def reset(self):
self.refrac = None
self.DC_Voltage = None
[docs] def reconstruct(self, spikes, initial_value=0):
"""Reconstruct an analog signal based on the UP and DOWN spikes produced by the ADM module.
Everytime the algorithm receives an UP/DOWN spike, the reconstructed signal is increment/decrement by the UP/DOWN threshold amount.
"""
reconstructed = torch.zeros(
spikes.shape[0], spikes.shape[1], spikes.shape[2] // 2
)
reconstructed[:, 0, :] = initial_value
for t in range(1, spikes.shape[1]):
spikes_p = spikes[:, t, : -spikes.shape[-1] // 2]
spikes_n = spikes[:, t, spikes.shape[-1] // 2 :]
reconstructed[:, t] = (
reconstructed[:, t - 1]
+ self.threshold_up * spikes_p
- self.threshold_down * spikes_n
)
return reconstructed
[docs] def forward(self, input_signal):
if self.DC_Voltage is None:
output = torch.zeros(
input_signal.shape[0], self.N * 2, device=input_signal.device
)
output_p = torch.zeros_like(input_signal)
output_n = torch.zeros_like(input_signal)
self.refrac = torch.zeros_like(input_signal)
self.DC_Voltage = input_signal
else:
self.refrac[self.refrac > 0] -= 1
output_p = fast_sigmoid(
input_signal - (self.DC_Voltage + self.threshold_up)
) * (self.refrac == 0)
self.refrac[output_p.bool()] = self.refractory
self.DC_Voltage += output_p * self.threshold_up
output_n = fast_sigmoid(
(self.DC_Voltage - self.threshold_down) - input_signal
) * (self.refrac == 0)
self.refrac[output_n.bool()] = self.refractory
self.DC_Voltage -= output_n * self.threshold_down
output = torch.cat([output_p.float(), output_n.float()], dim=1)
return output, output_p, output_n
[docs]class LIF(nn.Module):
LIFState = namedtuple("LIFState", ["V", "S"])
def __init__(self, n_in, n_out, thr, tau, dt):
super(LIF, self).__init__()
self.dt = dt
self.n_in = n_in
self.n_out = n_out
self.base_layer = nn.Linear(n_in, n_out, bias=False)
distribution = torch.distributions.gamma.Gamma(3, 3 / tau)
tau = distribution.rsample((1, n_out)).clamp(3, 100)
self.register_buffer("alpha", torch.exp(dt / tau).float())
self.register_buffer("thr", torch.tensor(thr).float())
self.reset()
def reset(self):
self.state = None
def init_state(self, input):
self.state = self.LIFState(
V=torch.zeros(input.shape[0], self.n_out, device=input.device),
S=torch.zeros(input.shape[0], self.n_out, device=input.device),
)
return self.state
[docs] def forward(self, input):
if self.state is None:
self.init_state(input)
V = self.state.V
S = self.state.S
V = (self.alpha * V + (1 - self.alpha) * self.base_layer(input)) * (
1 - S.detach()
)
S = fast_sigmoid(V - self.thr)
self.state = self.LIFState(V=V, S=S)
return S
[docs]class AdexLIF(nn.Module):
AdexLIFState = namedtuple(
"AdexLIFState",
[
"Isoma_mem",
"Isoma_ahp",
"refractory",
"Inmda",
"Iampa",
"Igaba_a",
"Igaba_b",
],
)
def __init__(self, num_neurons=1, input_per_synapse=[1, 1, 1, 1]):
super(AdexLIF, self).__init__()
self.num_neurons = num_neurons
# SUBSTRATE #########################################################################################
self.register_buffer(
"kn", torch.tensor(kappa_n)
) # Subthreshold slope factor for nFETs
self.register_buffer(
"kp", torch.tensor(kappa_p)
) # Subthreshold slope factor for pFETs
self.register_buffer("Ut", torch.tensor(Ut)) # Thermal voltage
self.register_buffer("I0", torch.tensor(I0)) # Dark current
# SCALING FACTORS #########################################################################################
self.register_buffer(
"alpha_soma", torch.tensor(4)
) # Scaling factor equal to Ig/Itau
self.register_buffer(
"alpha_ahp", torch.tensor(4)
) # Scaling factor equal to Ig/Itau
self.register_buffer(
"alpha_nmda", torch.tensor(4)
) # Scaling factor equal to Ig/Itau
self.register_buffer(
"alpha_ampa", torch.tensor(4)
) # Scaling factor equal to Ig/Itau
self.register_buffer(
"alpha_gaba_a", torch.tensor(4)
) # Scaling factor equal to Ig/Itau
self.register_buffer(
"alpha_gaba_b", torch.tensor(4)
) # Scaling factor equal to Ig/Itau
# Neuron parameters ###############
# SOMA ##############################################################################################
self.Isoma_mem_init = 1.1 * I0
self.register_buffer("Csoma_mem", torch.tensor(2 * pF)) # Membrane capacitance
self.register_buffer("Isoma_dpi_tau", torch.tensor(5 * I0)) # Leakage current
# self.register_buffer('Isoma_th', torch.tensor(2000 * I0)) # Spiking threshold
self.Isoma_th = torch.nn.Parameter(
torch.ones(1) * 2000 * I0, requires_grad=False
)
self.register_buffer("Isoma_reset", torch.tensor(1.2 * I0)) # Reset current
self.register_buffer(
"Isoma_const", torch.tensor(I0)
) # Additional input current similar to constant current injection
self.register_buffer("soma_refP", torch.tensor(5 * ms)) # Refractory period
# ADAPTATION ########################################################################################
self.register_buffer(
"Csoma_ahp", torch.tensor(4 * pF)
) # Spike-frequency adaptation capacitance
self.register_buffer(
"Isoma_ahp_tau", torch.tensor(2 * I0)
) # Leakage current for spike-frequency adaptation
self.register_buffer("Isoma_ahp_g", torch.tensor(0)) # AHP gain current
self.register_buffer(
"Isoma_ahp_w", torch.tensor(1 * I0)
) # AHP jump height, on post
# POSITIVE FEEDBACK ##################################################################################
self.register_buffer(
"Isoma_pfb_gain", torch.tensor(100 * I0)
) # Positive feedback gain
self.register_buffer(
"Isoma_pfb_th", torch.tensor(1000 * I0)
) # Positive feedback activation threshold
self.register_buffer(
"Isoma_pfb_norm", torch.tensor(20 * I0)
) # Positive feedback normalization current
# Synapse parameters ################
# # SLOW_EXC, NMDA ########################################################################################
self.Inmda_init = I0
self.register_buffer("Cnmda", torch.tensor(2 * pF)) # Synapse's capacitance
self.register_buffer(
"Inmda_tau", torch.tensor(2 * I0)
) # Leakage current, i.e. how much current is constantly leaked away (time-constant)
self.register_buffer(
"Inmda_w0", self.mismatch(100) # torch.tensor(100 * I0)
) # Base synaptic weight, to convert unitless weight (set in synapse) to current
self.register_buffer(
"Inmda_thr", torch.tensor(I0)
) # NMDA voltage-gating threshold
# FAST_EXC, AMPA ########################################################################################
self.Iampa_init = I0 # Output current initial value
self.register_buffer(
"Campa", torch.tensor(2 * pF)
) # Synaptic capacitance, fixed at layout time (see chip for details)
self.register_buffer(
"Iampa_tau", torch.tensor(20 * I0)
) # Synaptic time constant current, the time constant is inversely proportional to I_tau
self.register_buffer(
"Iampa_w0", self.mismatch(100) # torch.tensor(100 * I0)
) # Base synaptic weight current which can be scaled by the .weight parameter
# #INH, SLOW_INH, GABA_B, subtractive ##################################################################
self.Igaba_b_init = I0 # Output current initial value
self.register_buffer(
"Cgaba_b", torch.tensor(2 * pF)
) # Synaptic capacitance, fixed at layout time (see chip for details)
self.register_buffer(
"Igaba_b_tau", torch.tensor(5 * I0)
) # Synaptic time constant current, the time constant is inversely proportional to I_tau
self.register_buffer(
"Igaba_b_w0", self.mismatch(100) # torch.tensor(100 * I0)
) # Base synaptic weight current which can be scaled by the .weight parameter
# #FAST_INH, GABA_A, shunting, a mixture of subtractive and divisive ############################################
self.Igaba_a_init = I0 # Output current initial value
self.register_buffer(
"Cgaba_a", torch.tensor(2 * pF)
) # Synaptic capacitance, fixed at layout time (see chip for details)
self.register_buffer(
"Igaba_a_tau", torch.tensor(5 * I0)
) # Synaptic time constant current, the time constant is inversely proportional to I_tau
self.register_buffer(
"Igaba_a_w0", self.mismatch(100) # torch.tensor(100 * I0)
) # Base synaptic weight current which can be scaled by the .weight parameter
# ##################
self.weight_nmda = torch.nn.Parameter(
torch.ones(input_per_synapse[0], num_neurons), requires_grad=True
)
self.weight_ampa = torch.nn.Parameter(
torch.ones(input_per_synapse[1], num_neurons), requires_grad=True
)
self.weight_gaba_a = torch.nn.Parameter(
torch.ones(input_per_synapse[2], num_neurons), requires_grad=True
)
self.weight_gaba_b = torch.nn.Parameter(
torch.ones(input_per_synapse[3], num_neurons), requires_grad=True
)
self.dt = 1 * ms
self.reset()
def mismatch(self, initial):
return (
torch.maximum(
torch.tensor(I0),
initial + initial * 0.1 * torch.rand(1, self.num_neurons),
)
* I0
)
def reset(self):
self.state = None
def init_state(self, input):
## Soma states
Isoma_mem = torch.empty(input.shape[0], self.num_neurons, device=input.device)
init.constant_(Isoma_mem, self.Isoma_mem_init)
## Synapses states
Inmda = torch.empty(input.shape[0], self.num_neurons, device=input.device)
init.constant_(Inmda, self.Inmda_init)
Iampa = torch.empty(input.shape[0], self.num_neurons, device=input.device)
init.constant_(Iampa, self.Iampa_init)
Igaba_a = torch.empty(input.shape[0], self.num_neurons, device=input.device)
init.constant_(Igaba_a, self.Igaba_a_init)
Igaba_b = torch.empty(input.shape[0], self.num_neurons, device=input.device)
init.constant_(Igaba_b, self.Igaba_b_init)
self.state = self.AdexLIFState(
Isoma_mem=Isoma_mem,
Isoma_ahp=torch.zeros(
input.shape[0], self.num_neurons, device=input.device
),
refractory=torch.zeros(
input.shape[0], self.num_neurons, device=input.device
),
Inmda=Inmda,
Iampa=Iampa,
Igaba_a=Igaba_a,
Igaba_b=Igaba_b,
)
return self.state
def detach(self):
for state in self.state:
state._detach()
[docs] def forward(
self, input_nmda=None, input_ampa=None, input_gaba_a=None, input_gaba_b=None
):
##### GET STATES VALUES #####
## Soma states
Isoma_mem = self.state.Isoma_mem
Isoma_ahp = self.state.Isoma_ahp
refractory = self.state.refractory
## Synapses states
Inmda = self.state.Inmda
Iampa = self.state.Iampa
Igaba_a = self.state.Igaba_a
Igaba_b = self.state.Igaba_b
Isoma_mem_clip = torch.clip(Isoma_mem.clone(), self.I0, 1)
kappa = (self.kn + self.kp) / 2
## Input calculation
Inmda_dp = Inmda.clone() / (
1 + self.Inmda_thr / Isoma_mem_clip
) # Voltage gating differential pair block
Iin_clip = torch.clip(Inmda_dp + Iampa - Igaba_b + self.Isoma_const, self.I0, 1)
##### SOMA CALCULATION #####
## Isoma_sum components calculation
low_current_mem = self.I0 * (Isoma_mem.detach() <= self.I0)
Isoma_pfb = self.Isoma_pfb_gain / (
1 + torch.exp(-(Isoma_mem - self.Isoma_pfb_th) / self.Isoma_pfb_norm)
)
Isoma_pfb_shunt = Isoma_pfb * (Isoma_mem.detach() > self.I0) + low_current_mem
Isoma_ahp_shunt = (
Isoma_ahp.clone() * (Isoma_mem.detach() > self.I0) + low_current_mem
)
Igaba_a_shunt = (
Igaba_a.clone() * (Isoma_mem.detach() > self.I0) + low_current_mem
)
Isoma_dpi_tau_shunt = (
self.Isoma_dpi_tau * (Isoma_mem.detach() > self.I0) + low_current_mem
)
Isoma_dpi_g_shunt = (
self.alpha_soma * Isoma_dpi_tau_shunt * (Isoma_mem.detach() > self.I0)
+ low_current_mem
)
# Isoma_sum = Isoma_dpi_tau_shunt.detach() + Isoma_ahp_shunt.detach() - Isoma_pfb_shunt.detach() - low_current_mem
Isoma_sum = (
Isoma_dpi_tau_shunt.detach()
+ Isoma_ahp_shunt.detach()
+ Igaba_a_shunt.detach()
- Isoma_pfb_shunt.detach()
- low_current_mem
)
## Adaptation current
low_current_ahp = self.I0 * (Isoma_ahp.detach() <= self.I0)
Isoma_ahp_tau_shunt = (
self.Isoma_ahp_tau * (Isoma_ahp.detach() > self.I0) + low_current_ahp
)
Isoma_ahp_g_shunt = (
self.alpha_ahp * Isoma_ahp_tau_shunt * (Isoma_ahp.detach() > self.I0)
+ low_current_ahp
)
tau_soma_ahp = (self.Csoma_ahp * self.Ut) / (kappa * Isoma_ahp_tau_shunt)
dIsoma_ahp = (-Isoma_ahp_g_shunt - Isoma_ahp + 2 * low_current_ahp) / (
tau_soma_ahp * (1 + (Isoma_ahp_g_shunt / Isoma_ahp_shunt))
) # Adaptation current
## Isoma calculation
tau_soma = (self.Csoma_mem * self.Ut) / (kappa * Isoma_dpi_tau_shunt)
dIsoma_mem = (
self.alpha_soma * (Iin_clip - Isoma_sum)
- (Isoma_sum - low_current_mem)
* Isoma_mem_clip.detach()
/ Isoma_dpi_tau_shunt.detach()
) / (
tau_soma.detach()
* (1 + (Isoma_dpi_g_shunt.detach() / Isoma_mem_clip.detach()))
)
##### NMDA #####
low_current_nmda = self.I0 * (Inmda.detach() <= self.I0)
Inmda_g = self.alpha_nmda * self.Inmda_tau
Inmda_g_shunt = Inmda_g * (Inmda.detach() > self.I0) + low_current_nmda
Inmda_tau_shunt = self.Inmda_tau * (Inmda.detach() > self.I0) + low_current_nmda
tau_nmda = self.Cnmda * self.Ut / (kappa * Inmda_tau_shunt)
dInmda = (-Inmda - Inmda_g_shunt + 2 * low_current_nmda) / (
tau_nmda * ((Inmda_g_shunt / Inmda) + 1)
)
if input_nmda is not None:
Inmda = Inmda + self.Inmda_w0 * self.alpha_nmda * (
input_nmda @ self.weight_nmda
)
#### AMPA ####
low_current_ampa = self.I0 * (Iampa.detach() <= self.I0)
Iampa_g = self.alpha_ampa * self.Iampa_tau
Iampa_g_shunt = Iampa_g * (Iampa.detach() > self.I0) + low_current_ampa
Iampa_tau_shunt = self.Iampa_tau * (Iampa.detach() > self.I0) + low_current_ampa
tau_ampa = self.Campa * self.Ut / (kappa * Iampa_tau_shunt)
dIampa = (-Iampa - Iampa_g_shunt + 2 * low_current_ampa) / (
tau_ampa * ((Iampa_g_shunt / Iampa) + 1)
)
if input_ampa is not None:
Iampa = Iampa + self.Iampa_w0 * self.alpha_ampa * (
input_ampa @ self.weight_ampa
)
#### GABA B - inh ####
low_current_gaba_b = self.I0 * (Igaba_b.detach() <= self.I0)
Igaba_b_g = (
self.alpha_gaba_b * self.Igaba_b_tau
) # GABA B synapse gain expressed in terms of its tau current
Igaba_b_g_shunt = (
Igaba_b_g * (Igaba_b.detach() > self.I0) + low_current_gaba_b
) # Shunt g current if Igaba_b goes to I0
Igaba_b_tau_shunt = (
self.Igaba_b_tau * (Igaba_b.detach() > I0) + low_current_gaba_b
) # Shunt tau current if Iampa goes to I0
tau_gaba_b = (
self.Cgaba_b * self.Ut / (kappa * Igaba_b_tau_shunt)
) # Synaptic time-constant
dIgaba_b = (-Igaba_b - Igaba_b_g_shunt + 2 * low_current_gaba_b) / (
tau_gaba_b * ((Igaba_b_g_shunt / Igaba_b) + 1)
)
if input_gaba_b is not None:
Igaba_b = Igaba_b + self.Igaba_b_w0 * self.alpha_gaba_b * (
input_gaba_b @ self.weight_gaba_b
)
#### # GABA A - shunt ####
low_current_gaba_a = self.I0 * (Igaba_a.detach() <= self.I0)
Igaba_a_g = (
self.alpha_gaba_a * self.Igaba_a_tau
) # GABA A synapse gain expressed in terms of its tau current
Igaba_a_g_shunt = (
Igaba_a_g * (Igaba_a.detach() > self.I0) + low_current_gaba_a
) # Shunt g current if Igaba_a goes to I0
Igaba_a_tau_shunt = (
self.Igaba_a_tau * (Igaba_a.detach() > I0) + low_current_gaba_a
) # Shunt tau current if Iampa goes to I0
tau_gaba_a = (
self.Cgaba_a * self.Ut / (kappa * Igaba_a_tau_shunt)
) # Synaptic time-constant
dIgaba_a = (-Igaba_a - Igaba_a_g_shunt + 2 * low_current_gaba_a) / (
tau_gaba_a * ((Igaba_a_g_shunt / Igaba_a) + 1)
)
if input_gaba_a is not None:
Igaba_a = Igaba_a + self.Igaba_a_w0 * self.alpha_gaba_a * (
input_gaba_a @ self.weight_gaba_a
)
## Gradient updates
refractory = refractory - (refractory > 0).float()
Isoma_mem += self.dt * dIsoma_mem * (refractory <= 0)
Isoma_ahp += self.dt * dIsoma_ahp
Inmda += self.dt * dInmda
Iampa += self.dt * dIampa
Igaba_a += self.dt * dIgaba_a
Igaba_b += self.dt * dIgaba_b
## Fire
firing = fast_sigmoid(Isoma_mem - self.Isoma_th)
refractory = refractory + (firing * self.soma_refP / self.dt).long()
Isoma_ahp = Isoma_ahp + (self.Isoma_ahp_w * self.alpha_ahp) * firing
Isoma_mem = self.Isoma_reset * firing + Isoma_mem * (1 - firing)
## Save states
self.state = self.AdexLIFState(
Isoma_mem=Isoma_mem,
Isoma_ahp=Isoma_ahp,
refractory=refractory,
Inmda=Inmda,
Iampa=Iampa,
Igaba_a=Igaba_a,
Igaba_b=Igaba_b,
)
return firing