Skip to content

work-around to allow batch_size>1 for ann_to_snn() 2D CNN #537

@XiYuan68

Description

@XiYuan68

Hi BindsNET team and users, I've been using ann_to_snn() for a while. I notice that for MNSIT data, only converted MLP-SNN (input_shape=(784,)) allows batch_size>1. For 2D CNN-SNN (input_shape=(1, 28, 28)), batch_size is only allowed to be 1.

Here is my 2d CNN structure:

class MNISTCNN(nn.Module):
    def __init__(self):

        super(MNISTCNN, self).__init__()
        self.input_shape = (1, 28, 28)
        self.conv1 = nn.Conv2d(1, 32, 5, stride=1, padding=2, bias=True)
        self.maxpool1 = nn.MaxPool2d((2, 2), stride=(2, 2), padding=0)
        self.conv2 = nn.Conv2d(32, 64, 5, stride=1, padding=2, bias=True)
        self.maxpool2 = nn.MaxPool2d((2, 2), stride=(2, 2), padding=0)
        self.fc1 = nn.Linear(7 * 7 * 64, 1024, bias=True)  
        self.fc2 = nn.Linear(1024, 10)  
        
    def forward(self, x):  
        x = F.relu(self.conv1(x))
        x = self.maxpool1(x)
        x = F.relu(self.conv2(x))
        x = self.maxpool2(x)
        x = x.reshape([x.shape[0], -1])
        x = self.fc1(x)
        output = self.fc2(x)
        return output

And here is my converted SNN:

class MNISTSNN(nn.Module):
    def __init__(self,
                 ann: MNISTMLP,
                 n_step: int,
                 density: float, 
                 noise: float,
                 decode: bool = True, 
                 device: torch.device = DEVICE):
        super(MNISTSNN, self).__init__()
        self.encoder = PoissonEncoder(n_step, density)
        self.noise = SpikeNoise(noise)
        loader = get_loader('mnist', subset='train', proportion=0.2, num_workers=0)
        data = torch.cat([i[0].to(device) for i in loader])
        self.snn = ann_to_snn(ann, ann.input_shape, data).to(device)
        layer_output = list(self.snn.layers.keys())[-1]
        self.monitor_output = Monitor(self.snn.layers[layer_output], ['s'], 
                                      n_step, device=device)
        self.snn.add_monitor(self.monitor_output, 'output')
        self.decode = decode
        self.device = device

    def forward(self, x):

        x = self.encoder(x)
        x = self.noise(x)

        self.snn.reset_state_variables()
        self.snn = self.snn.to(self.device)

        self.snn.run({'Input': x}, self.encoder.n_step)
        # [time, batch_size, n_choice]
        y = self.snn.monitors['output'].get('s')
        if self.decode:
            y = torch.sum(y, axis=0)
        else:
            y = y.to(torch.bool)
        
        
        return y

I found that it's the bindsnet.network.topology.MaxPool2dConnection causing the trouble. When i try to run through a shape=[100, 32, 1, 28, 28] spike train, in which length=100 and batch_size=32, error raised:

  File "/home/chenxiyuan/projects/project/model.py", line 167, in forward
    self.snn.run({'Input': x}, self.encoder.n_step)

  File "/home/chenxiyuan/.local/lib/python3.9/site-packages/bindsnet/network/network.py", line 360, in run
    current_inputs.update(self._get_inputs())

  File "/home/chenxiyuan/.local/lib/python3.9/site-packages/bindsnet/network/network.py", line 245, in _get_inputs
    inputs[c[1]] += self.connections[c].compute(source.s)

  File "/home/chenxiyuan/.local/lib/python3.9/site-packages/bindsnet/network/topology.py", line 463, in compute
    self.firing_rates += s.float().squeeze()

RuntimeError: output with shape [1, 32, 28, 28] doesn't match the broadcast shape [32, 32, 28, 28]

It turns out that when I run self.snn.reset_state_variables(), bindsnet.network.topology.MaxPool2dConnection.reset_state_variables() always makes self.firing_rates.shape=[1, 32, 28, 28] instead of [32, 32, 28, 28], matching with batch_size=32, inside the connection.

So here is a little work-around to allow reseting the firing_rates Tensor with given batch_size:

For bindsnet.network.topology.MaxPool2dConnection.reset_state_variables(): here

    def reset_state_variables(self, batch_size=1) -> None:
        # language=rst
        """
        Contains resetting logic for the connection.
        """
        super().reset_state_variables()

        shape = [batch_size] + list(self.source.s.shape)[1:]
        self.firing_rates = torch.zeros(shape)

For bindsnet.network.network.Network.reset_state_variables(): here

    def reset_state_variables(self, batch_size=1) -> None:
        # language=rst
        """
        Reset state variables of objects in network.
        """
        for layer in self.layers:
            self.layers[layer].reset_state_variables()

        for connection in self.connections:
            if isinstance(self.connections[connection], MaxPool2dConnection):
                self.connections[connection].reset_state_variables(batch_size)
            else:
                self.connections[connection].reset_state_variables()

        for monitor in self.monitors:
            self.monitors[monitor].reset_state_variables()

Finally the converted SNN looks like this, assigning the batch_size by self.snn.reset_state_variables(x.shape[1]) :

class MNISTSNN(nn.Module):
    def __init__(self,
                 ann: MNISTMLP,
                 n_step: int,
                 density: float, 
                 noise: float,
                 decode: bool = True, 
                 device: torch.device = DEVICE):
        super(MNISTSNN, self).__init__()
        self.encoder = PoissonEncoder(n_step, density)
        self.noise = SpikeNoise(noise)
        loader = get_loader('mnist', subset='train', proportion=0.2, num_workers=0)
        data = torch.cat([i[0].to(device) for i in loader])
        self.snn = ann_to_snn(ann, ann.input_shape, data).to(device)
        layer_output = list(self.snn.layers.keys())[-1]
        self.monitor_output = Monitor(self.snn.layers[layer_output], ['s'], 
                                      n_step, device=device)
        self.snn.add_monitor(self.monitor_output, 'output')
        self.decode = decode
        self.device = device

    def forward(self, x):

        x = self.encoder(x)
        x = self.noise(x)

        self.snn.reset_state_variables(x.shape[1])
        self.snn = self.snn.to(self.device)

        self.snn.run({'Input': x}, self.encoder.n_step)
        y = self.snn.monitors['output'].get('s')
        if self.decode:
            y = torch.sum(y, axis=0)
        else:
            y = y.to(torch.bool)
        
        
        return y

I only tested this work-around in this model on bindsnet==0.3.0., and there is a 10x speed boosting when batch_size=32. Hope this post can help the bindsnet team to add this feature in the future version and anyone who wants batch_size>1 with ann_to_snn() 2D CNNs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions