Skip to content

Commit 1142720

Browse files
authored
Merge pull request #646 from apdavison/issue645
Fix for #645
2 parents df49d69 + 3dad39e commit 1142720

File tree

4 files changed

+159
-7
lines changed

4 files changed

+159
-7
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
A demonstration of the use of callbacks to update the spike times in a SpikeSourceArray.
3+
4+
Usage: update_spike_source_array.py [-h] [--plot-figure] simulator
5+
6+
positional arguments:
7+
simulator neuron, nest, brian or another backend simulator
8+
9+
optional arguments:
10+
-h, --help show this help message and exit
11+
--plot-figure Plot the simulation results to a file.
12+
"""
13+
14+
import numpy as np
15+
from pyNN.utility import get_simulator, normalized_filename, ProgressBar
16+
from pyNN.utility.plotting import Figure, Panel
17+
from pyNN.parameters import Sequence
18+
19+
sim, options = get_simulator(("--plot-figure", "Plot the simulation results to a file.",
20+
{"action": "store_true"}))
21+
22+
rate_increment = 20
23+
interval = 200
24+
25+
26+
class SetRate(object):
27+
"""
28+
A callback which changes the firing rate of a population of spike
29+
sources at a fixed interval.
30+
"""
31+
32+
def __init__(self, population, rate_generator, update_interval=20.0):
33+
assert isinstance(population.celltype, sim.SpikeSourceArray)
34+
self.population = population
35+
self.update_interval = update_interval
36+
self.rate_generator = rate_generator
37+
38+
def __call__(self, t):
39+
try:
40+
rate = next(rate_generator)
41+
if rate > 0:
42+
isi = 1000.0/rate
43+
times = t + np.arange(0, self.update_interval, isi)
44+
# here each neuron fires with the same isi,
45+
# but there is a phase offset between neurons
46+
spike_times = [
47+
Sequence(times + phase * isi)
48+
for phase in self.population.annotations["phase"]
49+
]
50+
else:
51+
spike_times = []
52+
self.population.set(spike_times=spike_times)
53+
except StopIteration:
54+
pass
55+
return t + self.update_interval
56+
57+
58+
class MyProgressBar(object):
59+
"""
60+
A callback which draws a progress bar in the terminal.
61+
"""
62+
63+
def __init__(self, interval, t_stop):
64+
self.interval = interval
65+
self.t_stop = t_stop
66+
self.pb = ProgressBar(width=int(t_stop / interval), char=".")
67+
68+
def __call__(self, t):
69+
self.pb(t / self.t_stop)
70+
return t + self.interval
71+
72+
73+
sim.setup()
74+
75+
76+
# === Create a population of poisson processes ===============================
77+
78+
p = sim.Population(50, sim.SpikeSourceArray())
79+
p.annotate(phase=np.random.uniform(0, 1, size=p.size))
80+
p.record('spikes')
81+
82+
83+
# === Run the simulation, with two callback functions ========================
84+
85+
rate_generator = iter(range(0, 100, rate_increment))
86+
sim.run(1000, callbacks=[MyProgressBar(10.0, 1000.0),
87+
SetRate(p, rate_generator, interval)])
88+
89+
90+
# === Retrieve recorded data, and count the spikes in each interval ==========
91+
92+
data = p.get_data().segments[0]
93+
94+
all_spikes = np.hstack([st.magnitude for st in data.spiketrains])
95+
spike_counts = [((all_spikes >= x) & (all_spikes < x + interval)).sum()
96+
for x in range(0, 1000, interval)]
97+
expected_spike_counts = [p.size * rate * interval / 1000.0
98+
for rate in range(0, 100, rate_increment)]
99+
100+
print("\nActual spike counts: {}".format(spike_counts))
101+
print("Expected mean spike counts: {}".format(expected_spike_counts))
102+
103+
if options.plot_figure:
104+
Figure(
105+
Panel(data.spiketrains, xlabel="Time (ms)", xticks=True, markersize=0.5),
106+
title="Incrementally updated SpikeSourceArrays",
107+
annotations="Simulated with %s" % options.simulator.upper()
108+
).save(normalized_filename("Results", "update_spike_source_array", "png", options.simulator))
109+
110+
sim.end()

pyNN/neuron/cells.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -581,10 +581,12 @@ class VectorSpikeSource(hclass(h.VecStim)):
581581
parameter_names = ('spike_times',)
582582

583583
def __init__(self, spike_times=[]):
584+
self.recording = False
584585
self.spike_times = spike_times
585586
self.source = self
586587
self.source_section = None
587588
self.rec = None
589+
self._recorded_spikes = numpy.array([])
588590

589591
def _set_spike_times(self, spike_times):
590592
# spike_times should be a Sequence object
@@ -595,18 +597,32 @@ def _set_spike_times(self, spike_times):
595597
if numpy.any(spike_times.value[:-1] > spike_times.value[1:]):
596598
raise errors.InvalidParameterValueError("Spike times given to SpikeSourceArray must be in increasing order")
597599
self.play(self._spike_times)
600+
if self.recording:
601+
self._recorded_spikes = numpy.hstack((self._recorded_spikes, spike_times.value))
598602

599603
def _get_spike_times(self):
600604
return self._spike_times
601605

602606
spike_times = property(fget=_get_spike_times,
603607
fset=_set_spike_times)
604608

609+
@property
610+
def recording(self):
611+
return self._recording
612+
613+
@recording.setter
614+
def recording(self, value):
615+
self._recording = value
616+
if value:
617+
# when we turn recording on, the cell may already have had its spike times assigned
618+
self._recorded_spikes = numpy.hstack((self._recorded_spikes, self.spike_times))
619+
620+
def get_recorded_spike_times(self):
621+
return self._recorded_spikes
622+
605623
def clear_past_spikes(self):
606624
"""If previous recordings are cleared, need to remove spikes from before the current time."""
607-
end = self._spike_times.indwhere(">", h.t)
608-
if end > 0:
609-
self._spike_times.remove(0, end - 1) # range is inclusive
625+
self._recorded_spikes = self._recorded_spikes[self._recorded_spikes > h.t]
610626

611627

612628
class ArtificialCell(object):

pyNN/neuron/recording.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ def _record(self, variable, new_ids, sampling_interval=None):
2424
for id in new_ids:
2525
if id._cell.rec is not None:
2626
id._cell.rec.record(id._cell.spike_times)
27+
else: # SpikeSourceArray
28+
id._cell.recording = True
2729
else:
2830
self.sampling_interval = sampling_interval or self._simulator.state.dt
2931
for id in new_ids:
@@ -97,7 +99,10 @@ def _get_spiketimes(self, id):
9799
if hasattr(id, "__len__"):
98100
all_spiketimes = {}
99101
for cell_id in id:
100-
spikes = numpy.array(cell_id._cell.spike_times)
102+
if cell_id._cell.rec is None: # SpikeSourceArray
103+
spikes = cell_id._cell.get_recorded_spike_times()
104+
else:
105+
spikes = numpy.array(cell_id._cell.spike_times)
101106
all_spiketimes[cell_id] = spikes[spikes <= simulator.state.t + 1e-9]
102107
return all_spiketimes
103108
else:

test/system/scenarios/test_cell_types.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
have_scipy = True
88
except ImportError:
99
have_scipy = False
10+
from numpy.testing import assert_array_equal
1011
import quantities as pq
1112
from nose.tools import assert_greater, assert_less, assert_raises
13+
from pyNN.parameters import Sequence
1214
from pyNN.errors import InvalidParameterValueError
1315

1416
from .registry import register
@@ -31,7 +33,7 @@ def test_EIF_cond_alpha_isfa_ista(sim, plot_figure=False):
3133
plt.plot(expected_spike_times, -40 * numpy.ones_like(expected_spike_times), "ro")
3234
plt.savefig("test_EIF_cond_alpha_isfa_ista_%s.png" % sim.__name__)
3335
diff = (data.spiketrains[0].rescale(pq.ms).magnitude - expected_spike_times) / expected_spike_times
34-
assert abs(diff).max() < 0.01, abs(diff).max()
36+
assert abs(diff).max() < 0.01, abs(diff).max()
3537
sim.end()
3638
return data
3739
test_EIF_cond_alpha_isfa_ista.__test__ = False
@@ -262,7 +264,6 @@ def test_SpikeSourcePoissonRefractory(sim, plot_figure=False):
262264
test_SpikeSourcePoissonRefractory.__test__ = False
263265

264266

265-
266267
@register()
267268
def issue511(sim):
268269
"""Giving SpikeSourceArray an array of non-ordered spike times should produce an InvalidParameterValueError error"""
@@ -271,6 +272,25 @@ def issue511(sim):
271272
assert_raises(InvalidParameterValueError, sim.Population, 2, celltype)
272273

273274

275+
@register()
276+
def test_update_SpikeSourceArray(sim, plot_figure=False):
277+
sim.setup()
278+
sources = sim.Population(2, sim.SpikeSourceArray(spike_times=[]))
279+
sources.record('spikes')
280+
sim.run(10.0)
281+
sources.set(spike_times=[
282+
Sequence([12, 15, 18]),
283+
Sequence([17, 19])
284+
])
285+
sim.run(10.0)
286+
sources.set(spike_times=[
287+
Sequence([22, 25]),
288+
Sequence([23, 27, 29])
289+
])
290+
sim.run(10.0)
291+
data = sources.get_data().segments[0].spiketrains
292+
assert_array_equal(data[0].magnitude, numpy.array([12, 15, 18, 22, 25]))
293+
test_update_SpikeSourceArray.__test__ = False
274294

275295
# todo: add test of Izhikevich model
276296

@@ -286,4 +306,5 @@ def issue511(sim):
286306
test_SpikeSourcePoisson(sim, plot_figure=args.plot_figure)
287307
test_SpikeSourceGamma(sim, plot_figure=args.plot_figure)
288308
test_SpikeSourcePoissonRefractory(sim, plot_figure=args.plot_figure)
289-
issue511(sim)
309+
issue511(sim)
310+
test_update_SpikeSourceArray(sim, plot_figure=args.plot_figure)

0 commit comments

Comments
 (0)