diff --git a/.github/scripts/create_probe_compat_fixtures.py b/.github/scripts/create_probe_compat_fixtures.py new file mode 100644 index 0000000000..52f6746584 --- /dev/null +++ b/.github/scripts/create_probe_compat_fixtures.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python +""" +Creates probe compatibility fixtures using the *currently installed* spikeinterface. + +Run this script with spikeinterface==0.104.* installed to produce the fixture +files consumed by test_probe_backward_compat.py: + + python create_probe_compat_fixtures.py [output_dir] + +If output_dir is omitted, fixtures are written to ./probe_compat_fixtures. + +Note: we use `in_place=True` since a bug (fixed in #4300) prevented probes_info to be properly +saved as annotations in the probegroup when using `in_place=False` in spikeinterface 0.104.*. +""" + +import sys +import shutil +import numpy as np +from pathlib import Path + +import spikeinterface + +print(f"Creating fixtures with spikeinterface {spikeinterface.__version__}") + +OUTPUT_DIR = Path(sys.argv[1]) if len(sys.argv) > 1 else Path("probe_compat_fixtures") +if OUTPUT_DIR.exists(): + shutil.rmtree(OUTPUT_DIR) +OUTPUT_DIR.mkdir(parents=True) + +from probeinterface import generate_linear_probe, ProbeGroup +from spikeinterface.core import NumpyRecording + +# ----------------------------------------------------------------------- +# Fixture 1: single probe, sequential device_channel_indices +# ----------------------------------------------------------------------- +n = 8 +probe = generate_linear_probe(num_elec=n, ypitch=20.0) +probe.annotate(name="test_probe", manufacturer="test_vendor") +probe.set_contact_ids([f"e{i}" for i in range(n)]) +probe.set_device_channel_indices(np.arange(n)) +probe.create_auto_shape() + +traces = np.arange(1000 * n, dtype="int16").reshape(1000, n) +rec_single = NumpyRecording([traces], sampling_frequency=30000.0) +rec_single.set_probe(probe, in_place=True) + +rec_single_bin = rec_single.save(folder=str(OUTPUT_DIR / "single_probe_binary")) +rec_single_zarr = rec_single.save(folder=str(OUTPUT_DIR / "single_probe.zarr"), format="zarr") +rec_single_bin.dump_to_json(str(OUTPUT_DIR / "single_probe.json")) +rec_single_bin.dump_to_pickle(str(OUTPUT_DIR / "single_probe.pkl")) + +# ----------------------------------------------------------------------- +# Fixture 2: two probes with per-probe name/manufacturer +# ----------------------------------------------------------------------- +n_A, n_B = 8, 8 +probe_A = generate_linear_probe(num_elec=n_A, ypitch=20.0) +probe_A.move([0.0, 0.0]) +probe_A.annotate(name="probe_A", manufacturer="vendor_X") +probe_A.set_contact_ids([f"a{i}" for i in range(n_A)]) +probe_A.set_device_channel_indices(np.arange(n_A)) +probe_A.create_auto_shape() + +probe_B = generate_linear_probe(num_elec=n_B, ypitch=20.0) +probe_B.move([500.0, 0.0]) +probe_B.annotate(name="probe_B", manufacturer="vendor_Y") +probe_B.set_contact_ids([f"b{i}" for i in range(n_B)]) +probe_B.set_device_channel_indices(np.arange(n_A, n_A + n_B)) +probe_B.create_auto_shape() + +pg = ProbeGroup() +pg.add_probe(probe_A) +pg.add_probe(probe_B) + +n_total = n_A + n_B +traces2 = np.arange(1000 * n_total, dtype="int16").reshape(1000, n_total) +rec_two = NumpyRecording([traces2], sampling_frequency=30000.0) +rec_two.set_probegroup(pg, in_place=True) + +rec_two_bin = rec_two.save(folder=str(OUTPUT_DIR / "two_probe_binary")) +rec_two_zarr = rec_two.save(folder=str(OUTPUT_DIR / "two_probe.zarr"), format="zarr") +rec_two_bin.dump_to_json(str(OUTPUT_DIR / "two_probe.json")) +rec_two_bin.dump_to_pickle(str(OUTPUT_DIR / "two_probe.pkl")) + +# ----------------------------------------------------------------------- +# Fixture 3: probe with shuffled device_channel_indices +# Verifies that the channel-reordering logic is preserved across versions. +# ----------------------------------------------------------------------- +n = 8 +probe_sh = generate_linear_probe(num_elec=n, ypitch=20.0) +probe_sh.annotate(name="shuffled_probe", manufacturer="shuffle_vendor") +shuffled_dci = np.array([3, 0, 7, 1, 5, 2, 6, 4]) # permutation of 0..7 +probe_sh.set_device_channel_indices(shuffled_dci) + +# traces[:, j] corresponds to recording channel j, which after set_probe +# is mapped to the contact whose dci equals j. +traces3 = np.arange(1000 * n, dtype="int16").reshape(1000, n) +rec_sh = NumpyRecording([traces3], sampling_frequency=30000.0) +rec_sh.set_probe(probe_sh, in_place=True) + +rec_sh_bin = rec_sh.save(folder=str(OUTPUT_DIR / "shuffled_probe_binary")) +rec_sh_zarr = rec_sh.save(folder=str(OUTPUT_DIR / "shuffled_probe.zarr"), format="zarr") +rec_sh_bin.dump_to_json(str(OUTPUT_DIR / "shuffled_probe.json")) +rec_sh_bin.dump_to_pickle(str(OUTPUT_DIR / "shuffled_probe.pkl")) + +print(f"Fixtures written to: {OUTPUT_DIR.resolve()}") + +# ----------------------------------------------------------------------- +# Fixture 4: two probes with interleaved device_channel_indices +# ----------------------------------------------------------------------- +n = 8 +probe_A = generate_linear_probe(num_elec=n, ypitch=20.0) +probe_A.move([0.0, 0.0]) +probe_A.annotate(name="probe_A", manufacturer="vendor_X") +probe_A.set_contact_ids([f"a{i}" for i in range(n)]) +probe_A.set_device_channel_indices(np.arange(0, 2 * n, 2)) # even indices +probe_A.create_auto_shape() + +probe_B = generate_linear_probe(num_elec=n, ypitch=20.0) +probe_B.move([500.0, 0.0]) +probe_B.annotate(name="probe_B", manufacturer="vendor_Y") +probe_B.set_contact_ids([f"b{i}" for i in range(n)]) +probe_B.set_device_channel_indices(np.arange(1, 2 * n, 2)) # odd indices +probe_B.create_auto_shape() + +pg = ProbeGroup() +pg.add_probe(probe_A) +pg.add_probe(probe_B) + +n_total = 2 * n +traces2 = np.arange(1000 * n_total, dtype="int16").reshape(1000, n_total) +rec_two_inter = NumpyRecording([traces2], sampling_frequency=30000.0) +rec_two_inter.set_probegroup(pg, in_place=True) + +rec_two_inter_bin = rec_two_inter.save(folder=str(OUTPUT_DIR / "two_probe_interleaved_binary")) +rec_two_inter_zarr = rec_two_inter.save(folder=str(OUTPUT_DIR / "two_probe_interleaved.zarr"), format="zarr") +rec_two_inter_bin.dump_to_json(str(OUTPUT_DIR / "two_probe_interleaved.json")) +rec_two_inter_bin.dump_to_pickle(str(OUTPUT_DIR / "two_probe_interleaved.pkl")) diff --git a/.github/workflows/probe_backward_compat.yml b/.github/workflows/probe_backward_compat.yml new file mode 100644 index 0000000000..3ac134be72 --- /dev/null +++ b/.github/workflows/probe_backward_compat.yml @@ -0,0 +1,58 @@ +name: Probe backward compatibility + +on: + pull_request: + types: [synchronize, opened, reopened] + branches: + - main + paths: + - 'src/spikeinterface/core/**' + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + probe-backward-compat: + name: Probe compat (SI ${{ matrix.si-version }} → current) + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + si-version: + - '0.102.*' + - '0.103.*' + - '0.104.*' + env: + SI_PROBE_COMPAT_FIXTURES_DIR: ${{ github.workspace }}/probe_compat_fixtures + + steps: + - name: Check out code + uses: actions/checkout@v6 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.11' + + - name: Set up uv + uses: astral-sh/setup-uv@v7 + with: + python-version: '3.11' + enable-cache: false + + # Step 1: install the OLD release and create fixtures. + # The fixture script uses the old in_place=False default (returns a new recording), + # saves to binary folder + JSON, and writes a known probe name/manufacturer/contact_ids. + - name: Install spikeinterface ${{ matrix.si-version }} to create fixtures + run: uv pip install --system "spikeinterface[core]==${{ matrix.si-version }}" + + - name: Create compatibility fixtures with old version + run: python .github/scripts/create_probe_compat_fixtures.py "$SI_PROBE_COMPAT_FIXTURES_DIR" + + # Step 2: install the NEW version from this PR source and run the load tests. + - name: Install new spikeinterface from source + run: uv pip install --system -e . --group test-core + + - name: Run backward compatibility tests + run: pytest src/spikeinterface/core/tests/test_probe_backward_compat.py -v diff --git a/doc/api.rst b/doc/api.rst index fc55017606..1bc8156aef 100755 --- a/doc/api.rst +++ b/doc/api.rst @@ -16,6 +16,12 @@ spikeinterface.core .. automethod:: BaseRecording.dump_to_json .. automethod:: BaseRecording.dump_to_pickle .. automethod:: BaseRecording.remove_channels + .. automethod:: BaseRecording.set_probe + .. automethod:: BaseRecording.set_probegroup + .. automethod:: BaseRecording.reset_probe + .. automethod:: BaseRecording.select_channels_with_probe + .. automethod:: BaseRecording.select_channels_with_probegroup + .. automethod:: BaseRecording.split_by .. autoclass:: BaseSorting :members: .. automethod:: BaseSorting.save @@ -25,6 +31,8 @@ spikeinterface.core .. automethod:: BaseSorting.dump .. automethod:: BaseSorting.dump_to_json .. automethod:: BaseSorting.dump_to_pickle + .. automethod:: BaseSorting.split_by + .. automethod:: BaseSorting.register_recording .. autoclass:: BaseSnippets :members: .. automethod:: BaseSnippets.save diff --git a/doc/get_started/quickstart.rst b/doc/get_started/quickstart.rst index b5d3c2a985..3fbd113ba7 100644 --- a/doc/get_started/quickstart.rst +++ b/doc/get_started/quickstart.rst @@ -198,8 +198,9 @@ to set it *manually*. If your recording does not have a ``Probe``, you can set it using -``set_probe``. Note: ``set_probe`` creates a copy of the recording with -the new probe, rather than modifying the existing recording in place. +``set_probe``. Note: ``set_probe`` modifies the recording in place. To +get a new recording object with a subset of channels attached to a probe, +use ``select_channels_with_probe``. There is more information `here `__. diff --git a/doc/modules/core.rst b/doc/modules/core.rst index 681542368b..6ea3d25eb6 100644 --- a/doc/modules/core.rst +++ b/doc/modules/core.rst @@ -522,12 +522,18 @@ The probe has 4 shanks, which can be loaded as separate groups (and spike sorted # add wiring probe.wiring_to_device('ASSY-156>RHD2164') - # set probe - recording_w_probe = recording.set_probe(probe) - # set probe with group info and return a new recording object - recording_w_probe = recording.set_probe(probe, group_mode="by_shank") - # set probe in place, ie, modify the current recording - recording.set_probe(probe, group_mode="by_shank", in_place=True) + # set probe (modifies the recording in place) + recording.set_probe(probe) + # set probe with group info derived from shank ids (in place) + recording.set_probe(probe, group_mode="by_shank") + + # to get a *new* recording without modifying the original, use select_channels_with_probe + recording_w_probe = recording.select_channels_with_probe(probe) + recording_w_probe = recording.select_channels_with_probe(probe, group_mode="by_shank") + + # multi-probe recordings use set_probegroup / select_channels_with_probegroup + recording.set_probegroup(probegroup) + recording_w_probegroup = recording.select_channels_with_probegroup(probegroup) # retrieve probe probe_from_recording = recording.get_probe() diff --git a/examples/forhowto/plot_working_with_tetrodes.py b/examples/forhowto/plot_working_with_tetrodes.py index 0c652a5186..547e9deae1 100644 --- a/examples/forhowto/plot_working_with_tetrodes.py +++ b/examples/forhowto/plot_working_with_tetrodes.py @@ -62,15 +62,15 @@ # We can now attach the :code:`tetrode_group` to our recording. To check if this worked, we'll # plot the probe map -recording_with_probe = recording.set_probegroup(tetrode_group) -plot_probe_map(recording_with_probe) +recording.set_probegroup(tetrode_group) +plot_probe_map(recording) ############################################################################## # Looks good! Now that the recording is aware of the probe geometry, we can # begin a standard spike sorting pipeline. First, we can apply preprocessing. # Note that we apply this preprocessing on the entire bundle of tetrodes. -preprocessed_recording = spre.bandpass_filter(recording_with_probe) +preprocessed_recording = spre.bandpass_filter(recording) ############################################################################## # WARNING: a very common preprocessing step is to apply a common median diff --git a/examples/get_started/quickstart.py b/examples/get_started/quickstart.py index 2481f8569f..75d5c8d63a 100644 --- a/examples/get_started/quickstart.py +++ b/examples/get_started/quickstart.py @@ -137,8 +137,8 @@ # - # If your recording does not have a `Probe`, you can set it using `set_probe`. -# Note: `set_probe` creates a copy of the recording with the new probe, -# rather than modifying the existing recording in place. +# Note: `set_probe` modifies the recording in place. To get a new recording +# object with a subset of channels attached to a probe, use `select_channels_with_probe`. # There is more information [here](https://spikeinterface.readthedocs.io/en/latest/modules_gallery/core/plot_3_handle_probe_info.html). # Using the `spikeinterface.preprocessing` module, you can perform preprocessing on the recordings. diff --git a/examples/tutorials/core/plot_1_recording_extractor.py b/examples/tutorials/core/plot_1_recording_extractor.py index e3bfda5855..477ba165b6 100644 --- a/examples/tutorials/core/plot_1_recording_extractor.py +++ b/examples/tutorials/core/plot_1_recording_extractor.py @@ -70,7 +70,7 @@ probe.set_device_channel_indices(np.arange(7)) # then we need to actually set the probe to the recording object -recording = recording.set_probe(probe) +recording.set_probe(probe) plot_probe(probe) ############################################################################## diff --git a/examples/tutorials/core/plot_3_handle_probe_info.py b/examples/tutorials/core/plot_3_handle_probe_info.py index deff58ebb7..28d2af655a 100644 --- a/examples/tutorials/core/plot_3_handle_probe_info.py +++ b/examples/tutorials/core/plot_3_handle_probe_info.py @@ -43,8 +43,8 @@ print(other_probe) other_probe.set_device_channel_indices(np.arange(32)) -recording_2_shanks = recording.set_probe(other_probe, group_mode="by_shank") -plot_probe(recording_2_shanks.get_probe()) +recording.set_probe(other_probe, group_mode="by_shank") +plot_probe(recording.get_probe()) ############################################################################### # Now let's check what we have loaded. The :code:`group_mode='by_shank'` automatically @@ -53,11 +53,11 @@ # We can access this information either as a dict with :code:`outputs='dict'` (default) # or as a list of recordings with :code:`outputs='list'`. -print(recording_2_shanks) -print(f'\nGroup Property: {recording_2_shanks.get_property("group")}\n') +print(recording) +print(f'\nGroup Property: {recording.get_property("group")}\n') # Here we split as a dict -sub_recording_dict = recording_2_shanks.split_by(property="group", outputs='dict') +sub_recording_dict = recording.split_by(property="group", outputs='dict') # Then we can pull out the individual sub-recordings sub_rec0 = sub_recording_dict[0] diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index fcbafdb6bf..8be962ec6f 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -220,6 +220,14 @@ def id_to_index(self, id) -> int: return ind def annotate(self, **new_annotations) -> None: + """Adds annotations. + + Parameters + ---------- + **new_annotations : dict + Key-value pairs of annotations to add. If an annotation key already exists, + it will be overwritten. + """ self._annotations.update(new_annotations) def set_annotation(self, annotation_key: str, value: Any, overwrite=False) -> None: @@ -243,6 +251,24 @@ def set_annotation(self, annotation_key: str, value: Any, overwrite=False) -> No else: raise ValueError(f"{annotation_key} is already an annotation key. Use 'overwrite=True' to overwrite it") + def delete_annotation(self, annotation_key: str) -> None: + """Deletes existing annotation. + + Parameters + ---------- + annotation_key : str + The annotation key to delete + + Raises + ------ + ValueError + If the annotation key does not exist + """ + if annotation_key in self._annotations.keys(): + del self._annotations[annotation_key] + else: + raise ValueError(f"{annotation_key} is not an annotation key") + def get_preferred_mp_context(self): """ Get the preferred context for multiprocessing. @@ -441,6 +467,15 @@ def copy_metadata( if self._preferred_mp_context is not None: other._preferred_mp_context = self._preferred_mp_context + if not only_main: + self._extra_metadata_copy(other) + + def _extra_metadata_copy(self, other: "BaseExtractor") -> None: + """ + This is a hook to copy extra metadata that is not in the annotations/properties dict. + """ + pass + def to_dict( self, include_annotations: bool = False, @@ -574,6 +609,8 @@ def to_dict( folder_metadata = Path(folder_metadata).resolve().absolute().relative_to(relative_to) dump_dict["folder_metadata"] = str(folder_metadata) + self._extra_metadata_to_dict(dump_dict) + return dump_dict @staticmethod @@ -610,8 +647,6 @@ def load_metadata_from_folder(self, folder_metadata): # hack to load probe for recording folder_metadata = Path(folder_metadata) - self._extra_metadata_from_folder(folder_metadata) - # load properties prop_folder = folder_metadata / "properties" if prop_folder.is_dir(): @@ -621,6 +656,8 @@ def load_metadata_from_folder(self, folder_metadata): key = prop_file.stem self.set_property(key, values) + self._extra_metadata_from_folder(folder_metadata) + def save_metadata_to_folder(self, folder_metadata): self._extra_metadata_to_folder(folder_metadata) @@ -862,6 +899,14 @@ def _extra_metadata_to_folder(self, folder): # This implemented in BaseRecording for probe pass + def _extra_metadata_from_dict(self, dump_dict): + # This implemented in BaseRecording for probe + pass + + def _extra_metadata_to_dict(self, dump_dict): + # This implemented in BaseRecording for probe + pass + def save(self, **kwargs) -> "BaseExtractor": """ Save a SpikeInterface object. @@ -997,10 +1042,10 @@ def save_to_folder( else: warnings.warn("The extractor is not serializable to file. The provenance will not be saved.") - self.save_metadata_to_folder(folder) - # save data (done the subclass) + self.save_metadata_to_folder(folder) cached = self._save(folder=folder, verbose=verbose, **save_kwargs) + cached.load_metadata_from_folder(folder) # copy properties/ self.copy_metadata(cached) @@ -1145,8 +1190,8 @@ def _load_extractor_from_dict(dic) -> "BaseExtractor": assert extractor_class is not None and class_name is not None, "Could not load spikeinterface class" is_old_version = not _check_same_version(class_name, dic["version"]) - if is_old_version and hasattr(extractor_class, "_handle_backward_compatibility"): - new_kwargs = extractor_class._handle_backward_compatibility(new_kwargs, dic) + if is_old_version and hasattr(extractor_class, "_handle_kwargs_backward_compatibility"): + new_kwargs = extractor_class._handle_kwargs_backward_compatibility(new_kwargs, dic) # Initialize the extractor extractor = extractor_class(**new_kwargs) @@ -1155,6 +1200,10 @@ def _load_extractor_from_dict(dic) -> "BaseExtractor": for k, v in dic["properties"].items(): extractor.set_property(k, v) + extractor._extra_metadata_from_dict(dic) + if hasattr(extractor, "_handle_extractor_backward_compatibility"): + extractor._handle_extractor_backward_compatibility() + return extractor diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index b0f75930d3..c61d602026 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -20,7 +20,6 @@ class BaseRecording(BaseRecordingSnippets, TimeSeries): _main_annotations = BaseRecordingSnippets._main_annotations + ["is_filtered"] _main_properties = [ "group", - "location", "gain_to_uV", "offset_to_uV", "gain_to_physical_unit", @@ -324,6 +323,8 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): if format == "binary": from .time_series_tools import write_binary + from .binaryrecordingextractor import BinaryRecordingExtractor + from .binaryfolder import BinaryFolderRecording folder = kwargs["folder"] file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] @@ -332,8 +333,6 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): write_binary(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) - from .binaryrecordingextractor import BinaryRecordingExtractor - # This is created so it can be saved as json because the `BinaryFolderRecording` requires it loading # See the __init__ of `BinaryFolderRecording` @@ -351,9 +350,6 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): offset_to_uV=self.get_channel_offsets(), ) binary_rec.dump(folder / "binary.json", relative_to=folder) - - from .binaryfolder import BinaryFolderRecording - cached = BinaryFolderRecording(folder_path=folder) # timestamps are not saved in binary, so we have to set them explicitly @@ -389,7 +385,6 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): self, zarr_path, storage_options, verbose=verbose, **kwargs, **job_kwargs ) cached = ZarrRecordingExtractor(zarr_path, storage_options) - # timestamps are saved and restored in zarr, so no need to set them explicitly elif format == "nwb": @@ -399,18 +394,11 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: - probegroup = self.get_probegroup() - cached.set_probegroup(probegroup) - return cached def _extra_metadata_from_folder(self, folder): # load probe - folder = Path(folder) - if (folder / "probe.json").is_file(): - probegroup = read_probeinterface(folder / "probe.json") - self.set_probegroup(probegroup, in_place=True) + super()._extra_metadata_from_folder(folder) # load time vector if any for segment_index, rs in enumerate(self.segments): @@ -420,10 +408,7 @@ def _extra_metadata_from_folder(self, folder): rs.time_vector = time_vector def _extra_metadata_to_folder(self, folder): - # save probe - if self.get_property("contact_vector") is not None: - probegroup = self.get_probegroup() - write_probeinterface(folder / "probe.json", probegroup) + super()._extra_metadata_to_folder(folder) # save time vector if any for segment_index, rs in enumerate(self.segments): diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 58e91ec35c..5459203df6 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -1,11 +1,12 @@ from pathlib import Path - +from typing import Literal +import warnings import numpy as np from probeinterface import Probe, ProbeGroup, write_probeinterface, read_probeinterface, select_axes from .base import BaseExtractor -from .recording_tools import check_probe_do_not_overlap +from .recording_tools import _set_group_property_based_on_probegroup, check_probe_do_not_overlap from warnings import warn @@ -19,6 +20,7 @@ def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype BaseExtractor.__init__(self, channel_ids) self._sampling_frequency = float(sampling_frequency) self._dtype = np.dtype(dtype) + self._probegroup = None @property def channel_ids(self): @@ -51,198 +53,258 @@ def has_scaleable_traces(self) -> bool: return True def has_probe(self) -> bool: - return "contact_vector" in self.get_property_keys() + # probe group is saved and loaded to binary/zarr, so we don't need to check for legacy "contact_vector" property + return self._probegroup is not None + + def has_3d_probe(self) -> bool: + if self.has_probe(): + probe = self.get_probegroup().probes[0] + return probe.ndim == 3 + else: + return False def has_channel_location(self) -> bool: - return self.has_probe() or "location" in self.get_property_keys() + return self.has_probe() def is_filtered(self): # the is_filtered is handle with annotation return self._annotations.get("is_filtered", False) - def set_probe(self, probe, group_mode="auto", in_place=False): + def remove_probe(self): + """ + Removes probe information + """ + self._probegroup = None + + def set_probe( + self, + probe: Probe, + group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] = "auto", + in_place: bool | None = None, + ) -> None: + """ + Attach a Probe object to a recording. + + Parameters + ---------- + probe: Probe + The probe to be attached to the recording + group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" + How to add the "group" property. + "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks + and two sides are present. + in_place: (deprecated) bool | None, default: None + Deprecated argument to indicate whether to modify the recording in place + or return a new recording. The function is always in place now. + Use the `recording.select_channels_with_probegroup()` method instead of `in_place=False` + to return a new recording with a channel selection to match the probe/probegroup. + + Notes + ----- + Internally, this will construct a ProbeGroup with the probe and call `set_probegroup()`. + """ + assert isinstance(probe, Probe), "The input must be a Probe object" + probegroup = ProbeGroup() + probegroup.add_probe(probe) + # TODO: remove return in 0.106.0 after removing in_place argument + return self.set_probegroup(probegroup, group_mode=group_mode, in_place=in_place) + + def set_probegroup( + self, + probegroup: ProbeGroup, + group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] = "auto", + in_place: bool | None = None, + check_overlap: bool = True, + ) -> None: """ - Attach a list of Probe object to a recording. + Attach a ProbeGroup or dict to a recording. + For this Probe.device_channel_indices is used to link contacts to recording channels. + After removing unconnected contacts, the number of connected contacts must match the + number of channels in the recording. If this is not the case, use the `recording.select_with_probegroup()` + method instead to return a new recording with a channel selection to match the probe/probegroup. + + Note: The probe order of the probegroup is not kept. Channel ids are re-ordered to match the channel_ids of the recording. Parameters ---------- - probe_or_probegroup: Probe, list of Probe, or ProbeGroup + probe_or_probegroup: ProbeGroup, or dict The probe(s) to be attached to the recording group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" How to add the "group" property. "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. - in_place: bool - False by default. - Useful internally when extractor do self.set_probegroup(probe) + in_place: (deprecated) bool | None, default: None + Deprecated argument to indicate whether to modify the recording in place + or return a new recording. The function is always in place now. + Use the `recording.select_channels_with_probegroup()` method instead of `in_place=False` + to return a new recording with a channel selection to match the probe/probegroup. + check_overlap: bool, default: True + If True, check that the probes in the probegroup do not overlap in space. + This should be set to False when aggregating recordings whose probes share + the same physical space (e.g. channels split by group from a single probe), + where contact positions are unique but probe bounding boxes may overlap. + """ + if in_place is not None: + warnings.warn( + "The 'in_place' argument is deprecated and will be removed in version 0.106.0. " + "The `set_probe/probegroup()` are now always in place; please remove the in_place argument.", + FutureWarning, + stacklevel=2, + ) + if not in_place: + return self.select_channels_with_probegroup(probegroup, group_mode=group_mode) + + if check_overlap and len(probegroup.probes) > 0: + check_probe_do_not_overlap(probegroup.probes) + + probegroup_sorted = self._get_probegroup_based_on_device_channel_indices(probegroup) + + if probegroup_sorted.get_contact_count() != self.get_num_channels(): + raise ValueError( + "The probe/probegroup must have the same number of connected contacts " + f"as the number of channels as the recording, but the probe has {probegroup.get_contact_count()} " + f"connected channels and the recording has {self.get_num_channels()} channels. " + "Use the `recording.select_channels_with_probegroup()` method instead to return a new recording with " + "a channel selection to match the probe/probegroup." + ) + + device_channel_indices = probegroup_sorted.get_global_device_channel_indices()["device_channel_indices"] + if not np.array_equal(device_channel_indices, np.arange(self.get_num_channels())): + raise ValueError( + "`device_channel_indices` is wrong! " + "It should contain only values [0...n-1] after ordering, " + f"but they are: {device_channel_indices}" + ) + + # probegroup_sorted.set_global_device_channel_indices(np.arange(probegroup_sorted.get_contact_count())) + self._probegroup = probegroup_sorted + + # Handle and set channel groups + _set_group_property_based_on_probegroup(self, probegroup_sorted, group_mode=group_mode) + + def select_channels_with_probe( + self, probe: Probe, group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] = "auto" + ) -> "BaseRecordingSnippets": + """ + Returns a new recording with channels selected based on the probe. + + Parameters + ---------- + probe: Probe + The probe to be used for channel selection + group_mode: "auto" | "by_probe" | "by_shank" | + "by_side", default: "auto" + How to add the "group" property. + "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. Returns ------- sub_recording: BaseRecording A view of the recording (ChannelSlice or clone or itself) """ - assert isinstance(probe, Probe), "must give Probe" + assert isinstance(probe, Probe), "The input must be a Probe object" probegroup = ProbeGroup() probegroup.add_probe(probe) - return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place) - - def set_probegroup(self, probegroup, group_mode="auto", in_place=False): - return self._set_probes(probegroup, group_mode=group_mode, in_place=in_place) + return self.select_channels_with_probegroup(probegroup, group_mode=group_mode) - def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): + def select_channels_with_probegroup( + self, probegroup: ProbeGroup, group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] = "auto" + ) -> "BaseRecordingSnippets": """ - Attach a list of Probe objects to a recording. - For this Probe.device_channel_indices is used to link contacts to recording channels. - If some contacts of the Probe are not connected (device_channel_indices=-1) - then the recording is "sliced" and only connected channel are kept. - - The probe order is not kept. Channel ids are re-ordered to match the channel_ids of the recording. - + Selects channels based on the given ProbeGroup and returns a new recording with the selected channels. Parameters ---------- - probe_or_probegroup: Probe, list of Probe, or ProbeGroup - The probe(s) to be attached to the recording - group_mode: "auto" | "by_probe" | "by_shank" | "by_side", default: "auto" + probegroup: ProbeGroup + The probegroup to be used for channel selection + group_mode: "auto" | "by_probe" | "by_shank" | + "by_side", default: "auto" How to add the "group" property. - "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks and two sides are present. - in_place: bool - False by default. - Useful internally when extractor do self.set_probegroup(probe) + "auto" is the best splitting possible that can be all at once when multiple probes, multiple shanks + and two sides are present. Returns ------- sub_recording: BaseRecording A view of the recording (ChannelSlice or clone or itself) """ - assert group_mode in ( - "auto", - "by_probe", - "by_shank", - "by_side", - ), "'group_mode' can be 'auto' 'by_probe' 'by_shank' or 'by_side'" - - # handle several input possibilities - if isinstance(probe_or_probegroup, Probe): - probegroup = ProbeGroup() - probegroup.add_probe(probe_or_probegroup) - elif isinstance(probe_or_probegroup, ProbeGroup): - probegroup = probe_or_probegroup - elif isinstance(probe_or_probegroup, list): - assert all([isinstance(e, Probe) for e in probe_or_probegroup]) - probegroup = ProbeGroup() - for probe in probe_or_probegroup: - probegroup.add_probe(probe) + probegroup_sorted = self._get_probegroup_based_on_device_channel_indices(probegroup) + if probegroup_sorted.get_contact_count() > 0: + sorted_dci = probegroup_sorted.get_global_device_channel_indices()["device_channel_indices"] + new_channel_ids = self.channel_ids[sorted_dci] + probegroup_sorted.set_global_device_channel_indices(np.arange(len(new_channel_ids))) + if np.array_equal(new_channel_ids, self.channel_ids): + sub_recording = self.clone() + else: + sub_recording = self.select_channels(new_channel_ids) + sub_recording._probegroup = probegroup_sorted + _set_group_property_based_on_probegroup(sub_recording, probegroup_sorted, group_mode=group_mode) else: - raise ValueError("must give Probe or ProbeGroup or list of Probe") + sub_recording = self.select_channels([]) # empty recording + sub_recording._probegroup = ProbeGroup() # empty probegroup + return sub_recording - # check that the probe do not overlap - num_probes = len(probegroup.probes) - if num_probes > 1: - check_probe_do_not_overlap(probegroup.probes) + def _get_probegroup_based_on_device_channel_indices(self, probegroup: ProbeGroup) -> ProbeGroup: + """ + Returns a new probegroup sorted based on their device_channel_indices. + This is useful to ensure that the probes are ordered correctly when attached to a recording. + Also checks that the device_channel_indices are consistent with the recording channel count and + contacts are unique across probes in the probegroup. + + Parameters + ---------- + probegroup : ProbeGroup + The probegroup to be sorted. + + Returns + ------- + ProbeGroup + The sorted probegroup. + """ + if not isinstance(probegroup, ProbeGroup): + raise ValueError("The input must be a ProbeGroup or dict") - # handle not connected channels assert all( probe.device_channel_indices is not None for probe in probegroup.probes ), "Probe must have device_channel_indices" - # this is a vector with complex fileds (dataframe like) that handle all contact attr - probe_as_numpy_array = probegroup.to_numpy(complete=True) - - # keep only connected contact ( != -1) - keep = probe_as_numpy_array["device_channel_indices"] >= 0 - if np.any(~keep): - warn("The given probes have unconnected contacts: they are removed") - - probe_as_numpy_array = probe_as_numpy_array[keep] - - device_channel_indices = probe_as_numpy_array["device_channel_indices"] - order = np.argsort(device_channel_indices) - device_channel_indices = device_channel_indices[order] - - # check TODO: Where did this came from? - number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) - if number_of_device_channel_indices >= self.get_num_channels(): - error_msg = ( - f"The given Probe either has 'device_channel_indices' that does not match channel count \n" - f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" - f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" - f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n" - f"device_channel_indices are the following: {device_channel_indices} \n" - f"recording channels are the following: {self.get_channel_ids()} \n" - ) - raise ValueError(error_msg) - - new_channel_ids = self.get_channel_ids()[device_channel_indices] - probe_as_numpy_array = probe_as_numpy_array[order] - probe_as_numpy_array["device_channel_indices"] = np.arange(probe_as_numpy_array.size, dtype="int64") - - # create recording : channel slice or clone or self - if in_place: - if not np.array_equal(new_channel_ids, self.get_channel_ids()): - raise Exception("set_probe(inplace=True) must have all channel indices") - sub_recording = self - else: - if np.array_equal(new_channel_ids, self.get_channel_ids()): - sub_recording = self.clone() + # Remove unconnected contacts and slice the probe group accordingly + device_channel_indices = probegroup.get_global_device_channel_indices()["device_channel_indices"] + keep_indices = np.flatnonzero(device_channel_indices >= 0) + if len(keep_indices) < len(device_channel_indices): + if len(keep_indices) == 0: + device_channel_indices = np.array([], dtype="int64") else: - sub_recording = self.select_channels(new_channel_ids) - - # create a vector that handle all contacts in property - sub_recording.set_property("contact_vector", probe_as_numpy_array, ids=None) - - # planar_contour is saved in annotations - for probe_index, probe in enumerate(probegroup.probes): - contour = probe.probe_planar_contour - if contour is not None: - sub_recording.set_annotation(f"probe_{probe_index}_planar_contour", contour, overwrite=True) - - # duplicate positions to "locations" property - ndim = probegroup.ndim - locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") - for i, dim in enumerate(["x", "y", "z"][:ndim]): - locations[:, i] = probe_as_numpy_array[dim] - sub_recording.set_property("location", locations, ids=None) - - # handle groups - has_shank_id = "shank_ids" in probe_as_numpy_array.dtype.fields - has_contact_side = "contact_sides" in probe_as_numpy_array.dtype.fields - if group_mode == "auto": - group_keys = ["probe_index"] - if has_shank_id: - group_keys += ["shank_ids"] - if has_contact_side: - group_keys += ["contact_sides"] - elif group_mode == "by_probe": - group_keys = ["probe_index"] - elif group_mode == "by_shank": - assert has_shank_id, "shank_ids is None in probe, you cannot group by shank" - group_keys = ["probe_index", "shank_ids"] - elif group_mode == "by_side": - assert has_contact_side, "contact_sides is None in probe, you cannot group by side" - if has_shank_id: - group_keys = ["probe_index", "shank_ids", "contact_sides"] - else: - group_keys = ["probe_index", "contact_sides"] - groups = np.zeros(probe_as_numpy_array.size, dtype="int64") - unique_keys = np.unique(probe_as_numpy_array[group_keys]) - for group, a in enumerate(unique_keys): - mask = np.ones(probe_as_numpy_array.size, dtype=bool) - for k in group_keys: - mask &= probe_as_numpy_array[k] == a[k] - groups[mask] = group - sub_recording.set_property("group", groups, ids=None) - - # add probe annotations to recording - probes_info = [] - for probe in probegroup.probes: - probes_info.append(probe.annotations) - sub_recording.annotate(probes_info=probes_info) + probegroup = probegroup.get_slice(keep_indices) + device_channel_indices = device_channel_indices[keep_indices] + + if len(device_channel_indices) > 0: + # Check consistency of device_channel_indices with the recording channel count + number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) + if number_of_device_channel_indices >= self.get_num_channels(): + error_msg = ( + f"The given Probe either has 'device_channel_indices' that does not match channel count \n" + f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" + f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" + f"If using all channels remember that python is 0-indexed so max device_channel_index should be {self.get_num_channels() - 1} \n" + f"device_channel_indices are the following: {device_channel_indices} \n" + f"recording channels are the following: {self.get_channel_ids()} \n" + ) + raise ValueError(error_msg) + # Now slice the probe using the device channel indices to match the recording channel_ids + order = np.argsort(device_channel_indices) + probegroup = probegroup.get_slice(order) + else: + warn( + "No connected channels in the probegroup! " + "The probegroup will be attached but no channel will be selected." + ) + probegroup = ProbeGroup() # empty probegroup - return sub_recording + return probegroup def get_probe(self): probes = self.get_probes() - assert len(probes) == 1, "there are several probe use .get_probes() or get_probegroup()" + assert len(probes) == 1, "There are several probe use .get_probes() or get_probegroup()" return probes[0] def get_probes(self): @@ -250,43 +312,92 @@ def get_probes(self): return probegroup.probes def get_probegroup(self): - arr = self.get_property("contact_vector") - if arr is None: - positions = self.get_property("location") - if positions is None: - raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") - else: - warn("There is no Probe attached to this recording. Creating a dummy one with contact positions") - probe = self.create_dummy_probe_from_locations(positions) - # probe.create_auto_shape() - probegroup = ProbeGroup() - probegroup.add_probe(probe) - else: - probegroup = ProbeGroup.from_numpy(arr) + if self._probegroup is None: + raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") + return self._probegroup - if "probes_info" in self.get_annotation_keys(): - probes_info = self.get_annotation("probes_info") - for probe, probe_info in zip(probegroup.probes, probes_info): - probe.annotations = probe_info - - for probe_index, probe in enumerate(probegroup.probes): - contour = self.get_annotation(f"probe_{probe_index}_planar_contour") - if contour is not None: - probe.set_planar_contour(contour) - return probegroup + def _extra_metadata_copy(self, other): + if self._probegroup is not None: + other._probegroup = self._probegroup.copy() def _extra_metadata_from_folder(self, folder): - # load probe + # load probe from folder + # Note: we don't need any fix for legacy probegroups, since the + # set_probegroup() method will handle the device_channel_indices + # sorting and global contact order folder = Path(folder) - if (folder / "probe.json").is_file(): - probegroup = read_probeinterface(folder / "probe.json") - self.set_probegroup(probegroup, in_place=True) + probe_file = folder / "probegroup.json" + legacy_probe_file = folder / "probe.json" + if probe_file.is_file(): + probegroup = read_probeinterface(probe_file) + self.set_probegroup(probegroup) + elif legacy_probe_file.is_file(): + probegroup = read_probeinterface(legacy_probe_file) + self.set_probegroup(probegroup) + + # remove "contact_vector" property if present as it is not needed anymore + if "contact_vector" in self.get_property_keys(): + self.delete_property("contact_vector") def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() - write_probeinterface(folder / "probe.json", probegroup) + write_probeinterface(folder / "probegroup.json", probegroup) + + def _extra_metadata_from_dict(self, dump_dict): + # load probe and hanlde backward-compatibility with legacy "contact_vector"/"location" property + if "probegroup" in dump_dict: + # this is for SI>=0.105.0 + probegroup = dump_dict["probegroup"] + self._probegroup = ProbeGroup.from_dict(probegroup) + + def _extra_metadata_to_dict(self, dump_dict): + # save probe + if self.has_probe(): + probegroup = self.get_probegroup() + dump_dict["probegroup"] = probegroup.to_dict() + + def _handle_extractor_backward_compatibility(self): + """ + This handles backward compatibility for recordings that were saved with older versions of spikeinterface. + + Options: + + 1. "contact_vector" property: This was used in versions < 0.105.0 to store probe information, when saved to + pickle + 2. "location" property: This was used in versions < 0.105.0 to store probe information, when saved to JSON + (no contact_vector saved) + 3. probe annotation: probe annotations and contours were saved as recording properties in versions < 0.105.0, + but now they are saved in the probegroup. This method will copy the annotations and the contour to the probes + in the the probegroup and remove the annotations from the recording. + """ + if self._probegroup is None: + check_for_probes_info = False + if "contact_vector" in self.get_property_keys(): + # this is for SI<0.105.0 and from pickle + contact_vector = self.get_property("contact_vector") + probegroup = ProbeGroup.from_numpy(contact_vector=contact_vector) + self._probegroup = probegroup + check_for_probes_info = True + elif "location" in self.get_property_keys(): + # this is for SI<0.105.0 and from JSON (no contact_vector saved) + locations = self.get_property("location") + self.set_dummy_probe_from_locations(locations) + check_for_probes_info = True + + if check_for_probes_info: + for i, probe in enumerate(self._probegroup.probes): + if "probes_info" in self._annotations: + probe_dict = self._annotations["probes_info"][i] + probe.annotations.update(probe_dict) + if f"probe_{i}_planar_contour" in self._annotations: + contour = self.get_annotation(f"probe_{i}_planar_contour") + if contour is not None: + probe.set_planar_contour(contour) + self.delete_annotation(f"probe_{i}_planar_contour") + if "probes_info" in self._annotations: + self._annotations.pop("probes_info") def create_dummy_probe_from_locations(self, locations, shape="circle", shape_params={"radius": 1}, axes="xy"): """ @@ -330,51 +441,55 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params ---------- locations : np.array Array with channel locations (num_channels, ndim) [ndim can be 2 or 3] - shape : str, default: default: "circle" + shape : str, default: "circle" Electrode shapes shape_params : dict, default: {"radius": 1} Shape parameters axes : "xy" | "yz" | "xz", default: "xy" If ndim is 3, indicates the axes that define the plane of the electrodes """ - probe = self.create_dummy_probe_from_locations(locations, shape=shape, shape_params=shape_params, axes=axes) - self.set_probe(probe, in_place=True) + probe = self.create_dummy_probe_from_locations( + np.array(locations), shape=shape, shape_params=shape_params, axes=axes + ) + self.set_probe(probe) def set_channel_locations(self, locations, channel_ids=None): - if self.get_property("contact_vector") is not None: - raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") - self.set_property("location", locations, ids=channel_ids) + warnings.warn( + ( + "set_channel_locations() is deprecated and will be removed in version 0.106.0. " + "If you want to set probe information, use `set_dummy_probe_from_locations()`." + ), + DeprecationWarning, + stacklevel=2, + ) + self.set_dummy_probe_from_locations(locations, axes="xy") def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray: if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - # here we bypass the probe reconstruction so this works both for probe and probegroup - ndim = len(axes) - all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") - for i, dim in enumerate(axes): - all_positions[:, i] = contact_vector[dim] - positions = all_positions[channel_indices] - return positions - else: - locations = self.get_property("location") - if locations is None: - raise Exception("There are no channel locations") - locations = np.asarray(locations)[channel_indices] - return select_axes(locations, axes) + if not self.has_probe(): + raise ValueError("get_channel_locations(..) needs a probe to be attached to the recording") + probegroup = self.get_probegroup() + contact_positions = probegroup.get_global_contact_positions() + return select_axes(contact_positions, axes)[channel_indices] - def has_3d_locations(self) -> bool: - return self.get_property("location").shape[1] == 3 + def is_probe_3d(self) -> bool: + if not self.has_probe(): + raise ValueError("is_probe_3d() needs a probe to be attached to the recording") + probegroup = self.get_probegroup() + return probegroup.ndim == 3 def clear_channel_locations(self, channel_ids=None): - if channel_ids is None: - n = self.get_num_channel() - else: - n = len(channel_ids) - locations = np.zeros((n, 2)) * np.nan - self.set_property("location", locations, ids=channel_ids) + warnings.warn( + ( + "clear_channel_locations() is deprecated and will be removed in version 0.106.0. " + "If you want to remove probe information, use `reset_probe()`." + ), + DeprecationWarning, + stacklevel=2, + ) + self.remove_probe() def set_channel_groups(self, groups, channel_ids=None): if "probes" in self._annotations: @@ -429,12 +544,12 @@ def planarize(self, axes: str = "xy"): BaseRecording The recording with 2D positions """ - assert self.has_3d_locations, "The 'planarize' function needs a recording with 3d locations" + assert self.has_3d_probe(), "The 'planarize' function needs a recording with 3d locations" assert len(axes) == 2, "You need to specify 2 dimensions (e.g. 'xy', 'zy')" probe2d = self.get_probe().to_2d(axes=axes) recording2d = self.clone() - recording2d.set_probe(probe2d, in_place=True) + recording2d.set_probe(probe2d) return recording2d diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index b56a093ccc..a1b0563186 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -11,7 +11,7 @@ class BaseSnippets(BaseRecordingSnippets): Abstract class representing several multichannel snippets. """ - _main_properties = ["group", "location", "gain_to_uV", "offset_to_uV"] + _main_properties = ["group", "gain_to_uV", "offset_to_uV"] _main_features = [] def __init__(self, sampling_frequency: float, nbefore: int | None, snippet_len: int, channel_ids: list, dtype): @@ -259,7 +259,7 @@ def _save(self, format="npy", **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) diff --git a/src/spikeinterface/core/binaryfolder.py b/src/spikeinterface/core/binaryfolder.py index 4b9d7b7d09..e9986193a3 100644 --- a/src/spikeinterface/core/binaryfolder.py +++ b/src/spikeinterface/core/binaryfolder.py @@ -3,6 +3,8 @@ import numpy as np +from probeinterface import read_probeinterface + from .binaryrecordingextractor import BinaryRecordingExtractor from .core_tools import define_function_from_class, make_paths_absolute @@ -39,8 +41,57 @@ def __init__(self, folder_path): BinaryRecordingExtractor.__init__(self, **d["kwargs"]) - folder_metadata = folder_path - self.load_metadata_from_folder(folder_metadata) + # Load properties + prop_folder = folder_path / "properties" + if prop_folder.is_dir(): + for prop_file in prop_folder.iterdir(): + if prop_file.suffix == ".npy": + values = np.load(prop_file, allow_pickle=True) + key = prop_file.stem + if key == "contact_vector": + continue + self.set_property(key, values) + + # Load the probegroup + probe_file = folder_path / "probegroup.json" + # In spikeinterface version < 0.105.0, the probegroup was saved in a file called probe.json + legacy_probe_file = folder_path / "probe.json" + probegroup = None + if probe_file.is_file(): + # This is the new version: the probegroup is already ordered correctly + probegroup = read_probeinterface(probe_file) + elif legacy_probe_file.is_file(): + probegroup = read_probeinterface(legacy_probe_file) + order = np.argsort(probegroup.to_numpy(complete=True)["device_channel_indices"]) + if not np.array_equal(order, np.arange(len(order))): + # In spikeinterface version < 0.105.0, the order was saved in the contact vector, but not + # in the probegroup. We need to check if the order is correct and if not, we need to reorder + # the probegroup to match the channel ids. + probegroup = probegroup.get_slice(order) + + # In some older SI versions, before #4300, the probe annotations were + # saved to the recording annotations as `probes_info`. If this is the + # case, we can copy the annotations to the probegroup and delete the + # `probes_info` from the recording annotations. + si_folder_json = folder_path / "si_folder.json" + if si_folder_json.is_file(): + with open(si_folder_json, "r") as f: + si_folder_dict = json.load(f) + if "annotations" in si_folder_dict: + si_annotations = si_folder_dict["annotations"] + if "probes_info" in si_annotations: + probes_info = si_annotations.pop("probes_info") + for probe, probe_info in zip(probegroup.probes, probes_info): + probe.annotations.update(probe_info) + + if probegroup is not None: + self._probegroup = probegroup + + # Load time vectors if any + for segment_index, rs in enumerate(self.segments): + time_file = folder_path / f"times_cached_seg{segment_index}.npy" + if time_file.is_file(): + rs.time_vector = np.load(time_file, mmap_mode="r") self._kwargs = dict(folder_path=str(Path(folder_path).absolute())) self._bin_kwargs = d["kwargs"] diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 697aab875e..c9ee95c8bb 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -2,6 +2,7 @@ import numpy as np +from probeinterface import ProbeGroup from .baserecording import BaseRecording, BaseRecordingSegment @@ -90,32 +91,34 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record break for prop_name, prop_values in property_dict.items(): - if prop_name == "contact_vector": - # remap device channel indices correctly - prop_values["device_channel_indices"] = np.arange(self.get_num_channels()) self.set_property(key=prop_name, values=prop_values) - # if locations are present, check that they are all different! - if "location" in self.get_property_keys(): - location_tuple = [tuple(loc) for loc in self.get_property("location")] - assert len(set(location_tuple)) == self.get_num_channels(), ( - "Locations are not unique! " "Cannot aggregate recordings!" + # Aggregate probe information + all_probegroups = [rec.get_probegroup() for rec in recording_list if rec.has_probe()] + if len(all_probegroups) == len(recording_list): + # check that contact positions are unique across all recordings + all_positions = [] + for probegroup in all_probegroups: + for probe in probegroup.probes: + all_positions.extend(probe.contact_positions) + assert len(np.unique(np.array(all_positions), axis=0)) == len( + all_positions + ), "Contact positions are not unique! Cannot aggregate recordings." + + # Now make a new probegroup with all probes and set global device channel indices + probegroup_agg = ProbeGroup() + for probegroup in all_probegroups: + for probe in probegroup.probes: + probegroup_agg.add_probe(probe.copy()) + probegroup_agg.set_global_device_channel_indices(np.arange(num_all_channels)) + # contact positions are already checked to be unique above; probe bounding + # boxes may overlap when aggregating channels split from a single probe + self.set_probegroup(probegroup_agg, check_overlap=False) + elif len(all_probegroups) > 0 and len(all_probegroups) < len(recording_list): + raise ValueError( + "Some recordings have probes while others do not. Cannot aggregate recordings with inconsistent probe information." ) - planar_contour_keys = [ - key for recording in recording_list for key in recording.get_annotation_keys() if "planar_contour" in key - ] - if len(planar_contour_keys) > 0: - if all( - k == planar_contour_keys[0] for k in planar_contour_keys - ): # we add the 'planar_contour' annotations only if there is a unique one in the recording_list - planar_contour_key = planar_contour_keys[0] - collect_planar_contours = [] - for rec in recording_list: - collect_planar_contours.append(rec.get_annotation(planar_contour_key)) - if all(np.array_equal(arr, collect_planar_contours[0]) for arr in collect_planar_contours): - self.set_annotation(planar_contour_key, collect_planar_contours[0]) - # finally add segments, we need a channel mapping ch_id = 0 channel_map = {} diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index de693d5c26..8669e3c90c 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -62,10 +62,11 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) self._parent = parent_recording # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if self._parent.has_probe(): + parent_probegroup = self._parent.get_probegroup() + sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(self._channel_ids))) + self.set_probegroup(sliced_probegroup) # update dump dict self._kwargs = { @@ -152,10 +153,11 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): parent_snippets.copy_metadata(self, only_main=False, ids=self._channel_ids) # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if self._parent_snippets.has_probe(): + parent_probegroup = self._parent_snippets.get_probegroup() + sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(self._channel_ids))) + self.set_probegroup(sliced_probegroup) # update dump dict self._kwargs = { diff --git a/src/spikeinterface/core/core_tools.py b/src/spikeinterface/core/core_tools.py index fda08ff1b0..02ee9bd915 100644 --- a/src/spikeinterface/core/core_tools.py +++ b/src/spikeinterface/core/core_tools.py @@ -10,6 +10,7 @@ from collections import namedtuple import inspect +from probeinterface import ProbeGroup import numpy as np @@ -148,6 +149,9 @@ def default(self, obj): if isinstance(obj, Motion): return obj.to_dict() + if isinstance(obj, ProbeGroup): + return obj.to_dict() + # The base-class handles the assertion return super().default(obj) diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 4fa68ebec0..9ca5cb2df9 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -86,7 +86,7 @@ def generate_recording( if ndim == 3: probe = probe.to_3d() probe.set_device_channel_indices(np.arange(num_channels)) - recording.set_probe(probe, in_place=True) + recording.set_probe(probe) recording.name = "SyntheticRecording" @@ -675,7 +675,7 @@ def generate_snippets( if set_probe: probe = recording.get_probe() - snippets = snippets.set_probe(probe) + snippets.set_probe(probe) return snippets, sorting @@ -2462,7 +2462,7 @@ def generate_ground_truth_recording( upsample_vector=upsample_vector, ) recording.annotate(is_filtered=True) - recording.set_probe(probe, in_place=True) + recording.set_probe(probe) recording.set_channel_gains(1.0) recording.set_channel_offsets(0.0) diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index 74b3ccb56e..d545e17f20 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -8,7 +8,8 @@ import numpy as np -from .core_tools import add_suffix, make_shared_array +from probeinterface import ProbeGroup + from .job_tools import ( ensure_chunk_size, divide_segment_into_chunks, @@ -723,6 +724,62 @@ def check_probe_do_not_overlap(probes): raise Exception("Probes are overlapping! Retrieve locations of single probes separately") +def _set_group_property_based_on_probegroup( + recording, probegroup: ProbeGroup, group_mode: Literal["auto", "by_probe", "by_shank", "by_side"] +): + """ + Set the group property for a recording based on a ProbeGroup. + Use "auto" (default) to automatically determine the grouping based on the available + information in the ProbeGroup (default: probe + shank + side if available). + + Parameters + ---------- + recording : BaseRecording + The recording object + probegroup : ProbeGroup + The ProbeGroup object + group_mode : {"auto", "by_probe", "by_shank", "by_side"} + The mode for grouping channels + """ + if not isinstance(probegroup, ProbeGroup): + raise ValueError("`probegroup` must be a ProbeGroup instance.") + assert group_mode in ( + "auto", + "by_probe", + "by_shank", + "by_side", + ), "'group_mode' can be 'auto' 'by_probe' 'by_shank' or 'by_side'" + + probe_array = probegroup.to_numpy(complete=True) + has_shank_id = "shank_ids" in probe_array.dtype.fields + has_contact_side = "contact_sides" in probe_array.dtype.fields + if group_mode == "auto": + group_keys = ["probe_index"] + if has_shank_id: + group_keys += ["shank_ids"] + if has_contact_side: + group_keys += ["contact_sides"] + elif group_mode == "by_probe": + group_keys = ["probe_index"] + elif group_mode == "by_shank": + assert has_shank_id, "shank_ids is None in probe, you cannot group by shank" + group_keys = ["probe_index", "shank_ids"] + elif group_mode == "by_side": + assert has_contact_side, "contact_sides is None in probe, you cannot group by side" + if has_shank_id: + group_keys = ["probe_index", "shank_ids", "contact_sides"] + else: + group_keys = ["probe_index", "contact_sides"] + groups = np.zeros(probe_array.size, dtype="int64") + unique_keys = np.unique(probe_array[group_keys]) + for group, a in enumerate(unique_keys): + mask = np.ones(probe_array.size, dtype=bool) + for k in group_keys: + mask &= probe_array[k] == a[k] + groups[mask] = group + recording.set_property("group", groups, ids=None) + + def get_rec_attributes(recording): """ Construct rec_attributes from recording object @@ -738,8 +795,6 @@ def get_rec_attributes(recording): The rec_attributes dictionary """ properties_to_attrs = deepcopy(recording._properties) - if "contact_vector" in properties_to_attrs: - del properties_to_attrs["contact_vector"] rec_attributes = dict( channel_ids=recording.channel_ids, sampling_frequency=recording.get_sampling_frequency(), diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index b5885598fe..c166194338 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -23,7 +23,7 @@ from spikeinterface.core import BaseRecording, BaseSorting, aggregate_channels, aggregate_units from spikeinterface.core.waveform_tools import has_exceeding_spikes -from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match +from .recording_tools import get_rec_attributes, do_recording_attributes_match, check_probe_do_not_overlap from .core_tools import ( check_json, retrieve_importing_provenance, @@ -365,7 +365,6 @@ def create( ) # check that multiple probes are non-overlapping all_probes = recording.get_probegroup().probes - check_probe_do_not_overlap(all_probes) if has_exceeding_spikes(sorting=sorting, recording=recording): warnings.warn( diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index bb6db4cb66..1272332db0 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -12,7 +12,13 @@ from probeinterface import Probe, ProbeGroup, generate_linear_probe -from spikeinterface.core import BinaryRecordingExtractor, NumpyRecording, load, get_default_zarr_compressor +from spikeinterface.core import ( + BinaryRecordingExtractor, + NumpyRecording, + load, + get_default_zarr_compressor, + aggregate_channels, +) from spikeinterface.core.base import BaseExtractor from spikeinterface.core.testing import check_recordings_equal @@ -197,16 +203,21 @@ def test_BaseRecording(create_cache_folder): ) probe.create_auto_shape() - rec_p = rec.set_probe(probe, group_mode="auto") - rec_p = rec.set_probe(probe, group_mode="by_shank") - rec_p = rec.set_probe(probe, group_mode="by_probe") + rec_p = rec.select_channels_with_probe(probe, group_mode="auto") + positions2 = rec_p.get_channel_locations() + assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) + + rec_p = rec.select_channels_with_probe(probe, group_mode="by_shank") + positions2 = rec_p.get_channel_locations() + assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) + + rec_p = rec.select_channels_with_probe(probe, group_mode="by_probe") positions2 = rec_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) probe2 = rec_p.get_probe() positions3 = probe2.contact_positions assert np.array_equal(positions2, positions3) - assert np.array_equal(probe2.device_channel_indices, [0, 1]) # test save with probe @@ -243,13 +254,13 @@ def test_BaseRecording(create_cache_folder): probe.create_auto_shape() traces = np.zeros((1000, 12), dtype="int16") rec = NumpyRecording([traces], 30000.0) - rec1 = rec.set_probe(probe, group_mode="auto") + rec1 = rec.select_channels_with_probe(probe, group_mode="auto") assert np.unique(rec1.get_property("group")).size == 4 - rec2 = rec.set_probe(probe, group_mode="by_probe") + rec2 = rec.select_channels_with_probe(probe, group_mode="by_probe") assert np.unique(rec2.get_property("group")).size == 1 - rec3 = rec.set_probe(probe, group_mode="by_shank") + rec3 = rec.select_channels_with_probe(probe, group_mode="by_shank") assert np.unique(rec3.get_property("group")).size == 2 - rec4 = rec.set_probe(probe, group_mode="by_side") + rec4 = rec.select_channels_with_probe(probe, group_mode="by_side") assert np.unique(rec4.get_property("group")).size == 4 # set unconnected probe @@ -259,7 +270,7 @@ def test_BaseRecording(create_cache_folder): probe.set_device_channel_indices([-1, -1, -1]) probe.create_auto_shape() - rec_empty_probe = rec.set_probe(probe, group_mode="by_shank") + rec_empty_probe = rec.select_channels_with_probe(probe, group_mode="by_shank") assert rec_empty_probe.channel_ids.size == 0 # test scaling parameters @@ -286,8 +297,9 @@ def test_BaseRecording(create_cache_folder): rec_int16.set_property("offset_to_uV", [0.0] * 5) # Test deprecated return_scaled parameter - traces_float32_old = rec_int16.get_traces(return_scaled=True) # Keep this for testing the deprecation warning - assert traces_float32_old.dtype == "float32" + with pytest.warns(DeprecationWarning, match="`return_scaled` is deprecated"): + traces_float32_old = rec_int16.get_traces(return_scaled=True) # Keep this for testing the deprecation warning + assert traces_float32_old.dtype == "float32" # Test new return_in_uV parameter traces_float32_new = rec_int16.get_traces(return_in_uV=True) @@ -344,7 +356,7 @@ def test_BaseRecording(create_cache_folder): # test 3d probe rec_3d = generate_recording(ndim=3, num_channels=30) - locations_3d = rec_3d.get_property("location") + locations_3d = rec_3d.get_probe().contact_positions locations_xy = rec_3d.get_channel_locations(axes="xy") assert np.allclose(locations_xy, locations_3d[:, [0, 1]]) @@ -413,42 +425,14 @@ def test_json_pickle_equivalence(create_cache_folder): for key, value in data_json.items(): # skip probe info, since pickle keeps some additional information - if key not in ["properties"]: - if isinstance(value, dict): + if key not in ["properties", "probegroup"]: + if isinstance(value, dict) and isinstance(data_pickle[key], dict): for sub_key, sub_value in value.items(): assert np.all(sub_value == data_pickle[key][sub_key]) else: assert np.all(value == data_pickle[key]) -def test_interleaved_probegroups(): - recording = generate_recording(durations=[1.0], num_channels=16) - - probe1 = generate_linear_probe(num_elec=8, ypitch=20.0) - probe2_overlap = probe1.copy() - - probegroup_overlap = ProbeGroup() - probegroup_overlap.add_probe(probe1) - probegroup_overlap.add_probe(probe2_overlap) - probegroup_overlap.set_global_device_channel_indices(np.arange(16)) - - # setting overlapping probes should raise an error - with pytest.raises(Exception): - recording.set_probegroup(probegroup_overlap) - - probe2 = probe1.copy() - probe2.move([100.0, 100.0]) - probegroup = ProbeGroup() - probegroup.add_probe(probe1) - probegroup.add_probe(probe2) - probegroup.set_global_device_channel_indices(np.random.permutation(16)) - - recording.set_probegroup(probegroup) - probegroup_set = recording.get_probegroup() - # check that the probe group is correctly set, by sorting the device channel indices - assert np.array_equal(probegroup_set.get_global_device_channel_indices()["device_channel_indices"], np.arange(16)) - - def test_rename_channels(): recording = generate_recording(durations=[1.0], num_channels=3) renamed_recording = recording.rename_channels(new_channel_ids=["a", "b", "c"]) diff --git a/src/spikeinterface/core/tests/test_basesnippets.py b/src/spikeinterface/core/tests/test_basesnippets.py index 751a03460c..05710b1607 100644 --- a/src/spikeinterface/core/tests/test_basesnippets.py +++ b/src/spikeinterface/core/tests/test_basesnippets.py @@ -142,8 +142,7 @@ def test_BaseSnippets(create_cache_folder): probe.set_device_channel_indices([2, -1, 0]) probe.create_auto_shape() - snippets_p = snippets.set_probe(probe, group_mode="auto") - snippets_p = snippets.set_probe(probe, group_mode="by_probe") + snippets_p = snippets.select_channels_with_probe(probe, group_mode="auto") positions2 = snippets_p.get_channel_locations() assert np.array_equal(positions2, [[0, 30.0], [0.0, 0.0]]) diff --git a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py index d5ba74cfd9..8936e6a650 100644 --- a/src/spikeinterface/core/tests/test_channelsaggregationrecording.py +++ b/src/spikeinterface/core/tests/test_channelsaggregationrecording.py @@ -1,10 +1,23 @@ import numpy as np +from probeinterface import generate_linear_probe from spikeinterface.core import aggregate_channels from spikeinterface.core import generate_recording from spikeinterface.core.testing import check_recordings_equal +def _make_rec_with_named_probe(name, manufacturer, x_shift): + """Helper: single-probe recording with annotated name and manufacturer.""" + probe = generate_linear_probe(num_elec=8, ypitch=20.0) + probe.move([x_shift, 0.0]) + probe.annotate(name=name, manufacturer=manufacturer) + probe.set_device_channel_indices(np.arange(8)) + probe.create_auto_shape() + rec = generate_recording(num_channels=8, durations=[1.0], set_probe=False) + rec.set_probe(probe) + return rec + + def test_channelsaggregationrecording(): num_channels = 3 @@ -262,5 +275,51 @@ def test_channel_aggregation_with_string_dtypes_of_different_size(): assert aggregated_recording.channel_ids.dtype == np.dtype(" +with spikeinterface==0.104.* installed. + +The GH Action workflow probe_backward_compat.yml does this automatically. +Set SI_PROBE_COMPAT_FIXTURES_DIR to point at the fixture directory if running locally. +""" + +import os +import numpy as np +import pytest +from pathlib import Path + +from spikeinterface.core import load + +FIXTURES_DIR = Path(os.environ.get("SI_PROBE_COMPAT_FIXTURES_DIR", "probe_compat_fixtures")) + +pytestmark = pytest.mark.skipif( + not FIXTURES_DIR.exists(), + reason=( + f"Probe compatibility fixtures not found at '{FIXTURES_DIR}'. " + "Run .github/scripts/create_probe_compat_fixtures.py with spikeinterface==0.104.* first, " + "or set SI_PROBE_COMPAT_FIXTURES_DIR to the fixture directory." + ), +) + + +# --------------------------------------------------------------------------- +# Shared assertion helpers +# --------------------------------------------------------------------------- + + +def _check_single_probe(rec): + assert rec.has_probe(), "Recording must have a probe after loading" + assert rec.get_num_channels() == 8 + probes = rec.get_probes() + assert len(probes) == 1 + probe = probes[0] + assert probe.annotations.get("name") == "test_probe" + assert probe.annotations.get("manufacturer") == "test_vendor" + assert list(probe.contact_ids) == [f"e{i}" for i in range(8)] + # After loading, device_channel_indices must be sorted 0..N-1 + assert np.array_equal(probe.device_channel_indices, np.arange(8)) + + +def _check_two_probes(rec): + assert rec.has_probe() + assert rec.get_num_channels() == 16 + probes = rec.get_probes() + assert len(probes) == 2, "Both probes must survive after loading" + probe_names = {p.annotations.get("name") for p in probes} + assert probe_names == {"probe_A", "probe_B"}, "Per-probe names must be preserved" + manufacturers = {p.annotations.get("manufacturer") for p in probes} + assert manufacturers == {"vendor_X", "vendor_Y"}, "Per-probe manufacturers must be preserved" + all_contact_ids = set() + for p in probes: + all_contact_ids.update(p.contact_ids.tolist()) + assert all_contact_ids == {f"a{i}" for i in range(8)} | {f"b{i}" for i in range(8)} + groups = rec.get_property("group") + assert len(np.unique(groups)) == 2, "Each probe must have its own group" + + +def _check_shuffled_probe(rec): + assert rec.has_probe() + assert rec.get_num_channels() == 8 + probe = rec.get_probes()[0] + assert probe.annotations.get("name") == "shuffled_probe" + assert probe.annotations.get("manufacturer") == "shuffle_vendor" + # After the old set_probe sorted contacts by device_channel_indices and + # normalised them, the stored probegroup has dci = 0..7. + assert np.array_equal(probe.device_channel_indices, np.arange(8)) + traces = rec.get_traces(segment_index=0) + assert traces.shape == (1000, 8) + + +# --------------------------------------------------------------------------- +# Binary folder fixtures +# --------------------------------------------------------------------------- + + +def test_single_probe_binary_compat(): + _check_single_probe(load(FIXTURES_DIR / "single_probe_binary")) + + +def test_two_probe_binary_compat(): + _check_two_probes(load(FIXTURES_DIR / "two_probe_binary")) + + +def test_shuffled_probe_binary_compat(): + _check_shuffled_probe(load(FIXTURES_DIR / "shuffled_probe_binary")) + + +def test_interleaved_probe_binary_compat(): + _check_two_probes(load(FIXTURES_DIR / "two_probe_interleaved_binary")) + + +# --------------------------------------------------------------------------- +# Zarr dump fixtures +# --------------------------------------------------------------------------- + + +def test_single_probe_zarr_compat(): + _check_single_probe(load(FIXTURES_DIR / "single_probe.zarr")) + + +def test_two_probe_zarr_compat(): + _check_two_probes(load(FIXTURES_DIR / "two_probe.zarr")) + + +def test_shuffled_probe_zarr_compat(): + _check_shuffled_probe(load(FIXTURES_DIR / "shuffled_probe.zarr")) + + +def test_interleaved_probe_zarr_compat(): + _check_two_probes(load(FIXTURES_DIR / "two_probe_interleaved.zarr")) + + +# --------------------------------------------------------------------------- +# JSON dump fixtures +# --------------------------------------------------------------------------- + + +def test_single_probe_json_compat(): + _check_single_probe(load(FIXTURES_DIR / "single_probe.json")) + + +def test_two_probe_json_compat(): + _check_two_probes(load(FIXTURES_DIR / "two_probe.json")) + + +def test_shuffled_probe_json_compat(): + _check_shuffled_probe(load(FIXTURES_DIR / "shuffled_probe.json")) + + +def test_interleaved_probe_json_compat(): + _check_two_probes(load(FIXTURES_DIR / "two_probe_interleaved.json")) + + +# --------------------------------------------------------------------------- +# Pickle dump fixtures +# --------------------------------------------------------------------------- + + +def test_single_probe_pickle_compat(): + _check_single_probe(load(FIXTURES_DIR / "single_probe.pkl")) + + +def test_two_probe_pickle_compat(): + _check_two_probes(load(FIXTURES_DIR / "two_probe.pkl")) + + +def test_shuffled_probe_pickle_compat(): + _check_shuffled_probe(load(FIXTURES_DIR / "shuffled_probe.pkl")) + + +def test_interleaved_probe_pickle_compat(): + _check_two_probes(load(FIXTURES_DIR / "two_probe_interleaved.pkl")) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index a9bd71b5c0..8b40e3e93f 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -318,7 +318,7 @@ def test_SortingAnalyzer_interleaved_probegroup(dataset): probegroup.add_probe(probe2) probegroup.set_global_device_channel_indices(np.random.permutation(num_channels)) - recording = recording.set_probegroup(probegroup) + recording.set_probegroup(probegroup) sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) # check that locations are correct diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index bbc797c693..6929148f07 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -177,15 +177,41 @@ def __init__( total_nbytes_stored += nbytes_stored_segment # load probe - probe_dict = self._root.attrs.get("probe", None) + probe_dict = self._root.attrs.get("probegroup", None) + probe_dict_legacy = self._root.attrs.get("probe", None) + probegroup = None if probe_dict is not None: probegroup = ProbeGroup.from_dict(probe_dict) - self.set_probegroup(probegroup, in_place=True) + self._probegroup = probegroup + elif probe_dict_legacy is not None: + probegroup = ProbeGroup.from_dict(probe_dict_legacy) + order = np.argsort(probegroup.to_numpy(complete=True)["device_channel_indices"]) + if not np.array_equal(order, np.arange(len(order))): + # In spikeinterface version < 0.105.0, the order was saved in the contact vector, but not + # in the probegroup. We need to check if the order is correct and if not, we need to reorder + # the probegroup to match the channel ids. + probegroup = probegroup.get_slice(order) + + # In some older SI versions, before #4300, the probe annotations were + # saved to the recording annotations as `probes_info`. If this is the + # case, we can copy the annotations to the probegroup and delete the + # `probes_info` from the recording annotations. + si_annotations = self._root.attrs.get("annotations", {}) + if "probes_info" in si_annotations: + probes_info = si_annotations.pop("probes_info") + for probe, probe_info in zip(probegroup.probes, probes_info): + probe.annotations.update(probe_info) + + if probegroup is not None: + self._probegroup = probegroup # load properties if "properties" in self._root: prop_group = self._root["properties"] for key in prop_group.keys(): + # Skip contact_vector property since it is not used anymore to represent probegroup + if key == "contact_vector": + continue values = self._root["properties"][key] self.set_property(key, values) @@ -548,9 +574,9 @@ def add_recording_to_zarr_group( ) # save probe - if recording.get_property("contact_vector") is not None: + if recording.has_probe(): probegroup = recording.get_probegroup() - zarr_group.attrs["probe"] = check_json(probegroup.to_dict(array_as_list=True)) + zarr_group.attrs["probegroup"] = check_json(probegroup.to_dict(array_as_list=True)) # save time vector if any t_starts = np.zeros(recording.get_num_segments(), dtype="float64") * np.nan diff --git a/src/spikeinterface/extractors/bids.py b/src/spikeinterface/extractors/bids.py index 777bdd914b..3a48084ab9 100644 --- a/src/spikeinterface/extractors/bids.py +++ b/src/spikeinterface/extractors/bids.py @@ -40,7 +40,7 @@ def read_bids(folder_path): rec.annotate(bids_name=bids_name) rec.extra_requirements.extend("pandas") probegroup = _read_probe_group(file_path.parent, bids_name, rec.channel_ids) - rec = rec.set_probegroup(probegroup) + rec.set_probegroup(probegroup) recordings.append(rec) elif file_path.suffix == ".nix": @@ -54,7 +54,7 @@ def read_bids(folder_path): rec = read_nix(file_path, stream_id=stream_id) rec.extra_requirements.extend("pandas") probegroup = _read_probe_group(file_path.parent, bids_name, rec.channel_ids) - rec = rec.set_probegroup(probegroup) + rec.set_probegroup(probegroup) recordings.append(rec) return recordings diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index 2a53b999e3..891cbaee07 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -102,9 +102,9 @@ def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", probe = probeinterface.read_spikeglx(meta_file) if probe.shank_ids is not None: - self.set_probe(probe, in_place=True, group_mode="by_shank") + self.set_probe(probe, group_mode="by_shank") else: - self.set_probe(probe, in_place=True) + self.set_probe(probe) sample_shifts = get_neuropixels_sample_shifts_from_probe(probe) self.set_property("inter_sample_shift", sample_shifts) diff --git a/src/spikeinterface/extractors/iblextractors.py b/src/spikeinterface/extractors/iblextractors.py index 779c36fa23..8a57e40ec3 100644 --- a/src/spikeinterface/extractors/iblextractors.py +++ b/src/spikeinterface/extractors/iblextractors.py @@ -221,9 +221,9 @@ def __init__( probe = probeinterface.read_spikeglx(meta_file) if probe.shank_ids is not None: - self.set_probe(probe, in_place=True, group_mode="by_shank") + self.set_probe(probe, group_mode="by_shank") else: - self.set_probe(probe, in_place=True) + self.set_probe(probe) # set channel properties # sometimes there are missing metadata files on the IBL side diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 8d1fac0c72..b3ccb92cbd 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -70,9 +70,11 @@ def __init__( if electrode_width is not None: probe_kwargs["electrode_width"] = electrode_width probe = probeinterface.read_3brain(file_path, **probe_kwargs) - self.set_probe(probe, in_place=True) - self.set_property("row", self.get_property("contact_vector")["row"]) - self.set_property("col", self.get_property("contact_vector")["col"]) + rows = probe.contact_annotations["row"] + cols = probe.contact_annotations["col"] + self.set_probe(probe) + self.set_property("row", rows) + self.set_property("col", cols) self._kwargs.update( { diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 932ecee106..38e65096c2 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -74,8 +74,9 @@ def __init__( # rec_name auto set by neo rec_name = self.neo_reader.rec_name probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) - self.set_probe(probe, in_place=True) - self.set_property("electrode", self.get_property("contact_vector")["electrode"]) + electrodes = probe.contact_annotations["electrode"] + self.set_probe(probe) + self.set_property("electrode", electrodes) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name)) @classmethod diff --git a/src/spikeinterface/extractors/neoextractors/mearec.py b/src/spikeinterface/extractors/neoextractors/mearec.py index 7ca82af01e..d4cbe1b0de 100644 --- a/src/spikeinterface/extractors/neoextractors/mearec.py +++ b/src/spikeinterface/extractors/neoextractors/mearec.py @@ -55,7 +55,7 @@ def __init__(self, file_path: str | Path, all_annotations: bool = False, use_nam probe = probeinterface.read_mearec(file_path) probe.annotations["mearec_name"] = str(probe.annotations["mearec_name"]) - self.set_probe(probe, in_place=True) + self.set_probe(probe) self.annotate(is_filtered=True) if hasattr(self.neo_reader._recgen, "gain_to_uV"): diff --git a/src/spikeinterface/extractors/neoextractors/openephys.py b/src/spikeinterface/extractors/neoextractors/openephys.py index 5dc9220aa5..22ae82b117 100644 --- a/src/spikeinterface/extractors/neoextractors/openephys.py +++ b/src/spikeinterface/extractors/neoextractors/openephys.py @@ -331,9 +331,9 @@ def __init__( settings_file=settings_file, stream_name=oe_stream_name ) if probe.shank_ids is not None: - self.set_probe(probe, in_place=True, group_mode="by_shank") + self.set_probe(probe, group_mode="by_shank") else: - self.set_probe(probe, in_place=True) + self.set_probe(probe) # get inter-sample shifts based on the probe information and mux channels sample_shifts = get_neuropixels_sample_shifts_from_probe(probe) if sample_shifts is not None: diff --git a/src/spikeinterface/extractors/neoextractors/spikegadgets.py b/src/spikeinterface/extractors/neoextractors/spikegadgets.py index adc50df12f..da4a66e1f5 100644 --- a/src/spikeinterface/extractors/neoextractors/spikegadgets.py +++ b/src/spikeinterface/extractors/neoextractors/spikegadgets.py @@ -79,7 +79,7 @@ def __init__( if saturation_threshold_uV_probe is not None: saturation_thresholds_uV.append(saturation_threshold_uV_probe) - self.set_probegroup(probegroup, in_place=True) + self.set_probegroup(probegroup) if np.all(sample_shifts != -1): self.set_property("inter_sample_shift", sample_shifts) diff --git a/src/spikeinterface/extractors/neoextractors/spikeglx.py b/src/spikeinterface/extractors/neoextractors/spikeglx.py index 60b1a98be8..41c2b77bfc 100644 --- a/src/spikeinterface/extractors/neoextractors/spikeglx.py +++ b/src/spikeinterface/extractors/neoextractors/spikeglx.py @@ -86,9 +86,9 @@ def __init__( probe = probeinterface.read_spikeglx(ap_meta_filename) if probe.shank_ids is not None: - self.set_probe(probe, in_place=True, group_mode="by_shank") + self.set_probe(probe, group_mode="by_shank") else: - self.set_probe(probe, in_place=True) + self.set_probe(probe) # get inter-sample shifts based on the probe information and mux channels sample_shifts = get_neuropixels_sample_shifts_from_probe(probe) diff --git a/src/spikeinterface/extractors/shybridextractors.py b/src/spikeinterface/extractors/shybridextractors.py index ff08c1a3f3..eca7d46724 100644 --- a/src/spikeinterface/extractors/shybridextractors.py +++ b/src/spikeinterface/extractors/shybridextractors.py @@ -82,7 +82,7 @@ def __init__(self, file_path): # load probe file probegroup = probeinterface.read_prb(params["probe"]) - self.set_probegroup(probegroup, in_place=True) + self.set_probegroup(probegroup) self._kwargs = {"file_path": str(Path(file_path).absolute())} self.extra_requirements.extend(["hybridizer", "pyyaml"]) diff --git a/src/spikeinterface/extractors/sinapsrecordingextractors.py b/src/spikeinterface/extractors/sinapsrecordingextractors.py index 132a01f300..f47a83bc47 100644 --- a/src/spikeinterface/extractors/sinapsrecordingextractors.py +++ b/src/spikeinterface/extractors/sinapsrecordingextractors.py @@ -84,7 +84,7 @@ def __init__(self, file_path: str | Path, stream_name: str = "filt"): if (stream_name == "filt") | (stream_name == "raw"): probe = get_sinaps_probe(probe_type) if probe is not None: - self.set_probe(probe, in_place=True) + self.set_probe(probe) self._kwargs = {"file_path": str(file_path.absolute()), "stream_name": stream_name} @@ -143,7 +143,7 @@ def __init__(self, file_path: str | Path, stream_name: str = "filt"): # set probe probe = get_sinaps_probe(sinaps_info["probe_type"]) if probe is not None: - self.set_probe(probe, in_place=True) + self.set_probe(probe) self._kwargs = {"file_path": str(Path(file_path).absolute()), "stream_name": stream_name} diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index bd0d2184d4..fa9d59ec47 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -31,11 +31,13 @@ def setUpClass(cls): cache_dir=None, ) except: + print("Skipping test due to server being down.") pytest.skip("Skipping test due to server being down.") try: cls.recording = read_ibl_recording(eid=cls.eid, stream_name="probe00.ap", one=cls.one) except requests.exceptions.HTTPError as e: if e.response.status_code == 503: + print("Skipping test due to server being down (HTTP 503).") pytest.skip("Skipping test due to server being down (HTTP 503).") else: raise @@ -84,8 +86,6 @@ def test_property_keys(self): expected_property_keys = [ "gain_to_uV", "offset_to_uV", - "contact_vector", - "location", "group", "shank", "shank_row", @@ -97,6 +97,9 @@ def test_property_keys(self): ] self.assertCountEqual(first=self.recording.get_property_keys(), second=expected_property_keys) + def test_has_probe(self): + assert self.recording.has_probe() is True + def test_trace_shape(self): expected_shape = (21, 384) self.assertTupleEqual(tuple1=self.small_scaled_trace.shape, tuple2=expected_shape) diff --git a/src/spikeinterface/generation/drift_tools.py b/src/spikeinterface/generation/drift_tools.py index 1800138dae..6996800e27 100644 --- a/src/spikeinterface/generation/drift_tools.py +++ b/src/spikeinterface/generation/drift_tools.py @@ -477,7 +477,7 @@ def __init__( ) self.add_recording_segment(recording_segment) - self.set_probe(drifting_templates.probe, in_place=True) + self.set_probe(drifting_templates.probe) # templates are too large, we don't serialize them to JSON self._serializability["json"] = False diff --git a/src/spikeinterface/postprocessing/localization_tools.py b/src/spikeinterface/postprocessing/localization_tools.py index 7975097629..dc83ab9596 100644 --- a/src/spikeinterface/postprocessing/localization_tools.py +++ b/src/spikeinterface/postprocessing/localization_tools.py @@ -117,7 +117,7 @@ def compute_monopolar_triangulation( # if enforce_decrease: # enforce_decrease_shells_data( - # wf_data, best_channels[unit_id], enforce_decrease_radial_parents, in_place=True + # wf_data, best_channels[unit_id], enforce_decrease_radial_parents # ) unit_location[i] = solve_monopolar_triangulation(wf_data, local_contact_locations, max_distance_um, optimizer) diff --git a/src/spikeinterface/preprocessing/average_across_direction.py b/src/spikeinterface/preprocessing/average_across_direction.py index 8d1c4475cd..23e0a1d5ae 100644 --- a/src/spikeinterface/preprocessing/average_across_direction.py +++ b/src/spikeinterface/preprocessing/average_across_direction.py @@ -63,7 +63,7 @@ def __init__( # my geometry channel_locations = np.zeros( (n_pos_unique, parent_channel_locations.shape[1]), - dtype=parent_channel_locations.dtype, + dtype=np.float32, ) # average other dimensions in the geometry other_dim = np.arange(parent_channel_locations.shape[1]) != dim diff --git a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py index c7c37968d9..c8825831b0 100644 --- a/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py +++ b/src/spikeinterface/preprocessing/deepinterpolation/tests/test_deepinterpolation.py @@ -30,7 +30,7 @@ def recording_and_shape(): probe = probeinterface.generate_multi_columns_probe(num_columns=num_cols, num_contact_per_column=num_rows) probe.set_device_channel_indices(np.arange(num_cols * num_rows)) recording = generate_recording(num_channels=num_cols * num_rows, durations=[10.0], sampling_frequency=30000) - recording.set_probe(probe, in_place=True) + recording.set_probe(probe) recording = depth_order(recording) recording = zscore(recording) desired_shape = (num_rows, num_cols) diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index b4ceed886e..6a31286c9a 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -285,7 +285,7 @@ def __init__( self._kwargs.update(filter_kwargs) @classmethod - def _handle_backward_compatibility(cls, old_kwargs, full_dict): + def _handle_kwargs_backward_compatibility(cls, old_kwargs, full_dict): new_kwargs = old_kwargs.copy() is_lfp_case = old_kwargs["freq_min"] < HIGHPASS_ERROR_THRESHOLD_HZ if "ignore_low_freq_error" not in new_kwargs: @@ -354,7 +354,7 @@ def __init__( self._kwargs.update(filter_kwargs) @classmethod - def _handle_backward_compatibility(cls, old_kwargs, full_dict): + def _handle_kwargs_backward_compatibility(cls, old_kwargs, full_dict): new_kwargs = old_kwargs.copy() is_lfp_case = old_kwargs["freq_min"] < HIGHPASS_ERROR_THRESHOLD_HZ if "ignore_low_freq_error" not in new_kwargs: diff --git a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py index 35f398f985..5a0e160f92 100644 --- a/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_detect_bad_channels.py @@ -80,7 +80,7 @@ def test_detect_bad_channels_std_mad(): probe = generate_linear_probe(num_elec=num_channels) probe.set_device_channel_indices(np.arange(num_channels)) - rec.set_probe(probe, in_place=True) + rec.set_probe(probe) bad_channels_std, bad_labels_std = detect_bad_channels(rec, method="std") bad_channels_mad, bad_labels_mad = detect_bad_channels(rec, method="std") @@ -125,7 +125,7 @@ def test_detect_bad_channels_extremes(outside_channels_location): probe = generate_linear_probe(num_elec=num_channels) probe.set_device_channel_indices(np.arange(num_channels)) - rec.set_probe(probe, in_place=True) + rec.set_probe(probe) bad_channel_ids, bad_labels = detect_bad_channels( rec, method="coherence+psd", outside_channels_location=outside_channels_location diff --git a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py index bfa4d3d9ae..89e8e36cf8 100644 --- a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py @@ -118,7 +118,7 @@ def test_highpass_spatial_filter_with_dead_channels(): rec_with_dead = NumpyRecording( traces_list=[traces], sampling_frequency=rec.sampling_frequency, channel_ids=rec.channel_ids ) - rec_with_dead.set_probe(rec.get_probe(), in_place=True) + rec_with_dead.set_probe(rec.get_probe()) filtered = spre.highpass_spatial_filter(rec_with_dead, n_channel_pad=2) result = filtered.get_traces() assert result.shape == traces.shape diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index 61996e9036..c79605a110 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -2,6 +2,8 @@ import numpy as np import os +import probeinterface as pi + import spikeinterface as si import spikeinterface.preprocessing as spre import spikeinterface.extractors as se @@ -125,9 +127,12 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan # distribute default probe locations across 4 shanks if set rng = np.random.default_rng(seed=None) - x = rng.choice(shanks, num_channels) - for idx, __ in enumerate(recording._properties["contact_vector"]): - recording._properties["contact_vector"][idx]["x"] = x[idx] + x_new = rng.choice(shanks, num_channels) + probe = recording.get_probe() + new_positions = probe.contact_positions.copy() + new_positions[:, 0] = x_new # column 0 is x + recording._probegroup.probes[0]._contact_positions = new_positions + recording.set_probe(probe) # generate random bad channel locations bad_channel_indexes = rng.choice(num_channels, rng.integers(1, int(num_channels / 5)), replace=False) @@ -161,18 +166,21 @@ def test_output_values(): the non-interpolated channels is also an implicit test these were not accidently changed. """ - recording = generate_recording(num_channels=5, durations=[1]) + recording = generate_recording(num_channels=5, durations=[1], set_probe=False) bad_channel_indexes = np.array([0]) bad_channel_ids = recording.channel_ids[bad_channel_indexes] - new_probe_locs = [ - [5, 7, 3, 5, 5], # 5 channels, a in the center ('bad channel', zero index) - [5, 5, 5, 7, 3], - ] # all others equal distance away. - # Overwrite the probe information with the new locations - for idx, (x, y) in enumerate(zip(*new_probe_locs)): - recording._properties["contact_vector"][idx]["x"] = x - recording._properties["contact_vector"][idx]["y"] = y + probe_locs = np.array( + [ + [5, 7, 3, 5, 5], # 5 channels, a in the center ('bad channel', zero index) + [5, 5, 5, 7, 3], + ] # all others equal distance away. + ).T + # Set the probe information with the new locations + probe = pi.Probe(ndim=2) + probe.set_contacts(positions=probe_locs) + probe.set_device_channel_indices(np.arange(len(probe_locs))) + recording.set_probe(probe) # Run interpolation in SI and check the interpolated channel # 0 is a linear combination of other channels @@ -186,8 +194,7 @@ def test_output_values(): # Shift the last channel position so that it is 4 units, rather than 2 # away. Setting sigma_um = p = 1 allows easy calculation of the expected # weights. - recording._properties["contact_vector"][-1]["x"] = 5 - recording._properties["contact_vector"][-1]["y"] = 9 + recording._probegroup.probes[0]._contact_positions[-1] = [5, 9] expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)] expected_weights /= np.sum(expected_weights) diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index 45d4809cd8..4854d94dba 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -157,7 +157,7 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: "The new mapping cannot exceed total number of channels " "in the zero-chanenl-padded recording." ) else: - if "locations" in recording.get_property_keys() or "contact_vector" in recording.get_property_keys(): + if recording.has_probe(): self.channel_mapping = np.argsort(recording.get_channel_locations()[:, 1]) else: self.channel_mapping = np.arange(recording.get_num_channels()) diff --git a/src/spikeinterface/sorters/basesorter.py b/src/spikeinterface/sorters/basesorter.py index 09e92c7a6a..0ae0f3645b 100644 --- a/src/spikeinterface/sorters/basesorter.py +++ b/src/spikeinterface/sorters/basesorter.py @@ -105,13 +105,21 @@ def initialize_folder(cls, recording, output_folder, verbose, remove_existing_fo raise ValueError("recording must be a Recording or a Snippets!!") if cls.requires_locations: - locations = recording.get_channel_locations() - if locations is None: + if not recording.has_probe(): raise RuntimeError( "Channel locations are required for this spike sorter. " "Locations can be added to the RecordingExtractor by loading a probe file " "(.prb or .csv) or by setting them manually." ) + # check uniqueness of locations + locations = recording.get_channel_locations() + if len(locations) != len(set(map(tuple, locations))): + raise RuntimeError( + "Channel locations are not unique! " + "Please ensure that each channel has a unique location before running spike sorting. " + "If you have multiple groups with overlapping channel locations, you can use the " + "``run_sorter_by_property`` function to sort each group separately" + ) if output_folder is None: output_folder = cls.sorter_name + "_output" diff --git a/src/spikeinterface/sorters/external/hdsort.py b/src/spikeinterface/sorters/external/hdsort.py index 3daaf85b7a..07d59332d2 100644 --- a/src/spikeinterface/sorters/external/hdsort.py +++ b/src/spikeinterface/sorters/external/hdsort.py @@ -276,8 +276,8 @@ def write_hdsort_input_format(cls, recording, save_path, chunk_memory="500M"): [("electrode", np.int32), ("x", np.float64), ("y", np.float64), ("channel", np.int32)] ) - locations = recording.get_property("location") - assert locations is not None, "'location' property is needed to run HDSort" + assert recording.has_probe(), "The recording must have a probe to run HDSort" + locations = recording.get_channel_locations() with h5py.File(save_path, "w") as f: f.create_group("ephys") diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index f616888166..5698e0e142 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -405,11 +405,12 @@ def __init__( if border_mode == "remove_channels": # change the wiring of the probe - # TODO this is also done in ChannelSliceRecording, this should be done in a common place - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if recording.has_probe(): + probegroup = recording.get_probegroup() + channel_indices = recording.ids_to_indices(channel_ids) + probegroup_sliced = probegroup.get_slice(channel_indices) + probegroup_sliced.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) + self.set_probegroup(probegroup_sliced) # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below diff --git a/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py b/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py index 8840a5a00d..942074ef31 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py +++ b/src/spikeinterface/sortingcomponents/peak_localization/monopolar.py @@ -97,9 +97,7 @@ def compute(self, traces, peaks, waveforms): wf_data = np.abs(wf[self.nbefore]) if self.enforce_decrease_radial_parents is not None: - enforce_decrease_shells_data( - wf_data, peak["channel_index"], self.enforce_decrease_radial_parents, in_place=True - ) + enforce_decrease_shells_data(wf_data, peak["channel_index"], self.enforce_decrease_radial_parents) peak_locations[i] = solve_monopolar_triangulation( wf_data, local_contact_locations, self.max_distance_um, self.optimizer