Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from dask.dataframe import Scalar, read_parquet
from geopandas import GeoDataFrame
from shapely import MultiPolygon, Polygon
from upath import UPath
from xarray import DataArray, DataTree
from zarr.errors import GroupNotFoundError

Expand Down Expand Up @@ -1810,15 +1811,17 @@ def tables(self, tables: dict[str, AnnData]) -> None:

@staticmethod
def read(
file_path: Path | str, selection: tuple[str] | None = None, reconsolidate_metadata: bool = False
file_path: str | Path | UPath | zarr.Group,
selection: tuple[str] | None = None,
reconsolidate_metadata: bool = False,
) -> SpatialData:
"""
Read a SpatialData object from a Zarr storage (on-disk or remote).

Parameters
----------
file_path
The path or URL to the Zarr storage.
The path, URL, or zarr.Group to the Zarr storage.
selection
The elements to read (images, labels, points, shapes, table). If None, all elements are read.
reconsolidate_metadata
Expand Down
4 changes: 2 additions & 2 deletions src/spatialdata/_io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,8 @@ def _resolve_zarr_store(
if isinstance(path, zarr.Group):
# if the input is a zarr.Group, wrap it with a store
if isinstance(path.store, LocalStore):
# create a simple FSStore if the store is a LocalStore with just the path
return FsspecStore(os.path.join(path.store.path, path.path), **kwargs)
store_path = UPath(path.store.root) / path.path
return LocalStore(store_path.path)
if isinstance(path.store, FsspecStore):
# if the store within the zarr.Group is an FSStore, return it
# but extend the path of the store with that of the zarr.Group
Expand Down
7 changes: 4 additions & 3 deletions src/spatialdata/_io/io_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from geopandas import GeoDataFrame
from ome_zarr.format import Format
from pyarrow import ArrowInvalid
from upath import UPath
from zarr.errors import ArrayNotFoundError

from spatialdata._core.spatialdata import SpatialData
Expand Down Expand Up @@ -120,7 +121,7 @@ def get_raster_format_for_read(


def read_zarr(
store: str | Path,
store: str | Path | UPath | zarr.Group,
selection: None | tuple[str] = None,
on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR,
) -> SpatialData:
Expand All @@ -130,7 +131,7 @@ def read_zarr(
Parameters
----------
store
Path to the zarr store (on-disk or remote).
Path, URL, or zarr.Group to the zarr store (on-disk or remote).

selection
List of elements to read from the zarr store (images, labels, points, shapes, table). If None, all elements are
Expand Down Expand Up @@ -228,7 +229,7 @@ def read_zarr(
tables=tables,
attrs=attrs,
)
sdata.path = Path(store)
sdata.path = resolved_store.root
return sdata


Expand Down
28 changes: 28 additions & 0 deletions tests/io/test_readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import zarr
from anndata import AnnData
from numpy.random import default_rng
from upath import UPath
from zarr.errors import GroupNotFoundError

from spatialdata import SpatialData, deepcopy, read_zarr
Expand Down Expand Up @@ -963,3 +964,30 @@ def test_can_read_sdata_with_reconsolidation(full_sdata, sdata_container_format:

new_sdata = SpatialData.read(path, reconsolidate_metadata=True)
assert_spatial_data_objects_are_identical(full_sdata, new_sdata)


def test_read_sdata(tmp_path: Path, points: SpatialData) -> None:
sdata_path = tmp_path / "sdata.zarr"
points.write(sdata_path)

# path as Path
sdata_from_path = SpatialData.read(sdata_path)
assert sdata_from_path.path == sdata_path

# path as str
sdata_from_str = SpatialData.read(str(sdata_path))
assert sdata_from_str.path == sdata_path

# path as UPath
sdata_from_upath = SpatialData.read(UPath(sdata_path))
assert sdata_from_upath.path == sdata_path

# path as zarr Group
zarr_group = zarr.open_group(sdata_path, mode="r")
sdata_from_zarr_group = SpatialData.read(zarr_group)
assert sdata_from_zarr_group.path == sdata_path

# Assert all read methods produce identical SpatialData objects
assert_spatial_data_objects_are_identical(sdata_from_path, sdata_from_str)
assert_spatial_data_objects_are_identical(sdata_from_path, sdata_from_upath)
assert_spatial_data_objects_are_identical(sdata_from_path, sdata_from_zarr_group)
Loading