diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index bed33ff1..9ddea9c4 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -185,7 +185,7 @@ def parse( # if there are no dims in the data, use the model's dims or provided dims elif isinstance(data, np.ndarray | DaskArray): if not isinstance(data, DaskArray): # numpy -> dask - data = from_array(data) + data = from_array(data.data) if dims is None: dims = cls.dims.dims else: @@ -239,6 +239,10 @@ def parse( chunks=chunks, ) _parse_transformations(data, parsed_transform) + else: + # Chunk single scale images + if chunks is not None: + data = data.chunk(chunks=chunks) cls()._check_chunk_size_not_too_large(data) # recompute coordinates for (multiscale) spatial image return compute_coordinates(data) diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 2ed108b7..7c7087b8 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -195,6 +195,45 @@ def test_raster_schema( with pytest.raises(ValueError): model.parse(image, **kwargs) + @pytest.mark.parametrize( + "model,chunks,expected", + [ + (Labels2DModel, None, (10, 10)), + (Labels2DModel, 5, (5, 5)), + (Labels2DModel, (5, 5), (5, 5)), + (Labels2DModel, {"x": 5, "y": 5}, (5, 5)), + (Labels3DModel, None, (1, 10, 10)), + (Labels3DModel, 5, (1, 5, 5)), + (Labels3DModel, (1, 5, 5), (1, 5, 5)), + (Labels3DModel, {"z": 1, "x": 5, "y": 5}, (1, 5, 5)), + (Image2DModel, None, (1, 10, 10)), # Image2D Models always have a c dimension + (Image2DModel, 5, (1, 5, 5)), + (Image2DModel, (1, 5, 5), (1, 5, 5)), + (Image2DModel, {"c": 1, "x": 5, "y": 5}, (1, 5, 5)), + (Image3DModel, None, (1, 1, 10, 10)), # Image3D models have z in addition, so 4 total dimensions + (Image3DModel, 5, (1, 1, 5, 5)), + (Image3DModel, (1, 1, 5, 5), (1, 1, 5, 5)), + ( + Image3DModel, + {"c": 1, "z": 1, "x": 5, "y": 5}, + (1, 1, 5, 5), + ), + ], + ) + def test_raster_models_parse_with_chunks_parameter(self, model, chunks, expected): + dims = np.array(model.dims.dims).tolist() + n_dims = len(dims) + + image: ArrayLike = np.arange(100).reshape((10, 10)) + if n_dims == 3: + image = np.expand_dims(image, axis=0) + + if n_dims == 4: + image = np.expand_dims(image, axis=(0, 1)) + + x = model.parse(image, chunks=chunks) + assert x.data.chunksize == expected + @pytest.mark.parametrize("model", [Labels2DModel, Labels3DModel]) def test_labels_model_with_multiscales(self, model): # Passing "scale_factors" should generate multiscales with a "method" appropriate for labels