77from bindsnet .network .monitors import AbstractMonitor
88from bindsnet .network .nodes import CSRMNodes , Nodes
99from bindsnet .network .topology import AbstractConnection
10+ from torch .multiprocessing .spawn import spawn
1011
1112
1213def load (file_name : str , map_location : str = "cpu" , learning : bool = None ) -> "Network" :
@@ -28,6 +29,52 @@ def load(file_name: str, map_location: str = "cpu", learning: bool = None) -> "N
2829 return network
2930
3031
32+ def update_layer (_ , self , l , t ):
33+ # Update each layer of nodes.
34+ if l in self .inputs :
35+ if l in self .current_inputs :
36+ self .current_inputs [l ] += self .inputs [l ][t ]
37+ else :
38+ self .current_inputs [l ] = self .inputs [l ][t ]
39+
40+ if self .one_step :
41+ # Get input to this layer (one-step mode).
42+ self .current_inputs .update (self ._get_inputs (layers = [l ]))
43+
44+ # Inject voltage to neurons.
45+ inject_v = self .injects_v .get (l , None )
46+ if inject_v is not None :
47+ if inject_v .ndimension () == 1 :
48+ self .layers [l ].v += inject_v
49+ else :
50+ self .layers [l ].v += inject_v [t ]
51+
52+ if l in self .current_inputs :
53+ self .layers [l ].forward (x = self .current_inputs [l ])
54+ else :
55+ self .layers [l ].forward (
56+ x = torch .zeros (
57+ self .layers [l ].s .shape , device = self .layers [l ].s .device
58+ )
59+ )
60+
61+ # Clamp neurons to spike.
62+ clamp = self .clamps .get (l , None )
63+ if clamp is not None :
64+ if clamp .ndimension () == 1 :
65+ self .layers [l ].s [:, clamp ] = 1
66+ else :
67+ self .layers [l ].s [:, clamp [t ]] = 1
68+
69+ # Clamp neurons not to spike.
70+ unclamp = self .unclamps .get (l , None )
71+ if unclamp is not None :
72+ if unclamp .ndimension () == 1 :
73+ self .layers [l ].s [:, unclamp ] = 0
74+ else :
75+ self .layers [l ].s [:, unclamp [t ]] = 0
76+
77+
3178class Network (torch .nn .Module ):
3279 # language=rst
3380 """
@@ -383,50 +430,23 @@ def run(
383430 if not one_step :
384431 current_inputs .update (self ._get_inputs ())
385432
386- for l in self .layers :
387- # Update each layer of nodes.
388- if l in inputs :
389- if l in current_inputs :
390- current_inputs [l ] += inputs [l ][t ]
391- else :
392- current_inputs [l ] = inputs [l ][t ]
393-
394- if one_step :
395- # Get input to this layer (one-step mode).
396- current_inputs .update (self ._get_inputs (layers = [l ]))
433+ processes = []
434+ self .inputs = inputs
435+ self .current_inputs = current_inputs
436+ self .one_step = one_step
437+ self .injects_v = injects_v
438+ self .unclamps = unclamps
439+ self .clamps = clamps
397440
398- # Inject voltage to neurons.
399- inject_v = injects_v .get (l , None )
400- if inject_v is not None :
401- if inject_v .ndimension () == 1 :
402- self .layers [l ].v += inject_v
403- else :
404- self .layers [l ].v += inject_v [t ]
405-
406- if l in current_inputs :
407- self .layers [l ].forward (x = current_inputs [l ])
408- else :
409- self .layers [l ].forward (
410- x = torch .zeros (
411- self .layers [l ].s .shape , device = self .layers [l ].s .device
412- )
413- )
441+ for l in self .layers :
442+ processes .append (
443+ spawn (update_layer , args = (self , l , t ), join = False )
444+ )
414445
415- # Clamp neurons to spike.
416- clamp = clamps .get (l , None )
417- if clamp is not None :
418- if clamp .ndimension () == 1 :
419- self .layers [l ].s [:, clamp ] = 1
420- else :
421- self .layers [l ].s [:, clamp [t ]] = 1
446+ for p in processes :
447+ p .join ()
422448
423- # Clamp neurons not to spike.
424- unclamp = unclamps .get (l , None )
425- if unclamp is not None :
426- if unclamp .ndimension () == 1 :
427- self .layers [l ].s [:, unclamp ] = 0
428- else :
429- self .layers [l ].s [:, unclamp [t ]] = 0
449+ print (t )
430450
431451 for c in self .connections :
432452 flad_m = False
0 commit comments