diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index f92bc9f5..21bd6c5b 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -16,6 +16,7 @@ from dask.dataframe import Scalar, read_parquet from geopandas import GeoDataFrame from shapely import MultiPolygon, Polygon +from upath import UPath from xarray import DataArray, DataTree from zarr.errors import GroupNotFoundError @@ -1810,7 +1811,9 @@ def tables(self, tables: dict[str, AnnData]) -> None: @staticmethod def read( - file_path: Path | str, selection: tuple[str] | None = None, reconsolidate_metadata: bool = False + file_path: str | Path | UPath | zarr.Group, + selection: tuple[str] | None = None, + reconsolidate_metadata: bool = False, ) -> SpatialData: """ Read a SpatialData object from a Zarr storage (on-disk or remote). @@ -1818,7 +1821,7 @@ def read( Parameters ---------- file_path - The path or URL to the Zarr storage. + The path, URL, or zarr.Group to the Zarr storage. selection The elements to read (images, labels, points, shapes, table). If None, all elements are read. reconsolidate_metadata diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index a8e194a7..b58d6744 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -470,8 +470,8 @@ def _resolve_zarr_store( if isinstance(path, zarr.Group): # if the input is a zarr.Group, wrap it with a store if isinstance(path.store, LocalStore): - # create a simple FSStore if the store is a LocalStore with just the path - return FsspecStore(os.path.join(path.store.path, path.path), **kwargs) + store_path = UPath(path.store.root) / path.path + return LocalStore(store_path.path) if isinstance(path.store, FsspecStore): # if the store within the zarr.Group is an FSStore, return it # but extend the path of the store with that of the zarr.Group diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index d6de1665..98919d61 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -11,6 +11,7 @@ from geopandas import GeoDataFrame from ome_zarr.format import Format from pyarrow import ArrowInvalid +from upath import UPath from zarr.errors import ArrayNotFoundError from spatialdata._core.spatialdata import SpatialData @@ -120,7 +121,7 @@ def get_raster_format_for_read( def read_zarr( - store: str | Path, + store: str | Path | UPath | zarr.Group, selection: None | tuple[str] = None, on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR, ) -> SpatialData: @@ -130,7 +131,7 @@ def read_zarr( Parameters ---------- store - Path to the zarr store (on-disk or remote). + Path, URL, or zarr.Group to the zarr store (on-disk or remote). selection List of elements to read from the zarr store (images, labels, points, shapes, table). If None, all elements are @@ -228,7 +229,7 @@ def read_zarr( tables=tables, attrs=attrs, ) - sdata.path = Path(store) + sdata.path = resolved_store.root return sdata diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 6e948f51..11855a22 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -11,6 +11,7 @@ import zarr from anndata import AnnData from numpy.random import default_rng +from upath import UPath from zarr.errors import GroupNotFoundError from spatialdata import SpatialData, deepcopy, read_zarr @@ -963,3 +964,30 @@ def test_can_read_sdata_with_reconsolidation(full_sdata, sdata_container_format: new_sdata = SpatialData.read(path, reconsolidate_metadata=True) assert_spatial_data_objects_are_identical(full_sdata, new_sdata) + + +def test_read_sdata(tmp_path: Path, points: SpatialData) -> None: + sdata_path = tmp_path / "sdata.zarr" + points.write(sdata_path) + + # path as Path + sdata_from_path = SpatialData.read(sdata_path) + assert sdata_from_path.path == sdata_path + + # path as str + sdata_from_str = SpatialData.read(str(sdata_path)) + assert sdata_from_str.path == sdata_path + + # path as UPath + sdata_from_upath = SpatialData.read(UPath(sdata_path)) + assert sdata_from_upath.path == sdata_path + + # path as zarr Group + zarr_group = zarr.open_group(sdata_path, mode="r") + sdata_from_zarr_group = SpatialData.read(zarr_group) + assert sdata_from_zarr_group.path == sdata_path + + # Assert all read methods produce identical SpatialData objects + assert_spatial_data_objects_are_identical(sdata_from_path, sdata_from_str) + assert_spatial_data_objects_are_identical(sdata_from_path, sdata_from_upath) + assert_spatial_data_objects_are_identical(sdata_from_path, sdata_from_zarr_group)