Skip to content

Commit 307ee5f

Browse files
Merge pull request #354 from BindsNET/hananel
Hananel
2 parents fb8fb08 + 3ed2da5 commit 307ee5f

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

bindsnet/analysis/plotting.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -468,15 +468,15 @@ def plot_voltages(
468468

469469
if time is None:
470470
for key in voltages.keys():
471-
time = (0, voltages[key].size(-1))
471+
time = (0, voltages[key].size(0))
472472
break
473473

474474
if n_neurons is None:
475475
n_neurons = {}
476476

477477
for key, val in voltages.items():
478478
if key not in n_neurons.keys():
479-
n_neurons[key] = (0, val.size(0))
479+
n_neurons[key] = (0, val.size(1))
480480

481481
if not ims:
482482
fig, axes = plt.subplots(n_subplots, 1, figsize=figsize)
@@ -538,8 +538,8 @@ def plot_voltages(
538538
v[1]
539539
.cpu()
540540
.numpy()[
541+
time[0]: time[1],
541542
n_neurons[v[0]][0] : n_neurons[v[0]][1],
542-
time[0] : time[1],
543543
]
544544
)
545545
)
@@ -557,8 +557,8 @@ def plot_voltages(
557557
v[1]
558558
.cpu()
559559
.numpy()[
560+
time[0]: time[1],
560561
n_neurons[v[0]][0] : n_neurons[v[0]][1],
561-
time[0] : time[1],
562562
]
563563
.T,
564564
cmap=cmap,
@@ -621,7 +621,8 @@ def plot_voltages(
621621
v[1]
622622
.cpu()
623623
.numpy()[
624-
n_neurons[v[0]][0] : n_neurons[v[0]][1], time[0] : time[1]
624+
time[0]: time[1],
625+
n_neurons[v[0]][0]: n_neurons[v[0]][1],
625626
]
626627
)
627628
if thresholds is not None and thresholds[v[0]].size() == torch.Size(
@@ -635,7 +636,8 @@ def plot_voltages(
635636
v[1]
636637
.cpu()
637638
.numpy()[
638-
n_neurons[v[0]][0] : n_neurons[v[0]][1], time[0] : time[1]
639+
time[0]: time[1],
640+
n_neurons[v[0]][0]: n_neurons[v[0]][1],
639641
]
640642
.T,
641643
cmap=cmap,

examples/mnist/supervised_mnist.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,9 @@
118118
# Sequence of accuracy estimates.
119119
accuracy = {"all": [], "proportion": []}
120120

121+
# Labels to determine neuron assignments and spike proportions and estimate accuracy
122+
labels = torch.empty(update_interval)
123+
121124
spikes = {}
122125
for layer in set(network.layers):
123126
spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time)
@@ -154,10 +157,10 @@
154157

155158
# Compute network accuracy according to available classification strategies.
156159
accuracy["all"].append(
157-
100 * torch.sum(label.long() == all_activity_pred).item() / update_interval
160+
100 * torch.sum(labels.long() == all_activity_pred).item() / update_interval
158161
)
159162
accuracy["proportion"].append(
160-
100 * torch.sum(label.long() == proportion_pred).item() / update_interval
163+
100 * torch.sum(labels.long() == proportion_pred).item() / update_interval
161164
)
162165

163166
print(
@@ -174,13 +177,16 @@
174177
)
175178

176179
# Assign labels to excitatory layer neurons.
177-
assignments, proportions, rates = assign_labels(spike_record, label, 10, rates)
180+
assignments, proportions, rates = assign_labels(spike_record, labels, 10, rates)
181+
182+
#Add the current label to the list of labels for this update_interval
183+
labels[i % update_interval] = label[0]
178184

179185
# Run the network on the input.
180186
choice = np.random.choice(int(n_neurons / 10), size=n_clamp, replace=False)
181187
clamp = {"Ae": per_class * label.long() + torch.Tensor(choice).long()}
182-
inputs = {"X": image.view(time, 1, 28, 28)}
183-
network.run(inputs=inputs, time=time, clamp=clamp)
188+
inputs = {"X": image.view(time, 1, 1, 28, 28)}
189+
network.run(inpts=inputs, time=time, clamp=clamp)
184190

185191
# Get voltage recording.
186192
exc_voltages = exc_voltage_monitor.get("v")
@@ -203,7 +209,7 @@
203209
image.sum(1).view(28, 28), inpt, label=label, axes=inpt_axes, ims=inpt_ims
204210
)
205211
spike_ims, spike_axes = plot_spikes(
206-
{layer: spikes[layer].get("s") for layer in spikes},
212+
{layer: spikes[layer].get("s").view(time, 1, -1) for layer in spikes},
207213
ims=spike_ims,
208214
axes=spike_axes,
209215
)

0 commit comments

Comments
 (0)