@@ -725,42 +725,313 @@ def reset_state_variables(self) -> None:
725725        self .eligibility_trace .zero_ ()
726726        return 
727727
728+ # Remove the MyBackpropVariant class and replace with: 
728729
729- 
730- class  MyBackpropVariant (MCC_LearningRule ):
731-     def  __init__ (self , connection , feature_value , ** kwargs ):
732-         super ().__init__ (connection = connection , feature_value = feature_value , ** kwargs )
733-         # Potentially initialize other parameters specific to your variant 
734-         self .update  =  self ._custom_connection_update 
735- 
736-     def  _custom_connection_update (self , ** kwargs ) ->  None :
737-         # Assume 'error_signal' for the target layer is passed in kwargs 
738-         # Assume 'surrogate_grad_target' for target neuron activations is available or computed 
739-         # Assume 'source_activity' (e.g., spikes or trace) is from self.source 
730+ class  ForwardForwardMCCLearning (MCC_LearningRule ):
731+     """ 
732+     Forward-Forward learning rule for MulticompartmentConnection. 
733+      
734+     This MCC learning rule wrapper integrates the Forward-Forward algorithm 
735+     with the MulticompartmentConnection architecture, enabling layer-wise 
736+     learning without backpropagation through time. 
737+      
738+     The learning rule works by: 
739+     1. Computing goodness scores from target layer activity 
740+     2. Collecting positive and negative sample statistics 
741+     3. Applying contrastive weight updates based on Forward-Forward loss 
742+     """ 
743+     
744+     def  __init__ (
745+         self ,
746+         alpha_loss : float  =  0.6 ,
747+         goodness_fn : str  =  "mean_squared" ,
748+         nu : float  =  0.001 ,
749+         momentum : float  =  0.0 ,
750+         weight_decay : float  =  0.0 ,
751+         ** kwargs 
752+     ):
753+         """ 
754+         Initialize Forward-Forward MCC learning rule. 
740755         
741-         if  "error_signal"  not  in   kwargs :
742-             return  # Or handle missing error 
756+         Args: 
757+             alpha_loss: Forward-Forward loss threshold parameter 
758+             goodness_fn: Goodness score computation method ("mean_squared", "sum_squared") 
759+             nu: Learning rate for weight updates 
760+             momentum: Momentum factor for weight updates 
761+             weight_decay: Weight decay factor for regularization 
762+             **kwargs: Additional arguments passed to parent MCC_LearningRule 
763+         """ 
764+         super ().__init__ (nu = nu , ** kwargs )
743765
744-         error_signal  =  kwargs ["error_signal" ] # This would be specific to target neurons 
766+         self .alpha_loss  =  alpha_loss 
767+         self .goodness_fn  =  goodness_fn 
768+         self .momentum  =  momentum 
769+         self .weight_decay  =  weight_decay 
745770
746-         # This is highly conceptual and depends on your specific variant:  
747-         # 1. Get pre-synaptic activity (e.g.,  self.source.s or self.source.x) 
748-         # 2. The 'error_signal' would correspond to the error at the post-synaptic (target) neurons 
749-         # 3. Compute weight updates, e.g., delta_w = learning_rate * error_signal * pre_synaptic_activity 
750-         #    (This is a simplification; SNN backprop is more complex) 
771+         # State tracking for Forward-Forward learning  
772+         self .positive_goodness   =   None 
773+         self . negative_goodness   =   None 
774+         self . positive_activations   =   None 
775+         self . negative_activations   =   None 
751776
752-         # Example: (very abstract, actual SNN backprop is more involved) 
753-         # Assume error_signal is shaped for target neurons, source_s for source neurons 
754-         # update_matrix = torch.outer(error_signal, self.source.s.float().mean(dim=0)) # Simplified 
755-         # self.feature_value += self.nu[0] * update_matrix * self.connection.dt  
777+         # Momentum state 
778+         self .velocity  =  None 
756779
757-         # Actual implementation would depend on the precise math of your variant 
758-         # (e.g., using surrogate derivatives of target neuron potentials, etc.) 
780+         # Sample type tracking 
781+         self .current_sample_type  =  None 
782+         self .samples_processed  =  0 
759783
760-         # Call the parent's update for decay, clamping, etc. 
761-         super ().update () 
784+     def  update (
785+         self ,
786+         connection : 'MulticompartmentConnection' ,
787+         source_s : torch .Tensor ,
788+         target_s : torch .Tensor ,
789+         ** kwargs 
790+     ) ->  None :
791+         """ 
792+         Perform Forward-Forward learning update. 
762793         
763-     def  reset_state_variables (self ) ->  None :
764-         # Reset any internal states if your rule has them 
765-         pass 
766- 
794+         This method is called by MCC during each simulation step. It accumulates 
795+         statistics for positive and negative samples, then applies contrastive 
796+         updates when both sample types are available. 
797+          
798+         Args: 
799+             connection: Parent MulticompartmentConnection 
800+             source_s: Source layer spikes [batch_size, source_neurons] 
801+             target_s: Target layer spikes [batch_size, target_neurons] 
802+             **kwargs: Additional arguments including 'sample_type' 
803+         """ 
804+         # Check if learning is enabled 
805+         if  not  connection .w .requires_grad :
806+             return 
807+         
808+         # Get sample type from kwargs 
809+         sample_type  =  kwargs .get ('sample_type' , self .current_sample_type )
810+         if  sample_type  is  None :
811+             # Default to positive for backward compatibility 
812+             sample_type  =  "positive" 
813+         
814+         # Compute goodness score for current batch 
815+         current_goodness  =  self ._compute_goodness (target_s )
816+         
817+         # Store activations and goodness based on sample type 
818+         if  sample_type  ==  "positive" :
819+             self .positive_goodness  =  current_goodness .detach ()
820+             self .positive_activations  =  {
821+                 'source' : source_s .detach (),
822+                 'target' : target_s .detach ()
823+             }
824+             
825+         elif  sample_type  ==  "negative" :
826+             self .negative_goodness  =  current_goodness .detach ()
827+             self .negative_activations  =  {
828+                 'source' : source_s .detach (),
829+                 'target' : target_s .detach ()
830+             }
831+             
832+         else :
833+             raise  ValueError (f"Invalid sample_type: { sample_type }  . Must be 'positive' or 'negative'" )
834+         
835+         self .samples_processed  +=  1 
836+         
837+         # Apply contrastive update if we have both positive and negative samples 
838+         if  (self .positive_goodness  is  not   None  and  
839+             self .negative_goodness  is  not   None  and 
840+             self .positive_activations  is  not   None  and 
841+             self .negative_activations  is  not   None ):
842+             
843+             self ._apply_forward_forward_update (connection )
844+             self ._reset_accumulated_data ()
845+     
846+     def  _compute_goodness (self , target_activity : torch .Tensor ) ->  torch .Tensor :
847+         """ 
848+         Compute Forward-Forward goodness score from target layer activity. 
849+          
850+         Args: 
851+             target_activity: Target neuron spikes [batch_size, neurons] 
852+              
853+         Returns: 
854+             Goodness scores [batch_size] 
855+         """ 
856+         if  self .goodness_fn  ==  "mean_squared" :
857+             # Mean squared activity across neurons (original FF paper) 
858+             goodness  =  torch .mean (target_activity  **  2 , dim = 1 )
859+             
860+         elif  self .goodness_fn  ==  "sum_squared" :
861+             # Sum of squared activity across neurons 
862+             goodness  =  torch .sum (target_activity  **  2 , dim = 1 )
863+             
864+         else :
865+             raise  ValueError (f"Unknown goodness function: { self .goodness_fn }  " )
866+         
867+         return  goodness 
868+     
869+     def  _apply_forward_forward_update (self , connection : 'MulticompartmentConnection' ):
870+         """ 
871+         Apply Forward-Forward contrastive weight update. 
872+          
873+         The update follows the Forward-Forward principle: 
874+         - Strengthen weights that increase goodness for positive samples 
875+         - Weaken weights that increase goodness for negative samples 
876+          
877+         Args: 
878+             connection: Parent MulticompartmentConnection 
879+         """ 
880+         # Get weight tensor 
881+         w  =  connection .w 
882+         
883+         # Compute Forward-Forward loss (for monitoring) 
884+         ff_loss  =  self ._compute_ff_loss (self .positive_goodness , self .negative_goodness )
885+         
886+         # Compute weight update based on activity correlations 
887+         pos_source  =  self .positive_activations ['source' ]
888+         pos_target  =  self .positive_activations ['target' ]
889+         neg_source  =  self .negative_activations ['source' ]
890+         neg_target  =  self .negative_activations ['target' ]
891+         
892+         # Positive update: strengthen weights for positive samples 
893+         # ΔW_pos = η * s_pos^T * t_pos / batch_size 
894+         delta_w_pos  =  torch .mm (pos_source .t (), pos_target ) /  pos_source .shape [0 ]
895+         
896+         # Negative update: weaken weights for negative samples   
897+         # ΔW_neg = -η * s_neg^T * t_neg / batch_size 
898+         delta_w_neg  =  - torch .mm (neg_source .t (), neg_target ) /  neg_source .shape [0 ]
899+         
900+         # Combined Forward-Forward update 
901+         delta_w  =  self .nu  *  (delta_w_pos  +  delta_w_neg )
902+         
903+         # Add weight decay if specified 
904+         if  self .weight_decay  >  0 :
905+             delta_w  =  delta_w  -  self .weight_decay  *  w 
906+         
907+         # Apply momentum if specified 
908+         if  self .momentum  >  0 :
909+             if  self .velocity  is  None :
910+                 self .velocity  =  torch .zeros_like (w )
911+             
912+             self .velocity  =  self .momentum  *  self .velocity  +  delta_w 
913+             delta_w  =  self .velocity 
914+         
915+         # Apply weight update 
916+         with  torch .no_grad ():
917+             w .add_ (delta_w )
918+         
919+         # Apply weight constraints if they exist 
920+         self ._apply_weight_constraints (connection )
921+     
922+     def  _compute_ff_loss (
923+         self ,
924+         goodness_pos : torch .Tensor ,
925+         goodness_neg : torch .Tensor 
926+     ) ->  torch .Tensor :
927+         """ 
928+         Compute Forward-Forward contrastive loss for monitoring. 
929+          
930+         L = log(1 + exp(-g_pos + α)) + log(1 + exp(g_neg - α)) 
931+          
932+         Args: 
933+             goodness_pos: Goodness scores for positive samples 
934+             goodness_neg: Goodness scores for negative samples 
935+              
936+         Returns: 
937+             Forward-Forward loss (scalar) 
938+         """ 
939+         # Positive loss: encourage high goodness for positive samples 
940+         loss_pos  =  torch .log (1  +  torch .exp (- goodness_pos  +  self .alpha_loss ))
941+         
942+         # Negative loss: encourage low goodness for negative samples 
943+         loss_neg  =  torch .log (1  +  torch .exp (goodness_neg  -  self .alpha_loss ))
944+         
945+         # Return mean loss across batch 
946+         total_loss  =  loss_pos  +  loss_neg 
947+         return  torch .mean (total_loss )
948+     
949+     def  _apply_weight_constraints (self , connection : 'MulticompartmentConnection' ):
950+         """ 
951+         Apply weight constraints (bounds, normalization) if specified. 
952+          
953+         Args: 
954+             connection: Parent connection with constraint parameters 
955+         """ 
956+         w  =  connection .w 
957+         
958+         # Apply weight bounds if specified 
959+         if  hasattr (connection , 'wmin' ) and  hasattr (connection , 'wmax' ):
960+             with  torch .no_grad ():
961+                 w .clamp_ (connection .wmin , connection .wmax )
962+         
963+         # Apply normalization if specified 
964+         if  hasattr (connection , 'norm' ) and  connection .norm  is  not   None :
965+             with  torch .no_grad ():
966+                 if  connection .norm  ==  "l2" :
967+                     # L2 normalize each output neuron's weights 
968+                     w .div_ (w .norm (dim = 0 , keepdim = True ) +  1e-8 )
969+                 elif  connection .norm  ==  "l1" :
970+                     # L1 normalize each output neuron's weights 
971+                     w .div_ (w .abs ().sum (dim = 0 , keepdim = True ) +  1e-8 )
972+     
973+     def  _reset_accumulated_data (self ):
974+         """Reset accumulated positive and negative sample data.""" 
975+         self .positive_goodness  =  None 
976+         self .negative_goodness  =  None 
977+         self .positive_activations  =  None 
978+         self .negative_activations  =  None 
979+     
980+     def  set_sample_type (self , sample_type : str ):
981+         """ 
982+         Set the current sample type for subsequent updates. 
983+          
984+         Args: 
985+             sample_type: Either "positive" or "negative" 
986+         """ 
987+         if  sample_type  not  in   ["positive" , "negative" ]:
988+             raise  ValueError (f"Invalid sample_type: { sample_type }  " )
989+         
990+         self .current_sample_type  =  sample_type 
991+     
992+     def  get_goodness_scores (self ) ->  dict :
993+         """Get current goodness scores for positive and negative samples.""" 
994+         return  {
995+             'positive_goodness' : self .positive_goodness ,
996+             'negative_goodness' : self .negative_goodness 
997+         }
998+     
999+     def  get_ff_loss (self ) ->  torch .Tensor :
1000+         """Compute and return current Forward-Forward loss if data available.""" 
1001+         if  self .positive_goodness  is  not   None  and  self .negative_goodness  is  not   None :
1002+             return  self ._compute_ff_loss (self .positive_goodness , self .negative_goodness )
1003+         else :
1004+             return  torch .tensor (0.0 )
1005+     
1006+     def  reset_state (self ):
1007+         """Reset all learning rule state.""" 
1008+         self ._reset_accumulated_data ()
1009+         self .velocity  =  None 
1010+         self .current_sample_type  =  None 
1011+         self .samples_processed  =  0 
1012+     
1013+     def  get_learning_stats (self ) ->  dict :
1014+         """Get learning rule statistics and configuration.""" 
1015+         return  {
1016+             'learning_rule_type' : 'ForwardForwardMCCLearning' ,
1017+             'alpha_loss' : self .alpha_loss ,
1018+             'goodness_fn' : self .goodness_fn ,
1019+             'learning_rate' : self .nu ,
1020+             'momentum' : self .momentum ,
1021+             'weight_decay' : self .weight_decay ,
1022+             'samples_processed' : self .samples_processed ,
1023+             'current_sample_type' : self .current_sample_type ,
1024+             'has_positive_data' : self .positive_goodness  is  not   None ,
1025+             'has_negative_data' : self .negative_goodness  is  not   None 
1026+         }
1027+     
1028+     def  __repr__ (self ):
1029+         """String representation of the learning rule.""" 
1030+         return  (
1031+             f"ForwardForwardMCCLearning(" 
1032+             f"nu={ self .nu }  , " 
1033+             f"alpha_loss={ self .alpha_loss }  , " 
1034+             f"goodness_fn='{ self .goodness_fn }  ', " 
1035+             f"momentum={ self .momentum }  , " 
1036+             f"weight_decay={ self .weight_decay }  )" 
1037+         )
0 commit comments