Source code for DynapSEtorch.networks

import torch
import torch.nn as nn
from DynapSEtorch.model import AdexLIF, ADM


[docs]class DelayChain(nn.Module): """Delay chain network""" def __init__(self, n_channels, n_pool, n_out): super(DelayChain, self).__init__() self.n_pool = n_pool self.n_channels = n_channels self.adm_encoder = ADM(n_channels, 1.0, 1.0, 0) self.pool_layer = nn.ModuleList() for _ in range(n_pool): pool = AdexLIF(n_channels * 2, [0, n_channels * 2, 0, 0]) pool.weight_ampa.data *= torch.eye(n_channels * 2) self.pool_layer.append(pool) self.readout = AdexLIF(n_out, [0, n_pool * n_channels * 2, 0, 0]) def reset(self): self.adm_encoder.reset() for layer in self.pool_layer: layer.reset() self.readout.reset()
[docs] def forward(self, input): in_spikes, _, _ = self.adm_encoder(input) out_spikes = [] for pool in self.pool_layer: if pool.state is None: pool.state = pool.init_state(in_spikes) s_o = pool(input_ampa=in_spikes) out_spikes.append(s_o) in_spikes = s_o pool_spikes = torch.stack(out_spikes, dim=1) if self.readout.state is None: self.readout.state = self.readout.init_state(pool_spikes) ro_spikes = self.readout( input_ampa=pool_spikes.view(-1, self.n_pool * self.n_channels * 2) ) return ro_spikes, pool_spikes
[docs]class EIBalancedNetwork(nn.Module): """EI-balanced network""" def __init__(self, n_in, n_class, ex_per_class=1): super(EIBalancedNetwork, self).__init__() self.n_in = n_in self.n_class = n_class self.ex_per_class = ex_per_class self.ex_layer = AdexLIF(ex_per_class * n_class, [n_in, 0, 0, n_class]) self.in_layer = AdexLIF(n_class, [ex_per_class * n_class, 0, 0, n_class]) self.out_in = None def reset(self): self.ex_layer.reset() self.in_layer.reset() self.out_in = None
[docs] def forward(self, input): if self.out_in is None: self.out_in = torch.zeros(input.shape[0], self.n_class, device=input.device) if self.ex_layer.state is None: self.ex_layer.state = self.ex_layer.init_state(input) out_ex = self.ex_layer(input_nmda=input, input_gaba_b=self.out_in) if self.out_in.state is None: self.out_in.state = self.out_in.init_state(input) self.out_in = self.in_layer(input_nmda=out_ex, input_gaba_b=self.out_in) return out_ex