-
Notifications
You must be signed in to change notification settings - Fork 34
Fix recording segment start time retrieval for spikeinterface>=0.104.1
#1712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -333,7 +333,7 @@ def _add_recording_segment_to_nwbfile( | |
| recording_t_start = timestamps[0] | ||
| else: | ||
| rate = recording.get_sampling_frequency() | ||
| recording_t_start = recording._recording_segments[segment_index].t_start or 0 | ||
| recording_t_start = _get_recording_segment_start_time(recording=recording, segment_index=segment_index) | ||
|
|
||
| if rate: | ||
| starting_time = float(recording_t_start) | ||
|
|
@@ -1353,7 +1353,7 @@ def _add_time_series_segment_to_nwbfile( | |
| recording_t_start = timestamps[0] | ||
| else: | ||
| rate = recording.get_sampling_frequency() | ||
| recording_t_start = recording._recording_segments[segment_index].t_start or 0 | ||
| recording_t_start = _get_recording_segment_start_time(recording=recording, segment_index=segment_index) | ||
|
|
||
| if rate: | ||
| starting_time = float(recording_t_start) | ||
|
|
@@ -2828,3 +2828,29 @@ def _stub_recording(recording: BaseRecording, *, stub_samples: int = 100) -> Bas | |
| recording_stubbed = AppendSegmentRecording(recording_list=recording_segments_stubbed) | ||
|
|
||
| return recording_stubbed | ||
|
|
||
|
|
||
| def _get_recording_segment_start_time(recording: BaseRecording, segment_index: int) -> float: | ||
| if hasattr(recording, "get_start_time"): | ||
| return float(recording.get_start_time(segment_index=segment_index)) | ||
|
|
||
| segments = None | ||
| if hasattr(recording, "segments"): | ||
| segments = recording.segments | ||
| elif hasattr(recording, "_recording_segments"): | ||
| segments = recording._recording_segments | ||
| elif hasattr(recording, "_segments"): | ||
| segments = recording._segments | ||
|
|
||
| if segments is None: | ||
| return 0.0 | ||
|
|
||
| segment = segments[segment_index] | ||
| if hasattr(segment, "get_start_time"): | ||
| return float(segment.get_start_time()) | ||
| if hasattr(segment, "t_start"): | ||
|
Comment on lines
+2849
to
+2853
|
||
| return 0.0 if segment.t_start is None else float(segment.t_start) | ||
| if hasattr(segment, "_t_start"): | ||
| return 0.0 if segment._t_start is None else float(segment._t_start) | ||
|
|
||
| return 0.0 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,13 +32,26 @@ | |
| add_sorting_analyzer_to_nwbfile, | ||
| add_sorting_to_nwbfile, | ||
| ) | ||
| from neuroconv.tools.spikeinterface.spikeinterface import ( | ||
| _get_recording_segment_start_time, | ||
| ) | ||
| from neuroconv.tools.spikeinterface.spikeinterfacerecordingdatachunkiterator import ( | ||
| SpikeInterfaceRecordingDataChunkIterator, | ||
| ) | ||
|
|
||
| testing_session_time = datetime.now().astimezone() | ||
|
|
||
|
|
||
| class _LegacyRecordingSegment: | ||
| def __init__(self, t_start): | ||
| self.t_start = t_start | ||
|
|
||
|
|
||
| class _LegacyRecording: | ||
| def __init__(self, t_start): | ||
| self._recording_segments = [_LegacyRecordingSegment(t_start=t_start)] | ||
|
|
||
|
|
||
| class TestAddElectricalSeriesWriting(unittest.TestCase): | ||
| @classmethod | ||
| def setUpClass(cls): | ||
|
|
@@ -67,6 +80,28 @@ def test_default_values(self): | |
| expected_data = self.test_recording_extractor.get_traces(segment_index=0) | ||
| np.testing.assert_array_almost_equal(expected_data, extracted_data) | ||
|
|
||
| def test_shifted_recording_uses_starting_time(self): | ||
| recording = generate_recording( | ||
| sampling_frequency=self.sampling_frequency, | ||
| num_channels=self.num_channels, | ||
| durations=self.durations, | ||
| ) | ||
| recording.shift_times(2.0) | ||
|
|
||
| add_recording_to_nwbfile(recording=recording, nwbfile=self.nwbfile, iterator_type=None) | ||
|
|
||
| electrical_series = self.nwbfile.acquisition["ElectricalSeriesRaw"] | ||
| assert electrical_series.starting_time == 2.0 | ||
| assert electrical_series.rate == self.sampling_frequency | ||
|
|
||
|
|
||
| def test_get_recording_segment_start_time_legacy_fallback(): | ||
| recording = _LegacyRecording(t_start=1.5) | ||
|
|
||
| start_time = _get_recording_segment_start_time(recording=recording, segment_index=0) | ||
|
|
||
| assert start_time == 1.5 | ||
|
|
||
| def test_write_as_lfp(self): | ||
| write_as = "lfp" | ||
|
||
| add_recording_to_nwbfile( | ||
|
|
@@ -1143,6 +1178,18 @@ def test_default_values(self): | |
| expected_data = recording.get_traces(segment_index=0) | ||
| np.testing.assert_array_almost_equal(expected_data, extracted_data) | ||
|
|
||
| def test_shifted_recording_uses_starting_time(self): | ||
| recording = generate_recording(sampling_frequency=1.0, num_channels=3, durations=[3.0]) | ||
| recording.shift_times(2.0) | ||
|
|
||
| nwbfile = mock_NWBFile() | ||
|
|
||
| add_recording_as_time_series_to_nwbfile(recording=recording, nwbfile=nwbfile, iterator_type=None) | ||
|
|
||
| time_series = nwbfile.acquisition["TimeSeries"] | ||
| assert time_series.starting_time == 2.0 | ||
| assert time_series.rate == 1.0 | ||
|
|
||
| def test_metadata_key(self): | ||
| """Test that metadata_key is used to look up metadata.""" | ||
| # Create a recording object for testing | ||
|
|
@@ -2422,9 +2469,6 @@ def test_analyzer_channel_sliced(self): | |
| def test_stub_recording_with_t_start(): | ||
| """Test that the _stub recording functionality does not fail when it has a start time. See issue #1355""" | ||
| recording = generate_recording(durations=[1.0]) | ||
| # TODO Remove the following line once Spikeinterface 0.102.4 or higher is released | ||
| # See https://github.com/SpikeInterface/spikeinterface/pull/3940 | ||
| recording._recording_segments[0].t_start = 0.0 | ||
| recording.shift_times(2.0) | ||
|
|
||
| _stub_recording(recording=recording) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recording.get_start_time(...)can legitimately returnNonewhen the recording has no explicit start time. Callingfloat(None)will raiseTypeErrorand reintroduce the original failure mode that... or 0previously avoided. TreatNoneas0.0before converting to float.