Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions src/neuroconv/tools/spikeinterface/spikeinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def _add_recording_segment_to_nwbfile(
recording_t_start = timestamps[0]
else:
rate = recording.get_sampling_frequency()
recording_t_start = recording._recording_segments[segment_index].t_start or 0
recording_t_start = _get_recording_segment_start_time(recording=recording, segment_index=segment_index)

if rate:
starting_time = float(recording_t_start)
Expand Down Expand Up @@ -1353,7 +1353,7 @@ def _add_time_series_segment_to_nwbfile(
recording_t_start = timestamps[0]
else:
rate = recording.get_sampling_frequency()
recording_t_start = recording._recording_segments[segment_index].t_start or 0
recording_t_start = _get_recording_segment_start_time(recording=recording, segment_index=segment_index)

if rate:
starting_time = float(recording_t_start)
Expand Down Expand Up @@ -2828,3 +2828,31 @@ def _stub_recording(recording: BaseRecording, *, stub_samples: int = 100) -> Bas
recording_stubbed = AppendSegmentRecording(recording_list=recording_segments_stubbed)

return recording_stubbed


def _get_recording_segment_start_time(recording: BaseRecording, segment_index: int) -> float:
if hasattr(recording, "get_start_time"):
start_time = recording.get_start_time(segment_index=segment_index)
return 0.0 if start_time is None else float(start_time)

segments = None
if hasattr(recording, "segments"):
segments = recording.segments
elif hasattr(recording, "_recording_segments"):
segments = recording._recording_segments
elif hasattr(recording, "_segments"):
segments = recording._segments

if segments is None:
return 0.0

segment = segments[segment_index]
if hasattr(segment, "get_start_time"):
start_time = segment.get_start_time()
return 0.0 if start_time is None else float(start_time)
if hasattr(segment, "t_start"):
Comment on lines +2849 to +2853
Copy link

Copilot AI Apr 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

segment.get_start_time() may also return None (depending on spikeinterface version/segment implementation). This path currently does float(segment.get_start_time()) and will raise TypeError for recordings without an explicit start time. Handle a None return by defaulting to 0.0 (consistent with the t_start/_t_start branches).

Copilot uses AI. Check for mistakes.
return 0.0 if segment.t_start is None else float(segment.t_start)
if hasattr(segment, "_t_start"):
return 0.0 if segment._t_start is None else float(segment._t_start)

return 0.0
3 changes: 0 additions & 3 deletions tests/test_modalities/test_ecephys/test_ecephys_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,6 @@ def test_stub_with_starting_time(self, setup_interface):
interface = MockRecordingInterface(durations=[1.0])

recording = interface.recording_extractor
# TODO Remove the following line once Spikeinterface 0.102.4 or higher is released
# See https://github.com/SpikeInterface/spikeinterface/pull/3940
recording._recording_segments[0].t_start = 0.0
recording.shift_times(2.0)

interface.create_nwbfile(stub_test=True)
Expand Down
50 changes: 47 additions & 3 deletions tests/test_modalities/test_ecephys/test_tools_spikeinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,34 @@
add_sorting_analyzer_to_nwbfile,
add_sorting_to_nwbfile,
)
from neuroconv.tools.spikeinterface.spikeinterface import (
_get_recording_segment_start_time,
)
from neuroconv.tools.spikeinterface.spikeinterfacerecordingdatachunkiterator import (
SpikeInterfaceRecordingDataChunkIterator,
)

testing_session_time = datetime.now().astimezone()


class _LegacyRecordingSegment:
def __init__(self, t_start):
self.t_start = t_start


class _LegacyRecording:
def __init__(self, t_start):
self._recording_segments = [_LegacyRecordingSegment(t_start=t_start)]


def test_get_recording_segment_start_time_legacy_fallback():
recording = _LegacyRecording(t_start=1.5)

start_time = _get_recording_segment_start_time(recording=recording, segment_index=0)

assert start_time == 1.5


class TestAddElectricalSeriesWriting(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -67,6 +88,20 @@ def test_default_values(self):
expected_data = self.test_recording_extractor.get_traces(segment_index=0)
np.testing.assert_array_almost_equal(expected_data, extracted_data)

def test_shifted_recording_uses_starting_time(self):
recording = generate_recording(
sampling_frequency=self.sampling_frequency,
num_channels=self.num_channels,
durations=self.durations,
)
recording.shift_times(2.0)

add_recording_to_nwbfile(recording=recording, nwbfile=self.nwbfile, iterator_type=None)

electrical_series = self.nwbfile.acquisition["ElectricalSeriesRaw"]
assert electrical_series.starting_time == 2.0
assert electrical_series.rate == self.sampling_frequency

def test_write_as_lfp(self):
write_as = "lfp"
add_recording_to_nwbfile(
Expand Down Expand Up @@ -1143,6 +1178,18 @@ def test_default_values(self):
expected_data = recording.get_traces(segment_index=0)
np.testing.assert_array_almost_equal(expected_data, extracted_data)

def test_shifted_recording_uses_starting_time(self):
recording = generate_recording(sampling_frequency=1.0, num_channels=3, durations=[3.0])
recording.shift_times(2.0)

nwbfile = mock_NWBFile()

add_recording_as_time_series_to_nwbfile(recording=recording, nwbfile=nwbfile, iterator_type=None)

time_series = nwbfile.acquisition["TimeSeries"]
assert time_series.starting_time == 2.0
assert time_series.rate == 1.0

def test_metadata_key(self):
"""Test that metadata_key is used to look up metadata."""
# Create a recording object for testing
Expand Down Expand Up @@ -2422,9 +2469,6 @@ def test_analyzer_channel_sliced(self):
def test_stub_recording_with_t_start():
"""Test that the _stub recording functionality does not fail when it has a start time. See issue #1355"""
recording = generate_recording(durations=[1.0])
# TODO Remove the following line once Spikeinterface 0.102.4 or higher is released
# See https://github.com/SpikeInterface/spikeinterface/pull/3940
recording._recording_segments[0].t_start = 0.0
recording.shift_times(2.0)

_stub_recording(recording=recording)
Expand Down
Loading