Skip to content

Commit 6022dcc

Browse files
author
Kevin Chang
committed
test'
1 parent f7c2993 commit 6022dcc

File tree

1 file changed

+0
-144
lines changed

1 file changed

+0
-144
lines changed

bindsnet/models/models.py

Lines changed: 0 additions & 144 deletions
Original file line numberDiff line numberDiff line change
@@ -549,150 +549,6 @@ def __init__(
549549
self.add_connection(recurrent_conn, source="Y", target="Y")
550550

551551

552-
import snntorch as snn
553-
554-
class FFSNN(nn.Module):
555-
# language=rst
556-
"""
557-
A simple feedforward Spiking Neural Network (SNN) using snntorch,
558-
designed for use with the ForwardForwardPipeline.
559-
It consists of a sequence of Linear layers followed by Leaky Integrate-and-Fire
560-
(LIF) spiking neuron layers.
561-
"""
562-
563-
def __init__(
564-
self,
565-
input_size: int, # This should be 794 (image_features + num_classes)
566-
hidden_sizes: List[int], # e.g., [500, 500]
567-
output_size: Optional[int] = None, # If the last FF layer is also the output layer for classification
568-
beta: Union[float, torch.Tensor] = 0.9, # Decay rate for snn.Leaky neurons
569-
threshold: float = 1.0, # Firing threshold for snn.Leaky neurons
570-
reset_mechanism: str = "subtract", # "subtract", "zero", or "none"
571-
# Add other snn.Leaky parameters if needed, e.g., spike_grad
572-
) -> None:
573-
# language=rst
574-
"""
575-
Constructor for FFSNN.
576-
577-
:param input_size: Number of input features (after encoding and label embedding).
578-
:param hidden_sizes: A list of integers, where each integer is the number of
579-
neurons in a hidden layer.
580-
:param output_size: Optional. Number of neurons in the final layer if it's
581-
distinct or specifically for output. If None, the last
582-
size in hidden_sizes is considered the final FF layer.
583-
:param beta: Membrane potential decay rate for Leaky neurons.
584-
:param threshold: Firing threshold for Leaky neurons.
585-
:param reset_mechanism: Reset mechanism for Leaky neurons after a spike.
586-
"""
587-
super().__init__()
588-
589-
self.input_size = input_size
590-
self.hidden_sizes = hidden_sizes
591-
self.output_size = output_size
592-
self.beta = beta
593-
self.threshold = threshold
594-
self.reset_mechanism = reset_mechanism
595-
596-
self.fc_layers = nn.ModuleList()
597-
self.snn_layers = nn.ModuleList()
598-
self._ff_layer_pairs_info = []
599-
600-
current_dim = self.input_size # Starts at 794
601-
for i, hidden_dim in enumerate(self.hidden_sizes):
602-
linear_layer = nn.Linear(current_dim, hidden_dim) # Layer 1: 794 -> 500
603-
# Layer 2: 500 -> 500
604-
self.fc_layers.append(linear_layer)
605-
606-
snn_layer = snn.Leaky(
607-
beta=self.beta,
608-
threshold=self.threshold,
609-
reset_mechanism=self.reset_mechanism,
610-
# output_shape=[hidden_dim] # Optional: snntorch can infer this
611-
)
612-
self.snn_layers.append(snn_layer)
613-
self._ff_layer_pairs_info.append((linear_layer, snn_layer))
614-
current_dim = hidden_dim # Update current_dim for the *next* layer's input
615-
616-
# If there's an output_size for a final classifier (not typical for pure FF layers)
617-
if self.output_size is not None:
618-
self.fc_out = nn.Linear(current_dim, self.output_size)
619-
# Potentially another SNN layer if output is spiking
620-
# self.snn_out = snn.Leaky(...)
621-
# self._ff_layer_pairs_info.append((self.fc_out, self.snn_out)) # If FF applies here too
622-
623-
def forward(self, x_batch_time: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
624-
# language=rst
625-
"""
626-
Defines the forward pass of the SNN over time.
627-
This method might be used if the network is called directly with time-series data.
628-
However, the ForwardForwardPipeline._run_snn_batch currently iterates
629-
through self.network_sequence modules per time step.
630-
631-
:param x_batch_time: Input tensor with shape [batch_size, time_steps, num_features].
632-
:return: Final layer output spikes and a list of hidden states (membrane potentials)
633-
from all spiking layers.
634-
"""
635-
# Initialize hidden states for all spiking layers in the sequence
636-
# This assumes they are snn.Leaky and support init_leaky() or similar
637-
# Or, more generally, that they initialize if mem is None on first call.
638-
639-
spiking_layer_modules = [info['spiking'] for info in self._ff_layer_pairs_info]
640-
# mem_states = [layer.init_leaky() for layer in spiking_layer_modules] # This creates new states
641-
# For snntorch, typically pass None for initial state, layer handles it.
642-
643-
# The pipeline's _run_snn_batch actually handles the time loop and state passing.
644-
# This forward method is more for standalone use or if the pipeline changes.
645-
# If you want this model to be directly callable with (B, T, F) input and manage its own time loop:
646-
647-
batch_size = x_batch_time.shape[0]
648-
649-
# Initialize states for each spiking layer for this batch
650-
# This is tricky because snn.Leaky.init_hidden() doesn't take batch_size.
651-
# State initialization is usually handled by passing None to the layer's forward method
652-
# for the first time step, and it initializes based on the input batch size.
653-
654-
# Placeholder: The pipeline's _run_snn_batch is the primary runner.
655-
# This forward method would need a more elaborate state management if used directly.
656-
# For now, let's make it compatible with how _run_snn_batch works if it were to call this.
657-
658-
# If this 'forward' is to be used, it should mirror _run_snn_batch's logic:
659-
# Initialize all spiking layer states to None
660-
spiking_layer_states = {module: None for module in self.network_sequence if isinstance(module, snn.SpikingNeuron)}
661-
662-
# Record outputs if needed (e.g., for a final classification layer not part of FF)
663-
# final_spk_rec = [] # If you want to record output spikes over time
664-
665-
for t in range(x_batch_time.shape[1]): # Iterate over time
666-
x_t = x_batch_time[:, t, :]
667-
layer_input = x_t
668-
669-
current_module_idx = 0
670-
for module in self.network_sequence:
671-
if isinstance(module, snn.SpikingNeuron):
672-
spk_out, new_mem = module(layer_input, spiking_layer_states.get(module))
673-
spiking_layer_states[module] = new_mem
674-
layer_input = spk_out
675-
else: # nn.Linear
676-
layer_input = module(layer_input)
677-
# After passing through all layers for time step t, layer_input is the output of the last layer
678-
# final_spk_rec.append(layer_input)
679-
680-
# return torch.stack(final_spk_rec, dim=1), [state for state in spiking_layer_states.values()]
681-
return layer_input, [spiking_layer_states[info['spiking']] for info in self._ff_layer_pairs_info]
682-
683-
684-
def get_ff_layer_pairs(self) -> List[Tuple[nn.Linear, snn.SpikingNeuron]]:
685-
# language=rst
686-
"""
687-
Returns the list of (Linear, SpikingNeuron) pairs for Forward-Forward training.
688-
"""
689-
# If _ff_layer_pairs_info contains tuples, return them directly
690-
return self._ff_layer_pairs_info
691-
692-
# Alternative: If you want to be explicit about the structure
693-
# return [(pair[0], pair[1]) for pair in self._ff_layer_pairs_info]
694-
695-
696552
class FFSNN_BindsNET(Network):
697553
# language=rst
698554
"""

0 commit comments

Comments
 (0)