diff --git a/.gitignore b/.gitignore index dafbf664..9a6001e3 100644 --- a/.gitignore +++ b/.gitignore @@ -52,3 +52,4 @@ node_modules/ .mypy_cache .ruff_cache +uv.lock diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 2fb48350..cc94d6f6 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -40,6 +40,7 @@ "deepcopy", "sanitize_table", "sanitize_name", + "settings", ] from spatialdata import dataloader, datasets, models, transformations @@ -70,3 +71,4 @@ from spatialdata._io.format import SpatialDataFormatType from spatialdata._io.io_zarr import read_zarr from spatialdata._utils import get_pyramid_levels, unpad_raster +from spatialdata.config import settings diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index d16ea378..b251548a 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -1110,6 +1110,7 @@ def write( consolidate_metadata: bool = True, update_sdata_path: bool = True, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, + shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ Write the `SpatialData` object to a Zarr store. @@ -1154,6 +1155,9 @@ def write( unspecified, the element formats will be set to the latest element format compatible with the specified SpatialData container format. All the formats and relationships between them are defined in `spatialdata._io.format.py`. + shapes_geometry_encoding + Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` + for details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`. """ from spatialdata._io._utils import _resolve_zarr_store from spatialdata._io.format import _parse_formats @@ -1179,6 +1183,7 @@ def write( element_name=element_name, overwrite=False, parsed_formats=parsed, + shapes_geometry_encoding=shapes_geometry_encoding, ) if self.path != file_path and update_sdata_path: @@ -1195,6 +1200,7 @@ def _write_element( element_name: str, overwrite: bool, parsed_formats: dict[str, SpatialDataFormatType] | None = None, + shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: from spatialdata._io.io_zarr import _get_groups_for_element @@ -1247,6 +1253,7 @@ def _write_element( shapes=element, group=element_group, element_format=parsed_formats["shapes"], + geometry_encoding=shapes_geometry_encoding, ) elif element_type == "tables": write_table( @@ -1263,6 +1270,7 @@ def write_element( element_name: str | list[str], overwrite: bool = False, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, + shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """ Write a single element, or a list of elements, to the Zarr store used for backing. @@ -1278,6 +1286,9 @@ def write_element( sdata_formats It is recommended to leave this parameter equal to `None`. See more details in the documentation of `SpatialData.write()`. + shapes_geometry_encoding + Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` + for details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`. Notes ----- @@ -1291,7 +1302,12 @@ def write_element( if isinstance(element_name, list): for name in element_name: assert isinstance(name, str) - self.write_element(name, overwrite=overwrite, sdata_formats=sdata_formats) + self.write_element( + name, + overwrite=overwrite, + sdata_formats=sdata_formats, + shapes_geometry_encoding=shapes_geometry_encoding, + ) return check_valid_name(element_name) @@ -1325,6 +1341,7 @@ def write_element( element_name=element_name, overwrite=overwrite, parsed_formats=parsed_formats, + shapes_geometry_encoding=shapes_geometry_encoding, ) # After every write, metadata should be consolidated, otherwise this can lead to IO problems like when deleting. if self.has_consolidated_metadata(): diff --git a/src/spatialdata/_io/io_shapes.py b/src/spatialdata/_io/io_shapes.py index 8e6d4a60..65cb099a 100644 --- a/src/spatialdata/_io/io_shapes.py +++ b/src/spatialdata/_io/io_shapes.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any +from typing import Any, Literal import numpy as np import zarr @@ -70,6 +70,7 @@ def write_shapes( group: zarr.Group, group_type: str = "ngff:shapes", element_format: Format = CurrentShapesFormat(), + geometry_encoding: Literal["WKB", "geoarrow"] | None = None, ) -> None: """Write shapes to spatialdata zarr store. @@ -86,7 +87,15 @@ def write_shapes( The type of the element. element_format The format of the shapes element used to store it. + geometry_encoding + Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` for + details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`. """ + from spatialdata.config import settings + + if geometry_encoding is None: + geometry_encoding = settings.shapes_geometry_encoding + axes = get_axes_names(shapes) transformations = _get_transformations(shapes) if transformations is None: @@ -94,7 +103,7 @@ def write_shapes( if isinstance(element_format, ShapesFormatV01): attrs = _write_shapes_v01(shapes, group, element_format) elif isinstance(element_format, ShapesFormatV02 | ShapesFormatV03): - attrs = _write_shapes_v02_v03(shapes, group, element_format) + attrs = _write_shapes_v02_v03(shapes, group, element_format, geometry_encoding=geometry_encoding) else: raise ValueError(f"Unsupported format version {element_format.version}. Please update the spatialdata library.") @@ -139,7 +148,9 @@ def _write_shapes_v01(shapes: GeoDataFrame, group: zarr.Group, element_format: F return attrs -def _write_shapes_v02_v03(shapes: GeoDataFrame, group: zarr.Group, element_format: Format) -> Any: +def _write_shapes_v02_v03( + shapes: GeoDataFrame, group: zarr.Group, element_format: Format, geometry_encoding: Literal["WKB", "geoarrow"] +) -> Any: """Write shapes to spatialdata zarr store using format ShapesFormatV02 or ShapesFormatV03. Parameters @@ -150,6 +161,9 @@ def _write_shapes_v02_v03(shapes: GeoDataFrame, group: zarr.Group, element_forma The zarr group in the 'shapes' zarr group to write the shapes element to. element_format The format of the shapes element used to store it. + geometry_encoding + Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` for + details. """ from spatialdata.models._utils import TRANSFORM_KEY @@ -159,7 +173,7 @@ def _write_shapes_v02_v03(shapes: GeoDataFrame, group: zarr.Group, element_forma # Temporarily remove transformations from attrs to avoid serialization issues transforms = shapes.attrs[TRANSFORM_KEY] del shapes.attrs[TRANSFORM_KEY] - shapes.to_parquet(path) + shapes.to_parquet(path, geometry_encoding=geometry_encoding) shapes.attrs[TRANSFORM_KEY] = transforms attrs = element_format.attrs_to_dict(shapes.attrs) diff --git a/src/spatialdata/config.py b/src/spatialdata/config.py index 309f20e4..dab848b3 100644 --- a/src/spatialdata/config.py +++ b/src/spatialdata/config.py @@ -1,4 +1,28 @@ -# chunk sizes bigger than this value (bytes) can trigger a compression error -# https://github.com/scverse/spatialdata/issues/812#issuecomment-2559380276 -# so if we detect this during parsing/validation we raise a warning -LARGE_CHUNK_THRESHOLD_BYTES = 2147483647 +from dataclasses import dataclass +from typing import Literal + + +@dataclass +class Settings: + """Global settings for spatialdata. + + Attributes + ---------- + shapes_geometry_encoding + Default geometry encoding for GeoParquet files when writing shapes. + Can be "WKB" (Well-Known Binary) or "geoarrow". + See :meth:`geopandas.GeoDataFrame.to_parquet` for details. + large_chunk_threshold_bytes + Chunk sizes bigger than this value (bytes) can trigger a compression error. + See https://github.com/scverse/spatialdata/issues/812#issuecomment-2559380276 + If detected during parsing/validation, a warning is raised. + """ + + shapes_geometry_encoding: Literal["WKB", "geoarrow"] = "WKB" + large_chunk_threshold_bytes: int = 2147483647 + + +settings = Settings() + +# Backwards compatibility alias +LARGE_CHUNK_THRESHOLD_BYTES = settings.large_chunk_threshold_bytes diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index 92acbec4..ddda2a61 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -35,7 +35,7 @@ from spatialdata._logging import logger from spatialdata._types import ArrayLike from spatialdata._utils import _check_match_length_channels_c_dim -from spatialdata.config import LARGE_CHUNK_THRESHOLD_BYTES +from spatialdata.config import settings from spatialdata.models import C, X, Y, Z, get_axes_names from spatialdata.models._utils import ( DEFAULT_COORDINATE_SYSTEM, @@ -315,9 +315,9 @@ def _check_chunk_size_not_too_large(self, data: DataArray | DataTree) -> None: return n_elems = np.array(list(max_per_dimension.values())).prod().item() usage = n_elems * data.dtype.itemsize - if usage > LARGE_CHUNK_THRESHOLD_BYTES: + if usage > settings.large_chunk_threshold_bytes: warnings.warn( - f"Detected chunks larger than: {usage} > {LARGE_CHUNK_THRESHOLD_BYTES} bytes. " + f"Detected chunks larger than: {usage} > {settings.large_chunk_threshold_bytes} bytes. " "This can lead to low " "performance and memory issues downstream, and sometimes cause compression errors when writing " "(https://github.com/scverse/spatialdata/issues/812#issuecomment-2575983527). Please consider using" @@ -327,7 +327,7 @@ def _check_chunk_size_not_too_large(self, data: DataArray | DataTree) -> None: "2) Multiscale representations can be achieved by using the `scale_factors` argument in the " "`parse()` function.\n" "You can suppress this warning by increasing the value of " - "`spatialdata.config.LARGE_CHUNK_THRESHOLD_BYTES`.", + "`spatialdata.settings.large_chunk_threshold_bytes`.", UserWarning, stacklevel=2, ) diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 11855a22..7ecd7420 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -3,17 +3,21 @@ import tempfile from collections.abc import Callable from pathlib import Path -from typing import Any +from typing import Any, Literal import dask.dataframe as dd import numpy as np +import pandas as pd +import pyarrow.parquet as pq import pytest import zarr from anndata import AnnData from numpy.random import default_rng +from shapely import MultiPolygon, Polygon from upath import UPath from zarr.errors import GroupNotFoundError +import spatialdata.config from spatialdata import SpatialData, deepcopy, read_zarr from spatialdata._core.validation import ValidationError from spatialdata._io._utils import _are_directories_identical, get_dask_backing_files @@ -74,20 +78,90 @@ def test_labels( sdata = SpatialData.read(tmpdir) assert_spatial_data_objects_are_identical(labels, sdata) + @pytest.mark.parametrize("geometry_encoding", ["WKB", "geoarrow"]) def test_shapes( self, tmp_path: str, shapes: SpatialData, sdata_container_format: SpatialDataContainerFormatType, + geometry_encoding: Literal["WKB", "geoarrow"], ) -> None: tmpdir = Path(tmp_path) / "tmp.zarr" # check the index is correctly written and then read shapes["circles"].index = np.arange(1, len(shapes["circles"]) + 1) - shapes.write(tmpdir, sdata_formats=sdata_container_format) + # add a mixed Polygon + MultiPolygon element + shapes["mixed"] = pd.concat([shapes["poly"], shapes["multipoly"]]) + + shapes.write(tmpdir, sdata_formats=sdata_container_format, shapes_geometry_encoding=geometry_encoding) sdata = SpatialData.read(tmpdir) - assert_spatial_data_objects_are_identical(shapes, sdata) + + if geometry_encoding == "WKB": + assert_spatial_data_objects_are_identical(shapes, sdata) + else: + # convert each Polygon to a MultiPolygon + mixed_multipolygon = shapes["mixed"].assign( + geometry=lambda df: df.geometry.apply(lambda g: MultiPolygon([g]) if isinstance(g, Polygon) else g) + ) + assert sdata["mixed"].equals(mixed_multipolygon) + assert not sdata["mixed"].equals(shapes["mixed"]) + + del shapes["mixed"] + del sdata["mixed"] + assert_spatial_data_objects_are_identical(shapes, sdata) + + @pytest.mark.parametrize("geometry_encoding", ["WKB", "geoarrow"]) + def test_shapes_geometry_encoding_write_element( + self, + tmp_path: str, + shapes: SpatialData, + sdata_container_format: SpatialDataContainerFormatType, + geometry_encoding: Literal["WKB", "geoarrow"], + ) -> None: + """Test shapes geometry encoding with write_element() and global settings.""" + tmpdir = Path(tmp_path) / "tmp.zarr" + + # First write an empty SpatialData to create the zarr store + empty_sdata = SpatialData() + empty_sdata.write(tmpdir, sdata_formats=sdata_container_format) + + shapes["mixed"] = pd.concat([shapes["poly"], shapes["multipoly"]]) + + # Add shapes to the empty sdata + for shape_name in shapes.shapes: + empty_sdata[shape_name] = shapes[shape_name] + + # Store original setting and set global encoding + original_encoding = spatialdata.config.settings.shapes_geometry_encoding + try: + spatialdata.config.settings.shapes_geometry_encoding = geometry_encoding + + # Write each shape element - should use global setting + for shape_name in shapes.shapes: + empty_sdata.write_element(shape_name, sdata_formats=sdata_container_format) + + # Verify the encoding metadata in the parquet file + parquet_file = tmpdir / "shapes" / shape_name / "shapes.parquet" + with pq.ParquetFile(parquet_file) as pf: + md = pf.metadata + d = json.loads(md.metadata[b"geo"].decode("utf-8")) + found_encoding = d["columns"]["geometry"]["encoding"] + if geometry_encoding == "WKB": + expected_encoding = "WKB" + elif shape_name == "circles": + expected_encoding = "point" + elif shape_name == "poly": + expected_encoding = "polygon" + elif shape_name in ["multipoly", "mixed"]: + expected_encoding = "multipolygon" + else: + raise ValueError( + f"Uncovered case for shape_name: {shape_name}, found encoding: {found_encoding}." + ) + assert found_encoding == expected_encoding + finally: + spatialdata.config.settings.shapes_geometry_encoding = original_encoding def test_points( self,