Skip to content

Commit 3f5a586

Browse files
author
Kevin Chang
committed
finished
1 parent 6f14e48 commit 3f5a586

File tree

7 files changed

+659
-677
lines changed

7 files changed

+659
-677
lines changed

bindsnet/network/topology.py

Lines changed: 0 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,133 +2052,3 @@ def reset_state_variables(self) -> None:
20522052

20532053

20542054

2055-
# designed for ANN
2056-
# the weights are not a true connection
2057-
# gradient
2058-
2059-
class ForwardForwardConnection(AbstractConnection):
2060-
"""
2061-
Connection class specifically designed for Forward-Forward training with arctangent surrogate gradients.
2062-
"""
2063-
2064-
def __init__(
2065-
self,
2066-
source: Nodes,
2067-
target: Nodes,
2068-
nu: Optional[Union[float, Sequence[float], Sequence[torch.Tensor]]] = None,
2069-
weight_decay: float = 0.0,
2070-
spike_threshold: float = 1.0,
2071-
alpha: float = 2.0, # α parameter for arctangent surrogate
2072-
**kwargs,
2073-
) -> None:
2074-
super().__init__(source, target, nu, weight_decay, **kwargs)
2075-
2076-
# Initialize weights with gradient support
2077-
w = kwargs.get("w", None)
2078-
if w is None:
2079-
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
2080-
w = torch.clamp(torch.randn(source.n, target.n) * 0.1, self.wmin, self.wmax)
2081-
else:
2082-
w = self.wmin + (torch.randn(source.n, target.n) * 0.1) * (self.wmax - self.wmin)
2083-
else:
2084-
if (self.wmin == -np.inf).any() or (self.wmax == np.inf).any():
2085-
w = torch.clamp(w, self.wmin, self.wmax)
2086-
2087-
# CRITICAL: Enable gradients for Forward-Forward training
2088-
self.w = Parameter(w, requires_grad=True)
2089-
2090-
# Surrogate gradient parameters
2091-
self.spike_threshold = spike_threshold
2092-
self.alpha = alpha
2093-
2094-
# Track membrane potential for surrogate gradients
2095-
self.membrane_potential = None
2096-
2097-
def atan_surrogate_spike(self, x: torch.Tensor) -> torch.Tensor:
2098-
"""
2099-
Arctangent surrogate gradient function.
2100-
2101-
Forward pass: Heaviside step function shifted by threshold
2102-
Backward pass: Gradient of shifted arc-tan function with parameter α
2103-
"""
2104-
class AtanSurrogate(torch.autograd.Function):
2105-
@staticmethod
2106-
def forward(ctx, input, threshold, alpha):
2107-
ctx.save_for_backward(input)
2108-
ctx.threshold = threshold
2109-
ctx.alpha = alpha
2110-
# Forward: Heaviside step function shifted by threshold
2111-
return (input > threshold).float()
2112-
2113-
@staticmethod
2114-
def backward(ctx, grad_output):
2115-
input, = ctx.saved_tensors
2116-
grad_input = grad_output.clone()
2117-
# Backward: Gradient of shifted arc-tan function
2118-
# surrogate = 1 / (α * |input - threshold| + 1)
2119-
surrogate_grad = 1.0 / (ctx.alpha * torch.abs(input - ctx.threshold) + 1.0)
2120-
return grad_input * surrogate_grad, None, None
2121-
2122-
return AtanSurrogate.apply(x, self.spike_threshold, self.alpha)
2123-
2124-
def compute_with_surrogate(self, s: torch.Tensor) -> torch.Tensor:
2125-
"""
2126-
Compute pre-activations with arctangent surrogate gradients.
2127-
2128-
:param s: Incoming spikes [batch_size, source_neurons]
2129-
:return: Output spikes with surrogate gradients [batch_size, target_neurons]
2130-
"""
2131-
batch_size = s.shape[0]
2132-
2133-
# Initialize membrane potential if needed
2134-
if self.membrane_potential is None or self.membrane_potential.shape != (batch_size, self.target.n):
2135-
self.membrane_potential = torch.zeros(batch_size, self.target.n, device=s.device)
2136-
2137-
# Synaptic input: spikes @ weights
2138-
synaptic_input = torch.mm(s.float(), self.w)
2139-
2140-
# Simple LIF dynamics with decay (you can customize this)
2141-
decay_factor = 0.9 # Can be made configurable
2142-
self.membrane_potential = decay_factor * self.membrane_potential + synaptic_input
2143-
2144-
# Generate spikes with arctangent surrogate gradients
2145-
spikes = self.atan_surrogate_spike(self.membrane_potential)
2146-
2147-
# Reset mechanism: subtract threshold from membrane potential where spikes occurred
2148-
self.membrane_potential = self.membrane_potential - spikes * self.spike_threshold
2149-
2150-
return spikes
2151-
2152-
def compute(self, s: torch.Tensor) -> torch.Tensor:
2153-
"""
2154-
Standard compute method (calls compute_with_surrogate for FF training).
2155-
"""
2156-
return self.compute_with_surrogate(s)
2157-
2158-
def reset_membrane_potential(self):
2159-
"""Reset membrane potential (call between samples/batches)."""
2160-
self.membrane_potential = None
2161-
2162-
def update(self, **kwargs) -> None:
2163-
"""
2164-
Override standard BindsNET update - FF uses PyTorch optimizers.
2165-
"""
2166-
# Forward-Forward training uses PyTorch optimizers for weight updates
2167-
# So we don't need the standard BindsNET learning rule updates
2168-
pass
2169-
2170-
def normalize(self) -> None:
2171-
"""
2172-
Normalize weights so each target neuron has sum of connection weights equal to self.norm.
2173-
"""
2174-
if self.norm is not None:
2175-
w_abs_sum = self.w.abs().sum(0).unsqueeze(0)
2176-
w_abs_sum[w_abs_sum == 0] = 1.0
2177-
self.w.data *= self.norm / w_abs_sum
2178-
2179-
def reset_state_variables(self) -> None:
2180-
"""
2181-
Contains resetting logic for the connection.
2182-
"""
2183-
super().reset_state_variables()
2184-
self.reset_membrane_potential()

bindsnet/network/topology_features.py

Lines changed: 133 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -955,23 +955,38 @@ class ArctangentSurrogateFeature(AbstractFeature):
955955

956956
def __init__(
957957
self,
958+
name: str,
958959
spike_threshold: float = 1.0,
959960
alpha: float = 2.0,
960961
dt: float = 1.0,
961962
reset_mechanism: str = "subtract",
963+
value: Union[torch.Tensor, float, int] = None,
964+
range: Optional[Union[list, tuple]] = None,
962965
**kwargs
963966
):
964967
"""
965968
Initialize arctangent surrogate feature.
966969
967970
Args:
971+
name: Name of the feature
968972
spike_threshold: Voltage threshold for spike generation
969973
alpha: Steepness parameter for surrogate gradient (higher = steeper)
970974
dt: Integration time step
971975
reset_mechanism: Post-spike reset ("subtract", "zero", "none")
976+
value: Initial membrane potential values (optional)
977+
range: Range of acceptable values for membrane potential
972978
**kwargs: Additional arguments for AbstractFeature
973979
"""
974-
super().__init__(**kwargs)
980+
# Set default range if not provided
981+
if range is None:
982+
range = [-10.0, 10.0] # Reasonable range for membrane potentials
983+
984+
super().__init__(
985+
name=name,
986+
value=value,
987+
range=range,
988+
**kwargs
989+
)
975990

976991
self.spike_threshold = spike_threshold
977992
self.alpha = alpha
@@ -983,26 +998,32 @@ def __init__(
983998
self.batch_size = None
984999
self.target_size = None
9851000
self.initialized = False
1001+
self.connection = None
9861002

987-
def compute(
988-
self,
989-
connection: 'MulticompartmentConnection',
990-
source_s: torch.Tensor,
991-
**kwargs
992-
) -> torch.Tensor:
1003+
def compute(self, conn_spikes) -> torch.Tensor:
9931004
"""
9941005
Compute forward pass with arctangent surrogate gradients.
9951006
9961007
Args:
997-
connection: Parent MulticompartmentConnection
998-
source_s: Source layer spikes [batch_size, source_neurons]
999-
**kwargs: Additional arguments
1008+
conn_spikes: Connection spikes tensor [batch_size, source_neurons * target_neurons]
10001009
10011010
Returns:
1002-
Target spikes with differentiable surrogate gradients [batch_size, target_neurons]
1011+
Target spikes with differentiable surrogate gradients [batch_size, source_neurons * target_neurons]
10031012
"""
1004-
# Step 1: Compute synaptic input
1005-
synaptic_input = torch.mm(source_s.float(), connection.w)
1013+
# Ensure connection is available
1014+
if self.connection is None:
1015+
raise RuntimeError("ArctangentSurrogateFeature not properly initialized. Call prime_feature first.")
1016+
1017+
# Reshape conn_spikes to [batch_size, source_neurons, target_neurons]
1018+
batch_size = conn_spikes.size(0)
1019+
source_n = self.connection.source.n
1020+
target_n = self.connection.target.n
1021+
1022+
# Reshape connection spikes to matrix form
1023+
conn_spikes_matrix = conn_spikes.view(batch_size, source_n, target_n)
1024+
1025+
# Step 1: Compute synaptic input (sum over source neurons)
1026+
synaptic_input = conn_spikes_matrix.sum(dim=1) # [batch_size, target_neurons]
10061027

10071028
# Step 2: Initialize membrane potential if needed
10081029
if not self.initialized:
@@ -1021,7 +1042,15 @@ def compute(
10211042
# Step 5: Apply reset mechanism
10221043
self._apply_reset(spikes)
10231044

1024-
return spikes
1045+
# Step 6: Broadcast spikes back to connection format
1046+
# Each target spike affects all connections to that target
1047+
spikes_broadcast = spikes.unsqueeze(1).expand(batch_size, source_n, target_n)
1048+
1049+
# Apply spikes to incoming connections
1050+
output_spikes = conn_spikes_matrix * spikes_broadcast
1051+
1052+
# Reshape back to original format
1053+
return output_spikes.view(batch_size, source_n * target_n)
10251054

10261055
def arctangent_surrogate_spike(
10271056
self,
@@ -1048,7 +1077,13 @@ def arctangent_surrogate_spike(
10481077
def _initialize_state(self, reference_tensor: torch.Tensor):
10491078
"""Initialize state tensors based on input dimensions."""
10501079
self.batch_size, self.target_size = reference_tensor.shape
1051-
self.v_membrane = torch.zeros_like(reference_tensor)
1080+
# Initialize membrane potential to match batch size and target neurons
1081+
if self.v_membrane is None:
1082+
self.v_membrane = torch.zeros_like(reference_tensor)
1083+
else:
1084+
# Expand existing membrane potential to match batch size
1085+
if self.v_membrane.size(0) != self.batch_size:
1086+
self.v_membrane = self.v_membrane.expand(self.batch_size, -1)
10521087
self.initialized = True
10531088

10541089
def _apply_reset(self, spikes: torch.Tensor):
@@ -1067,10 +1102,39 @@ def _apply_reset(self, spikes: torch.Tensor):
10671102

10681103
def reset_state_variables(self):
10691104
"""Reset all internal state variables."""
1105+
super().reset_state_variables()
10701106
self.v_membrane = None
10711107
self.batch_size = None
10721108
self.target_size = None
10731109
self.initialized = False
1110+
1111+
def prime_feature(self, connection, device, **kwargs) -> None:
1112+
"""
1113+
Prime the feature for use in a connection.
1114+
1115+
Args:
1116+
connection: Parent connection object
1117+
device: Device to run on
1118+
**kwargs: Additional arguments
1119+
"""
1120+
# Store connection reference
1121+
self.connection = connection
1122+
1123+
# Call parent prime_feature
1124+
super().prime_feature(connection, device, **kwargs)
1125+
1126+
def initialize_value(self):
1127+
"""
1128+
Initialize default membrane potential values.
1129+
1130+
Returns:
1131+
Zero membrane potentials for all target neurons
1132+
"""
1133+
if self.connection is None:
1134+
raise RuntimeError("Connection not set. Call prime_feature first.")
1135+
1136+
# Initialize with zeros - membrane potentials start at rest
1137+
return torch.zeros(1, self.connection.target.n)
10741138

10751139
def get_membrane_potential(self) -> Optional[torch.Tensor]:
10761140
"""Get current membrane potential."""
@@ -1101,6 +1165,7 @@ def __repr__(self) -> str:
11011165
"""String representation."""
11021166
return (
11031167
f"ArctangentSurrogateFeature("
1168+
f"name='{self.name}', "
11041169
f"spike_threshold={self.spike_threshold}, "
11051170
f"alpha={self.alpha}, "
11061171
f"dt={self.dt}, "
@@ -1168,11 +1233,61 @@ def backward(ctx, grad_output):
11681233
# threshold and alpha gradients are None (not optimized)
11691234
return grad_input, None, None
11701235

1236+
class GoodnessScore(AbstractSubFeature):
1237+
"""
1238+
SubFeature to compute the goodness score (sum of spikes over time) for all layers in a BindsNET network.
1239+
"""
11711240

1241+
def __init__(
1242+
self,
1243+
name: str,
1244+
parent_feature: AbstractFeature = None,
1245+
network=None, # <-- Add this argument
1246+
time: int = 250,
1247+
input_layer: str = "X",
1248+
) -> None:
1249+
super().__init__(name, parent_feature)
1250+
self.time = time
1251+
self.input_layer = input_layer
1252+
self.network = network # <-- Store the network
1253+
1254+
def compute(self, sample: torch.Tensor) -> dict:
1255+
# Use self.network if provided, else fall back to parent_feature's connection
1256+
if self.network is not None:
1257+
network = self.network
1258+
else:
1259+
if not hasattr(self.parent, "connection") or self.parent.connection is None:
1260+
raise RuntimeError("Parent feature must have a valid connection attribute.")
1261+
if not hasattr(self.parent.connection, "network") or self.parent.connection.network is None:
1262+
raise RuntimeError("Connection must have a valid network attribute.")
1263+
network = self.parent.connection.network
1264+
1265+
network.reset_state_variables()
1266+
inputs = {self.input_layer: sample.unsqueeze(0) if sample.dim() == 1 else sample}
1267+
spike_record = {layer_name: [] for layer_name in network.layers}
1268+
1269+
for t in range(self.time):
1270+
network.run(inputs, time=1)
1271+
for layer_name, layer in network.layers.items():
1272+
spike_record[layer_name].append(layer.s.clone().detach())
1273+
1274+
goodness_per_layer = {}
1275+
for layer_name, spikes_list in spike_record.items():
1276+
spikes = torch.stack(spikes_list, dim=0)
1277+
goodness = spikes.sum(dim=0).sum(dim=0)
1278+
goodness_per_layer[layer_name] = goodness
1279+
1280+
total_goodness = sum([v.sum() for v in goodness_per_layer.values()])
1281+
goodness_per_layer["total_goodness"] = total_goodness
1282+
1283+
return goodness_per_layer
1284+
1285+
11721286
# Helper function for easy creation
11731287
def create_arctangent_surrogate_connection(
11741288
source,
11751289
target,
1290+
name: str = "arctangent_surrogate",
11761291
w: Optional[torch.Tensor] = None,
11771292
spike_threshold: float = 1.0,
11781293
alpha: float = 2.0,
@@ -1186,6 +1301,7 @@ def create_arctangent_surrogate_connection(
11861301
Args:
11871302
source: Source population
11881303
target: Target population
1304+
name: Name for the surrogate feature
11891305
w: Weight matrix (initialized randomly if None)
11901306
spike_threshold: Spike threshold
11911307
alpha: Surrogate gradient steepness
@@ -1204,6 +1320,7 @@ def create_arctangent_surrogate_connection(
12041320

12051321
# Create arctangent surrogate feature
12061322
surrogate_feature = ArctangentSurrogateFeature(
1323+
name=name,
12071324
spike_threshold=spike_threshold,
12081325
alpha=alpha,
12091326
dt=dt,
@@ -1215,7 +1332,7 @@ def create_arctangent_surrogate_connection(
12151332
source=source,
12161333
target=target,
12171334
w=w,
1218-
features=[surrogate_feature],
1335+
pipeline=[surrogate_feature],
12191336
**mcc_kwargs
12201337
)
12211338

0 commit comments

Comments
 (0)