@@ -549,150 +549,6 @@ def __init__(
549549        self .add_connection (recurrent_conn , source = "Y" , target = "Y" )
550550
551551
552- import  snntorch  as  snn 
553- 
554- class  FFSNN (nn .Module ):
555-     # language=rst 
556-     """ 
557-     A simple feedforward Spiking Neural Network (SNN) using snntorch, 
558-     designed for use with the ForwardForwardPipeline. 
559-     It consists of a sequence of Linear layers followed by Leaky Integrate-and-Fire 
560-     (LIF) spiking neuron layers. 
561-     """ 
562- 
563-     def  __init__ (
564-         self ,
565-         input_size : int , # This should be 794 (image_features + num_classes) 
566-         hidden_sizes : List [int ], # e.g., [500, 500] 
567-         output_size : Optional [int ] =  None , # If the last FF layer is also the output layer for classification 
568-         beta : Union [float , torch .Tensor ] =  0.9 ,  # Decay rate for snn.Leaky neurons 
569-         threshold : float  =  1.0 ,  # Firing threshold for snn.Leaky neurons 
570-         reset_mechanism : str  =  "subtract" ,  # "subtract", "zero", or "none" 
571-         # Add other snn.Leaky parameters if needed, e.g., spike_grad 
572-     ) ->  None :
573-         # language=rst 
574-         """ 
575-         Constructor for FFSNN. 
576- 
577-         :param input_size: Number of input features (after encoding and label embedding). 
578-         :param hidden_sizes: A list of integers, where each integer is the number of 
579-                              neurons in a hidden layer. 
580-         :param output_size: Optional. Number of neurons in the final layer if it's 
581-                             distinct or specifically for output. If None, the last 
582-                             size in hidden_sizes is considered the final FF layer. 
583-         :param beta: Membrane potential decay rate for Leaky neurons. 
584-         :param threshold: Firing threshold for Leaky neurons. 
585-         :param reset_mechanism: Reset mechanism for Leaky neurons after a spike. 
586-         """ 
587-         super ().__init__ ()
588- 
589-         self .input_size  =  input_size 
590-         self .hidden_sizes  =  hidden_sizes 
591-         self .output_size  =  output_size 
592-         self .beta  =  beta 
593-         self .threshold  =  threshold 
594-         self .reset_mechanism  =  reset_mechanism 
595- 
596-         self .fc_layers  =  nn .ModuleList ()
597-         self .snn_layers  =  nn .ModuleList ()
598-         self ._ff_layer_pairs_info  =  []
599- 
600-         current_dim  =  self .input_size  # Starts at 794 
601-         for  i , hidden_dim  in  enumerate (self .hidden_sizes ):
602-             linear_layer  =  nn .Linear (current_dim , hidden_dim ) # Layer 1: 794 -> 500 
603-                                                               # Layer 2: 500 -> 500 
604-             self .fc_layers .append (linear_layer )
605- 
606-             snn_layer  =  snn .Leaky (
607-                 beta = self .beta , 
608-                 threshold = self .threshold , 
609-                 reset_mechanism = self .reset_mechanism ,
610-                 # output_shape=[hidden_dim] # Optional: snntorch can infer this 
611-             )
612-             self .snn_layers .append (snn_layer )
613-             self ._ff_layer_pairs_info .append ((linear_layer , snn_layer ))
614-             current_dim  =  hidden_dim  # Update current_dim for the *next* layer's input 
615-         
616-         # If there's an output_size for a final classifier (not typical for pure FF layers) 
617-         if  self .output_size  is  not   None :
618-             self .fc_out  =  nn .Linear (current_dim , self .output_size )
619-             # Potentially another SNN layer if output is spiking 
620-             # self.snn_out = snn.Leaky(...) 
621-             # self._ff_layer_pairs_info.append((self.fc_out, self.snn_out)) # If FF applies here too 
622- 
623-     def  forward (self , x_batch_time : torch .Tensor ) ->  Tuple [torch .Tensor , List [torch .Tensor ]]:
624-         # language=rst 
625-         """ 
626-         Defines the forward pass of the SNN over time. 
627-         This method might be used if the network is called directly with time-series data. 
628-         However, the ForwardForwardPipeline._run_snn_batch currently iterates 
629-         through self.network_sequence modules per time step. 
630- 
631-         :param x_batch_time: Input tensor with shape [batch_size, time_steps, num_features]. 
632-         :return: Final layer output spikes and a list of hidden states (membrane potentials) 
633-                  from all spiking layers. 
634-         """ 
635-         # Initialize hidden states for all spiking layers in the sequence 
636-         # This assumes they are snn.Leaky and support init_leaky() or similar 
637-         # Or, more generally, that they initialize if mem is None on first call. 
638-         
639-         spiking_layer_modules  =  [info ['spiking' ] for  info  in  self ._ff_layer_pairs_info ]
640-         # mem_states = [layer.init_leaky() for layer in spiking_layer_modules] # This creates new states 
641-         # For snntorch, typically pass None for initial state, layer handles it. 
642-         
643-         # The pipeline's _run_snn_batch actually handles the time loop and state passing. 
644-         # This forward method is more for standalone use or if the pipeline changes. 
645-         # If you want this model to be directly callable with (B, T, F) input and manage its own time loop: 
646- 
647-         batch_size  =  x_batch_time .shape [0 ]
648-         
649-         # Initialize states for each spiking layer for this batch 
650-         # This is tricky because snn.Leaky.init_hidden() doesn't take batch_size. 
651-         # State initialization is usually handled by passing None to the layer's forward method 
652-         # for the first time step, and it initializes based on the input batch size. 
653-         
654-         # Placeholder: The pipeline's _run_snn_batch is the primary runner. 
655-         # This forward method would need a more elaborate state management if used directly. 
656-         # For now, let's make it compatible with how _run_snn_batch works if it were to call this. 
657-         
658-         # If this 'forward' is to be used, it should mirror _run_snn_batch's logic: 
659-         # Initialize all spiking layer states to None 
660-         spiking_layer_states  =  {module : None  for  module  in  self .network_sequence  if  isinstance (module , snn .SpikingNeuron )}
661-         
662-         # Record outputs if needed (e.g., for a final classification layer not part of FF) 
663-         # final_spk_rec = [] # If you want to record output spikes over time 
664- 
665-         for  t  in  range (x_batch_time .shape [1 ]): # Iterate over time 
666-             x_t  =  x_batch_time [:, t , :]
667-             layer_input  =  x_t 
668-             
669-             current_module_idx  =  0 
670-             for  module  in  self .network_sequence :
671-                 if  isinstance (module , snn .SpikingNeuron ):
672-                     spk_out , new_mem  =  module (layer_input , spiking_layer_states .get (module ))
673-                     spiking_layer_states [module ] =  new_mem 
674-                     layer_input  =  spk_out 
675-                 else : # nn.Linear 
676-                     layer_input  =  module (layer_input )
677-             # After passing through all layers for time step t, layer_input is the output of the last layer 
678-             # final_spk_rec.append(layer_input)  
679-         
680-         # return torch.stack(final_spk_rec, dim=1), [state for state in spiking_layer_states.values()] 
681-         return  layer_input , [spiking_layer_states [info ['spiking' ]] for  info  in  self ._ff_layer_pairs_info ]
682- 
683- 
684-     def  get_ff_layer_pairs (self ) ->  List [Tuple [nn .Linear , snn .SpikingNeuron ]]:
685-         # language=rst 
686-         """ 
687-         Returns the list of (Linear, SpikingNeuron) pairs for Forward-Forward training. 
688-         """ 
689-         # If _ff_layer_pairs_info contains tuples, return them directly 
690-         return  self ._ff_layer_pairs_info 
691-         
692-         # Alternative: If you want to be explicit about the structure 
693-         # return [(pair[0], pair[1]) for pair in self._ff_layer_pairs_info] 
694- 
695- 
696552class  FFSNN_BindsNET (Network ):
697553    # language=rst 
698554    """ 
0 commit comments