@@ -375,12 +375,10 @@ def forward(self, x: torch.Tensor) -> None:
375375 :param x: Inputs to the layer.
376376 """
377377 # Integrate input voltages.
378- self .v += (self .refrac_count = = 0 ).float () * x
378+ self .v += (self .refrac_count < = 0 ).float () * x
379379
380380 # Decrement refractory counters.
381- self .refrac_count = (self .refrac_count > 0 ).float () * (
382- self .refrac_count - self .dt
383- )
381+ self .refrac_count -= self .dt
384382
385383 # Check for spiking neurons.
386384 self .s = self .v >= self .thresh
@@ -509,16 +507,16 @@ def forward(self, x: torch.Tensor) -> None:
509507 self .v = self .decay * (self .v - self .rest ) + self .rest
510508
511509 # Integrate inputs.
512- self .v += (self .refrac_count == 0 ).float () * x
513-
510+ x .masked_fill_ (self .refrac_count > 0 , 0.0 ) # OPTIM 2
514511 # Decrement refractory counters.
515- self .refrac_count = ( self .refrac_count > 0 ). float () * (
516- self . refrac_count - self . dt
517- )
512+ self .refrac_count -= self .dt # OPTIM 1
513+
514+ self . v += x # interlaced
518515
519516 # Check for spiking neurons.
520517 self .s = self .v >= self .thresh
521518
519+
522520 # Refractoriness and voltage reset.
523521 self .refrac_count .masked_fill_ (self .s , self .refrac )
524522 self .v .masked_fill_ (self .s , self .reset )
@@ -653,13 +651,11 @@ def forward(self, x: torch.Tensor) -> None:
653651 self .i *= self .i_decay
654652
655653 # Decrement refractory counters.
656- self .refrac_count = (self .refrac_count > 0 ).float () * (
657- self .refrac_count - self .dt
658- )
654+ self .refrac_count -= self .dt
659655
660656 # Integrate inputs.
661657 self .i += x
662- self .v += (self .refrac_count = = 0 ).float () * self .i
658+ self .v += (self .refrac_count < = 0 ).float () * self .i
663659
664660 # Check for spiking neurons.
665661 self .s = self .v >= self .thresh
@@ -776,7 +772,7 @@ def __init__(
776772 "tc_decay" , torch .tensor (tc_decay )
777773 ) # Time constant of neuron voltage decay.
778774 self .register_buffer (
779- "decay" , torch .empty_like (self .tc_decay )
775+ "decay" , torch .empty_like (self .tc_decay , dtype = torch . float32 )
780776 ) # Set in compute_decays.
781777 self .register_buffer (
782778 "theta_plus" , torch .tensor (theta_plus )
@@ -808,12 +804,10 @@ def forward(self, x: torch.Tensor) -> None:
808804 self .theta *= self .theta_decay
809805
810806 # Integrate inputs.
811- self .v += (self .refrac_count = = 0 ).float () * x
807+ self .v += (self .refrac_count < = 0 ).float () * x
812808
813809 # Decrement refractory counters.
814- self .refrac_count = (self .refrac_count > 0 ).float () * (
815- self .refrac_count - self .dt
816- )
810+ self .refrac_count -= self .dt
817811
818812 # Check for spiking neurons.
819813 self .s = self .v >= self .thresh + self .theta
@@ -965,12 +959,10 @@ def forward(self, x: torch.Tensor) -> None:
965959 self .theta *= self .theta_decay
966960
967961 # Integrate inputs.
968- self .v += (self .refrac_count = = 0 ).float () * x
962+ self .v += (self .refrac_count < = 0 ).float () * x
969963
970964 # Decrement refractory counters.
971- self .refrac_count = (self .refrac_count > 0 ).float () * (
972- self .refrac_count - self .dt
973- )
965+ self .refrac_count -= self .dt
974966
975967 # Check for spiking neurons.
976968 self .s = self .v >= self .thresh + self .theta
@@ -1298,17 +1290,15 @@ def forward(self, x: torch.Tensor) -> None:
12981290 self .v = self .decay * (self .v - self .rest ) + self .rest
12991291
13001292 # Integrate inputs.
1301- self .v += (self .refrac_count = = 0 ).float () * self .eps_0 * x
1293+ self .v += (self .refrac_count < = 0 ).float () * self .eps_0 * x
13021294
13031295 # Compute (instantaneous) probabilities of spiking, clamp between 0 and 1 using exponentials.
13041296 # Also known as 'escape noise', this simulates nearby neurons.
13051297 self .rho = self .rho_0 * torch .exp ((self .v - self .thresh ) / self .d_thresh )
13061298 self .s_prob = 1.0 - torch .exp (- self .rho * self .dt )
13071299
13081300 # Decrement refractory counters.
1309- self .refrac_count = (self .refrac_count > 0 ).float () * (
1310- self .refrac_count - self .dt
1311- )
1301+ self .refrac_count -= self .dt
13121302
13131303 # Check for spiking neurons (spike when probability > some random number).
13141304 self .s = torch .rand_like (self .s_prob ) < self .s_prob
0 commit comments