Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/spikeinterface/benchmark/benchmark_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
from spikeinterface.core.template_tools import get_template_extremum_channel


class ClusteringBenchmark(Benchmark):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset
from spikeinterface.benchmark.benchmark_clustering import ClusteringStudy
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
from spikeinterface.core.template_tools import get_template_extremum_channel

from pathlib import Path

Expand All @@ -33,7 +31,8 @@ def test_benchmark_clustering(create_cache_folder):

# sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False)
# sorting_analyzer.compute(["random_spikes", "templates"])
extremum_channel_inds = get_template_extremum_channel(gt_analyzer, outputs="index")
extremum_channel_inds = gt_analyzer.get_main_channels(outputs="index", with_dict=True)

spikes = gt_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)
peaks[dataset] = spikes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset
from spikeinterface.benchmark.benchmark_peak_detection import PeakDetectionStudy
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
from spikeinterface.core.template_tools import get_template_extremum_channel


@pytest.mark.skip()
Expand All @@ -30,7 +29,7 @@ def test_benchmark_peak_detection(create_cache_folder):
sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False, **job_kwargs)
sorting_analyzer.compute("random_spikes")
sorting_analyzer.compute("templates", **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index")
extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True)
spikes = gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)
peaks[dataset] = spikes

Expand Down
7 changes: 6 additions & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ class BaseSorting(BaseExtractor):
"""
Abstract class representing several segment several units and relative spiketrains.
"""
_main_properties = [
"main_channel_index",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"main_channel_index",
"main_channel_id",

Use channel_id instead

]

def __init__(self, sampling_frequency: float, unit_ids: list):
BaseExtractor.__init__(self, unit_ids)
Expand Down Expand Up @@ -786,6 +789,7 @@ def _compute_and_cache_spike_vector(self) -> None:
self._cached_spike_vector = spikes
self._cached_spike_vector_segment_slices = segment_slices

# TODO sam : change extremum_channel_inds to main_channel_index with vector
def to_spike_vector(
self,
concatenated=True,
Expand All @@ -806,7 +810,8 @@ def to_spike_vector(
extremum_channel_inds : None or dict, default: None
If a dictionnary of unit_id to channel_ind is given then an extra field "channel_index".
This can be convinient for computing spikes postion after sorter.
This dict can be computed with `get_template_extremum_channel(we, outputs="index")`
This dict can be given by analyzer.get_main_channels(outputs="index", with_dict=True)

use_cache : bool, default: True
When True the spikes vector is cached as an attribute of the object (`_cached_spike_vector`).
This caching only occurs when extremum_channel_inds=None.
Expand Down
4 changes: 4 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2444,6 +2444,10 @@ def generate_ground_truth_recording(
**generate_templates_kwargs,
)
sorting.set_property("gt_unit_locations", unit_locations)
distances = np.linalg.norm(unit_locations[:, np.newaxis, :2] - channel_locations[np.newaxis, :, :], axis=2)
main_channel_index = np.argmin(distances, axis=1)
sorting.set_property("main_channel_index", main_channel_index)

else:
assert templates.shape[0] == num_units

Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea
return (local_peaks,)


# TODO sam replace extremum_channels_indices by main_channel_index

# this is not implemented yet this will be done in separted PR
class SpikeRetriever(PeakSource):
"""
Expand Down
Loading
Loading