Skip to content

Commit 6f14e48

Browse files
author
Kevin Chang
committed
test
1 parent 6022dcc commit 6f14e48

File tree

5 files changed

+507
-210
lines changed

5 files changed

+507
-210
lines changed

bindsnet/learning/MCC_learning.py

Lines changed: 302 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -725,42 +725,313 @@ def reset_state_variables(self) -> None:
725725
self.eligibility_trace.zero_()
726726
return
727727

728+
# Remove the MyBackpropVariant class and replace with:
728729

729-
730-
class MyBackpropVariant(MCC_LearningRule):
731-
def __init__(self, connection, feature_value, **kwargs):
732-
super().__init__(connection=connection, feature_value=feature_value, **kwargs)
733-
# Potentially initialize other parameters specific to your variant
734-
self.update = self._custom_connection_update
735-
736-
def _custom_connection_update(self, **kwargs) -> None:
737-
# Assume 'error_signal' for the target layer is passed in kwargs
738-
# Assume 'surrogate_grad_target' for target neuron activations is available or computed
739-
# Assume 'source_activity' (e.g., spikes or trace) is from self.source
730+
class ForwardForwardMCCLearning(MCC_LearningRule):
731+
"""
732+
Forward-Forward learning rule for MulticompartmentConnection.
733+
734+
This MCC learning rule wrapper integrates the Forward-Forward algorithm
735+
with the MulticompartmentConnection architecture, enabling layer-wise
736+
learning without backpropagation through time.
737+
738+
The learning rule works by:
739+
1. Computing goodness scores from target layer activity
740+
2. Collecting positive and negative sample statistics
741+
3. Applying contrastive weight updates based on Forward-Forward loss
742+
"""
743+
744+
def __init__(
745+
self,
746+
alpha_loss: float = 0.6,
747+
goodness_fn: str = "mean_squared",
748+
nu: float = 0.001,
749+
momentum: float = 0.0,
750+
weight_decay: float = 0.0,
751+
**kwargs
752+
):
753+
"""
754+
Initialize Forward-Forward MCC learning rule.
740755
741-
if "error_signal" not in kwargs:
742-
return # Or handle missing error
756+
Args:
757+
alpha_loss: Forward-Forward loss threshold parameter
758+
goodness_fn: Goodness score computation method ("mean_squared", "sum_squared")
759+
nu: Learning rate for weight updates
760+
momentum: Momentum factor for weight updates
761+
weight_decay: Weight decay factor for regularization
762+
**kwargs: Additional arguments passed to parent MCC_LearningRule
763+
"""
764+
super().__init__(nu=nu, **kwargs)
743765

744-
error_signal = kwargs["error_signal"] # This would be specific to target neurons
766+
self.alpha_loss = alpha_loss
767+
self.goodness_fn = goodness_fn
768+
self.momentum = momentum
769+
self.weight_decay = weight_decay
745770

746-
# This is highly conceptual and depends on your specific variant:
747-
# 1. Get pre-synaptic activity (e.g., self.source.s or self.source.x)
748-
# 2. The 'error_signal' would correspond to the error at the post-synaptic (target) neurons
749-
# 3. Compute weight updates, e.g., delta_w = learning_rate * error_signal * pre_synaptic_activity
750-
# (This is a simplification; SNN backprop is more complex)
771+
# State tracking for Forward-Forward learning
772+
self.positive_goodness = None
773+
self.negative_goodness = None
774+
self.positive_activations = None
775+
self.negative_activations = None
751776

752-
# Example: (very abstract, actual SNN backprop is more involved)
753-
# Assume error_signal is shaped for target neurons, source_s for source neurons
754-
# update_matrix = torch.outer(error_signal, self.source.s.float().mean(dim=0)) # Simplified
755-
# self.feature_value += self.nu[0] * update_matrix * self.connection.dt
777+
# Momentum state
778+
self.velocity = None
756779

757-
# Actual implementation would depend on the precise math of your variant
758-
# (e.g., using surrogate derivatives of target neuron potentials, etc.)
780+
# Sample type tracking
781+
self.current_sample_type = None
782+
self.samples_processed = 0
759783

760-
# Call the parent's update for decay, clamping, etc.
761-
super().update()
784+
def update(
785+
self,
786+
connection: 'MulticompartmentConnection',
787+
source_s: torch.Tensor,
788+
target_s: torch.Tensor,
789+
**kwargs
790+
) -> None:
791+
"""
792+
Perform Forward-Forward learning update.
762793
763-
def reset_state_variables(self) -> None:
764-
# Reset any internal states if your rule has them
765-
pass
766-
794+
This method is called by MCC during each simulation step. It accumulates
795+
statistics for positive and negative samples, then applies contrastive
796+
updates when both sample types are available.
797+
798+
Args:
799+
connection: Parent MulticompartmentConnection
800+
source_s: Source layer spikes [batch_size, source_neurons]
801+
target_s: Target layer spikes [batch_size, target_neurons]
802+
**kwargs: Additional arguments including 'sample_type'
803+
"""
804+
# Check if learning is enabled
805+
if not connection.w.requires_grad:
806+
return
807+
808+
# Get sample type from kwargs
809+
sample_type = kwargs.get('sample_type', self.current_sample_type)
810+
if sample_type is None:
811+
# Default to positive for backward compatibility
812+
sample_type = "positive"
813+
814+
# Compute goodness score for current batch
815+
current_goodness = self._compute_goodness(target_s)
816+
817+
# Store activations and goodness based on sample type
818+
if sample_type == "positive":
819+
self.positive_goodness = current_goodness.detach()
820+
self.positive_activations = {
821+
'source': source_s.detach(),
822+
'target': target_s.detach()
823+
}
824+
825+
elif sample_type == "negative":
826+
self.negative_goodness = current_goodness.detach()
827+
self.negative_activations = {
828+
'source': source_s.detach(),
829+
'target': target_s.detach()
830+
}
831+
832+
else:
833+
raise ValueError(f"Invalid sample_type: {sample_type}. Must be 'positive' or 'negative'")
834+
835+
self.samples_processed += 1
836+
837+
# Apply contrastive update if we have both positive and negative samples
838+
if (self.positive_goodness is not None and
839+
self.negative_goodness is not None and
840+
self.positive_activations is not None and
841+
self.negative_activations is not None):
842+
843+
self._apply_forward_forward_update(connection)
844+
self._reset_accumulated_data()
845+
846+
def _compute_goodness(self, target_activity: torch.Tensor) -> torch.Tensor:
847+
"""
848+
Compute Forward-Forward goodness score from target layer activity.
849+
850+
Args:
851+
target_activity: Target neuron spikes [batch_size, neurons]
852+
853+
Returns:
854+
Goodness scores [batch_size]
855+
"""
856+
if self.goodness_fn == "mean_squared":
857+
# Mean squared activity across neurons (original FF paper)
858+
goodness = torch.mean(target_activity ** 2, dim=1)
859+
860+
elif self.goodness_fn == "sum_squared":
861+
# Sum of squared activity across neurons
862+
goodness = torch.sum(target_activity ** 2, dim=1)
863+
864+
else:
865+
raise ValueError(f"Unknown goodness function: {self.goodness_fn}")
866+
867+
return goodness
868+
869+
def _apply_forward_forward_update(self, connection: 'MulticompartmentConnection'):
870+
"""
871+
Apply Forward-Forward contrastive weight update.
872+
873+
The update follows the Forward-Forward principle:
874+
- Strengthen weights that increase goodness for positive samples
875+
- Weaken weights that increase goodness for negative samples
876+
877+
Args:
878+
connection: Parent MulticompartmentConnection
879+
"""
880+
# Get weight tensor
881+
w = connection.w
882+
883+
# Compute Forward-Forward loss (for monitoring)
884+
ff_loss = self._compute_ff_loss(self.positive_goodness, self.negative_goodness)
885+
886+
# Compute weight update based on activity correlations
887+
pos_source = self.positive_activations['source']
888+
pos_target = self.positive_activations['target']
889+
neg_source = self.negative_activations['source']
890+
neg_target = self.negative_activations['target']
891+
892+
# Positive update: strengthen weights for positive samples
893+
# ΔW_pos = η * s_pos^T * t_pos / batch_size
894+
delta_w_pos = torch.mm(pos_source.t(), pos_target) / pos_source.shape[0]
895+
896+
# Negative update: weaken weights for negative samples
897+
# ΔW_neg = -η * s_neg^T * t_neg / batch_size
898+
delta_w_neg = -torch.mm(neg_source.t(), neg_target) / neg_source.shape[0]
899+
900+
# Combined Forward-Forward update
901+
delta_w = self.nu * (delta_w_pos + delta_w_neg)
902+
903+
# Add weight decay if specified
904+
if self.weight_decay > 0:
905+
delta_w = delta_w - self.weight_decay * w
906+
907+
# Apply momentum if specified
908+
if self.momentum > 0:
909+
if self.velocity is None:
910+
self.velocity = torch.zeros_like(w)
911+
912+
self.velocity = self.momentum * self.velocity + delta_w
913+
delta_w = self.velocity
914+
915+
# Apply weight update
916+
with torch.no_grad():
917+
w.add_(delta_w)
918+
919+
# Apply weight constraints if they exist
920+
self._apply_weight_constraints(connection)
921+
922+
def _compute_ff_loss(
923+
self,
924+
goodness_pos: torch.Tensor,
925+
goodness_neg: torch.Tensor
926+
) -> torch.Tensor:
927+
"""
928+
Compute Forward-Forward contrastive loss for monitoring.
929+
930+
L = log(1 + exp(-g_pos + α)) + log(1 + exp(g_neg - α))
931+
932+
Args:
933+
goodness_pos: Goodness scores for positive samples
934+
goodness_neg: Goodness scores for negative samples
935+
936+
Returns:
937+
Forward-Forward loss (scalar)
938+
"""
939+
# Positive loss: encourage high goodness for positive samples
940+
loss_pos = torch.log(1 + torch.exp(-goodness_pos + self.alpha_loss))
941+
942+
# Negative loss: encourage low goodness for negative samples
943+
loss_neg = torch.log(1 + torch.exp(goodness_neg - self.alpha_loss))
944+
945+
# Return mean loss across batch
946+
total_loss = loss_pos + loss_neg
947+
return torch.mean(total_loss)
948+
949+
def _apply_weight_constraints(self, connection: 'MulticompartmentConnection'):
950+
"""
951+
Apply weight constraints (bounds, normalization) if specified.
952+
953+
Args:
954+
connection: Parent connection with constraint parameters
955+
"""
956+
w = connection.w
957+
958+
# Apply weight bounds if specified
959+
if hasattr(connection, 'wmin') and hasattr(connection, 'wmax'):
960+
with torch.no_grad():
961+
w.clamp_(connection.wmin, connection.wmax)
962+
963+
# Apply normalization if specified
964+
if hasattr(connection, 'norm') and connection.norm is not None:
965+
with torch.no_grad():
966+
if connection.norm == "l2":
967+
# L2 normalize each output neuron's weights
968+
w.div_(w.norm(dim=0, keepdim=True) + 1e-8)
969+
elif connection.norm == "l1":
970+
# L1 normalize each output neuron's weights
971+
w.div_(w.abs().sum(dim=0, keepdim=True) + 1e-8)
972+
973+
def _reset_accumulated_data(self):
974+
"""Reset accumulated positive and negative sample data."""
975+
self.positive_goodness = None
976+
self.negative_goodness = None
977+
self.positive_activations = None
978+
self.negative_activations = None
979+
980+
def set_sample_type(self, sample_type: str):
981+
"""
982+
Set the current sample type for subsequent updates.
983+
984+
Args:
985+
sample_type: Either "positive" or "negative"
986+
"""
987+
if sample_type not in ["positive", "negative"]:
988+
raise ValueError(f"Invalid sample_type: {sample_type}")
989+
990+
self.current_sample_type = sample_type
991+
992+
def get_goodness_scores(self) -> dict:
993+
"""Get current goodness scores for positive and negative samples."""
994+
return {
995+
'positive_goodness': self.positive_goodness,
996+
'negative_goodness': self.negative_goodness
997+
}
998+
999+
def get_ff_loss(self) -> torch.Tensor:
1000+
"""Compute and return current Forward-Forward loss if data available."""
1001+
if self.positive_goodness is not None and self.negative_goodness is not None:
1002+
return self._compute_ff_loss(self.positive_goodness, self.negative_goodness)
1003+
else:
1004+
return torch.tensor(0.0)
1005+
1006+
def reset_state(self):
1007+
"""Reset all learning rule state."""
1008+
self._reset_accumulated_data()
1009+
self.velocity = None
1010+
self.current_sample_type = None
1011+
self.samples_processed = 0
1012+
1013+
def get_learning_stats(self) -> dict:
1014+
"""Get learning rule statistics and configuration."""
1015+
return {
1016+
'learning_rule_type': 'ForwardForwardMCCLearning',
1017+
'alpha_loss': self.alpha_loss,
1018+
'goodness_fn': self.goodness_fn,
1019+
'learning_rate': self.nu,
1020+
'momentum': self.momentum,
1021+
'weight_decay': self.weight_decay,
1022+
'samples_processed': self.samples_processed,
1023+
'current_sample_type': self.current_sample_type,
1024+
'has_positive_data': self.positive_goodness is not None,
1025+
'has_negative_data': self.negative_goodness is not None
1026+
}
1027+
1028+
def __repr__(self):
1029+
"""String representation of the learning rule."""
1030+
return (
1031+
f"ForwardForwardMCCLearning("
1032+
f"nu={self.nu}, "
1033+
f"alpha_loss={self.alpha_loss}, "
1034+
f"goodness_fn='{self.goodness_fn}', "
1035+
f"momentum={self.momentum}, "
1036+
f"weight_decay={self.weight_decay})"
1037+
)

bindsnet/network/topology.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2051,6 +2051,11 @@ def reset_state_variables(self) -> None:
20512051
super().reset_state_variables()
20522052

20532053

2054+
2055+
# designed for ANN
2056+
# the weights are not a true connection
2057+
# gradient
2058+
20542059
class ForwardForwardConnection(AbstractConnection):
20552060
"""
20562061
Connection class specifically designed for Forward-Forward training with arctangent surrogate gradients.
@@ -2176,4 +2181,4 @@ def reset_state_variables(self) -> None:
21762181
Contains resetting logic for the connection.
21772182
"""
21782183
super().reset_state_variables()
2179-
self.reset_membrane_potential()
2184+
self.reset_membrane_potential()

0 commit comments

Comments
 (0)