Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 202 additions & 6 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def get_unit_spike_train(
segment_index=segment_index,
start_time=start_time,
end_time=end_time,
use_cache=use_cache,
)

segment_index = self._check_segment_index(segment_index)
Expand Down Expand Up @@ -212,6 +213,7 @@ def get_unit_spike_train_in_seconds(
segment_index: int | None = None,
start_time: float | None = None,
end_time: float | None = None,
use_cache: bool = True,
) -> np.ndarray:
"""
Get spike train for a unit in seconds.
Expand All @@ -236,6 +238,10 @@ def get_unit_spike_train_in_seconds(
The start time in seconds for spike train extraction
end_time : float or None, default: None
The end time in seconds for spike train extraction
return_times : bool, default: False
If True, returns spike times in seconds instead of frames
use_cache : bool, default: True
If True, then precompute (or use) the to_reordered_spike_vector using

Returns
-------
Expand All @@ -258,7 +264,7 @@ def get_unit_spike_train_in_seconds(
start_frame=start_frame,
end_frame=end_frame,
return_times=False,
use_cache=True,
use_cache=use_cache,
)

spike_times = self.sample_index_to_time(spike_frames, segment_index=segment_index)
Expand Down Expand Up @@ -288,13 +294,170 @@ def get_unit_spike_train_in_seconds(
start_frame=start_frame,
end_frame=end_frame,
return_times=False,
use_cache=True,
use_cache=use_cache,
)

t_start = segment._t_start if segment._t_start is not None else 0
spike_times = spike_frames / self.get_sampling_frequency()
return t_start + spike_times

def get_unit_spike_trains(
self,
unit_ids: np.ndarray | list,
segment_index: int | None = None,
start_frame: int | None = None,
end_frame: int | None = None,
return_times: bool = False,
use_cache: bool = True,
) -> dict[int | str, np.ndarray]:
"""Return spike trains for multiple units.

Parameters
----------
unit_ids : np.ndarray | list
Unit ids to retrieve spike trains for
segment_index : int or None, default: None
The segment index to retrieve spike train from.
For multi-segment objects, it is required
start_frame : int or None, default: None
The start frame for spike train extraction
end_frame : int or None, default: None
The end frame for spike train extraction
return_times : bool, default: False
If True, returns spike times in seconds instead of frames
use_cache : bool, default: True
If True, then precompute (or use) the to_reordered_spike_vector using

Returns
-------
dict[int | str, np.ndarray]
A dictionary where keys are unit ids and values are spike trains (arrays of spike times or frames)
"""
if return_times:
start_time = (
self.sample_index_to_time(start_frame, segment_index=segment_index) if start_frame is not None else None
)
end_time = (
self.sample_index_to_time(end_frame, segment_index=segment_index) if end_frame is not None else None
)

return self.get_unit_spike_trains_in_seconds(
unit_ids=unit_ids,
segment_index=segment_index,
start_time=start_time,
end_time=end_time,
use_cache=use_cache,
)

segment_index = self._check_segment_index(segment_index)
if use_cache:
# TODO: speed things up
ordered_spike_vector, slices = self.to_reordered_spike_vector(
lexsort=("sample_index", "segment_index", "unit_index"),
return_order=False,
return_slices=True,
)
unit_indices = self.ids_to_indices(unit_ids)
spike_trains = {}
for unit_index, unit_id in zip(unit_indices, unit_ids):
sl0, sl1 = slices[unit_index, segment_index, :]
spikes = ordered_spike_vector[sl0:sl1]
spike_frames = spikes["sample_index"]
if start_frame is not None:
start = np.searchsorted(spike_frames, start_frame)
spike_frames = spike_frames[start:]
if end_frame is not None:
end = np.searchsorted(spike_frames, end_frame)
spike_frames = spike_frames[:end]
spike_trains[unit_id] = spike_frames
else:
spike_trains = segment.get_unit_spike_trains(
unit_ids=unit_ids, start_frame=start_frame, end_frame=end_frame, return_times=return_times
)

def get_unit_spike_trains_in_seconds(
self,
unit_ids: np.ndarray | list,
segment_index: int | None = None,
start_time: float | None = None,
end_time: float | None = None,
return_times: bool = False,
use_cache: bool = True,
) -> dict[int | str, np.ndarray]:
"""Return spike trains for multiple units in seconds

Parameters
----------
unit_ids : np.ndarray | list
Unit ids to retrieve spike trains for
segment_index : int or None, default: None
The segment index to retrieve spike train from.
For multi-segment objects, it is required
start_time : float or None, default: None
The start time in seconds for spike train extraction
end_time : float or None, default: None
The end time in seconds for spike train extraction
return_times : bool, default: False
If True, returns spike times in seconds instead of frames
use_cache : bool, default: True
If True, then precompute (or use) the to_reordered_spike_vector using

Returns
-------
dict[int | str, np.ndarray]
A dictionary where keys are unit ids and values are spike trains (arrays of spike times in seconds)
"""
segment_index = self._check_segment_index(segment_index)
segment = self.segments[segment_index]

# If sorting has a registered recording, get the frames and get the times from the recording
# Note that this take into account the segment start time of the recording
spike_times = {}
if self.has_recording():
# Get all the spike times and then slice them
start_frame = None
end_frame = None
spike_train_frames = self.get_unit_spike_trains(
unit_ids=unit_ids,
segment_index=segment_index,
start_frame=start_frame,
end_frame=end_frame,
return_times=False,
use_cache=use_cache,
)

for unit_id in unit_ids:
spike_frames = self.sample_index_to_time(spike_train_frames[unit_id], segment_index=segment_index)

# Filter to return only the spikes within the specified time range
if start_time is not None:
spike_frames = spike_frames[spike_frames >= start_time]
if end_time is not None:
spike_frames = spike_frames[spike_frames <= end_time]

spike_times[unit_id] = spike_frames

return spike_times

# If no recording attached and all back to frame-based conversion
# Get spike train in frames and convert to times using traditional method
start_frame = self.time_to_sample_index(start_time, segment_index=segment_index) if start_time else None
end_frame = self.time_to_sample_index(end_time, segment_index=segment_index) if end_time else None

spike_frames = self.get_unit_spike_trains(
unit_ids=unit_ids,
segment_index=segment_index,
start_frame=start_frame,
end_frame=end_frame,
return_times=False,
use_cache=use_cache,
)
for unit_id in unit_ids:
spike_frames = spike_frames[unit_id]
t_start = segment._t_start if segment._t_start is not None else 0
spike_times[unit_id] = spike_frames / self.get_sampling_frequency()
return t_start + spike_times

def register_recording(self, recording, check_spike_frames: bool = True):
"""
Register a recording to the sorting. If the sorting and recording both contain
Expand Down Expand Up @@ -978,7 +1141,7 @@ def to_reordered_spike_vector(
s1 = seg_slices[segment_index + 1]
slices[unit_index, segment_index, :] = [u0 + s0, u0 + s1]

elif ("sample_index", "unit_index", "segment_index"):
elif lexsort == ("sample_index", "unit_index", "segment_index"):
slices = np.zeros((num_segments, num_units, 2), dtype=np.int64)
seg_slices = np.searchsorted(ordered_spikes["segment_index"], np.arange(num_segments + 1), side="left")
for segment_index in range(self.get_num_segments()):
Expand Down Expand Up @@ -1083,26 +1246,59 @@ def __init__(self, t_start=None):

def get_unit_spike_train(
self,
unit_id,
unit_id: int | str,
start_frame: int | None = None,
end_frame: int | None = None,
) -> np.ndarray:
"""Get the spike train for a unit.

Parameters
----------
unit_id
unit_id : int | str
The unit id for which to get the spike train.
start_frame : int, default: None
The start frame for the spike train. If None, it is set to the beginning of the segment.
end_frame : int, default: None
The end frame for the spike train. If None, it is set to the end of the segment.


Returns
-------
np.ndarray

The spike train for the given unit id and time interval.
"""
# must be implemented in subclass
raise NotImplementedError

def get_unit_spike_trains(
self,
unit_ids: np.ndarray | list,
start_frame: int | None = None,
end_frame: int | None = None,
) -> dict[int | str, np.ndarray]:
"""Get the spike trains for several units.
Can be implemented in subclass for performance but the default implementation is to call
get_unit_spike_train for each unit_id.

Parameters
----------
unit_ids : numpy.array or list
The unit ids for which to get the spike trains.
start_frame : int, default: None
The start frame for the spike trains. If None, it is set to the beginning of the segment.
end_frame : int, default: None
The end frame for the spike trains. If None, it is set to the end of the segment.

Returns
-------
dict[int | str, np.ndarray]
A dictionary where keys are unit_ids and values are the corresponding spike trains.
"""
spike_trains = {}
for unit_id in unit_ids:
spike_trains[unit_id] = self.get_unit_spike_train(unit_id, start_frame=start_frame, end_frame=end_frame)
return spike_trains


class SpikeVectorSortingSegment(BaseSortingSegment):
"""
Expand Down
23 changes: 18 additions & 5 deletions src/spikeinterface/core/unitsselectionsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,15 @@ def _compute_and_cache_spike_vector(self) -> None:
all_old_unit_ids=self._parent_sorting.unit_ids,
all_new_unit_ids=self._unit_ids,
)
# lexsort by segment_index, sample_index, unit_index
sort_indices = np.lexsort(
(spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"])
)
self._cached_spike_vector = spike_vector[sort_indices]
# lexsort by segment_index, sample_index, unit_index, only if needed
# (remapping can change the order of unit indices)
if np.diff(self.ids_to_indices(self._unit_ids)).min() < 0:
sort_indices = np.lexsort(
(spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"])
)
self._cached_spike_vector = spike_vector[sort_indices]
else:
self._cached_spike_vector = spike_vector


class UnitsSelectionSortingSegment(BaseSortingSegment):
Expand All @@ -81,3 +85,12 @@ def get_unit_spike_train(
unit_id_parent = self._ids_conversion[unit_id]
times = self._parent_segment.get_unit_spike_train(unit_id_parent, start_frame, end_frame)
return times

def get_unit_spike_trains(
self,
unit_ids,
start_frame: int | None = None,
end_frame: int | None = None,
) -> dict:
unit_ids_parent = [self._ids_conversion[unit_id] for unit_id in unit_ids]
return self._parent_segment.get_unit_spike_trains(unit_ids_parent, start_frame, end_frame)
Loading