From 3dc57290dbde0aeaa5048f2301ee75015a93fe26 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Dec 2025 15:43:44 +0100 Subject: [PATCH 1/3] Test IBL extractors tests failing for PI update --- src/spikeinterface/extractors/tests/test_iblextractors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 972a8e7bb0..56d01e38cf 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -76,8 +76,8 @@ def test_offsets(self): def test_probe_representation(self): probe = self.recording.get_probe() - expected_probe_representation = "Probe - 384ch - 1shanks" - assert repr(probe) == expected_probe_representation + expected_probe_representation = "Probe - 384ch" + assert expected_probe_representation in repr(probe) def test_property_keys(self): expected_property_keys = [ From 359b68bd234fa06b6e07b8037b72ae64a7801480 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 9 Apr 2026 17:38:35 +0200 Subject: [PATCH 2/3] Implement get_unit_spike_trains function --- src/spikeinterface/core/basesorting.py | 208 +++++++++++++++++- .../core/unitsselectionsorting.py | 23 +- 2 files changed, 220 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index cb68f3d455..4e0e1fed52 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -180,6 +180,7 @@ def get_unit_spike_train( segment_index=segment_index, start_time=start_time, end_time=end_time, + use_cache=use_cache, ) segment_index = self._check_segment_index(segment_index) @@ -212,6 +213,7 @@ def get_unit_spike_train_in_seconds( segment_index: int | None = None, start_time: float | None = None, end_time: float | None = None, + use_cache: bool = True, ) -> np.ndarray: """ Get spike train for a unit in seconds. @@ -236,6 +238,10 @@ def get_unit_spike_train_in_seconds( The start time in seconds for spike train extraction end_time : float or None, default: None The end time in seconds for spike train extraction + return_times : bool, default: False + If True, returns spike times in seconds instead of frames + use_cache : bool, default: True + If True, then precompute (or use) the to_reordered_spike_vector using Returns ------- @@ -258,7 +264,7 @@ def get_unit_spike_train_in_seconds( start_frame=start_frame, end_frame=end_frame, return_times=False, - use_cache=True, + use_cache=use_cache, ) spike_times = self.sample_index_to_time(spike_frames, segment_index=segment_index) @@ -288,13 +294,170 @@ def get_unit_spike_train_in_seconds( start_frame=start_frame, end_frame=end_frame, return_times=False, - use_cache=True, + use_cache=use_cache, ) t_start = segment._t_start if segment._t_start is not None else 0 spike_times = spike_frames / self.get_sampling_frequency() return t_start + spike_times + def get_unit_spike_trains( + self, + unit_ids: np.ndarray | list, + segment_index: int | None = None, + start_frame: int | None = None, + end_frame: int | None = None, + return_times: bool = False, + use_cache: bool = True, + ) -> dict[int | str, np.ndarray]: + """Return spike trains for multiple units. + + Parameters + ---------- + unit_ids : np.ndarray | list + Unit ids to retrieve spike trains for + segment_index : int or None, default: None + The segment index to retrieve spike train from. + For multi-segment objects, it is required + start_frame : int or None, default: None + The start frame for spike train extraction + end_frame : int or None, default: None + The end frame for spike train extraction + return_times : bool, default: False + If True, returns spike times in seconds instead of frames + use_cache : bool, default: True + If True, then precompute (or use) the to_reordered_spike_vector using + + Returns + ------- + dict[int | str, np.ndarray] + A dictionary where keys are unit ids and values are spike trains (arrays of spike times or frames) + """ + if return_times: + start_time = ( + self.sample_index_to_time(start_frame, segment_index=segment_index) if start_frame is not None else None + ) + end_time = ( + self.sample_index_to_time(end_frame, segment_index=segment_index) if end_frame is not None else None + ) + + return self.get_unit_spike_trains_in_seconds( + unit_ids=unit_ids, + segment_index=segment_index, + start_time=start_time, + end_time=end_time, + use_cache=use_cache, + ) + + segment_index = self._check_segment_index(segment_index) + if use_cache: + # TODO: speed things up + ordered_spike_vector, slices = self.to_reordered_spike_vector( + lexsort=("sample_index", "segment_index", "unit_index"), + return_order=False, + return_slices=True, + ) + unit_indices = self.ids_to_indices(unit_ids) + spike_trains = {} + for unit_index, unit_id in zip(unit_indices, unit_ids): + sl0, sl1 = slices[unit_index, segment_index, :] + spikes = ordered_spike_vector[sl0:sl1] + spike_frames = spikes["sample_index"] + if start_frame is not None: + start = np.searchsorted(spike_frames, start_frame) + spike_frames = spike_frames[start:] + if end_frame is not None: + end = np.searchsorted(spike_frames, end_frame) + spike_frames = spike_frames[:end] + spike_trains[unit_id] = spike_frames + else: + spike_trains = segment.get_unit_spike_trains( + unit_ids=unit_ids, start_frame=start_frame, end_frame=end_frame, return_times=return_times + ) + + def get_unit_spike_trains_in_seconds( + self, + unit_ids: np.ndarray | list, + segment_index: int | None = None, + start_time: float | None = None, + end_time: float | None = None, + return_times: bool = False, + use_cache: bool = True, + ) -> dict[int | str, np.ndarray]: + """Return spike trains for multiple units in seconds + + Parameters + ---------- + unit_ids : np.ndarray | list + Unit ids to retrieve spike trains for + segment_index : int or None, default: None + The segment index to retrieve spike train from. + For multi-segment objects, it is required + start_time : float or None, default: None + The start time in seconds for spike train extraction + end_time : float or None, default: None + The end time in seconds for spike train extraction + return_times : bool, default: False + If True, returns spike times in seconds instead of frames + use_cache : bool, default: True + If True, then precompute (or use) the to_reordered_spike_vector using + + Returns + ------- + dict[int | str, np.ndarray] + A dictionary where keys are unit ids and values are spike trains (arrays of spike times in seconds) + """ + segment_index = self._check_segment_index(segment_index) + segment = self.segments[segment_index] + + # If sorting has a registered recording, get the frames and get the times from the recording + # Note that this take into account the segment start time of the recording + spike_times = {} + if self.has_recording(): + # Get all the spike times and then slice them + start_frame = None + end_frame = None + spike_train_frames = self.get_unit_spike_trains( + unit_ids=unit_ids, + segment_index=segment_index, + start_frame=start_frame, + end_frame=end_frame, + return_times=False, + use_cache=use_cache, + ) + + for unit_id in unit_ids: + spike_frames = self.sample_index_to_time(spike_train_frames[unit_id], segment_index=segment_index) + + # Filter to return only the spikes within the specified time range + if start_time is not None: + spike_frames = spike_frames[spike_frames >= start_time] + if end_time is not None: + spike_frames = spike_frames[spike_frames <= end_time] + + spike_times[unit_id] = spike_frames + + return spike_times + + # If no recording attached and all back to frame-based conversion + # Get spike train in frames and convert to times using traditional method + start_frame = self.time_to_sample_index(start_time, segment_index=segment_index) if start_time else None + end_frame = self.time_to_sample_index(end_time, segment_index=segment_index) if end_time else None + + spike_frames = self.get_unit_spike_trains( + unit_ids=unit_ids, + segment_index=segment_index, + start_frame=start_frame, + end_frame=end_frame, + return_times=False, + use_cache=use_cache, + ) + for unit_id in unit_ids: + spike_frames = spike_frames[unit_id] + t_start = segment._t_start if segment._t_start is not None else 0 + spike_times[unit_id] = spike_frames / self.get_sampling_frequency() + return t_start + spike_times + def register_recording(self, recording, check_spike_frames: bool = True): """ Register a recording to the sorting. If the sorting and recording both contain @@ -978,7 +1141,7 @@ def to_reordered_spike_vector( s1 = seg_slices[segment_index + 1] slices[unit_index, segment_index, :] = [u0 + s0, u0 + s1] - elif ("sample_index", "unit_index", "segment_index"): + elif lexsort == ("sample_index", "unit_index", "segment_index"): slices = np.zeros((num_segments, num_units, 2), dtype=np.int64) seg_slices = np.searchsorted(ordered_spikes["segment_index"], np.arange(num_segments + 1), side="left") for segment_index in range(self.get_num_segments()): @@ -1083,7 +1246,7 @@ def __init__(self, t_start=None): def get_unit_spike_train( self, - unit_id, + unit_id: int | str, start_frame: int | None = None, end_frame: int | None = None, ) -> np.ndarray: @@ -1091,18 +1254,51 @@ def get_unit_spike_train( Parameters ---------- - unit_id + unit_id : int | str + The unit id for which to get the spike train. start_frame : int, default: None + The start frame for the spike train. If None, it is set to the beginning of the segment. end_frame : int, default: None + The end frame for the spike train. If None, it is set to the end of the segment. + Returns ------- np.ndarray - + The spike train for the given unit id and time interval. """ # must be implemented in subclass raise NotImplementedError + def get_unit_spike_trains( + self, + unit_ids: np.ndarray | list, + start_frame: int | None = None, + end_frame: int | None = None, + ) -> dict[int | str, np.ndarray]: + """Get the spike trains for several units. + Can be implemented in subclass for performance but the default implementation is to call + get_unit_spike_train for each unit_id. + + Parameters + ---------- + unit_ids : numpy.array or list + The unit ids for which to get the spike trains. + start_frame : int, default: None + The start frame for the spike trains. If None, it is set to the beginning of the segment. + end_frame : int, default: None + The end frame for the spike trains. If None, it is set to the end of the segment. + + Returns + ------- + dict[int | str, np.ndarray] + A dictionary where keys are unit_ids and values are the corresponding spike trains. + """ + spike_trains = {} + for unit_id in unit_ids: + spike_trains[unit_id] = self.get_unit_spike_train(unit_id, start_frame=start_frame, end_frame=end_frame) + return spike_trains + class SpikeVectorSortingSegment(BaseSortingSegment): """ diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index 59356db976..4e1e81b81f 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -59,11 +59,15 @@ def _compute_and_cache_spike_vector(self) -> None: all_old_unit_ids=self._parent_sorting.unit_ids, all_new_unit_ids=self._unit_ids, ) - # lexsort by segment_index, sample_index, unit_index - sort_indices = np.lexsort( - (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) - ) - self._cached_spike_vector = spike_vector[sort_indices] + # lexsort by segment_index, sample_index, unit_index, only if needed + # (remapping can change the order of unit indices) + if np.diff(self.ids_to_indixes(self._unit_ids)).min() < 0: + sort_indices = np.lexsort( + (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) + ) + self._cached_spike_vector = spike_vector[sort_indices] + else: + self._cached_spike_vector = spike_vector class UnitsSelectionSortingSegment(BaseSortingSegment): @@ -81,3 +85,12 @@ def get_unit_spike_train( unit_id_parent = self._ids_conversion[unit_id] times = self._parent_segment.get_unit_spike_train(unit_id_parent, start_frame, end_frame) return times + + def get_unit_spike_trains( + self, + unit_ids, + start_frame: int | None = None, + end_frame: int | None = None, + ) -> dict: + unit_ids_parent = [self._ids_conversion[unit_id] for unit_id in unit_ids] + return self._parent_segment.get_unit_spike_trains(unit_ids_parent, start_frame, end_frame) From 85220e5b914534e67298d6b655e6217b0c7a6dac Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 9 Apr 2026 17:58:04 +0200 Subject: [PATCH 3/3] oups --- src/spikeinterface/core/unitsselectionsorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index 4e1e81b81f..dbcf2ee7ce 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -61,7 +61,7 @@ def _compute_and_cache_spike_vector(self) -> None: ) # lexsort by segment_index, sample_index, unit_index, only if needed # (remapping can change the order of unit indices) - if np.diff(self.ids_to_indixes(self._unit_ids)).min() < 0: + if np.diff(self.ids_to_indices(self._unit_ids)).min() < 0: sort_indices = np.lexsort( (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) )