diff --git a/.github/scripts/generate_zarr_v2_fixtures.py b/.github/scripts/generate_zarr_v2_fixtures.py new file mode 100644 index 0000000000..e555873458 --- /dev/null +++ b/.github/scripts/generate_zarr_v2_fixtures.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +""" +Generate zarr v2 fixtures for backward compatibility tests. + +Run this script with an old spikeinterface version and zarr<3, e.g.: + pip install "spikeinterface==0.104.0" "zarr<3" + python generate_zarr_v2_fixtures.py --output /tmp/zarr_v2_fixtures + +The script saves: + - recording.zarr : a small ZarrRecordingExtractor + - sorting.zarr : a small ZarrSortingExtractor + - expected_values.json : key values used to verify correct loading +""" +import argparse +import shutil +import json +from pathlib import Path + +import numpy as np +import zarr + +import spikeinterface as si + + +def main(output_dir: Path) -> None: + print(f"spikeinterface version : {si.__version__}") + print(f"zarr version : {zarr.__version__}") + + + output_dir.mkdir(parents=True, exist_ok=True) + + recording, sorting = si.generate_ground_truth_recording(durations=[10, 5],num_channels=32, num_units=10, seed=0) + # save to binary to make them JSON serializable for later expected values extraction + recording = recording.save(folder=output_dir / "recording_binary", overwrite=True) + sorting = sorting.save(folder=output_dir / "sorting_binary", overwrite=True) + # --- save recording --- + recording_path = output_dir / "recording.zarr" + recording_zarr = recording.save(format="zarr", folder=recording_path, overwrite=True) + print(f"Saved recording -> {recording_path}") + + # --- save sorting --- + sorting_path = output_dir / "sorting.zarr" + sorting_zarr = sorting.save(format="zarr", folder=sorting_path, overwrite=True) + print(f"Saved sorting -> {sorting_path}") + + # --- save SortingAnalyzer --- + # Reload the recording from zarr so it is a serializable ZarrRecordingExtractor, + # which the analyzer can store as provenance. + analyzer_path = output_dir / "analyzer.zarr" + if analyzer_path.is_dir(): + shutil.rmtree(analyzer_path) + analyzer = si.create_sorting_analyzer( + sorting_zarr, recording_zarr, format="zarr", folder=analyzer_path, overwrite=True + ) + analyzer.compute(["random_spikes", "templates"]) + print(f"Saved analyzer -> {analyzer_path}") + + # Reload to verify templates are accessible before writing expected values + templates_array = analyzer.get_extension("templates").get_data() + + # --- capture expected values for later assertion --- + expected = { + "spikeinterface_version": si.__version__, + "zarr_version": zarr.__version__, + "recording": { + "num_channels": int(recording.get_num_channels()), + "num_segments": int(recording.get_num_segments()), + "sampling_frequency": float(recording.get_sampling_frequency()), + "num_samples_per_segment": [int(recording.get_num_samples(seg)) for seg in range(recording.get_num_segments())], + "channel_ids": recording.get_channel_ids().tolist(), + "dtype": str(recording.get_dtype()), + # first 10 frames of segment 0 for all channels + "traces_seg0_first10": recording.get_traces(start_frame=0, end_frame=10, segment_index=0).tolist(), + }, + "sorting": { + "num_segments": int(sorting.get_num_segments()), + "sampling_frequency": float(sorting.get_sampling_frequency()), + "unit_ids": sorting.get_unit_ids().tolist(), + "spike_trains_seg0": { + str(uid): sorting.get_unit_spike_train(unit_id=uid, segment_index=0).tolist() + for uid in sorting.unit_ids + }, + }, + "analyzer": { + "num_units": int(analyzer.get_num_units()), + "num_channels": int(analyzer.get_num_channels()), + "templates_shape": list(templates_array.shape), + }, + } + + expected_path = output_dir / "expected_values.json" + with open(expected_path, "w") as f: + json.dump(expected, f, indent=2) + print(f"Saved expected -> {expected_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate zarr v2 fixtures for backward compatibility tests") + parser.add_argument("--output", type=Path, required=True, help="Directory to write fixtures into") + args = parser.parse_args() + main(args.output) diff --git a/.github/workflows/all-tests.yml b/.github/workflows/all-tests.yml index 0d242b759a..41c3f81054 100644 --- a/.github/workflows/all-tests.yml +++ b/.github/workflows/all-tests.yml @@ -24,7 +24,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.13"] # Lower and higher versions we support + python-version: ["3.11", "3.13"] # Lower and higher versions we support os: [macos-latest, windows-latest, ubuntu-latest] steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/deepinterpolation.yml b/.github/workflows/deepinterpolation.yml index be003da742..2e7b8d03eb 100644 --- a/.github/workflows/deepinterpolation.yml +++ b/.github/workflows/deepinterpolation.yml @@ -1,10 +1,8 @@ name: Testing deepinterpolation +# Manual only — deepinterpolation requires Python 3.10, incompatible with 3.11+ required by Zarr 3.0.0+ on: - pull_request: - types: [synchronize, opened, reopened] - branches: - - main + workflow_dispatch: concurrency: # Cancel previous workflows on the same pull request group: ${{ github.workflow }}-${{ github.ref }} diff --git a/.github/workflows/test_containers_docker.yml b/.github/workflows/test_containers_docker.yml index 211db5f775..73a194efb3 100644 --- a/.github/workflows/test_containers_docker.yml +++ b/.github/workflows/test_containers_docker.yml @@ -15,7 +15,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.11' - name: Python version run: python --version diff --git a/.github/workflows/test_containers_singularity.yml b/.github/workflows/test_containers_singularity.yml index 00941215b1..0554a0060c 100644 --- a/.github/workflows/test_containers_singularity.yml +++ b/.github/workflows/test_containers_singularity.yml @@ -16,7 +16,7 @@ jobs: - uses: actions/checkout@v4 - uses: actions/setup-python@v5 with: - python-version: '3.10' + python-version: '3.11' - uses: eWaterCycle/setup-singularity@v7 with: singularity-version: 3.8.7 diff --git a/.github/workflows/test_zarr_compat.yml b/.github/workflows/test_zarr_compat.yml new file mode 100644 index 0000000000..27be8d633d --- /dev/null +++ b/.github/workflows/test_zarr_compat.yml @@ -0,0 +1,47 @@ +name: Test zarr backwards compatibility + +on: + workflow_dispatch: + pull_request: + types: [synchronize, opened, reopened] + branches: + - main + paths: + - "src/spikeinterface/core/zarrextractors.py" + - "src/spikeinterface/core/zarrrecordingextractor.py" + - "src/spikeinterface/core/tests/test_zarr_backwards_compat.py" + - ".github/workflows/test_zarr_compat.yml" + - ".github/scripts/generate_zarr_v2_fixtures.py" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + test-zarr-compat: + name: zarr v2 -> v3 backwards compatibility + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install SI 0.104.0 with zarr v2 + run: pip install "spikeinterface==0.104.0" "zarr<3" + + - name: Generate zarr v2 fixtures + run: python .github/scripts/generate_zarr_v2_fixtures.py --output /tmp/zarr_v2_fixtures + + - name: Install current SI with zarr v3 + run: pip install -e ".[test_core]" + + - name: Check zarr version is v3 + run: python -c "import zarr; v = zarr.__version__; print(f'zarr {v}'); assert int(v.split('.')[0]) >= 3" + + - name: Run backward compatibility tests + env: + ZARR_V2_FIXTURES_PATH: /tmp/zarr_v2_fixtures + run: pytest src/spikeinterface/core/tests/test_zarr_backwards_compat.py -v diff --git a/pyproject.toml b/pyproject.toml index 6a9d57cc64..5e71f930e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ ] description = "Python toolkit for analysis, visualization, and comparison of spike sorting output" readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.11" classifiers = [ "Programming Language :: Python :: 3 :: Only", "License :: OSI Approved :: MIT License", @@ -24,12 +24,11 @@ dependencies = [ "numpy>=2.0.0;python_version>='3.13'", "threadpoolctl>=3.0.0", "tqdm", - "zarr>=2.18,<3", + "zarr>=3,<4", "neo>=0.14.4", "probeinterface>=0.3.2", "packaging", "pydantic", - "numcodecs<0.16.0", # For supporting zarr < 3 ] [build-system] @@ -64,7 +63,9 @@ changelog = "https://spikeinterface.readthedocs.io/en/latest/whatisnew.html" extractors = [ "MEArec>=1.8", "pynwb>=2.6.0", - "hdmf-zarr>=0.11.0", + # FOR TESTING + "hdmf-zarr @ git+https://github.com/hdmf-dev/hdmf-zarr.git@zarr-v3-migration", + # "hdmf-zarr>=0.11.0", "pyedflib>=0.1.30", "sonpy;python_version<'3.10'", "lxml", # lxml for neuroscope @@ -81,7 +82,9 @@ streaming_extractors = [ "fsspec", "aiohttp", "requests", - "hdmf-zarr>=0.11.0", + # FOR TESTING + "hdmf-zarr @ git+https://github.com/hdmf-dev/hdmf-zarr.git@zarr-v3-migration", + # "hdmf-zarr>=0.11.0", "remfile", "s3fs" ] @@ -127,7 +130,9 @@ test_core = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # FOR TESTING: use probeinterface zarrv3 branch + "probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # for slurm jobs, @@ -139,7 +144,9 @@ test_extractors = [ "pooch>=1.8.2", "datalad>=1.0.2", # Commenting out for release - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # FOR TESTING: use probeinterface zarrv3 branch + "probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] @@ -190,7 +197,9 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # FOR TESTING: use probeinterface zarrv3 branch + "probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # for slurm jobs @@ -219,7 +228,9 @@ docs = [ "huggingface_hub", # For automated curation # for release we need pypi, so this needs to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + # FOR TESTING: use probeinterface zarrv3 branch + "probeinterface @ git+https://github.com/alejoe91/probeinterface.git@zarrv3", + # "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 8d149a7c49..ad34314a2d 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -880,6 +880,8 @@ def save(self, **kwargs) -> "BaseExtractor": * dump_ext: "json" or "pkl", default "json" (if format is "folder") * verbose: if True output is verbose * **save_kwargs: additional kwargs format-dependent and job kwargs for recording + (check `save_to_memory()`, `save_to_folder()`, `save_to_zarr()` for more details on format-dependent + kwargs) {} Returns @@ -899,13 +901,27 @@ def save(self, **kwargs) -> "BaseExtractor": save.__doc__ = save.__doc__.format(_shared_job_kwargs_doc) def save_to_memory(self, sharedmem=True, **save_kwargs) -> "BaseExtractor": + """ + Save the object to memory. + + Parameters + ---------- + sharedmem : bool, default: True + If True, the object is saved to shared memory, allowing it to be accessed by multiple processes without + copying. If False, the object is saved to regular memory, which may involve copying when accessed by + multiple processes. + + Returns + ------- + BaseExtractor + A saved copy of the extractor in memory. + """ save_kwargs.pop("format", None) cached = self._save(format="memory", sharedmem=sharedmem, **save_kwargs) self.copy_metadata(cached) return cached - # TODO rename to saveto_binary_folder def save_to_folder( self, name: str | None = None, @@ -951,8 +967,7 @@ def save_to_folder( If True, an existing folder at the specified path will be deleted before saving. verbose : bool, default: True If True, print information about the cache folder being used. - **save_kwargs - Additional keyword arguments to be passed to the underlying save method. + {} Returns ------- @@ -1017,7 +1032,6 @@ def save_to_zarr( folder=None, overwrite=False, storage_options=None, - channel_chunk_size=None, verbose=True, **save_kwargs, ): @@ -1037,26 +1051,9 @@ def save_to_zarr( storage_options: dict or None, default: None Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc. For cloud storage locations, this should not be None (in case of default values, use an empty dict) - channel_chunk_size: int or None, default: None - Channels per chunk (only for BaseRecording) - compressor: numcodecs.Codec or None, default: None - Global compressor. If None, Blosc-zstd, level 5, with bit shuffle is used - filters: list[numcodecs.Codec] or None, default: None - Global filters for zarr (global) - compressor_by_dataset: dict or None, default: None - Optional compressor per dataset: - - traces - - times - If None, the global compressor is used - filters_by_dataset: dict or None, default: None - Optional filters per dataset: - - traces - - times - If None, the global filters are used verbose: bool, default: True If True, the output is verbose - auto_cast_uint: bool, default: True - If True, unsigned integers are cast to signed integers to avoid issues with zarr (only for BaseRecording) + {} Returns ------- @@ -1092,7 +1089,6 @@ def save_to_zarr( assert not zarr_path.exists(), f"Path {zarr_path} already exists, choose another name" save_kwargs["zarr_path"] = zarr_path save_kwargs["storage_options"] = storage_options - save_kwargs["channel_chunk_size"] = channel_chunk_size cached = self._save(format="zarr", verbose=verbose, **save_kwargs) cached = read_zarr(zarr_path) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index f23b524271..c6df7bdf6e 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -5,10 +5,10 @@ import numpy as np from probeinterface import read_probeinterface, write_probeinterface -from .base import BaseSegment +from .base import BaseSegment, BaseExtractor from .baserecordingsnippets import BaseRecordingSnippets from .core_tools import convert_bytes_to_str, convert_seconds_to_str -from .job_tools import split_job_kwargs +from .job_tools import split_job_kwargs, _shared_job_kwargs_doc from .recording_tools import write_binary_recording @@ -39,6 +39,41 @@ class BaseRecording(BaseRecordingSnippets): "noise_level_rms_scaled", ] + _save_to_folder_docs_params = """dtype: np.dtype | None, default: None + The dtype to use for saving the binary file. If None, the dtype of the recording is used. +""" + _shared_job_kwargs_doc + + _save_to_zarr_docs_params = """ +channel_chunk_size: int | None, default: None + Chunk size for the channel dimension. If None, no chunking is done on the channel dimension. +chunks: tuple | None, default: None + Chunks for the traces dataset. If None, no chunking is done. Note that sharding requires chunking to be specified + and that chunk dimensions need to be larger than shard dimensions (if shards is not None). + If `chunks` is not None, it needs to be a tuple of length 2 with the chunk size for the time and channel + dimensions respectively and `channel_chunk_size` should not be specified. +shard_factor: int | None, default: None + If specified, the shard size will be set to chunk_size * shard_factor in the first dimension (time), + and to be the at most the total number of channels in the second dimension. Note that `shard_factor` cannot + be specified together with `shards`. +shards: tuple | None, default: None + Number of shard size. If None, no sharding is done. Note that shards dimensions need to be larger than + chunk dimensions (if chunks is not None) and that sharding is only done on the first dimension. +compressors: list[numcodecs.Codec] | None, default: None + Global compressor. If None, Blosc-zstd, level 5, with bit shuffle is used +filters: list[numcodecs.Codec] | None, default: None + Global filters for zarr (global) +compressors_by_dataset: dict | None, default: None + Optional compressor per dataset: + - traces + - times + If None, the global compressor is used +filters_by_dataset: dict | None, default: None + Optional filters per dataset: + - traces + - times + If None, the global filters are used +""" + _shared_job_kwargs_doc + def __init__(self, sampling_frequency: float, channel_ids: list, dtype): BaseRecordingSnippets.__init__( self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype @@ -581,8 +616,8 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): if format == "binary": folder = kwargs["folder"] - file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] dtype = kwargs.get("dtype", None) or self.get_dtype() + file_paths = [folder / f"traces_cached_seg{i}.raw" for i in range(self.get_num_segments())] t_starts = self._get_t_starts() write_binary_recording(self, file_paths=file_paths, dtype=dtype, verbose=verbose, **job_kwargs) @@ -893,6 +928,12 @@ def astype(self, dtype, round: bool | None = None): return astype(self, dtype=dtype, round=round) +BaseRecording.save_to_folder.__doc__ = BaseExtractor.save_to_folder.__doc__.format( + BaseRecording._save_to_folder_docs_params +) +BaseRecording.save_to_zarr.__doc__ = BaseExtractor.save_to_zarr.__doc__.format(BaseRecording._save_to_zarr_docs_params) + + class BaseRecordingSegment(BaseSegment): """ Abstract class representing a multichannel timeseries, or block of raw ephys traces diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8e16757bcc..6e26508c81 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -610,6 +610,7 @@ def _get_zarr_root(self, mode="r+"): assert mode in ("r+", "a", "r"), "mode must be 'r+', 'a' or 'r'" storage_options = self._backend_options.get("storage_options", {}) + zarr_root = super_zarr_open(self.folder, mode=mode, storage_options=storage_options) return zarr_root @@ -633,7 +634,12 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att storage_options = backend_options.get("storage_options", {}) saving_options = backend_options.get("saving_options", {}) - zarr_root = zarr.open(folder, mode="w", storage_options=storage_options) + if not is_path_remote(str(folder)): + storage_options_kwargs = {} + else: + storage_options_kwargs = storage_options + + zarr_root = zarr.open(folder, mode="w", **storage_options_kwargs) info = dict(version=spikeinterface.__version__, dev_mode=spikeinterface.DEV_MODE, object="SortingAnalyzer") zarr_root.attrs["spikeinterface_info"] = check_json(info) @@ -646,13 +652,8 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att if recording is not None: rec_dict = recording.to_dict(relative_to=relative_to, recursive=True) if recording.check_serializability("json"): - # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.JSON()) - zarr_rec = np.array([check_json(rec_dict)], dtype=object) - zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.JSON()) - elif recording.check_serializability("pickle"): - # zarr_root.create_dataset("recording", data=rec_dict, object_codec=numcodecs.Pickle()) - zarr_rec = np.array([rec_dict], dtype=object) - zarr_root.create_dataset("recording", data=zarr_rec, object_codec=numcodecs.Pickle()) + # In zarr v3, store JSON-serializable data in attributes instead of using object_codec + zarr_root.attrs["recording"] = check_json(rec_dict) else: warnings.warn("The Recording is not serializable! The recording link will be lost for future load") else: @@ -662,11 +663,8 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att # sorting provenance sort_dict = sorting.to_dict(relative_to=relative_to, recursive=True) if sorting.check_serializability("json"): - zarr_sort = np.array([check_json(sort_dict)], dtype=object) - zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.JSON()) - elif sorting.check_serializability("pickle"): - zarr_sort = np.array([sort_dict], dtype=object) - zarr_root.create_dataset("sorting_provenance", data=zarr_sort, object_codec=numcodecs.Pickle()) + # In zarr v3, store JSON-serializable data in attributes instead of using object_codec + zarr_root.attrs["sorting_provenance"] = check_json(sort_dict) else: warnings.warn( "The sorting provenance is not serializable! The sorting provenance link will be lost for future load" @@ -687,12 +685,13 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att recording_info.attrs["probegroup"] = check_json(probegroup.to_dict()) if sparsity is not None: - zarr_root.create_dataset("sparsity_mask", data=sparsity.mask, **saving_options) + zarr_root.create_array("sparsity_mask", data=sparsity.mask, **saving_options) add_sorting_to_zarr_group(sorting, zarr_root.create_group("sorting"), **saving_options) recording_info = zarr_root.create_group("extensions") + # consolidate metadata for zarr v3 zarr.consolidate_metadata(zarr_root.store) return cls.load_from_zarr(folder, recording=recording, backend_options=backend_options) @@ -704,6 +703,10 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): backend_options = {} if backend_options is None else backend_options storage_options = backend_options.get("storage_options", {}) + if not is_path_remote(str(folder)): + storage_options_kwargs = {} + else: + storage_options_kwargs = storage_options zarr_root = super_zarr_open(str(folder), mode="r", storage_options=storage_options) @@ -712,7 +715,7 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): # v0.101.0 did not have a consolidate metadata step after computing extensions. # Here we try to consolidate the metadata and throw a warning if it fails. try: - zarr_root_a = zarr.open(str(folder), mode="a", storage_options=storage_options) + zarr_root_a = zarr.open(str(folder), mode="a", **storage_options_kwargs) zarr.consolidate_metadata(zarr_root_a.store) except Exception as e: warnings.warn( @@ -730,9 +733,9 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): # load recording if possible if recording is None: - rec_field = zarr_root.get("recording") - if rec_field is not None: - rec_dict = rec_field[0] + # In zarr v3, recording is stored in attributes + rec_dict = zarr_root.attrs.get("recording", None) + if rec_dict is not None: try: recording = load(rec_dict, base_folder=folder) except: @@ -848,7 +851,7 @@ def set_sorting_property( if key in zarr_root["sorting"]["properties"]: zarr_root["sorting"]["properties"][key][:] = prop_values else: - zarr_root["sorting"]["properties"].create_dataset(name=key, data=prop_values, compressor=None) + zarr_root["sorting"]["properties"].create_array(name=key, data=prop_values, compressors=None) # IMPORTANT: we need to re-consolidate the zarr store! zarr.consolidate_metadata(zarr_root.store) @@ -1516,12 +1519,13 @@ def get_sorting_provenance(self): elif self.format == "zarr": zarr_root = self._get_zarr_root(mode="r") sorting_provenance = None - if "sorting_provenance" in zarr_root.keys(): + # In zarr v3, sorting_provenance is stored in attributes + sort_dict = zarr_root.attrs.get("sorting_provenance", None) + if sort_dict is not None: # try-except here is because it's not required to be able # to load the sorting provenance, as the user might have deleted # the original sorting folder try: - sort_dict = zarr_root["sorting_provenance"][0] sorting_provenance = load(sort_dict, base_folder=self.folder) except: pass @@ -1883,8 +1887,14 @@ def get_saved_extension_names(self): elif self.format == "zarr": zarr_root = self._get_zarr_root(mode="r") - if "extensions" in zarr_root.keys(): + # Avoid iterating zarr_root.keys() because legacy v2 stores may contain + # object-dtype arrays (e.g. "recording", "sorting_provenance") that zarr v3 + # cannot parse, causing ValueError on enumeration. + try: extension_group = zarr_root["extensions"] + except KeyError: + extension_group = None + if extension_group is not None: for extension_name in extension_group.keys(): if "params" in extension_group[extension_name].attrs.keys(): saved_extension_names.append(extension_name) @@ -2564,8 +2574,9 @@ def load_data(self): extension_group = self._get_zarr_extension_group(mode="r") for ext_data_name in extension_group.keys(): ext_data_ = extension_group[ext_data_name] - if "dict" in ext_data_.attrs: - ext_data = ext_data_[0] + # In zarr v3, check if it's a group with dict_data attribute + if "dict_data" in ext_data_.attrs: + ext_data = ext_data_.attrs["dict_data"] elif "dataframe" in ext_data_.attrs: import pandas as pd @@ -2650,9 +2661,10 @@ def run(self, save=True, **kwargs): if self.format == "zarr": import zarr - zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root(mode="r+").store) def save(self): + self._reset_extension_folder() self._save_params() self._save_importing_provenance() self._save_run_info() @@ -2661,7 +2673,7 @@ def save(self): if self.format == "zarr": import zarr - zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root().store) + zarr.consolidate_metadata(self.sorting_analyzer._get_zarr_root(mode="r+").store) def _save_data(self): if self.format == "memory": @@ -2709,42 +2721,44 @@ def _save_data(self): extension_group = self._get_zarr_extension_group(mode="r+") # if compression is not externally given, we use the default - if "compressor" not in saving_options: - saving_options["compressor"] = get_default_zarr_compressor() + if "compressors" not in saving_options and "compressor" not in saving_options: + saving_options["compressors"] = get_default_zarr_compressor() + if "compressor" in saving_options: + saving_options["compressors"] = [saving_options["compressor"]] + del saving_options["compressor"] for ext_data_name, ext_data in self.data.items(): if ext_data_name in extension_group: del extension_group[ext_data_name] - if isinstance(ext_data, (dict, list)): - ext_data_ = check_json(ext_data) - extension_group.create_dataset( - name=ext_data_name, data=np.array([ext_data_], dtype=object), object_codec=numcodecs.JSON() - ) - extension_group[ext_data_name].attrs["dict"] = True + if isinstance(ext_data, dict): + # In zarr v3, store dict in a subgroup with attributes + dict_group = extension_group.create_group(ext_data_name) + dict_group.attrs["dict_data"] = check_json(ext_data) elif isinstance(ext_data, np.ndarray): - extension_group.create_dataset(name=ext_data_name, data=ext_data, **saving_options) + extension_group.create_array(name=ext_data_name, data=ext_data, **saving_options) elif HAS_PANDAS and isinstance(ext_data, pd.DataFrame): df_group = extension_group.create_group(ext_data_name) # first we save the index indices = ext_data.index.to_numpy() if indices.dtype.kind == "O": indices = indices.astype(str) - df_group.create_dataset(name="index", data=indices) + df_group.create_array(name="index", data=indices) for col in ext_data.columns: col_data = ext_data[col].to_numpy() if col_data.dtype.kind == "O": col_data = col_data.astype(str) - df_group.create_dataset(name=col, data=col_data) + df_group.create_array(name=col, data=col_data) df_group.attrs["dataframe"] = True else: # any object - try: - extension_group.create_dataset( - name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.Pickle() - ) - except: - raise Exception(f"Could not save {ext_data_name} as extension data") - extension_group[ext_data_name].attrs["object"] = True + # try: + # extension_group.create_array( + # name=ext_data_name, data=np.array([ext_data], dtype=object), object_codec=numcodecs.Pickle() + # ) + # except: + # raise Exception(f"Could not save {ext_data_name} as extension data") + # extension_group[ext_data_name].attrs["object"] = True + warnings.warn(f"Data type of {ext_data_name} not supported for zarr saving, skipping.") def _reset_extension_folder(self): """ @@ -2822,8 +2836,6 @@ def set_params(self, save=True, **params): def _save_params(self): params_to_save = self.params.copy() - self._reset_extension_folder() - # TODO make sparsity local Result specific # if "sparsity" in params_to_save and params_to_save["sparsity"] is not None: # assert isinstance( diff --git a/src/spikeinterface/core/template.py b/src/spikeinterface/core/template.py index 67ba1179b0..834002278f 100644 --- a/src/spikeinterface/core/template.py +++ b/src/spikeinterface/core/template.py @@ -319,17 +319,19 @@ def add_templates_to_zarr_group(self, zarr_group: "zarr.Group") -> None: """ # Saves one chunk per unit - arrays_chunk = (1, None, None) - zarr_group.create_dataset("templates_array", data=self.templates_array, chunks=arrays_chunk) - zarr_group.create_dataset("channel_ids", data=self.channel_ids) - zarr_group.create_dataset("unit_ids", data=self.unit_ids) + # In zarr v3, chunks must be a full tuple with actual dimensions + num_units, num_samples, num_channels = self.templates_array.shape + arrays_chunk = (1, num_samples, num_channels) + zarr_group.create_array("templates_array", data=self.templates_array, chunks=arrays_chunk) + zarr_group.create_array("channel_ids", data=self.channel_ids) + zarr_group.create_array("unit_ids", data=self.unit_ids) zarr_group.attrs["sampling_frequency"] = self.sampling_frequency zarr_group.attrs["nbefore"] = self.nbefore zarr_group.attrs["is_in_uV"] = self.is_in_uV if self.sparsity_mask is not None: - zarr_group.create_dataset("sparsity_mask", data=self.sparsity_mask) + zarr_group.create_array("sparsity_mask", data=self.sparsity_mask) if self.probe is not None: probe_group = zarr_group.create_group("probe") diff --git a/src/spikeinterface/core/testing_tools.py b/src/spikeinterface/core/testing_tools.py index 899aa3852f..0169d5e50e 100644 --- a/src/spikeinterface/core/testing_tools.py +++ b/src/spikeinterface/core/testing_tools.py @@ -1,7 +1,7 @@ import warnings warnings.warn( - "The 'testing_tools' submodule is deprecated. " "Use spikeinterface.core.generate instead", + "The 'testing_tools' submodule is deprecated. Use spikeinterface.core.testing instead", DeprecationWarning, stacklevel=2, ) diff --git a/src/spikeinterface/core/tests/test_analyzer_extension_core.py b/src/spikeinterface/core/tests/test_analyzer_extension_core.py index fd550c729e..3bdbf981d0 100644 --- a/src/spikeinterface/core/tests/test_analyzer_extension_core.py +++ b/src/spikeinterface/core/tests/test_analyzer_extension_core.py @@ -91,7 +91,7 @@ def test_ComputeRandomSpikes(format, sparse, create_cache_folder): print("Checking results") _check_result_extension(sorting_analyzer, "random_spikes", cache_folder) - print("Delering extension") + print("Deleting extension") sorting_analyzer.delete_extension("random_spikes") print("Re-computing random spikes") diff --git a/src/spikeinterface/core/tests/test_baserecording.py b/src/spikeinterface/core/tests/test_baserecording.py index 1ebeb677c6..a96a03f70d 100644 --- a/src/spikeinterface/core/tests/test_baserecording.py +++ b/src/spikeinterface/core/tests/test_baserecording.py @@ -361,7 +361,7 @@ def test_BaseRecording(create_cache_folder): # test save to zarr compressor = get_default_zarr_compressor() - rec_zarr = rec2.save(format="zarr", folder=cache_folder / "recording", compressor=compressor) + rec_zarr = rec2.save(format="zarr", folder=cache_folder / "recording", compressors=compressor) rec_zarr_loaded = load(cache_folder / "recording.zarr") # annotations is False because Zarr adds compression ratios check_recordings_equal(rec2, rec_zarr, return_in_uV=False, check_annotations=False, check_properties=True) @@ -373,7 +373,7 @@ def test_BaseRecording(create_cache_folder): assert rec2.get_annotation(annotation_name) == rec_zarr_loaded.get_annotation(annotation_name) rec_zarr2 = rec2.save( - format="zarr", folder=cache_folder / "recording_channel_chunk", compressor=compressor, channel_chunk_size=2 + format="zarr", folder=cache_folder / "recording_channel_chunk", compressors=compressor, channel_chunk_size=2 ) rec_zarr2_loaded = load(cache_folder / "recording_channel_chunk.zarr") diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index a9bd71b5c0..2912d4f5a1 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -17,6 +17,7 @@ AnalyzerExtension, _sort_extensions_by_dependency, ) +from spikeinterface.core.zarr_tools import check_compressors_match from spikeinterface.core.analyzer_extension_core import BaseSpikeVectorExtension # to test basespikevectorextension with node pipeline @@ -41,6 +42,8 @@ def get_dataset(): integer_unit_ids = [int(id) for id in sorting.get_unit_ids()] recording = recording.rename_channels(new_channel_ids=integer_channel_ids) + # make sure the recording is serializable + recording = recording.save() sorting = sorting.rename_units(new_unit_ids=integer_unit_ids) return recording, sorting @@ -136,13 +139,12 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): _check_sorting_analyzers(sorting_analyzer, sorting, cache_folder=tmp_path) # check that compression is applied - assert ( - sorting_analyzer._get_zarr_root()["extensions"]["random_spikes"]["random_spikes_indices"].compressor.codec_id - == default_compressor.codec_id + check_compressors_match( + default_compressor, + sorting_analyzer._get_zarr_root()["extensions"]["random_spikes"]["random_spikes_indices"].compressors[0], ) - assert ( - sorting_analyzer._get_zarr_root()["extensions"]["templates"]["average"].compressor.codec_id - == default_compressor.codec_id + check_compressors_match( + default_compressor, sorting_analyzer._get_zarr_root()["extensions"]["templates"]["average"].compressors[0] ) # test select_units see https://github.com/SpikeInterface/spikeinterface/issues/3041 @@ -163,35 +165,34 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): sparsity=None, return_in_uV=False, overwrite=True, - backend_options={"saving_options": {"compressor": None}}, + backend_options={"saving_options": {"compressors": None}}, ) print(sorting_analyzer_no_compression._backend_options) sorting_analyzer_no_compression.compute(["random_spikes", "templates"]) assert ( - sorting_analyzer_no_compression._get_zarr_root()["extensions"]["random_spikes"][ - "random_spikes_indices" - ].compressor - is None + len( + sorting_analyzer_no_compression._get_zarr_root()["extensions"]["random_spikes"][ + "random_spikes_indices" + ].compressors + ) + == 0 ) - assert sorting_analyzer_no_compression._get_zarr_root()["extensions"]["templates"]["average"].compressor is None + assert len(sorting_analyzer_no_compression._get_zarr_root()["extensions"]["templates"]["average"].compressors) == 0 # test a different compressor - from numcodecs import LZMA + from zarr.codecs.numcodecs import LZMA lzma_compressor = LZMA() folder = tmp_path / "test_SortingAnalyzer_zarr_lzma.zarr" sorting_analyzer_lzma = sorting_analyzer_no_compression.save_as( - format="zarr", folder=folder, backend_options={"saving_options": {"compressor": lzma_compressor}} + format="zarr", folder=folder, backend_options={"saving_options": {"compressors": lzma_compressor}} ) - assert ( - sorting_analyzer_lzma._get_zarr_root()["extensions"]["random_spikes"][ - "random_spikes_indices" - ].compressor.codec_id - == LZMA.codec_id + check_compressors_match( + lzma_compressor, + sorting_analyzer_lzma._get_zarr_root()["extensions"]["random_spikes"]["random_spikes_indices"].compressors[0], ) - assert ( - sorting_analyzer_lzma._get_zarr_root()["extensions"]["templates"]["average"].compressor.codec_id - == LZMA.codec_id + check_compressors_match( + lzma_compressor, sorting_analyzer_lzma._get_zarr_root()["extensions"]["templates"]["average"].compressors[0] ) # test set_sorting_property diff --git a/src/spikeinterface/core/tests/test_zarr_backwards_compat.py b/src/spikeinterface/core/tests/test_zarr_backwards_compat.py new file mode 100644 index 0000000000..49d99ce8eb --- /dev/null +++ b/src/spikeinterface/core/tests/test_zarr_backwards_compat.py @@ -0,0 +1,84 @@ +""" +Tests for zarr format backward compatibility. + +Loads zarr v2 fixtures (generated with spikeinterface==0.104.0 and zarr<3) using +the current spikeinterface version, which uses zarr>=3. + +The fixtures directory is passed via the ZARR_V2_FIXTURES_PATH environment variable. +These tests are skipped when that variable is not set (i.e. in normal CI runs). + +To run locally: + ZARR_V2_FIXTURES_PATH=/tmp/zarr_v2_fixtures pytest test_zarr_backwards_compat.py -v +""" + +import json +import os +from pathlib import Path + +import numpy as np +import pytest + +import spikeinterface as si + +FIXTURES_PATH = os.environ.get("ZARR_V2_FIXTURES_PATH") + +pytestmark = pytest.mark.skipif( + FIXTURES_PATH is None, + reason="ZARR_V2_FIXTURES_PATH environment variable not set", +) + + +@pytest.fixture(scope="module") +def fixtures_dir() -> Path: + return Path(FIXTURES_PATH) + + +@pytest.fixture(scope="module") +def expected(fixtures_dir: Path) -> dict: + with open(fixtures_dir / "expected_values.json") as f: + return json.load(f) + + +def test_load_recording(fixtures_dir, expected): + recording = si.load(fixtures_dir / "recording.zarr") + exp = expected["recording"] + + assert recording.get_num_channels() == exp["num_channels"] + assert recording.get_num_segments() == exp["num_segments"] + assert recording.get_sampling_frequency() == exp["sampling_frequency"] + assert str(recording.get_dtype()) == exp["dtype"] + + for seg in range(recording.get_num_segments()): + assert recording.get_num_samples(seg) == exp["num_samples_per_segment"][seg] + + assert list(recording.get_channel_ids()) == exp["channel_ids"] + + traces = recording.get_traces(start_frame=0, end_frame=10, segment_index=0) + np.testing.assert_array_equal(traces, np.array(exp["traces_seg0_first10"])) + + +def test_load_sorting(fixtures_dir, expected): + sorting = si.load(fixtures_dir / "sorting.zarr") + exp = expected["sorting"] + + assert sorting.get_num_segments() == exp["num_segments"] + assert sorting.get_sampling_frequency() == exp["sampling_frequency"] + assert list(sorting.get_unit_ids()) == exp["unit_ids"] + + for uid in sorting.unit_ids: + spike_train = sorting.get_unit_spike_train(unit_id=uid, segment_index=0) + np.testing.assert_array_equal(spike_train, np.array(exp["spike_trains_seg0"][str(uid)])) + + +def test_load_sorting_analyzer(fixtures_dir, expected): + analyzer = si.load(fixtures_dir / "analyzer.zarr") + exp = expected["analyzer"] + + assert analyzer.get_num_units() == exp["num_units"] + assert analyzer.get_num_channels() == exp["num_channels"] + + templates_ext = analyzer.get_extension("templates") + assert templates_ext is not None, "templates extension not found in analyzer" + + templates = templates_ext.get_data() + assert list(templates.shape) == exp["templates_shape"] diff --git a/src/spikeinterface/core/tests/test_zarrextractors.py b/src/spikeinterface/core/tests/test_zarrextractors.py index cc0c60721e..0eedb645c8 100644 --- a/src/spikeinterface/core/tests/test_zarrextractors.py +++ b/src/spikeinterface/core/tests/test_zarrextractors.py @@ -10,50 +10,56 @@ generate_sorting, load, ) -from spikeinterface.core.zarrextractors import add_sorting_to_zarr_group, get_default_zarr_compressor +from spikeinterface.core.testing import check_recordings_equal +from spikeinterface.core.zarr_tools import check_compressors_match +from spikeinterface.core.zarrextractors import ( + add_sorting_to_zarr_group, + get_default_zarr_compressor, +) def test_zarr_compression_options(tmp_path): - from numcodecs import Blosc, Delta, FixedScaleOffset + from zarr.codecs.numcodecs import Delta, FixedScaleOffset + from zarr.codecs import BloscCodec, BloscShuffle recording = generate_recording(durations=[2]) recording.set_times(recording.get_times() + 100) # store in root standard normal way # default compressor - defaut_compressor = get_default_zarr_compressor() + default_compressor = get_default_zarr_compressor() # other compressor - other_compressor1 = Blosc(cname="zlib", clevel=3, shuffle=Blosc.NOSHUFFLE) - other_compressor2 = Blosc(cname="blosclz", clevel=8, shuffle=Blosc.AUTOSHUFFLE) + other_compressor1 = BloscCodec(cname="zlib", clevel=3, shuffle=BloscShuffle.noshuffle) + other_compressor2 = BloscCodec(cname="blosclz", clevel=8, shuffle=BloscShuffle.shuffle) # timestamps compressors / filters default_filters = None - other_filters1 = [FixedScaleOffset(scale=5, offset=2, dtype=recording.get_dtype())] + other_filters1 = [FixedScaleOffset(scale=5, offset=2, dtype=recording.get_dtype().str)] other_filters2 = [Delta(dtype="float64")] # default ZarrRecordingExtractor.write_recording(recording, tmp_path / "rec_default.zarr") rec_default = ZarrRecordingExtractor(tmp_path / "rec_default.zarr") - assert rec_default._root["traces_seg0"].compressor == defaut_compressor - assert rec_default._root["traces_seg0"].filters == default_filters - assert rec_default._root["times_seg0"].compressor == defaut_compressor - assert rec_default._root["times_seg0"].filters == default_filters + check_compressors_match(rec_default._root["traces_seg0"].compressors[0], default_compressor) + check_compressors_match(rec_default._root["times_seg0"].compressors[0], default_compressor) + check_compressors_match(rec_default._root["traces_seg0"].filters, default_filters) + check_compressors_match(rec_default._root["times_seg0"].filters, default_filters) # now with other compressor ZarrRecordingExtractor.write_recording( recording, tmp_path / "rec_other.zarr", - compressor=defaut_compressor, + compressors=default_compressor, filters=default_filters, compressor_by_dataset={"traces": other_compressor1, "times": other_compressor2}, filters_by_dataset={"traces": other_filters1, "times": other_filters2}, ) rec_other = ZarrRecordingExtractor(tmp_path / "rec_other.zarr") - assert rec_other._root["traces_seg0"].compressor == other_compressor1 - assert rec_other._root["traces_seg0"].filters == other_filters1 - assert rec_other._root["times_seg0"].compressor == other_compressor2 - assert rec_other._root["times_seg0"].filters == other_filters2 + check_compressors_match(rec_other._root["traces_seg0"].compressors[0], other_compressor1) + check_compressors_match(rec_other._root["traces_seg0"].filters, other_filters1) + check_compressors_match(rec_other._root["times_seg0"].compressors[0], other_compressor2) + check_compressors_match(rec_other._root["times_seg0"].filters, other_filters2) def test_ZarrSortingExtractor(tmp_path): @@ -75,6 +81,46 @@ def test_ZarrSortingExtractor(tmp_path): sorting = load(sorting.to_dict()) +def test_sharding_options(tmp_path): + recording = generate_recording(durations=[10], num_channels=20) + folder = tmp_path / "zarr_sharding" + + # explicitly specify chunks and shards + ZarrRecordingExtractor.write_recording(recording, folder, chunks=(1000, 5), shards=(5000, 10), n_jobs=2) + recording_zarr = ZarrRecordingExtractor(folder) + assert recording_zarr._root["traces_seg0"].chunks == (1000, 5) + assert recording_zarr._root["traces_seg0"].shards == (5000, 10) + check_recordings_equal(recording, recording_zarr) + + # specify shard_factor and chunk_size + folder = tmp_path / "zarr_sharding_factor" + ZarrRecordingExtractor.write_recording( + recording, folder, chunk_size=1000, channel_chunk_size=2, shard_factor=5, n_jobs=2 + ) + recording_zarr = ZarrRecordingExtractor(folder) + assert recording_zarr._root["traces_seg0"].chunks == (1000, 2) + assert recording_zarr._root["traces_seg0"].shards == (5000, 10) + check_recordings_equal(recording, recording_zarr) + + # raise error if both shards and shard_factor are provided + with pytest.raises(ValueError): + ZarrRecordingExtractor.write_recording( + recording, folder, chunk_size=1000, channel_chunk_size=2, shard_factor=5, shards=(5000, 10), n_jobs=2 + ) + + # raise error if shards is smaller than chunks + with pytest.raises(AssertionError): + ZarrRecordingExtractor.write_recording( + recording, folder, chunk_size=1000, channel_chunk_size=2, shards=(500, 10), n_jobs=2 + ) + + # raise error if shards is not a multiple of chunks + with pytest.raises(AssertionError): + ZarrRecordingExtractor.write_recording( + recording, folder, chunk_size=1000, channel_chunk_size=2, shards=(5500, 10), n_jobs=2 + ) + + if __name__ == "__main__": tmp_path = Path("tmp") test_zarr_compression_options(tmp_path) diff --git a/src/spikeinterface/core/zarr_tools.py b/src/spikeinterface/core/zarr_tools.py new file mode 100644 index 0000000000..d97630535e --- /dev/null +++ b/src/spikeinterface/core/zarr_tools.py @@ -0,0 +1,26 @@ +def check_compressors_match(comp1, comp2, skip_typesize=True): + """ + Check if two compressor objects match. + + Parameters + ---------- + comp1 : zarr.Codec | Tuple[zarr.Codec] + The first compressor object to compare. + comp2 : zarr.Codec | Tuple[zarr.Codec] + The second compressor object to compare. + skip_typesize : bool, optional + Whether to skip the typesize check, default: True + """ + if not isinstance(comp1, (list, tuple)): + assert not isinstance(comp2, list) + comp1 = [comp1] + comp2 = [comp2] + for i in range(len(comp1)): + comp1_dict = comp1[i].to_dict() + comp2_dict = comp2[i].to_dict() + if skip_typesize: + if "typesize" in comp1_dict["configuration"]: + comp1_dict["configuration"].pop("typesize", None) + if "typesize" in comp2_dict["configuration"]: + comp2_dict["configuration"].pop("typesize", None) + assert comp1_dict == comp2_dict, f"Compressor {i} does not match: {comp1_dict} != {comp2_dict}" diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 1ef5d76e5a..ebe1241d3d 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -8,9 +8,10 @@ from .base import minimum_spike_dtype, _get_class_from_string from .baserecording import BaseRecording, BaseRecordingSegment from .basesorting import BaseSorting, SpikeVectorSortingSegment -from .core_tools import define_function_from_class, check_json, retrieve_importing_provenance -from .job_tools import split_job_kwargs -from .core_tools import is_path_remote +from .core_tools import define_function_from_class, check_json, is_path_remote, retrieve_importing_provenance +from .job_tools import split_job_kwargs, fix_job_kwargs, ensure_chunk_size, ChunkRecordingExecutor + +zarr.config.set({"default_zarr_version": 3}) def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: dict | None = None): @@ -36,7 +37,7 @@ def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: d Returns ------- - root: zarr.hierarchy.Group + root: zarr.Group The zarr root group object Raises @@ -47,11 +48,13 @@ def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: d import zarr # if mode is append or read/write, we try to open the folder with zarr.open - # since zarr.open_consolidated does not support creating new groups/datasets + # In zarr v3, we use use_consolidated parameter instead of open_consolidated if mode in ("a", "r+"): open_funcs = (zarr.open,) + use_consolidated_options = (False,) else: - open_funcs = (zarr.open_consolidated, zarr.open) + open_funcs = (zarr.open,) + use_consolidated_options = (True, False) # if storage_options is None, we try to open the folder with and without anonymous access # if storage_options is not None, we try to open the folder with the given storage options @@ -63,12 +66,14 @@ def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: d root = None exception = None if is_path_remote(str(folder_path)): - for open_func in open_funcs: + for use_consolidated in use_consolidated_options: if root is not None: break for storage_options in storage_options_to_test: try: - root = open_func(str(folder_path), mode=mode, storage_options=storage_options) + root = zarr.open( + str(folder_path), mode=mode, storage_options=storage_options, use_consolidated=use_consolidated + ) break except Exception as e: exception = e @@ -76,9 +81,9 @@ def super_zarr_open(folder_path: str | Path, mode: str = "r", storage_options: d else: if not Path(folder_path).is_dir(): raise ValueError(f"Folder {folder_path} does not exist") - for open_func in open_funcs: + for use_consolidated in use_consolidated_options: try: - root = open_func(str(folder_path), mode=mode, storage_options=storage_options) + root = zarr.open(str(folder_path), mode=mode, use_consolidated=use_consolidated) break except Exception as e: exception = e @@ -129,7 +134,8 @@ def __init__( assert sampling_frequency is not None, "'sampling_frequency' attiribute not found!" assert num_segments is not None, "'num_segments' attiribute not found!" - channel_ids = np.array(channel_ids) + # zarr returns vlen-utf8 as StringDType (numpy 2.0); convert via list to classic unicode array. + channel_ids = np.array(channel_ids.tolist()) dtype = self._root["traces_seg0"].dtype @@ -167,7 +173,7 @@ def __init__( if load_compression_ratio: nbytes_segment = self._root[trace_name].nbytes - nbytes_stored_segment = self._root[trace_name].nbytes_stored + nbytes_stored_segment = self._root[trace_name].nbytes_stored() if nbytes_stored_segment > 0: cr_by_segment[segment_index] = nbytes_segment / nbytes_stored_segment else: @@ -186,7 +192,11 @@ def __init__( if "properties" in self._root: prop_group = self._root["properties"] for key in prop_group.keys(): - values = self._root["properties"][key] + values = self._root["properties"][key][:] + # zarr returns vlen-utf8 as StringDType (numpy 2.0); convert via list to classic unicode array. + if hasattr(values.dtype, "na_object") or values.dtype.kind == "O": + if values.size > 0 and isinstance(values.tolist()[0], str): + values = np.array(values.tolist()) self.set_property(key, values) # load annotations @@ -289,7 +299,7 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, BaseSorting.__init__(self, sampling_frequency, unit_ids) - spikes = np.zeros(len(spikes_group["sample_index"]), dtype=minimum_spike_dtype) + spikes = np.zeros(spikes_group["sample_index"].shape[0], dtype=minimum_spike_dtype) spikes["sample_index"] = spikes_group["sample_index"][:] spikes["unit_index"] = spikes_group["unit_index"][:] for i, (start, end) in enumerate(segment_slices_list): @@ -305,7 +315,11 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, if "properties" in self._root: prop_group = self._root["properties"] for key in prop_group.keys(): - values = self._root["properties"][key] + values = self._root["properties"][key][:] + # zarr returns vlen-utf8 as StringDType (numpy 2.0); convert via list to classic unicode array. + if hasattr(values.dtype, "na_object") or values.dtype.kind == "O": + if values.size > 0 and isinstance(values.tolist()[0], str): + values = np.array(values.tolist()) self.set_property(key, values) # load annotations @@ -402,12 +416,86 @@ def get_default_zarr_compressor(clevel: int = 5): Blosc.compressor The compressor object that can be used with the save to zarr function """ - from numcodecs import Blosc + from zarr.codecs import BloscCodec, BloscShuffle + + return BloscCodec(cname="zstd", clevel=clevel, shuffle=BloscShuffle.bitshuffle) + + +def build_codec_pipeline(filters=None, compressors=None): + """ + Build zarr v3 codec kwargs from filters and compressors. + + Classifies codecs into the three slots accepted by ``zarr.Group.create_array()``: + 1. ``filters`` — ArrayArrayCodec (e.g. Delta) + 2. ``serializer`` — ArrayBytesCodec (e.g. WavPack, BytesCodec) + 3. ``compressors``— BytesBytesCodec (e.g. BloscCodec, ZstdCodec) - return Blosc(cname="zstd", clevel=clevel, shuffle=Blosc.BITSHUFFLE) + This allows callers to pass an ArrayBytesCodec (e.g. WavPack) as a + compressor and have it placed in the correct serializer slot automatically. + + Parameters + ---------- + filters : ArrayArrayCodec or list of ArrayArrayCodec or None + Codec(s) applied before serialization. + compressors : codec or list of codecs or None + Can be a mix of ArrayBytesCodec (serializer) and BytesBytesCodec + (byte-level compressors). At most one ArrayBytesCodec is allowed. + + Returns + ------- + dict + Keyword arguments to unpack into ``zarr.Group.create_array()``. + Only keys with explicit values are included; omitted keys let zarr + use its defaults. + + Raises + ------ + ValueError + If filters contain non-ArrayArrayCodec instances, if more than one + ArrayBytesCodec is provided, or if an unrecognised codec type is passed. + """ + from zarr.abc.codec import ArrayArrayCodec, ArrayBytesCodec, BytesBytesCodec + + if filters is None: + filters = [] + if not isinstance(filters, (list, tuple)): + filters = [filters] + + if compressors is None: + compressors = [] + if not isinstance(compressors, (list, tuple)): + compressors = [compressors] + + for f in filters: + if not isinstance(f, ArrayArrayCodec): + raise ValueError(f"All filters must be ArrayArrayCodec instances, got {type(f)}") + + serializers = [c for c in compressors if isinstance(c, ArrayBytesCodec)] + byte_compressors = [c for c in compressors if isinstance(c, BytesBytesCodec)] + invalid = [c for c in compressors if not isinstance(c, (ArrayBytesCodec, BytesBytesCodec))] + + if invalid: + raise ValueError( + f"Compressors must be ArrayBytesCodec or BytesBytesCodec instances, got {[type(c) for c in invalid]}" + ) + if len(serializers) > 1: + raise ValueError("Only one ArrayBytesCodec (serializer) is allowed in the codec pipeline.") + codec_kwargs = {} + codec_kwargs["filters"] = filters + codec_kwargs["serializer"] = serializers[0] if len(serializers) == 1 else "auto" + codec_kwargs["compressors"] = byte_compressors + return codec_kwargs -def add_properties_and_annotations(zarr_group: zarr.hierarchy.Group, recording_or_sorting: BaseRecording | BaseSorting): + +def _has_string_fields(dtype: np.dtype) -> bool: + """Return True if dtype is or contains fixed-length unicode (U) sub-fields.""" + if dtype.names: + return any(_has_string_fields(dtype.fields[name][0]) for name in dtype.names) + return dtype.kind == "U" + + +def add_properties_and_annotations(zarr_group: zarr.Group, recording_or_sorting: BaseRecording | BaseSorting): # save properties prop_group = zarr_group.create_group("properties") for key in recording_or_sorting.get_property_keys(): @@ -415,13 +503,26 @@ def add_properties_and_annotations(zarr_group: zarr.hierarchy.Group, recording_o if values.dtype.kind == "O": warnings.warn(f"Property {key} not saved because it is a python Object type") continue - prop_group.create_dataset(name=key, data=values, compressor=None) + if values.dtype.names and _has_string_fields(values.dtype): + # Structured arrays with unicode sub-fields have no stable zarr v3 spec; skip them. + # Probe geometry (contact_vector) is already persisted via zarr_group.attrs["probe"]. + warnings.warn( + f"Property '{key}' not saved because it is a structured array with unicode fields, " + "which do not have a stable zarr V3 specification." + ) + continue + # Use variable-length UTF-8 (stable zarr v3 spec) for unicode arrays. + if values.dtype.kind == "U": + arr = prop_group.create_array(name=key, shape=values.shape, dtype=str, compressors=None) + arr[:] = values + else: + prop_group.create_array(name=key, data=values, compressors=None) # save annotations zarr_group.attrs["annotations"] = check_json(recording_or_sorting._annotations) -def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.Group, **kwargs): +def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.Group, **kwargs): """ Add a sorting extractor to a zarr group. @@ -429,46 +530,54 @@ def add_sorting_to_zarr_group(sorting: BaseSorting, zarr_group: zarr.hierarchy.G ---------- sorting : BaseSorting The sorting extractor object to be added to the zarr group - zarr_group : zarr.hierarchy.Group + zarr_group : zarr.Group The zarr group kwargs : dict Other arguments passed to the zarr compressor """ - from numcodecs import Delta + from zarr.codecs.numcodecs import Delta num_segments = sorting.get_num_segments() zarr_group.attrs["sampling_frequency"] = float(sorting.sampling_frequency) zarr_group.attrs["num_segments"] = int(num_segments) - zarr_group.create_dataset(name="unit_ids", data=sorting.unit_ids, compressor=None) + zarr_group.create_array(name="unit_ids", data=sorting.unit_ids, compressors=None) - compressor = kwargs.get("compressor", get_default_zarr_compressor()) + compressor = kwargs.get("compressors") or kwargs.get("compressor") + if compressor is None: + compressor = get_default_zarr_compressor() - # save sub fields + # Save sub fields of spikes as separate arrays to allow for more efficient compression and to + # avoid issues with structured arrays with unicode fields in zarr v3. + # The "segment_index" field is saved as "segment_slices" which contains the start and end indices of spikes for + # each segment, to avoid having a large array of segment indices when there are many spikes. spikes_group = zarr_group.create_group(name="spikes") spikes = sorting.to_spike_vector() for field in spikes.dtype.fields: if field != "segment_index": - spikes_group.create_dataset( - name=field, - data=spikes[field], - compressor=compressor, - filters=[Delta(dtype=spikes[field].dtype)], - ) + dtype = spikes[field].dtype + spikes_data = spikes[field] + if field == "sample_index": + # Delta filter is very effective for spike times (sample_index) + filters = [Delta(dtype=spikes[field].dtype.str)] + else: + filters = None + codec_kwargs = build_codec_pipeline(filters=filters, compressors=compressor) + spikes_group.create_array(name=field, data=spikes_data, **codec_kwargs) else: segment_slices = [] for segment_index in range(num_segments): i0, i1 = np.searchsorted(spikes["segment_index"], [segment_index, segment_index + 1]) segment_slices.append([i0, i1]) - spikes_group.create_dataset(name="segment_slices", data=segment_slices, compressor=None) + segment_slices = np.array(segment_slices, dtype="int64") + spikes_group.create_array(name="segment_slices", data=segment_slices, compressors=None) add_properties_and_annotations(zarr_group, sorting) # Recording -def add_recording_to_zarr_group( - recording: BaseRecording, zarr_group: zarr.hierarchy.Group, verbose=False, dtype=None, **kwargs -): +def add_recording_to_zarr_group(recording: BaseRecording, zarr_group: zarr.Group, verbose=False, dtype=None, **kwargs): zarr_kwargs, job_kwargs = split_job_kwargs(kwargs) + job_kwargs = fix_job_kwargs(job_kwargs) if recording.check_if_json_serializable(): zarr_group.attrs["provenance"] = check_json(recording.to_dict(recursive=True)) @@ -478,26 +587,67 @@ def add_recording_to_zarr_group( # save data (done the subclass) zarr_group.attrs["sampling_frequency"] = float(recording.get_sampling_frequency()) zarr_group.attrs["num_segments"] = int(recording.get_num_segments()) - zarr_group.create_dataset(name="channel_ids", data=recording.get_channel_ids(), compressor=None) + # Use variable-length UTF-8 (stable zarr v3 spec) instead of fixed-length unicode. + channel_ids = recording.get_channel_ids() + arr = zarr_group.create_array(name="channel_ids", data=channel_ids, compressors=None) dataset_paths = [f"traces_seg{i}" for i in range(recording.get_num_segments())] + num_channels = recording.get_num_channels() dtype = recording.get_dtype() if dtype is None else dtype - channel_chunk_size = zarr_kwargs.get("channel_chunk_size", None) - global_compressor = zarr_kwargs.pop("compressor", get_default_zarr_compressor()) + + # Compressors and filters + global_compressor = kwargs.get("compressors") or kwargs.get("compressor") + if global_compressor is None: + global_compressor = get_default_zarr_compressor() compressor_by_dataset = zarr_kwargs.pop("compressor_by_dataset", {}) global_filters = zarr_kwargs.pop("filters", None) filters_by_dataset = zarr_kwargs.pop("filters_by_dataset", {}) - compressor_traces = compressor_by_dataset.get("traces", global_compressor) filters_traces = filters_by_dataset.get("traces", global_filters) + + # Chunking and sharding + chunks = zarr_kwargs.get("chunks", None) + channel_chunk_size = zarr_kwargs.get("channel_chunk_size", None) + shards = zarr_kwargs.get("shards", None) + shard_factor = zarr_kwargs.get("shard_factor", None) + if shards is not None and shard_factor is not None: + raise ValueError("Cannot specify both 'shards' and 'shard_factor' in zarr_kwargs") + if chunks is not None and channel_chunk_size is not None: + raise ValueError("Cannot specify both 'chunks' and 'channel_chunk_size' in zarr_kwargs") + + # If not specified by chunk, we set the chunk size in the first dimension (time) to be the chunk size that we use + # for the job executor, and the chunk size in the second dimension (channels) to be either the provided + # channel_chunk_size or the total number of channels (no chunking in channels). + if chunks is not None: + job_kwargs["chunk_size"] = chunks[0] + else: + chunk_size = ensure_chunk_size(recording, **job_kwargs) + chunks = (chunk_size, channel_chunk_size if channel_chunk_size is not None else num_channels) + + if shards is not None: + assert len(shards) == len(chunks), "Shards and chunks must have the same number of dimensions" + for dim in range(len(chunks)): + assert ( + shards[dim] >= chunks[dim] and shards[dim] % chunks[dim] == 0 + ), "Shard size must be a multiple of chunk size" + # When sharding is used, chunk_size in job_kwargs is used to determine the number of samples per chunk to + # write in each job. Each process will write all chunks in a shard. + job_kwargs["chunk_size"] = shards[0] + elif shard_factor is not None: + # If shard_factor is provided, we set the shard size to be chunk_size * shard_factor in the first dimension (time), + # and to be the at most the total number of channels in the second dimension. + shards = (chunks[0] * shard_factor, min(chunks[1] * shard_factor, num_channels)) + job_kwargs["chunk_size"] = shards[0] + add_traces_to_zarr( recording=recording, zarr_group=zarr_group, dataset_paths=dataset_paths, - compressor=compressor_traces, + compressors=compressor_traces, filters=filters_traces, dtype=dtype, - channel_chunk_size=channel_chunk_size, + chunks=chunks, + shards=shards, verbose=verbose, **job_kwargs, ) @@ -517,17 +667,13 @@ def add_recording_to_zarr_group( filters_times = filters_by_dataset.get("times", global_filters) if time_vector is not None: - _ = zarr_group.create_dataset( - name=f"times_seg{segment_index}", - data=time_vector, - filters=filters_times, - compressor=compressor_times, - ) + codec_kwargs = build_codec_pipeline(filters=filters_times, compressors=compressor_times) + zarr_group.create_array(name=f"times_seg{segment_index}", data=time_vector, **codec_kwargs) elif d["t_start"] is not None: t_starts[segment_index] = d["t_start"] if np.any(~np.isnan(t_starts)): - zarr_group.create_dataset(name="t_starts", data=t_starts, compressor=None) + zarr_group.create_array(name="t_starts", data=t_starts, compressors=None) add_properties_and_annotations(zarr_group, recording) @@ -536,9 +682,10 @@ def add_traces_to_zarr( recording, zarr_group, dataset_paths, - channel_chunk_size=None, + chunks=None, + shards=None, dtype=None, - compressor=None, + compressors=None, filters=None, verbose=False, **job_kwargs, @@ -554,11 +701,13 @@ def add_traces_to_zarr( The zarr group to add traces to dataset_paths : list List of paths to traces datasets in the zarr group - channel_chunk_size : int or None, default: None (chunking in time only) + chunks : tuple or None, default: None (chunking in time only) Channels per chunk + shards : tuple or None, default: None + If not None, a tuple of (time, num_chunks_per_shard) to dtype : dtype, default: None Type of the saved data - compressor : zarr compressor or None, default: None + compressors : zarr compressor or None, default: None Zarr compressor filters : list, default: None List of zarr filters @@ -566,12 +715,6 @@ def add_traces_to_zarr( If True, output is verbose (when chunks are used) {} """ - from .job_tools import ( - ensure_chunk_size, - fix_job_kwargs, - ChunkRecordingExecutor, - ) - assert dataset_paths is not None, "Provide 'file_path'" if not isinstance(dataset_paths, list): @@ -581,8 +724,7 @@ def add_traces_to_zarr( if dtype is None: dtype = recording.get_dtype() - job_kwargs = fix_job_kwargs(job_kwargs) - chunk_size = ensure_chunk_size(recording, **job_kwargs) + codec_kwargs = build_codec_pipeline(filters=filters, compressors=compressors) # create zarr datasets files zarr_datasets = [] @@ -591,13 +733,9 @@ def add_traces_to_zarr( num_channels = recording.get_num_channels() dset_name = dataset_paths[segment_index] shape = (num_frames, num_channels) - dset = zarr_group.create_dataset( - name=dset_name, - shape=shape, - chunks=(chunk_size, channel_chunk_size), - dtype=dtype, - filters=filters, - compressor=compressor, + # In zarr v3, chunks must be a tuple of integers (no None allowed) + dset = zarr_group.create_array( + name=dset_name, shape=shape, chunks=chunks, shards=shards, dtype=dtype, **codec_kwargs ) zarr_datasets.append(dset) # synchronizer=zarr.ThreadSynchronizer()) diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index b89999d088..c5c7639209 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -307,8 +307,8 @@ def _get_backend_from_local_file(file_path: str | Path) -> str: try: import zarr - with zarr.open(file_path, "r") as f: - backend = "zarr" + _ = zarr.open(file_path, mode="r") + backend = "zarr" except: raise RuntimeError(f"{file_path} is not a valid Zarr folder!") else: @@ -333,7 +333,8 @@ def _find_neurodata_type_from_backend(group, path="", result=None, neurodata_typ if result is None: result = [] - for neurodata_name, value in group.items(): + for neurodata_name in group.keys(): + value = group[neurodata_name] # Check if it's a group and if it has the neurodata_type if isinstance(value, group_class): current_path = f"{path}/{neurodata_name}" if path else neurodata_name @@ -1409,7 +1410,8 @@ def _find_timeseries_from_backend(group, path="", result=None, backend="hdf5"): if result is None: result = [] - for name, value in group.items(): + for name in group.keys(): + value = group[name] if isinstance(value, group_class): current_path = f"{path}/{name}" if path else name if value.attrs.get("neurodata_type") == "TimeSeries": diff --git a/src/spikeinterface/preprocessing/tests/test_scaling.py b/src/spikeinterface/preprocessing/tests/test_scaling.py index a19d116b16..27f1de8542 100644 --- a/src/spikeinterface/preprocessing/tests/test_scaling.py +++ b/src/spikeinterface/preprocessing/tests/test_scaling.py @@ -1,6 +1,6 @@ import pytest import numpy as np -from spikeinterface.core.testing_tools import generate_recording +from spikeinterface.core.generate import generate_recording from spikeinterface.preprocessing.preprocessing_classes import scale_to_uV, CenterRecording, scale_to_physical_units