-
Notifications
You must be signed in to change notification settings - Fork 88
feat(centroids): add return_area + SpatialData persist to get_centroids #1150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cb6ced4
18c054e
7dd7d33
c218e90
1e78c54
8077186
460ec76
5f5fdd0
2091cbe
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,24 +1,33 @@ | ||
| from __future__ import annotations | ||
|
|
||
| from functools import singledispatch | ||
| from typing import Literal | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
| import xarray as xr | ||
| from anndata import AnnData | ||
| from dask.dataframe import DataFrame as DaskDataFrame | ||
| from geopandas import GeoDataFrame | ||
| from shapely import MultiPolygon, Point, Polygon | ||
| from xarray import DataArray, DataTree | ||
|
|
||
| from spatialdata._core.operations.transform import transform | ||
| from spatialdata.models import get_axes_names | ||
| from spatialdata._core.spatialdata import SpatialData | ||
| from spatialdata._utils import _affine_matrix_multiplication | ||
| from spatialdata.models import get_axes_names, get_table_keys | ||
| from spatialdata.models._utils import SpatialElement | ||
| from spatialdata.models.models import Labels2DModel, Labels3DModel, PointsModel, get_model | ||
| from spatialdata.models.models import Labels2DModel, Labels3DModel, PointsModel, ShapesModel, get_model | ||
| from spatialdata.transformations.operations import get_transformation | ||
| from spatialdata.transformations.transformations import BaseTransformation | ||
| from spatialdata.transformations.transformations import BaseTransformation, Identity | ||
|
|
||
| BoundingBoxDescription = dict[str, tuple[float, float]] | ||
|
|
||
| PersistAs = Literal["Points", "adata"] | ||
| # squidpy-style storage keys for persist_as="adata". | ||
| _SPATIAL_KEY = "spatial" | ||
| _AREA_KEY = "area" | ||
|
|
||
|
|
||
| def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> None: | ||
| d = get_transformation(e, get_all=True) | ||
|
|
@@ -29,37 +38,92 @@ def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> No | |
| ) | ||
|
|
||
|
|
||
| def _validate_persist_args(persist_as: str, coordinate_system: str | None, *, allow_adata: bool) -> None: | ||
| if persist_as not in ("Points", "adata"): | ||
| raise ValueError(f"`persist_as` must be 'Points' or 'adata', got {persist_as!r}.") | ||
| if persist_as == "adata" and not allow_adata: | ||
| raise ValueError( | ||
| "persist_as='adata' writes centroids into the element's annotating table, which needs the " | ||
| "`SpatialData` object: call `get_centroids(sdata, element_name, ..., persist_as='adata')`. " | ||
| "To get the centroids as a standalone element instead, use persist_as='Points'." | ||
| ) | ||
| # ``coordinate_system=None`` means "intrinsic coordinates, do not transform". An intrinsic Points | ||
| # element is ill-defined (Points always carry a coordinate system), so intrinsic coords are only | ||
| # meaningful when writing into a table (persist_as='adata'). | ||
| if coordinate_system is None and persist_as != "adata": | ||
| raise ValueError("`coordinate_system=None` (intrinsic coordinates) is only supported with persist_as='adata'.") | ||
|
|
||
|
|
||
| def _transform_centroid_coords( | ||
| xy: np.ndarray, axes: list[str], e: SpatialElement, coordinate_system: str | None | ||
| ) -> np.ndarray: | ||
| """Apply the element's affine to centroid coords in-memory; ``None``/``Identity`` pass through. | ||
|
|
||
| ``axes`` is the column order of ``xy`` (e.g. ``["x", "y"]``). | ||
|
Comment on lines
+57
to
+62
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this already works for any axes name and order, but the function makes it looks like it's only working for x and y. Better renaming. |
||
| """ | ||
| if coordinate_system is None: | ||
| return xy | ||
| t = get_transformation(e, coordinate_system) | ||
| assert isinstance(t, BaseTransformation) | ||
| if isinstance(t, Identity): | ||
| return xy | ||
| matrix = t.to_affine_matrix(input_axes=tuple(axes), output_axes=tuple(axes)) | ||
| return _affine_matrix_multiplication(matrix, xy) | ||
|
Comment on lines
+70
to
+71
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably the matrix multiplication is slower than just a transposition, but if this is a bottleneck we'll find out later during profiling. Let's keep as is. |
||
|
|
||
|
|
||
| @singledispatch | ||
| def get_centroids( | ||
| e: SpatialElement, | ||
| coordinate_system: str = "global", | ||
| e: SpatialElement | SpatialData, | ||
| coordinate_system: str | None = "global", | ||
| return_background: bool = False, | ||
| ) -> DaskDataFrame: | ||
| return_area: bool = False, | ||
| persist_as: PersistAs = "Points", | ||
| ) -> DaskDataFrame | AnnData | None: | ||
| """ | ||
| Get the centroids of the geometries contained in a SpatialElement, as a new Points element. | ||
| Get the centroids of the geometries contained in a SpatialElement. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| e | ||
| The SpatialElement. Only points, shapes (circles, polygons and multipolygons) and labels are supported. | ||
| The SpatialElement (points, shapes — circles, polygons and multipolygons — or labels), or a | ||
| :class:`~spatialdata.SpatialData` object. When a ``SpatialData`` is passed, the second | ||
| positional argument is the name of the element to measure (see the ``SpatialData`` overload). | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not a big fan of this (positional argument depending on the single dispatch argument), but it works.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's keep as is. |
||
| coordinate_system | ||
| The coordinate system in which the centroids are computed. | ||
| The coordinate system in which the centroids are computed. ``None`` returns the intrinsic | ||
| coordinates without applying any transformation (only supported with ``persist_as="adata"``). | ||
| return_background | ||
| If True, the centroid of the background label (0) is included in the output. | ||
| If True, the centroid of the background label (0) is included in the output (labels only). | ||
| return_area | ||
| If True, also return the per-instance area: the pixel/voxel count for labels and the geometric | ||
| area for shapes (``pi * r**2`` for circles). Not supported for points (raises). With | ||
| ``persist_as="Points"`` the area is added as a feature column of the returned Points element. | ||
|
Comment on lines
+96
to
+99
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is a design problem: the area is always returned untransformed (=in the intrinsic coordinate system). I think we should be consistent: if we transform the centroids (and we do this) we should transform the area. Otherwise we should not transform the centroids (or we should not return the area unless we have an identity/intrinsic). I like to return the area, so I would consider always requiring
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This requires to make a decision regarding the design, I don't have a solution in mind. |
||
| persist_as | ||
| ``"Points"`` (default) returns the centroids as a new Points element, transformed into | ||
| ``coordinate_system``. ``"adata"`` writes the centroids (and area) into the element's | ||
| annotating table and is only available through the :class:`~spatialdata.SpatialData` overload, | ||
| which can resolve that table. | ||
|
|
||
| Returns | ||
| ------- | ||
| A Points element (``persist_as="Points"``). With ``persist_as="adata"`` (``SpatialData`` overload), | ||
| ``None`` when written in place, or the new ``AnnData`` table when ``inplace=False``. | ||
|
|
||
| Notes | ||
| ----- | ||
| For :class:`~shapely.Multipolygon`s, the centroids are the average of the centroids of the polygons that constitute | ||
| each :class:`~shapely.Multipolygon`. | ||
| each :class:`~shapely.Multipolygon`. For multiscale labels the centroids are computed on the full-resolution | ||
| ``scale0`` level. | ||
| """ | ||
| raise ValueError(f"The object type {type(e)} is not supported.") | ||
|
|
||
|
|
||
| def _get_centroids_for_labels(xdata: xr.DataArray) -> pd.DataFrame: | ||
| def _get_centroids_for_labels(xdata: xr.DataArray, return_area: bool = False) -> pd.DataFrame: | ||
| """ | ||
| Compute centroids for all labels in a DataArray in a single O(n_voxels) pass. | ||
|
|
||
| Works for any number of spatial dimensions (2D and 3D labels). | ||
| Works for any number of spatial dimensions (2D and 3D labels). When ``return_area`` is True, an | ||
| ``area`` column (the per-label pixel/voxel count) is added; it is already computed for the | ||
| centroids, so this is free. | ||
| """ | ||
| arr = xdata.data.compute() | ||
| axes = list(xdata.dims) | ||
|
|
@@ -77,66 +141,212 @@ def _get_centroids_for_labels(xdata: xr.DataArray) -> pd.DataFrame: | |
| coord_sums = np.bincount(flat_inverse, weights=grid.ravel().astype(float)) | ||
| data[ax] = coord_sums / counts # counts > 0 by construction (unique guarantees this) | ||
|
|
||
| return pd.DataFrame(data, index=label_ids) | ||
| df = pd.DataFrame(data, index=label_ids) | ||
| if return_area: | ||
| df["area"] = counts.astype(float) | ||
| return df | ||
|
|
||
|
|
||
| @get_centroids.register(DataArray) | ||
| @get_centroids.register(DataTree) | ||
| def _( | ||
| e: DataArray | DataTree, | ||
| coordinate_system: str = "global", | ||
| return_background: bool = False, | ||
| ) -> DaskDataFrame: | ||
| """Get the centroids of a Labels element (2D or 3D).""" | ||
| model = get_model(e) | ||
| if model not in [Labels2DModel, Labels3DModel]: | ||
| raise ValueError("Expected a `Labels` element. Found an `Image` instead.") | ||
| _validate_coordinate_system(e, coordinate_system) | ||
|
|
||
| if isinstance(e, DataTree): | ||
| assert len(e["scale0"]) == 1 | ||
| e = next(iter(e["scale0"].values())) | ||
|
|
||
| df = _get_centroids_for_labels(e) | ||
| if not return_background and 0 in df.index: | ||
| df = df.drop(index=0) # drop the background label | ||
| t = get_transformation(e, coordinate_system) | ||
| centroids = PointsModel.parse(df, transformations={coordinate_system: t}) | ||
| return transform(centroids, to_coordinate_system=coordinate_system) | ||
|
|
||
|
|
||
| @get_centroids.register(GeoDataFrame) | ||
| def _(e: GeoDataFrame, coordinate_system: str = "global") -> DaskDataFrame: | ||
| """Get the centroids of a Shapes element (circles or polygons/multipolygons).""" | ||
| _validate_coordinate_system(e, coordinate_system) | ||
| t = get_transformation(e, coordinate_system) | ||
| assert isinstance(t, BaseTransformation) | ||
| # separate points from (multi-)polygons | ||
| def _get_centroids_for_shapes(e: GeoDataFrame, return_area: bool) -> tuple[pd.DataFrame, np.ndarray | None]: | ||
| """Intrinsic per-shape centroids (``x, y`` columns indexed by the element's index) and optional area.""" | ||
| first_geometry = e["geometry"].iloc[0] | ||
| if isinstance(first_geometry, Point): | ||
| xy = e.geometry.get_coordinates().values | ||
| # shapely .area is 0 for circles (Point geometry); the radius column carries the size. | ||
| area = np.pi * np.asarray(e["radius"], dtype=float) ** 2 if return_area else None | ||
| else: | ||
| assert isinstance(first_geometry, Polygon | MultiPolygon), ( | ||
| f"Expected a GeoDataFrame either composed entirely of circles (Points with the `radius` column) or" | ||
| f" Polygons/MultiPolygons. Found {type(first_geometry)} instead." | ||
| ) | ||
| xy = e.centroid.get_coordinates().values | ||
| area = e.geometry.area.to_numpy() if return_area else None | ||
| xy_df = pd.DataFrame(xy, columns=["x", "y"], index=e.index.copy()) | ||
| points = PointsModel.parse(xy_df, transformations={coordinate_system: t}) | ||
| return xy_df, area | ||
|
|
||
|
|
||
| def _intrinsic_centroid_frame( | ||
| element: SpatialElement, return_background: bool, return_area: bool | ||
| ) -> tuple[pd.DataFrame, np.ndarray | None, SpatialElement]: | ||
| """Per-instance intrinsic centroids (coordinate columns, indexed by instance id), optional area. | ||
|
|
||
| Also returns the element the centroids live on (for labels, the ``scale0`` level of a multiscale | ||
| raster), which carries the transformation to apply downstream. | ||
| """ | ||
| model = get_model(element) | ||
| if model in (Labels2DModel, Labels3DModel): | ||
| raster = next(iter(element["scale0"].values())) if isinstance(element, DataTree) else element | ||
| df = _get_centroids_for_labels(raster, return_area=return_area) | ||
| if not return_background and 0 in df.index: | ||
| df = df.drop(index=0) # drop the background label (its area, if any, goes with it) | ||
| area = df.pop("area").to_numpy() if return_area else None | ||
| return df, area, raster | ||
| if model is ShapesModel: | ||
| xy_df, area = _get_centroids_for_shapes(element, return_area) | ||
| return xy_df, area, element | ||
| if model is PointsModel: | ||
| if return_area: | ||
| raise ValueError("`return_area` is not supported for points elements (points have no area).") | ||
| axes = get_axes_names(element) | ||
| assert axes in [("x", "y"), ("x", "y", "z")] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should turn this assert into a if raises |
||
| return element[list(axes)].compute(), None, element | ||
| raise ValueError( | ||
| f"Centroids are not supported for elements modeled by {model.__name__}; expected a Labels, Shapes or Points element." | ||
| ) | ||
|
|
||
|
|
||
| def _points_from_centroids( | ||
| df: pd.DataFrame, area: np.ndarray | None, e: SpatialElement, coordinate_system: str | ||
| ) -> DaskDataFrame: | ||
| """Build a Points element from intrinsic centroids, transformed into ``coordinate_system``.""" | ||
| out = df.assign(area=np.asarray(area, dtype=float)) if area is not None else df | ||
| t = get_transformation(e, coordinate_system) | ||
| assert isinstance(t, BaseTransformation) | ||
| points = PointsModel.parse(out, transformations={coordinate_system: t}) | ||
| return transform(points, to_coordinate_system=coordinate_system) | ||
|
|
||
|
|
||
| @get_centroids.register(DataArray) | ||
| @get_centroids.register(DataTree) | ||
| @get_centroids.register(GeoDataFrame) | ||
| @get_centroids.register(DaskDataFrame) | ||
| def _(e: DaskDataFrame, coordinate_system: str = "global") -> DaskDataFrame: | ||
| """Get the centroids of a Points element.""" | ||
| def _( | ||
| e: SpatialElement, | ||
| coordinate_system: str | None = "global", | ||
| return_background: bool = False, | ||
| return_area: bool = False, | ||
| persist_as: PersistAs = "Points", | ||
| ) -> DaskDataFrame: | ||
| """Get the centroids of a Labels, Shapes or Points element.""" | ||
| _validate_persist_args(persist_as, coordinate_system, allow_adata=False) | ||
| assert coordinate_system is not None # guaranteed by _validate_persist_args (allow_adata=False) | ||
| _validate_coordinate_system(e, coordinate_system) | ||
| axes = get_axes_names(e) | ||
| assert axes in [("x", "y"), ("x", "y", "z")] | ||
| coords = e[list(axes)].compute() | ||
| t = get_transformation(e, coordinate_system) | ||
| assert isinstance(t, BaseTransformation) | ||
| centroids = PointsModel.parse(coords, transformations={coordinate_system: t}) | ||
| return transform(centroids, to_coordinate_system=coordinate_system) | ||
| df, area, raster = _intrinsic_centroid_frame(e, return_background, return_area) | ||
| return _points_from_centroids(df, area, raster, coordinate_system) | ||
|
|
||
|
|
||
| def _resolve_annotating_table(sdata: SpatialData, element_name: str, table_name: str | None) -> str: | ||
| """Resolve the single table that annotates ``element_name`` (where centroids are written).""" | ||
| from spatialdata._core.query.relational_query import get_element_annotators | ||
|
|
||
| if table_name is not None: | ||
| if table_name not in sdata.tables: | ||
| raise KeyError(f"Table {table_name!r} not found in `sdata.tables`.") | ||
| return table_name | ||
| annotators = sorted(get_element_annotators(sdata, element_name)) | ||
| if not annotators: | ||
| raise ValueError( | ||
| f"Element {element_name!r} has no annotating table to write centroids into. Use " | ||
| f"persist_as='Points' to get the centroids as a Points element instead, or annotate the " | ||
| f"element with a table first." | ||
| ) | ||
| if len(annotators) > 1: | ||
| raise ValueError( | ||
| f"Element {element_name!r} is annotated by multiple tables ({', '.join(annotators)}); " | ||
| f"pass `table_name=` to choose one." | ||
| ) | ||
| return annotators[0] | ||
|
|
||
|
|
||
| def _write_centroids_into_table( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function seems to do the job, but I think there is a big design problem. Nothing stops the user in having a table annotating multiple elmements and where:
I have some ideas but not a clear solution in mind. This requires some work into definiting how the Spatial AnnData specification works.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually the first is not a problem since the returned centroids are always
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The second point requires a design decision. |
||
| table: AnnData, | ||
| element_name: str, | ||
| centroids: pd.DataFrame, | ||
| area: np.ndarray | None, | ||
| ) -> None: | ||
| """Write centroids into ``obsm["spatial"]`` and area into ``obs["area"]`` at the element's rows. | ||
|
|
||
| Only the table rows annotating ``element_name`` are touched (a table may annotate several | ||
| elements); instances annotated but absent from the element are written as NaN. | ||
| """ | ||
| if not centroids.index.is_unique: | ||
| raise ValueError(f"Cannot persist centroids for {element_name!r}: its instance index has duplicate values.") | ||
| _, region_key, instance_key = get_table_keys(table) | ||
| mask = (table.obs[region_key].astype(str) == str(element_name)).to_numpy() | ||
| if not mask.any(): | ||
| raise ValueError(f"The resolved table does not annotate element {element_name!r} (no matching rows).") | ||
|
|
||
| # Map each annotated instance to its centroid row (-1 where absent -> NaN). A *total* miss means the | ||
| # instance ids never align with the element index (mismatched dtype, or no shared instances). | ||
| keys = table.obs[instance_key].to_numpy()[mask] | ||
| idx = centroids.index.get_indexer(keys) | ||
| if (idx == -1).all(): | ||
| raise ValueError( | ||
| f"No instance id annotating {element_name!r} is present in the element; check the table's " | ||
| f"`{instance_key}` values and dtype." | ||
| ) | ||
| hit = idx != -1 | ||
|
|
||
| def _scatter(values: np.ndarray) -> np.ndarray: | ||
| """Gather ``values`` (ordered like ``centroids``) onto the masked rows, NaN where absent.""" | ||
| out = np.full((len(idx), *values.shape[1:]), np.nan) | ||
| out[hit] = values[idx[hit]] | ||
| return out | ||
|
|
||
| ndim = centroids.shape[1] | ||
| spatial = np.full((table.n_obs, ndim), np.nan) | ||
| existing = table.obsm.get(_SPATIAL_KEY) | ||
| if existing is not None: | ||
| existing = np.asarray(existing) | ||
| if existing.shape == (table.n_obs, ndim): | ||
| spatial = existing.astype(float, copy=True) # preserve other regions' coordinates | ||
| elif existing.shape[0] == table.n_obs: | ||
| raise ValueError( | ||
| f"Existing obsm['{_SPATIAL_KEY}'] {existing.shape} is incompatible with {ndim}-D centroids for " | ||
| f"{element_name!r}; refusing to overwrite other regions. Persist with persist_as='Points' instead." | ||
| ) | ||
| spatial[mask] = _scatter(centroids.to_numpy(dtype=float)) | ||
| table.obsm[_SPATIAL_KEY] = spatial | ||
|
|
||
| if area is not None: | ||
| col = np.full(table.n_obs, np.nan) | ||
| if _AREA_KEY in table.obs: | ||
| col = table.obs[_AREA_KEY].to_numpy(dtype=float).copy() | ||
| col[mask] = _scatter(np.asarray(area, dtype=float)) | ||
| table.obs[_AREA_KEY] = col | ||
|
|
||
|
|
||
| @get_centroids.register(SpatialData) | ||
| def _get_centroids_sdata( | ||
| e: SpatialData, | ||
| element_name: str, | ||
| coordinate_system: str | None = "global", | ||
| return_background: bool = False, | ||
| return_area: bool = False, | ||
| persist_as: PersistAs = "Points", | ||
| table_name: str | None = None, | ||
| inplace: bool = True, | ||
| ) -> DaskDataFrame | AnnData | None: | ||
| """Get the centroids of ``element_name``, or (``persist_as="adata"``) write them into its annotating table. | ||
|
|
||
| With ``persist_as="adata"`` the centroids go into ``obsm["spatial"]`` (and area into ``obs["area"]``) of the | ||
| resolved annotating table (``table_name=`` disambiguates). ``inplace=True`` (default) mutates that table and | ||
| returns ``None``; ``inplace=False`` writes into a copy of *only that table* and returns the new ``AnnData``, | ||
| leaving ``e`` untouched. ``persist_as="Points"`` behaves like calling :func:`get_centroids` on the element. | ||
| """ | ||
| _validate_persist_args(persist_as, coordinate_system, allow_adata=True) | ||
| element = e[element_name] | ||
|
|
||
| if persist_as == "Points": | ||
| return get_centroids( | ||
| element, | ||
| coordinate_system=coordinate_system, | ||
| return_background=return_background, | ||
| return_area=return_area, | ||
| ) | ||
|
|
||
| # persist_as == "adata": resolve the annotating table and write the centroids into it. | ||
| if coordinate_system is not None: | ||
| _validate_coordinate_system(element, coordinate_system) | ||
| table_name = _resolve_annotating_table(e, element_name, table_name) | ||
| df, area, raster = _intrinsic_centroid_frame(element, return_background, return_area) | ||
| coord_cols = sorted(df.columns) # canonical x, y[, z] (squidpy obsm["spatial"] order) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Important point to put in the docstring (and later in the specs).
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (that the order is always x, y[, z]) |
||
| coords = _transform_centroid_coords(df[coord_cols].to_numpy(), coord_cols, raster, coordinate_system) | ||
| centroids = pd.DataFrame(coords, columns=coord_cols, index=df.index) | ||
|
|
||
| table = e.tables[table_name] if inplace else e.tables[table_name].copy() | ||
| _write_centroids_into_table(table, element_name, centroids, area) | ||
| return None if inplace else table | ||
|
|
||
|
|
||
| ## | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should go in
models.pyinsideTableModelsince it's basically an addition to the file format.