@@ -955,23 +955,38 @@ class ArctangentSurrogateFeature(AbstractFeature):
955955
956956 def __init__ (
957957 self ,
958+ name : str ,
958959 spike_threshold : float = 1.0 ,
959960 alpha : float = 2.0 ,
960961 dt : float = 1.0 ,
961962 reset_mechanism : str = "subtract" ,
963+ value : Union [torch .Tensor , float , int ] = None ,
964+ range : Optional [Union [list , tuple ]] = None ,
962965 ** kwargs
963966 ):
964967 """
965968 Initialize arctangent surrogate feature.
966969
967970 Args:
971+ name: Name of the feature
968972 spike_threshold: Voltage threshold for spike generation
969973 alpha: Steepness parameter for surrogate gradient (higher = steeper)
970974 dt: Integration time step
971975 reset_mechanism: Post-spike reset ("subtract", "zero", "none")
976+ value: Initial membrane potential values (optional)
977+ range: Range of acceptable values for membrane potential
972978 **kwargs: Additional arguments for AbstractFeature
973979 """
974- super ().__init__ (** kwargs )
980+ # Set default range if not provided
981+ if range is None :
982+ range = [- 10.0 , 10.0 ] # Reasonable range for membrane potentials
983+
984+ super ().__init__ (
985+ name = name ,
986+ value = value ,
987+ range = range ,
988+ ** kwargs
989+ )
975990
976991 self .spike_threshold = spike_threshold
977992 self .alpha = alpha
@@ -983,26 +998,32 @@ def __init__(
983998 self .batch_size = None
984999 self .target_size = None
9851000 self .initialized = False
1001+ self .connection = None
9861002
987- def compute (
988- self ,
989- connection : 'MulticompartmentConnection' ,
990- source_s : torch .Tensor ,
991- ** kwargs
992- ) -> torch .Tensor :
1003+ def compute (self , conn_spikes ) -> torch .Tensor :
9931004 """
9941005 Compute forward pass with arctangent surrogate gradients.
9951006
9961007 Args:
997- connection: Parent MulticompartmentConnection
998- source_s: Source layer spikes [batch_size, source_neurons]
999- **kwargs: Additional arguments
1008+ conn_spikes: Connection spikes tensor [batch_size, source_neurons * target_neurons]
10001009
10011010 Returns:
1002- Target spikes with differentiable surrogate gradients [batch_size, target_neurons]
1011+ Target spikes with differentiable surrogate gradients [batch_size, source_neurons * target_neurons]
10031012 """
1004- # Step 1: Compute synaptic input
1005- synaptic_input = torch .mm (source_s .float (), connection .w )
1013+ # Ensure connection is available
1014+ if self .connection is None :
1015+ raise RuntimeError ("ArctangentSurrogateFeature not properly initialized. Call prime_feature first." )
1016+
1017+ # Reshape conn_spikes to [batch_size, source_neurons, target_neurons]
1018+ batch_size = conn_spikes .size (0 )
1019+ source_n = self .connection .source .n
1020+ target_n = self .connection .target .n
1021+
1022+ # Reshape connection spikes to matrix form
1023+ conn_spikes_matrix = conn_spikes .view (batch_size , source_n , target_n )
1024+
1025+ # Step 1: Compute synaptic input (sum over source neurons)
1026+ synaptic_input = conn_spikes_matrix .sum (dim = 1 ) # [batch_size, target_neurons]
10061027
10071028 # Step 2: Initialize membrane potential if needed
10081029 if not self .initialized :
@@ -1021,7 +1042,15 @@ def compute(
10211042 # Step 5: Apply reset mechanism
10221043 self ._apply_reset (spikes )
10231044
1024- return spikes
1045+ # Step 6: Broadcast spikes back to connection format
1046+ # Each target spike affects all connections to that target
1047+ spikes_broadcast = spikes .unsqueeze (1 ).expand (batch_size , source_n , target_n )
1048+
1049+ # Apply spikes to incoming connections
1050+ output_spikes = conn_spikes_matrix * spikes_broadcast
1051+
1052+ # Reshape back to original format
1053+ return output_spikes .view (batch_size , source_n * target_n )
10251054
10261055 def arctangent_surrogate_spike (
10271056 self ,
@@ -1048,7 +1077,13 @@ def arctangent_surrogate_spike(
10481077 def _initialize_state (self , reference_tensor : torch .Tensor ):
10491078 """Initialize state tensors based on input dimensions."""
10501079 self .batch_size , self .target_size = reference_tensor .shape
1051- self .v_membrane = torch .zeros_like (reference_tensor )
1080+ # Initialize membrane potential to match batch size and target neurons
1081+ if self .v_membrane is None :
1082+ self .v_membrane = torch .zeros_like (reference_tensor )
1083+ else :
1084+ # Expand existing membrane potential to match batch size
1085+ if self .v_membrane .size (0 ) != self .batch_size :
1086+ self .v_membrane = self .v_membrane .expand (self .batch_size , - 1 )
10521087 self .initialized = True
10531088
10541089 def _apply_reset (self , spikes : torch .Tensor ):
@@ -1067,10 +1102,39 @@ def _apply_reset(self, spikes: torch.Tensor):
10671102
10681103 def reset_state_variables (self ):
10691104 """Reset all internal state variables."""
1105+ super ().reset_state_variables ()
10701106 self .v_membrane = None
10711107 self .batch_size = None
10721108 self .target_size = None
10731109 self .initialized = False
1110+
1111+ def prime_feature (self , connection , device , ** kwargs ) -> None :
1112+ """
1113+ Prime the feature for use in a connection.
1114+
1115+ Args:
1116+ connection: Parent connection object
1117+ device: Device to run on
1118+ **kwargs: Additional arguments
1119+ """
1120+ # Store connection reference
1121+ self .connection = connection
1122+
1123+ # Call parent prime_feature
1124+ super ().prime_feature (connection , device , ** kwargs )
1125+
1126+ def initialize_value (self ):
1127+ """
1128+ Initialize default membrane potential values.
1129+
1130+ Returns:
1131+ Zero membrane potentials for all target neurons
1132+ """
1133+ if self .connection is None :
1134+ raise RuntimeError ("Connection not set. Call prime_feature first." )
1135+
1136+ # Initialize with zeros - membrane potentials start at rest
1137+ return torch .zeros (1 , self .connection .target .n )
10741138
10751139 def get_membrane_potential (self ) -> Optional [torch .Tensor ]:
10761140 """Get current membrane potential."""
@@ -1101,6 +1165,7 @@ def __repr__(self) -> str:
11011165 """String representation."""
11021166 return (
11031167 f"ArctangentSurrogateFeature("
1168+ f"name='{ self .name } ', "
11041169 f"spike_threshold={ self .spike_threshold } , "
11051170 f"alpha={ self .alpha } , "
11061171 f"dt={ self .dt } , "
@@ -1168,11 +1233,61 @@ def backward(ctx, grad_output):
11681233 # threshold and alpha gradients are None (not optimized)
11691234 return grad_input , None , None
11701235
1236+ class GoodnessScore (AbstractSubFeature ):
1237+ """
1238+ SubFeature to compute the goodness score (sum of spikes over time) for all layers in a BindsNET network.
1239+ """
11711240
1241+ def __init__ (
1242+ self ,
1243+ name : str ,
1244+ parent_feature : AbstractFeature = None ,
1245+ network = None , # <-- Add this argument
1246+ time : int = 250 ,
1247+ input_layer : str = "X" ,
1248+ ) -> None :
1249+ super ().__init__ (name , parent_feature )
1250+ self .time = time
1251+ self .input_layer = input_layer
1252+ self .network = network # <-- Store the network
1253+
1254+ def compute (self , sample : torch .Tensor ) -> dict :
1255+ # Use self.network if provided, else fall back to parent_feature's connection
1256+ if self .network is not None :
1257+ network = self .network
1258+ else :
1259+ if not hasattr (self .parent , "connection" ) or self .parent .connection is None :
1260+ raise RuntimeError ("Parent feature must have a valid connection attribute." )
1261+ if not hasattr (self .parent .connection , "network" ) or self .parent .connection .network is None :
1262+ raise RuntimeError ("Connection must have a valid network attribute." )
1263+ network = self .parent .connection .network
1264+
1265+ network .reset_state_variables ()
1266+ inputs = {self .input_layer : sample .unsqueeze (0 ) if sample .dim () == 1 else sample }
1267+ spike_record = {layer_name : [] for layer_name in network .layers }
1268+
1269+ for t in range (self .time ):
1270+ network .run (inputs , time = 1 )
1271+ for layer_name , layer in network .layers .items ():
1272+ spike_record [layer_name ].append (layer .s .clone ().detach ())
1273+
1274+ goodness_per_layer = {}
1275+ for layer_name , spikes_list in spike_record .items ():
1276+ spikes = torch .stack (spikes_list , dim = 0 )
1277+ goodness = spikes .sum (dim = 0 ).sum (dim = 0 )
1278+ goodness_per_layer [layer_name ] = goodness
1279+
1280+ total_goodness = sum ([v .sum () for v in goodness_per_layer .values ()])
1281+ goodness_per_layer ["total_goodness" ] = total_goodness
1282+
1283+ return goodness_per_layer
1284+
1285+
11721286# Helper function for easy creation
11731287def create_arctangent_surrogate_connection (
11741288 source ,
11751289 target ,
1290+ name : str = "arctangent_surrogate" ,
11761291 w : Optional [torch .Tensor ] = None ,
11771292 spike_threshold : float = 1.0 ,
11781293 alpha : float = 2.0 ,
@@ -1186,6 +1301,7 @@ def create_arctangent_surrogate_connection(
11861301 Args:
11871302 source: Source population
11881303 target: Target population
1304+ name: Name for the surrogate feature
11891305 w: Weight matrix (initialized randomly if None)
11901306 spike_threshold: Spike threshold
11911307 alpha: Surrogate gradient steepness
@@ -1204,6 +1320,7 @@ def create_arctangent_surrogate_connection(
12041320
12051321 # Create arctangent surrogate feature
12061322 surrogate_feature = ArctangentSurrogateFeature (
1323+ name = name ,
12071324 spike_threshold = spike_threshold ,
12081325 alpha = alpha ,
12091326 dt = dt ,
@@ -1215,7 +1332,7 @@ def create_arctangent_surrogate_connection(
12151332 source = source ,
12161333 target = target ,
12171334 w = w ,
1218- features = [surrogate_feature ],
1335+ pipeline = [surrogate_feature ],
12191336 ** mcc_kwargs
12201337 )
12211338
0 commit comments