Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6dd7269
Try to remove get_template_extremum_channel()
samuelgarcia Feb 6, 2026
eede722
Put main_channel_peak_sign and main_channel_peak_mode in analyzer set…
samuelgarcia Feb 6, 2026
e7a32b5
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Feb 17, 2026
7b6a109
continue the tedious refactoring
samuelgarcia Feb 17, 2026
99da9a1
Merge branch 'main' of github.com:SpikeInterface/spikeinterface into …
samuelgarcia Mar 25, 2026
84e8d37
wip peak sign remove
samuelgarcia Mar 25, 2026
413a907
Merge remote-tracking branch 'origin/main' into less_peak_sign_more_m…
chrishalcrow Jun 23, 2026
91206ac
Change Template's default for get_main_channels output format
chrishalcrow Jun 23, 2026
6e00b92
get most tests passing
chrishalcrow Jun 23, 2026
a51a24f
Make main_channel_index a property
chrishalcrow Jun 23, 2026
0b81eb0
get waveform extractor working
chrishalcrow Jun 23, 2026
49b5199
fix 3d channel stuff
chrishalcrow Jun 23, 2026
72e2ced
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 23, 2026
01e2b22
2d 3d channel locations fix
chrishalcrow Jun 23, 2026
1895255
fix more bugs
chrishalcrow Jun 23, 2026
978d490
more bug fixes
chrishalcrow Jun 24, 2026
c09f8cb
deprecation warnings and return_to_uv None in template functions
chrishalcrow Jun 24, 2026
2df4187
change main_channel_peak_sign to peak_sign
chrishalcrow Jun 24, 2026
94b4372
good grouping for make_sorting_analyzer in tests
chrishalcrow Jun 24, 2026
b61bebf
respond to Erick
chrishalcrow Jun 25, 2026
a391558
oups
chrishalcrow Jun 25, 2026
61b790b
Remove peak_sign and peak_mode from compute_sparsity (were not being
chrishalcrow Jun 25, 2026
f84f1f7
Merge branch 'main' into less_peak_sign_more_main_channel
chrishalcrow Jun 25, 2026
ec6fac3
Remove more peak_amplitude from compute_sparsity
chrishalcrow Jun 25, 2026
cf0a327
Updates after Alessio/Sam discussion
chrishalcrow Jun 29, 2026
a474c03
bug fixes and internal sorters backwards compat
chrishalcrow Jun 29, 2026
15b0d42
oups
chrishalcrow Jun 29, 2026
64460c4
skip test_output_values test for now
chrishalcrow Jun 29, 2026
347609c
make tests work with int unit_ids
chrishalcrow Jun 30, 2026
5faa024
Merge branch 'main' into less_peak_sign_more_main_channel
chrishalcrow Jun 30, 2026
d277c10
get metrics tests to pass
chrishalcrow Jun 30, 2026
a24e030
sortingcomponents and curation tests
chrishalcrow Jun 30, 2026
8a5044b
self review
chrishalcrow Jul 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/spikeinterface/benchmark/benchmark_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from spikeinterface.core.job_tools import fix_job_kwargs, split_job_kwargs
from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
from spikeinterface.core.template_tools import get_template_extremum_channel


class ClusteringBenchmark(Benchmark):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@

from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset
from spikeinterface.benchmark.benchmark_clustering import ClusteringStudy
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
from spikeinterface.core.template_tools import get_template_extremum_channel

from pathlib import Path

Expand All @@ -33,7 +31,8 @@ def test_benchmark_clustering(create_cache_folder):

# sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False)
# sorting_analyzer.compute(["random_spikes", "templates"])
extremum_channel_inds = get_template_extremum_channel(gt_analyzer, outputs="index")
extremum_channel_inds = gt_analyzer.get_main_channels(outputs="index", with_dict=True)

spikes = gt_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)
peaks[dataset] = spikes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from spikeinterface.benchmark.tests.common_benchmark_testing import make_dataset
from spikeinterface.benchmark.benchmark_peak_detection import PeakDetectionStudy
from spikeinterface.core.sortinganalyzer import create_sorting_analyzer
from spikeinterface.core.template_tools import get_template_extremum_channel


@pytest.mark.skip()
Expand All @@ -30,7 +29,7 @@ def test_benchmark_peak_detection(create_cache_folder):
sorting_analyzer = create_sorting_analyzer(gt_sorting, recording, format="memory", sparse=False, **job_kwargs)
sorting_analyzer.compute("random_spikes")
sorting_analyzer.compute("templates", **job_kwargs)
extremum_channel_inds = get_template_extremum_channel(sorting_analyzer, outputs="index")
extremum_channel_inds = sorting_analyzer.get_main_channels(outputs="index", with_dict=True)
spikes = gt_sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)
peaks[dataset] = spikes

Expand Down
9 changes: 6 additions & 3 deletions src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,12 @@
# template tools
from .template_tools import (
get_template_amplitudes,
get_template_extremum_channel,
get_template_extremum_channel_peak_shift,
get_template_extremum_amplitude,
get_template_main_channel_peak_shift,
get_template_main_channel_amplitude,
# this is not needed anymore
get_template_extremum_channel, # keep for backward compatibility can be removed in 0.106
get_template_extremum_channel_peak_shift, # keep for backward compatibility can be removed in 0.106
get_template_extremum_amplitude, # keep for backward compatibility can be removed in 0.106
)


Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/analyzer_extension_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ class ComputeTemplates(AnalyzerExtension):

extension_name = "templates"
depend_on = ["random_spikes|waveforms"]
need_recording = True
need_recording = False

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have waveforms, you do not need the Recording. Not sure how to deal with this.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could do as we do with optional dependencies, i.e., having a function that evaluates if this is needed on the fly based on othere available extensions?

use_nodepipeline = False
need_job_kwargs = True
need_backward_compatibility_on_load = True
Expand Down
8 changes: 7 additions & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ class BaseSorting(BaseExtractor):
Abstract class representing several segment several units and relative spiketrains.
"""

_main_properties = [
"main_channel_id",
]

def __init__(self, sampling_frequency: float, unit_ids: list):
BaseExtractor.__init__(self, unit_ids)
self._sampling_frequency = float(sampling_frequency)
Expand Down Expand Up @@ -913,6 +917,7 @@ def _compute_and_cache_spike_vector(self) -> None:
self._cached_spike_vector = spikes
self._cached_spike_vector_segment_slices = segment_slices

# TODO sam : change extremum_channel_inds to main_channel_index with vector
def to_spike_vector(
self,
concatenated=True,
Expand All @@ -933,7 +938,8 @@ def to_spike_vector(
extremum_channel_inds : None or dict, default: None
If a dictionnary of unit_id to channel_ind is given then an extra field "channel_index".
This can be convinient for computing spikes postion after sorter.
This dict can be computed with `get_template_extremum_channel(we, outputs="index")`
This dict can be given by analyzer.get_main_channels(outputs="index", with_dict=True)
use_cache : bool, default: True
When True the spikes vector is cached as an attribute of the object (`_cached_spike_vector`).
This caching only occurs when extremum_channel_inds=None.
Expand Down
38 changes: 30 additions & 8 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,9 @@ def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=No
"""
rng = np.random.default_rng(seed)

other_ids = np.arange(np.max(sorting.unit_ids) + 1, np.max(sorting.unit_ids) + num + 1)
other_ids = np.arange(
np.max(sorting.unit_ids.astype(int)) + 1, np.max(sorting.unit_ids.astype(int)) + num + 1
).astype(sorting.unit_ids.dtype)
shifts = rng.integers(low=-max_shift, high=max_shift, size=num)

shifts[shifts == 0] += max_shift
Expand Down Expand Up @@ -1007,12 +1009,20 @@ def inject_some_split_units(sorting, split_ids: list, num_split=2, output_ids=Fa
The dictionary with the split unit_ids. Returned only if output_ids is True.
"""
unit_ids = sorting.unit_ids
assert unit_ids.dtype.kind == "i"

m = np.max(unit_ids) + 1
unit_ids_integer = unit_ids.dtype.kind == "i"
if not unit_ids_integer:
assert np.char.isdigit(unit_ids.astype(str)).all(), "Not all Unit IDs are parsable as integers"

integer_unit_ids = unit_ids.astype(int)

m = np.max(integer_unit_ids) + 1
other_ids = {}
for unit_id in split_ids:
other_ids[unit_id] = np.arange(m, m + num_split, dtype=unit_ids.dtype)
new_ids = np.arange(m, m + num_split).astype(int)
if not unit_ids_integer:
new_ids = [str(new_id) for new_id in new_ids]
other_ids[unit_id] = new_ids
m += num_split

rng = np.random.default_rng(seed)
Expand Down Expand Up @@ -2410,6 +2420,9 @@ def generate_ground_truth_recording(
else:
num_channels = probe.get_contact_count()

nbefore = ms_to_samples(ms_before, sampling_frequency)
nafter = ms_to_samples(ms_after, sampling_frequency)

if templates is None:
channel_locations = probe.contact_positions
unit_locations = generate_unit_locations(
Expand All @@ -2427,8 +2440,18 @@ def generate_ground_truth_recording(
**generate_templates_kwargs,
)
sorting.set_property("gt_unit_locations", unit_locations)
distances = np.linalg.norm(unit_locations[:, np.newaxis, :2] - channel_locations[np.newaxis, :, :2], axis=2)
main_channel_indices = np.argmin(distances, axis=1)

else:
assert templates.shape[0] == num_units
from .template_tools import _get_main_channel_from_template_array

main_channel_indices = _get_main_channel_from_template_array(
templates, peak_mode="extremum", peak_sign="both", nbefore=nbefore
)

assert (nbefore + nafter) == templates.shape[1]

if templates.ndim == 3:
upsample_vector = None
Expand All @@ -2437,10 +2460,6 @@ def generate_ground_truth_recording(
upsample_factor = templates.shape[3]
upsample_vector = rng.integers(0, upsample_factor, size=num_spikes)

nbefore = ms_to_samples(ms_before, sampling_frequency)
nafter = ms_to_samples(ms_after, sampling_frequency)
assert (nbefore + nafter) == templates.shape[1]

# construct recording
from spikeinterface.generation.noise_tools import NoiseGeneratorRecording

Expand All @@ -2466,6 +2485,9 @@ def generate_ground_truth_recording(
recording.set_channel_gains(1.0)
recording.set_channel_offsets(0.0)

main_channel_ids = recording.channel_ids[main_channel_indices]
sorting.set_property("main_channel_id", main_channel_ids)

recording.name = "GroundTruthRecording"
sorting.name = "GroundTruthSorting"

Expand Down
3 changes: 3 additions & 0 deletions src/spikeinterface/core/node_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def compute(self, traces, start_frame, end_frame, segment_index, max_margin, pea
return (local_peaks,)


# TODO sam replace extremum_channels_indices by main_channel_index


# this is not implemented yet this will be done in separted PR
class SpikeRetriever(PeakSource):
"""
Expand Down
7 changes: 7 additions & 0 deletions src/spikeinterface/core/sorting_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ def set_properties_after_merging(
default_missing_values = BaseExtractor.default_missing_property_values

for key in prop_keys:

parent_values = sorting_pre_merge.get_property(key)

# propagate keep values
Expand All @@ -511,6 +512,12 @@ def set_properties_after_merging(
if same_property_values:
# and new values only if they are all similar
new_values[new_index] = merge_values[0]
elif (not same_property_values) and key == "main_channel_id":
# Main channel id is special. For now, if there is a disagreement, we take the value of the unit
# with the most spikes. TODO: overwrite this for analyzer if templates exist.
num_spikes_per_unit = sorting_pre_merge.count_num_spikes_per_unit(unit_ids=merge_group)
max_unit_index = np.argmax(num_spikes_per_unit.values())
new_values[new_index] = merge_values[max_unit_index]
else:
if parent_values.dtype.kind not in default_missing_values:
# if the property doesn't have a default missing value and it is not the same
Expand Down
Loading
Loading