Skip to content

Commit 6b6e4b5

Browse files
committed
Refactor code for improved readability and consistency in tensor handling across multiple files
1 parent 988cb91 commit 6b6e4b5

File tree

8 files changed

+100
-82
lines changed

8 files changed

+100
-82
lines changed

bindsnet/evaluation/evaluation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def assign_labels(
4444
indices = torch.nonzero(labels == i).view(-1)
4545

4646
# Compute average firing rates for this label.
47-
selected_spikes = torch.index_select(spikes, dim=0, index=torch.tensor(indices))
47+
selected_spikes = torch.index_select(
48+
spikes, dim=0, index=torch.tensor(indices)
49+
)
4850
rates[:, i] = alpha * rates[:, i] + (
4951
torch.sum(selected_spikes, 0) / n_labeled
5052
)

bindsnet/learning/MCC_learning.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def update(self, **kwargs) -> None:
103103
self, NoOp
104104
):
105105
if self.feature_value.is_sparse:
106-
self.feature_value = self.feature_value.to_dense().clamp_(self.min, self.max).to_sparse()
106+
self.feature_value = (
107+
self.feature_value.to_dense().clamp_(self.min, self.max).to_sparse()
108+
)
107109
else:
108110
self.feature_value.clamp_(self.min, self.max)
109111

@@ -252,8 +254,7 @@ def _connection_update(self, **kwargs) -> None:
252254
else:
253255
if self.feature_value.is_sparse:
254256
self.feature_value -= (
255-
torch.bmm(source_s, target_x)
256-
* self.connection.dt
257+
torch.bmm(source_s, target_x) * self.connection.dt
257258
).to_sparse()
258259
else:
259260
self.feature_value -= (
@@ -289,8 +290,7 @@ def _connection_update(self, **kwargs) -> None:
289290
else:
290291
if self.feature_value.is_sparse:
291292
self.feature_value += (
292-
torch.bmm(source_x, target_s)
293-
* self.connection.dt
293+
torch.bmm(source_x, target_s) * self.connection.dt
294294
).to_sparse()
295295
else:
296296
self.feature_value += (
@@ -524,9 +524,7 @@ def _connection_update(self, **kwargs) -> None:
524524
) % self.average_update
525525

526526
if self.continues_update or self.average_buffer_index == 0:
527-
update = self.nu[0] * torch.mean(
528-
self.average_buffer, dim=0
529-
)
527+
update = self.nu[0] * torch.mean(self.average_buffer, dim=0)
530528
if self.feature_value.is_sparse:
531529
update = update.to_sparse()
532530
self.feature_value += update

bindsnet/models/models.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
from bindsnet.learning.MCC_learning import PostPre as MMCPostPre
1111
from bindsnet.network import Network
1212
from bindsnet.network.nodes import DiehlAndCookNodes, Input, LIFNodes
13-
from bindsnet.network.topology import Connection, LocalConnection, MulticompartmentConnection
13+
from bindsnet.network.topology import (
14+
Connection,
15+
LocalConnection,
16+
MulticompartmentConnection,
17+
)
1418
from bindsnet.network.topology_features import Weight
1519

1620

@@ -182,17 +186,17 @@ def __init__(
182186
device=device,
183187
pipeline=[
184188
Weight(
185-
'weight',
189+
"weight",
186190
w,
187191
range=[wmin, wmax],
188192
norm=norm,
189193
reduction=reduction,
190194
nu=nu,
191195
learning_rule=MMCPostPre,
192196
sparse=sparse,
193-
batch_size=batch_size
197+
batch_size=batch_size,
194198
)
195-
]
199+
],
196200
)
197201
w = self.exc * torch.diag(torch.ones(self.n_neurons))
198202
if sparse:
@@ -201,14 +205,7 @@ def __init__(
201205
source=exc_layer,
202206
target=inh_layer,
203207
device=device,
204-
pipeline=[
205-
Weight(
206-
'weight',
207-
w,
208-
range=[0, self.exc],
209-
sparse=sparse
210-
)
211-
]
208+
pipeline=[Weight("weight", w, range=[0, self.exc], sparse=sparse)],
212209
)
213210
w = -self.inh * (
214211
torch.ones(self.n_neurons, self.n_neurons)
@@ -220,14 +217,7 @@ def __init__(
220217
source=inh_layer,
221218
target=exc_layer,
222219
device=device,
223-
pipeline=[
224-
Weight(
225-
'weight',
226-
w,
227-
range=[-self.inh, 0],
228-
sparse=sparse
229-
)
230-
]
220+
pipeline=[Weight("weight", w, range=[-self.inh, 0], sparse=sparse)],
231221
)
232222

233223
# Add to network

bindsnet/network/monitors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(
4545
time: Optional[int] = None,
4646
batch_size: int = 1,
4747
device: str = "cpu",
48-
sparse: Optional[bool] = False
48+
sparse: Optional[bool] = False,
4949
):
5050
# language=rst
5151
"""
@@ -100,9 +100,9 @@ def record(self) -> None:
100100
for v in self.state_vars:
101101
data = getattr(self.obj, v).unsqueeze(0)
102102
# self.recording[v].append(data.detach().clone().to(self.device))
103-
record = torch.empty_like(data, device=self.device, requires_grad=False).copy_(
104-
data, non_blocking=True
105-
)
103+
record = torch.empty_like(
104+
data, device=self.device, requires_grad=False
105+
).copy_(data, non_blocking=True)
106106
if self.sparse:
107107
record = record.to_sparse()
108108
self.recording[v].append(record)

bindsnet/network/topology_features.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,9 @@ def __init__(
134134
self.value = self.value.unsqueeze(0).repeat(self.batch_size, 1, 1)
135135

136136
self.value = self.value.to_sparse()
137-
assert not getattr(self, 'enforce_polarity', False), \
138-
"enforce_polarity isn't supported for sparse tensors"
137+
assert not getattr(
138+
self, "enforce_polarity", False
139+
), "enforce_polarity isn't supported for sparse tensors"
139140

140141
@abstractmethod
141142
def reset_state_variables(self) -> None:
@@ -179,9 +180,15 @@ def prime_feature(self, connection, device, **kwargs) -> None:
179180
# Check if values/norms are the correct shape
180181
if isinstance(self.value, torch.Tensor):
181182
if self.sparse:
182-
assert tuple(self.value.shape[1:]) == (connection.source.n, connection.target.n)
183+
assert tuple(self.value.shape[1:]) == (
184+
connection.source.n,
185+
connection.target.n,
186+
)
183187
else:
184-
assert tuple(self.value.shape) == (connection.source.n, connection.target.n)
188+
assert tuple(self.value.shape) == (
189+
connection.source.n,
190+
connection.target.n,
191+
)
185192

186193
if self.norm is not None and isinstance(self.norm, torch.Tensor):
187194
assert self.norm.shape[0] == connection.target.n
@@ -325,10 +332,12 @@ def assert_feature_in_range(self):
325332

326333
def assert_valid_shape(self, source_shape, target_shape, f):
327334
# Multidimensional feat
328-
if (not self.sparse and len(f.shape) > 1) or (self.sparse and len(f.shape[1:]) > 1):
335+
if (not self.sparse and len(f.shape) > 1) or (
336+
self.sparse and len(f.shape[1:]) > 1
337+
):
329338
if self.sparse:
330339
f_shape = f.shape[1:]
331-
expected = ('batch_size', source_shape, target_shape)
340+
expected = ("batch_size", source_shape, target_shape)
332341
else:
333342
f_shape = f.shape
334343
expected = (source_shape, target_shape)
@@ -352,7 +361,7 @@ def __init__(
352361
decay: float = 0.0,
353362
parent_feature=None,
354363
sparse: Optional[bool] = False,
355-
batch_size: int = 1
364+
batch_size: int = 1,
356365
) -> None:
357366
# language=rst
358367
"""
@@ -386,19 +395,15 @@ def __init__(
386395
decay=decay,
387396
parent_feature=parent_feature,
388397
sparse=sparse,
389-
batch_size=batch_size
398+
batch_size=batch_size,
390399
)
391400

392401
def sparse_bernoulli(self):
393402
values = torch.bernoulli(self.value.values())
394403
mask = values != 0
395404
indices = self.value.indices()[:, mask]
396405
non_zero = values[mask]
397-
return torch.sparse_coo_tensor(
398-
indices,
399-
non_zero,
400-
self.value.size()
401-
)
406+
return torch.sparse_coo_tensor(indices, non_zero, self.value.size())
402407

403408
def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
404409
if self.sparse:
@@ -448,7 +453,7 @@ def __init__(
448453
name: str,
449454
value: Union[torch.Tensor, float, int] = None,
450455
sparse: Optional[bool] = False,
451-
batch_size: int = 1
456+
batch_size: int = 1,
452457
) -> None:
453458
# language=rst
454459
"""
@@ -472,12 +477,7 @@ def __init__(
472477
# Send boolean to tensor (priming wont work if it's not a tensor)
473478
value = torch.tensor(value)
474479

475-
super().__init__(
476-
name=name,
477-
value=value,
478-
sparse=sparse,
479-
batch_size=batch_size
480-
)
480+
super().__init__(name=name, value=value, sparse=sparse, batch_size=batch_size)
481481

482482
def compute(self, conn_spikes) -> torch.Tensor:
483483
return conn_spikes * self.value
@@ -561,7 +561,7 @@ def __init__(
561561
enforce_polarity: Optional[bool] = False,
562562
decay: float = 0.0,
563563
sparse: Optional[bool] = False,
564-
batch_size: int = 1
564+
batch_size: int = 1,
565565
) -> None:
566566
# language=rst
567567
"""
@@ -596,7 +596,7 @@ def __init__(
596596
reduction=reduction,
597597
decay=decay,
598598
sparse=sparse,
599-
batch_size=batch_size
599+
batch_size=batch_size,
600600
)
601601

602602
def reset_state_variables(self) -> None:
@@ -651,7 +651,7 @@ def __init__(
651651
range: Optional[Sequence[float]] = None,
652652
norm: Optional[Union[torch.Tensor, float, int]] = None,
653653
sparse: Optional[bool] = False,
654-
batch_size: int = 1
654+
batch_size: int = 1,
655655
) -> None:
656656
# language=rst
657657
"""
@@ -671,7 +671,7 @@ def __init__(
671671
range=[-torch.inf, +torch.inf] if range is None else range,
672672
norm=norm,
673673
sparse=sparse,
674-
batch_size=batch_size
674+
batch_size=batch_size,
675675
)
676676

677677
def reset_state_variables(self) -> None:
@@ -697,7 +697,7 @@ def __init__(
697697
value: Union[torch.Tensor, float, int] = None,
698698
range: Optional[Sequence[float]] = None,
699699
sparse: Optional[bool] = False,
700-
batch_size: int = 1
700+
batch_size: int = 1,
701701
) -> None:
702702
# language=rst
703703
"""
@@ -708,7 +708,9 @@ def __init__(
708708
:param batch_size: Mini-batch size.
709709
"""
710710

711-
super().__init__(name=name, value=value, range=range, sparse=sparse, batch_size=batch_size)
711+
super().__init__(
712+
name=name, value=value, range=range, sparse=sparse, batch_size=batch_size
713+
)
712714

713715
def reset_state_variables(self) -> None:
714716
pass
@@ -738,7 +740,7 @@ def __init__(
738740
degrade_function: callable = None,
739741
parent_feature: Optional[AbstractFeature] = None,
740742
sparse: Optional[bool] = False,
741-
batch_size: int = 1
743+
batch_size: int = 1,
742744
) -> None:
743745
# language=rst
744746
"""
@@ -754,7 +756,13 @@ def __init__(
754756
"""
755757

756758
# Note: parent_feature will override value. See abstract constructor
757-
super().__init__(name=name, value=value, parent_feature=parent_feature, sparse=sparse, batch_size=batch_size)
759+
super().__init__(
760+
name=name,
761+
value=value,
762+
parent_feature=parent_feature,
763+
sparse=sparse,
764+
batch_size=batch_size,
765+
)
758766

759767
self.degrade_function = degrade_function
760768

@@ -774,7 +782,7 @@ def __init__(
774782
const_update_rate: float = 0.1,
775783
const_decay: float = 0.001,
776784
sparse: Optional[bool] = False,
777-
batch_size: int = 1
785+
batch_size: int = 1,
778786
) -> None:
779787
# language=rst
780788
"""
@@ -833,7 +841,9 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
833841
flat_conn_spikes = conn_spikes.to_dense().flatten()
834842
else:
835843
flat_conn_spikes = conn_spikes.flatten()
836-
self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = flat_conn_spikes
844+
self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = (
845+
flat_conn_spikes
846+
)
837847
self.counter += 1
838848

839849
# Update the masks
@@ -872,7 +882,7 @@ def __init__(
872882
const_update_rate: float = 0.1,
873883
const_decay: float = 0.01,
874884
sparse: Optional[bool] = False,
875-
batch_size: int = 1
885+
batch_size: int = 1,
876886
) -> None:
877887
# language=rst
878888
"""
@@ -931,7 +941,9 @@ def compute(self, conn_spikes) -> Union[torch.Tensor, float, int]:
931941
flat_conn_spikes = conn_spikes.to_dense().flatten()
932942
else:
933943
flat_conn_spikes = conn_spikes.flatten()
934-
self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = flat_conn_spikes
944+
self.spike_buffer[:, self.counter % self.spike_buffer.shape[1]] = (
945+
flat_conn_spikes
946+
)
935947
self.counter += 1
936948

937949
# Update the masks

0 commit comments

Comments
 (0)