-
Notifications
You must be signed in to change notification settings - Fork 341
Description
Hi BindsNET team and users, I've been using ann_to_snn() for a while. I notice that for MNSIT data, only converted MLP-SNN (input_shape=(784,)) allows batch_size>1. For 2D CNN-SNN (input_shape=(1, 28, 28)), batch_size is only allowed to be 1.
Here is my 2d CNN structure:
class MNISTCNN(nn.Module):
def __init__(self):
super(MNISTCNN, self).__init__()
self.input_shape = (1, 28, 28)
self.conv1 = nn.Conv2d(1, 32, 5, stride=1, padding=2, bias=True)
self.maxpool1 = nn.MaxPool2d((2, 2), stride=(2, 2), padding=0)
self.conv2 = nn.Conv2d(32, 64, 5, stride=1, padding=2, bias=True)
self.maxpool2 = nn.MaxPool2d((2, 2), stride=(2, 2), padding=0)
self.fc1 = nn.Linear(7 * 7 * 64, 1024, bias=True)
self.fc2 = nn.Linear(1024, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.maxpool1(x)
x = F.relu(self.conv2(x))
x = self.maxpool2(x)
x = x.reshape([x.shape[0], -1])
x = self.fc1(x)
output = self.fc2(x)
return output
And here is my converted SNN:
class MNISTSNN(nn.Module):
def __init__(self,
ann: MNISTMLP,
n_step: int,
density: float,
noise: float,
decode: bool = True,
device: torch.device = DEVICE):
super(MNISTSNN, self).__init__()
self.encoder = PoissonEncoder(n_step, density)
self.noise = SpikeNoise(noise)
loader = get_loader('mnist', subset='train', proportion=0.2, num_workers=0)
data = torch.cat([i[0].to(device) for i in loader])
self.snn = ann_to_snn(ann, ann.input_shape, data).to(device)
layer_output = list(self.snn.layers.keys())[-1]
self.monitor_output = Monitor(self.snn.layers[layer_output], ['s'],
n_step, device=device)
self.snn.add_monitor(self.monitor_output, 'output')
self.decode = decode
self.device = device
def forward(self, x):
x = self.encoder(x)
x = self.noise(x)
self.snn.reset_state_variables()
self.snn = self.snn.to(self.device)
self.snn.run({'Input': x}, self.encoder.n_step)
# [time, batch_size, n_choice]
y = self.snn.monitors['output'].get('s')
if self.decode:
y = torch.sum(y, axis=0)
else:
y = y.to(torch.bool)
return y
I found that it's the bindsnet.network.topology.MaxPool2dConnection causing the trouble. When i try to run through a shape=[100, 32, 1, 28, 28] spike train, in which length=100 and batch_size=32, error raised:
File "/home/chenxiyuan/projects/project/model.py", line 167, in forward
self.snn.run({'Input': x}, self.encoder.n_step)
File "/home/chenxiyuan/.local/lib/python3.9/site-packages/bindsnet/network/network.py", line 360, in run
current_inputs.update(self._get_inputs())
File "/home/chenxiyuan/.local/lib/python3.9/site-packages/bindsnet/network/network.py", line 245, in _get_inputs
inputs[c[1]] += self.connections[c].compute(source.s)
File "/home/chenxiyuan/.local/lib/python3.9/site-packages/bindsnet/network/topology.py", line 463, in compute
self.firing_rates += s.float().squeeze()
RuntimeError: output with shape [1, 32, 28, 28] doesn't match the broadcast shape [32, 32, 28, 28]
It turns out that when I run self.snn.reset_state_variables(), bindsnet.network.topology.MaxPool2dConnection.reset_state_variables() always makes self.firing_rates.shape=[1, 32, 28, 28] instead of [32, 32, 28, 28], matching with batch_size=32, inside the connection.
So here is a little work-around to allow reseting the firing_rates Tensor with given batch_size:
For bindsnet.network.topology.MaxPool2dConnection.reset_state_variables(): here
def reset_state_variables(self, batch_size=1) -> None:
# language=rst
"""
Contains resetting logic for the connection.
"""
super().reset_state_variables()
shape = [batch_size] + list(self.source.s.shape)[1:]
self.firing_rates = torch.zeros(shape)
For bindsnet.network.network.Network.reset_state_variables(): here
def reset_state_variables(self, batch_size=1) -> None:
# language=rst
"""
Reset state variables of objects in network.
"""
for layer in self.layers:
self.layers[layer].reset_state_variables()
for connection in self.connections:
if isinstance(self.connections[connection], MaxPool2dConnection):
self.connections[connection].reset_state_variables(batch_size)
else:
self.connections[connection].reset_state_variables()
for monitor in self.monitors:
self.monitors[monitor].reset_state_variables()
Finally the converted SNN looks like this, assigning the batch_size by self.snn.reset_state_variables(x.shape[1]) :
class MNISTSNN(nn.Module):
def __init__(self,
ann: MNISTMLP,
n_step: int,
density: float,
noise: float,
decode: bool = True,
device: torch.device = DEVICE):
super(MNISTSNN, self).__init__()
self.encoder = PoissonEncoder(n_step, density)
self.noise = SpikeNoise(noise)
loader = get_loader('mnist', subset='train', proportion=0.2, num_workers=0)
data = torch.cat([i[0].to(device) for i in loader])
self.snn = ann_to_snn(ann, ann.input_shape, data).to(device)
layer_output = list(self.snn.layers.keys())[-1]
self.monitor_output = Monitor(self.snn.layers[layer_output], ['s'],
n_step, device=device)
self.snn.add_monitor(self.monitor_output, 'output')
self.decode = decode
self.device = device
def forward(self, x):
x = self.encoder(x)
x = self.noise(x)
self.snn.reset_state_variables(x.shape[1])
self.snn = self.snn.to(self.device)
self.snn.run({'Input': x}, self.encoder.n_step)
y = self.snn.monitors['output'].get('s')
if self.decode:
y = torch.sum(y, axis=0)
else:
y = y.to(torch.bool)
return y
I only tested this work-around in this model on bindsnet==0.3.0., and there is a 10x speed boosting when batch_size=32. Hope this post can help the bindsnet team to add this feature in the future version and anyone who wants batch_size>1 with ann_to_snn() 2D CNNs.