|
118 | 118 | # Sequence of accuracy estimates. |
119 | 119 | accuracy = {"all": [], "proportion": []} |
120 | 120 |
|
| 121 | +# Labels to determine neuron assignments and spike proportions and estimate accuracy |
| 122 | +labels = torch.empty(update_interval) |
| 123 | + |
121 | 124 | spikes = {} |
122 | 125 | for layer in set(network.layers): |
123 | 126 | spikes[layer] = Monitor(network.layers[layer], state_vars=["s"], time=time) |
|
154 | 157 |
|
155 | 158 | # Compute network accuracy according to available classification strategies. |
156 | 159 | accuracy["all"].append( |
157 | | - 100 * torch.sum(label.long() == all_activity_pred).item() / update_interval |
| 160 | + 100 * torch.sum(labels.long() == all_activity_pred).item() / update_interval |
158 | 161 | ) |
159 | 162 | accuracy["proportion"].append( |
160 | | - 100 * torch.sum(label.long() == proportion_pred).item() / update_interval |
| 163 | + 100 * torch.sum(labels.long() == proportion_pred).item() / update_interval |
161 | 164 | ) |
162 | 165 |
|
163 | 166 | print( |
|
174 | 177 | ) |
175 | 178 |
|
176 | 179 | # Assign labels to excitatory layer neurons. |
177 | | - assignments, proportions, rates = assign_labels(spike_record, label, 10, rates) |
| 180 | + assignments, proportions, rates = assign_labels(spike_record, labels, 10, rates) |
| 181 | + |
| 182 | + #Add the current label to the list of labels for this update_interval |
| 183 | + labels[i % update_interval] = label[0] |
178 | 184 |
|
179 | 185 | # Run the network on the input. |
180 | 186 | choice = np.random.choice(int(n_neurons / 10), size=n_clamp, replace=False) |
181 | 187 | clamp = {"Ae": per_class * label.long() + torch.Tensor(choice).long()} |
182 | | - inputs = {"X": image.view(time, 1, 28, 28)} |
183 | | - network.run(inputs=inputs, time=time, clamp=clamp) |
| 188 | + inputs = {"X": image.view(time, 1, 1, 28, 28)} |
| 189 | + network.run(inpts=inputs, time=time, clamp=clamp) |
184 | 190 |
|
185 | 191 | # Get voltage recording. |
186 | 192 | exc_voltages = exc_voltage_monitor.get("v") |
|
203 | 209 | image.sum(1).view(28, 28), inpt, label=label, axes=inpt_axes, ims=inpt_ims |
204 | 210 | ) |
205 | 211 | spike_ims, spike_axes = plot_spikes( |
206 | | - {layer: spikes[layer].get("s") for layer in spikes}, |
| 212 | + {layer: spikes[layer].get("s").view(time, 1, -1) for layer in spikes}, |
207 | 213 | ims=spike_ims, |
208 | 214 | axes=spike_axes, |
209 | 215 | ) |
|
0 commit comments