Skip to content

Commit c713cdd

Browse files
Merge pull request #384 from SimonInParis/optim
faster LIF forward
2 parents 4006a18 + 3af105b commit c713cdd

File tree

2 files changed

+22
-30
lines changed

2 files changed

+22
-30
lines changed

bindsnet/learning/learning.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
)
1414
from ..utils import im2col_indices
1515

16-
1716
class LearningRule(ABC):
1817
# language=rst
1918
"""
@@ -56,9 +55,12 @@ def __init__(
5655

5756
# Parameter update reduction across minibatch dimension.
5857
if reduction is None:
59-
reduction = torch.mean
60-
61-
self.reduction = reduction
58+
if self.source.batch_size == 1:
59+
self.reduction = torch.squeeze
60+
else:
61+
self.reduction = torch.sum
62+
else:
63+
self.reduction = reduction
6264

6365
# Weight decay.
6466
self.weight_decay = weight_decay

bindsnet/network/nodes.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -375,12 +375,10 @@ def forward(self, x: torch.Tensor) -> None:
375375
:param x: Inputs to the layer.
376376
"""
377377
# Integrate input voltages.
378-
self.v += (self.refrac_count == 0).float() * x
378+
self.v += (self.refrac_count <= 0).float() * x
379379

380380
# Decrement refractory counters.
381-
self.refrac_count = (self.refrac_count > 0).float() * (
382-
self.refrac_count - self.dt
383-
)
381+
self.refrac_count -= self.dt
384382

385383
# Check for spiking neurons.
386384
self.s = self.v >= self.thresh
@@ -509,16 +507,16 @@ def forward(self, x: torch.Tensor) -> None:
509507
self.v = self.decay * (self.v - self.rest) + self.rest
510508

511509
# Integrate inputs.
512-
self.v += (self.refrac_count == 0).float() * x
513-
510+
x.masked_fill_(self.refrac_count > 0, 0.0) # OPTIM 2
514511
# Decrement refractory counters.
515-
self.refrac_count = (self.refrac_count > 0).float() * (
516-
self.refrac_count - self.dt
517-
)
512+
self.refrac_count -= self.dt # OPTIM 1
513+
514+
self.v += x # interlaced
518515

519516
# Check for spiking neurons.
520517
self.s = self.v >= self.thresh
521518

519+
522520
# Refractoriness and voltage reset.
523521
self.refrac_count.masked_fill_(self.s, self.refrac)
524522
self.v.masked_fill_(self.s, self.reset)
@@ -653,13 +651,11 @@ def forward(self, x: torch.Tensor) -> None:
653651
self.i *= self.i_decay
654652

655653
# Decrement refractory counters.
656-
self.refrac_count = (self.refrac_count > 0).float() * (
657-
self.refrac_count - self.dt
658-
)
654+
self.refrac_count -= self.dt
659655

660656
# Integrate inputs.
661657
self.i += x
662-
self.v += (self.refrac_count == 0).float() * self.i
658+
self.v += (self.refrac_count <= 0).float() * self.i
663659

664660
# Check for spiking neurons.
665661
self.s = self.v >= self.thresh
@@ -776,7 +772,7 @@ def __init__(
776772
"tc_decay", torch.tensor(tc_decay)
777773
) # Time constant of neuron voltage decay.
778774
self.register_buffer(
779-
"decay", torch.empty_like(self.tc_decay)
775+
"decay", torch.empty_like(self.tc_decay, dtype=torch.float32)
780776
) # Set in compute_decays.
781777
self.register_buffer(
782778
"theta_plus", torch.tensor(theta_plus)
@@ -808,12 +804,10 @@ def forward(self, x: torch.Tensor) -> None:
808804
self.theta *= self.theta_decay
809805

810806
# Integrate inputs.
811-
self.v += (self.refrac_count == 0).float() * x
807+
self.v += (self.refrac_count <= 0).float() * x
812808

813809
# Decrement refractory counters.
814-
self.refrac_count = (self.refrac_count > 0).float() * (
815-
self.refrac_count - self.dt
816-
)
810+
self.refrac_count -= self.dt
817811

818812
# Check for spiking neurons.
819813
self.s = self.v >= self.thresh + self.theta
@@ -965,12 +959,10 @@ def forward(self, x: torch.Tensor) -> None:
965959
self.theta *= self.theta_decay
966960

967961
# Integrate inputs.
968-
self.v += (self.refrac_count == 0).float() * x
962+
self.v += (self.refrac_count <= 0).float() * x
969963

970964
# Decrement refractory counters.
971-
self.refrac_count = (self.refrac_count > 0).float() * (
972-
self.refrac_count - self.dt
973-
)
965+
self.refrac_count -= self.dt
974966

975967
# Check for spiking neurons.
976968
self.s = self.v >= self.thresh + self.theta
@@ -1298,17 +1290,15 @@ def forward(self, x: torch.Tensor) -> None:
12981290
self.v = self.decay * (self.v - self.rest) + self.rest
12991291

13001292
# Integrate inputs.
1301-
self.v += (self.refrac_count == 0).float() * self.eps_0 * x
1293+
self.v += (self.refrac_count <= 0).float() * self.eps_0 * x
13021294

13031295
# Compute (instantaneous) probabilities of spiking, clamp between 0 and 1 using exponentials.
13041296
# Also known as 'escape noise', this simulates nearby neurons.
13051297
self.rho = self.rho_0 * torch.exp((self.v - self.thresh) / self.d_thresh)
13061298
self.s_prob = 1.0 - torch.exp(-self.rho * self.dt)
13071299

13081300
# Decrement refractory counters.
1309-
self.refrac_count = (self.refrac_count > 0).float() * (
1310-
self.refrac_count - self.dt
1311-
)
1301+
self.refrac_count -= self.dt
13121302

13131303
# Check for spiking neurons (spike when probability > some random number).
13141304
self.s = torch.rand_like(self.s_prob) < self.s_prob

0 commit comments

Comments
 (0)