diff --git a/src/neuroconv/tools/spikeinterface/spikeinterface.py b/src/neuroconv/tools/spikeinterface/spikeinterface.py index 270a8470b..743a46cd2 100644 --- a/src/neuroconv/tools/spikeinterface/spikeinterface.py +++ b/src/neuroconv/tools/spikeinterface/spikeinterface.py @@ -333,7 +333,7 @@ def _add_recording_segment_to_nwbfile( recording_t_start = timestamps[0] else: rate = recording.get_sampling_frequency() - recording_t_start = recording._recording_segments[segment_index].t_start or 0 + recording_t_start = _get_recording_segment_start_time(recording=recording, segment_index=segment_index) if rate: starting_time = float(recording_t_start) @@ -1353,7 +1353,7 @@ def _add_time_series_segment_to_nwbfile( recording_t_start = timestamps[0] else: rate = recording.get_sampling_frequency() - recording_t_start = recording._recording_segments[segment_index].t_start or 0 + recording_t_start = _get_recording_segment_start_time(recording=recording, segment_index=segment_index) if rate: starting_time = float(recording_t_start) @@ -2828,3 +2828,31 @@ def _stub_recording(recording: BaseRecording, *, stub_samples: int = 100) -> Bas recording_stubbed = AppendSegmentRecording(recording_list=recording_segments_stubbed) return recording_stubbed + + +def _get_recording_segment_start_time(recording: BaseRecording, segment_index: int) -> float: + if hasattr(recording, "get_start_time"): + start_time = recording.get_start_time(segment_index=segment_index) + return 0.0 if start_time is None else float(start_time) + + segments = None + if hasattr(recording, "segments"): + segments = recording.segments + elif hasattr(recording, "_recording_segments"): + segments = recording._recording_segments + elif hasattr(recording, "_segments"): + segments = recording._segments + + if segments is None: + return 0.0 + + segment = segments[segment_index] + if hasattr(segment, "get_start_time"): + start_time = segment.get_start_time() + return 0.0 if start_time is None else float(start_time) + if hasattr(segment, "t_start"): + return 0.0 if segment.t_start is None else float(segment.t_start) + if hasattr(segment, "_t_start"): + return 0.0 if segment._t_start is None else float(segment._t_start) + + return 0.0 diff --git a/tests/test_modalities/test_ecephys/test_ecephys_interfaces.py b/tests/test_modalities/test_ecephys/test_ecephys_interfaces.py index 352956e58..2b85c5013 100644 --- a/tests/test_modalities/test_ecephys/test_ecephys_interfaces.py +++ b/tests/test_modalities/test_ecephys/test_ecephys_interfaces.py @@ -216,9 +216,6 @@ def test_stub_with_starting_time(self, setup_interface): interface = MockRecordingInterface(durations=[1.0]) recording = interface.recording_extractor - # TODO Remove the following line once Spikeinterface 0.102.4 or higher is released - # See https://github.com/SpikeInterface/spikeinterface/pull/3940 - recording._recording_segments[0].t_start = 0.0 recording.shift_times(2.0) interface.create_nwbfile(stub_test=True) diff --git a/tests/test_modalities/test_ecephys/test_tools_spikeinterface.py b/tests/test_modalities/test_ecephys/test_tools_spikeinterface.py index 20aa8a8d5..f9d224516 100644 --- a/tests/test_modalities/test_ecephys/test_tools_spikeinterface.py +++ b/tests/test_modalities/test_ecephys/test_tools_spikeinterface.py @@ -32,6 +32,9 @@ add_sorting_analyzer_to_nwbfile, add_sorting_to_nwbfile, ) +from neuroconv.tools.spikeinterface.spikeinterface import ( + _get_recording_segment_start_time, +) from neuroconv.tools.spikeinterface.spikeinterfacerecordingdatachunkiterator import ( SpikeInterfaceRecordingDataChunkIterator, ) @@ -39,6 +42,24 @@ testing_session_time = datetime.now().astimezone() +class _LegacyRecordingSegment: + def __init__(self, t_start): + self.t_start = t_start + + +class _LegacyRecording: + def __init__(self, t_start): + self._recording_segments = [_LegacyRecordingSegment(t_start=t_start)] + + +def test_get_recording_segment_start_time_legacy_fallback(): + recording = _LegacyRecording(t_start=1.5) + + start_time = _get_recording_segment_start_time(recording=recording, segment_index=0) + + assert start_time == 1.5 + + class TestAddElectricalSeriesWriting(unittest.TestCase): @classmethod def setUpClass(cls): @@ -67,6 +88,20 @@ def test_default_values(self): expected_data = self.test_recording_extractor.get_traces(segment_index=0) np.testing.assert_array_almost_equal(expected_data, extracted_data) + def test_shifted_recording_uses_starting_time(self): + recording = generate_recording( + sampling_frequency=self.sampling_frequency, + num_channels=self.num_channels, + durations=self.durations, + ) + recording.shift_times(2.0) + + add_recording_to_nwbfile(recording=recording, nwbfile=self.nwbfile, iterator_type=None) + + electrical_series = self.nwbfile.acquisition["ElectricalSeriesRaw"] + assert electrical_series.starting_time == 2.0 + assert electrical_series.rate == self.sampling_frequency + def test_write_as_lfp(self): write_as = "lfp" add_recording_to_nwbfile( @@ -1143,6 +1178,18 @@ def test_default_values(self): expected_data = recording.get_traces(segment_index=0) np.testing.assert_array_almost_equal(expected_data, extracted_data) + def test_shifted_recording_uses_starting_time(self): + recording = generate_recording(sampling_frequency=1.0, num_channels=3, durations=[3.0]) + recording.shift_times(2.0) + + nwbfile = mock_NWBFile() + + add_recording_as_time_series_to_nwbfile(recording=recording, nwbfile=nwbfile, iterator_type=None) + + time_series = nwbfile.acquisition["TimeSeries"] + assert time_series.starting_time == 2.0 + assert time_series.rate == 1.0 + def test_metadata_key(self): """Test that metadata_key is used to look up metadata.""" # Create a recording object for testing @@ -2422,9 +2469,6 @@ def test_analyzer_channel_sliced(self): def test_stub_recording_with_t_start(): """Test that the _stub recording functionality does not fail when it has a start time. See issue #1355""" recording = generate_recording(durations=[1.0]) - # TODO Remove the following line once Spikeinterface 0.102.4 or higher is released - # See https://github.com/SpikeInterface/spikeinterface/pull/3940 - recording._recording_segments[0].t_start = 0.0 recording.shift_times(2.0) _stub_recording(recording=recording)