Skip to content
Open
322 changes: 266 additions & 56 deletions src/spatialdata/_core/centroids.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,33 @@
from __future__ import annotations

from functools import singledispatch
from typing import Literal

import numpy as np
import pandas as pd
import xarray as xr
from anndata import AnnData
from dask.dataframe import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
from shapely import MultiPolygon, Point, Polygon
from xarray import DataArray, DataTree

from spatialdata._core.operations.transform import transform
from spatialdata.models import get_axes_names
from spatialdata._core.spatialdata import SpatialData
from spatialdata._utils import _affine_matrix_multiplication
from spatialdata.models import get_axes_names, get_table_keys
from spatialdata.models._utils import SpatialElement
from spatialdata.models.models import Labels2DModel, Labels3DModel, PointsModel, get_model
from spatialdata.models.models import Labels2DModel, Labels3DModel, PointsModel, ShapesModel, get_model
from spatialdata.transformations.operations import get_transformation
from spatialdata.transformations.transformations import BaseTransformation
from spatialdata.transformations.transformations import BaseTransformation, Identity

BoundingBoxDescription = dict[str, tuple[float, float]]

PersistAs = Literal["Points", "adata"]
# squidpy-style storage keys for persist_as="adata".
_SPATIAL_KEY = "spatial"
_AREA_KEY = "area"
Comment on lines +27 to +29

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should go in models.py inside TableModel since it's basically an addition to the file format.



def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> None:
d = get_transformation(e, get_all=True)
Expand All @@ -29,37 +38,92 @@ def _validate_coordinate_system(e: SpatialElement, coordinate_system: str) -> No
)


def _validate_persist_args(persist_as: str, coordinate_system: str | None, *, allow_adata: bool) -> None:
if persist_as not in ("Points", "adata"):
raise ValueError(f"`persist_as` must be 'Points' or 'adata', got {persist_as!r}.")
if persist_as == "adata" and not allow_adata:
raise ValueError(
"persist_as='adata' writes centroids into the element's annotating table, which needs the "
"`SpatialData` object: call `get_centroids(sdata, element_name, ..., persist_as='adata')`. "
"To get the centroids as a standalone element instead, use persist_as='Points'."
)
# ``coordinate_system=None`` means "intrinsic coordinates, do not transform". An intrinsic Points
# element is ill-defined (Points always carry a coordinate system), so intrinsic coords are only
# meaningful when writing into a table (persist_as='adata').
if coordinate_system is None and persist_as != "adata":
raise ValueError("`coordinate_system=None` (intrinsic coordinates) is only supported with persist_as='adata'.")


def _transform_centroid_coords(
xy: np.ndarray, axes: list[str], e: SpatialElement, coordinate_system: str | None
) -> np.ndarray:
"""Apply the element's affine to centroid coords in-memory; ``None``/``Identity`` pass through.

``axes`` is the column order of ``xy`` (e.g. ``["x", "y"]``).
Comment on lines +57 to +62

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this already works for any axes name and order, but the function makes it looks like it's only working for x and y. Better renaming.

"""
if coordinate_system is None:
return xy
t = get_transformation(e, coordinate_system)
assert isinstance(t, BaseTransformation)
if isinstance(t, Identity):
return xy
matrix = t.to_affine_matrix(input_axes=tuple(axes), output_axes=tuple(axes))
return _affine_matrix_multiplication(matrix, xy)
Comment on lines +70 to +71

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably the matrix multiplication is slower than just a transposition, but if this is a bottleneck we'll find out later during profiling. Let's keep as is.



@singledispatch
def get_centroids(
e: SpatialElement,
coordinate_system: str = "global",
e: SpatialElement | SpatialData,
coordinate_system: str | None = "global",
return_background: bool = False,
) -> DaskDataFrame:
return_area: bool = False,
persist_as: PersistAs = "Points",
) -> DaskDataFrame | AnnData | None:
"""
Get the centroids of the geometries contained in a SpatialElement, as a new Points element.
Get the centroids of the geometries contained in a SpatialElement.

Parameters
----------
e
The SpatialElement. Only points, shapes (circles, polygons and multipolygons) and labels are supported.
The SpatialElement (points, shapes — circles, polygons and multipolygons — or labels), or a
:class:`~spatialdata.SpatialData` object. When a ``SpatialData`` is passed, the second
positional argument is the name of the element to measure (see the ``SpatialData`` overload).

@LucaMarconato LucaMarconato Jun 23, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big fan of this (positional argument depending on the single dispatch argument), but it works.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep as is.

coordinate_system
The coordinate system in which the centroids are computed.
The coordinate system in which the centroids are computed. ``None`` returns the intrinsic
coordinates without applying any transformation (only supported with ``persist_as="adata"``).
return_background
If True, the centroid of the background label (0) is included in the output.
If True, the centroid of the background label (0) is included in the output (labels only).
return_area
If True, also return the per-instance area: the pixel/voxel count for labels and the geometric
area for shapes (``pi * r**2`` for circles). Not supported for points (raises). With
``persist_as="Points"`` the area is added as a feature column of the returned Points element.
Comment on lines +96 to +99

@LucaMarconato LucaMarconato Jun 23, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a design problem: the area is always returned untransformed (=in the intrinsic coordinate system). I think we should be consistent: if we transform the centroids (and we do this) we should transform the area. Otherwise we should not transform the centroids (or we should not return the area unless we have an identity/intrinsic).

I like to return the area, so I would consider always requiring coordinate_system=None (or mapped to a system where there is an identity) when we also compute the area.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requires to make a decision regarding the design, I don't have a solution in mind.

persist_as
``"Points"`` (default) returns the centroids as a new Points element, transformed into
``coordinate_system``. ``"adata"`` writes the centroids (and area) into the element's
annotating table and is only available through the :class:`~spatialdata.SpatialData` overload,
which can resolve that table.

Returns
-------
A Points element (``persist_as="Points"``). With ``persist_as="adata"`` (``SpatialData`` overload),
``None`` when written in place, or the new ``AnnData`` table when ``inplace=False``.

Notes
-----
For :class:`~shapely.Multipolygon`s, the centroids are the average of the centroids of the polygons that constitute
each :class:`~shapely.Multipolygon`.
each :class:`~shapely.Multipolygon`. For multiscale labels the centroids are computed on the full-resolution
``scale0`` level.
"""
raise ValueError(f"The object type {type(e)} is not supported.")


def _get_centroids_for_labels(xdata: xr.DataArray) -> pd.DataFrame:
def _get_centroids_for_labels(xdata: xr.DataArray, return_area: bool = False) -> pd.DataFrame:
"""
Compute centroids for all labels in a DataArray in a single O(n_voxels) pass.

Works for any number of spatial dimensions (2D and 3D labels).
Works for any number of spatial dimensions (2D and 3D labels). When ``return_area`` is True, an
``area`` column (the per-label pixel/voxel count) is added; it is already computed for the
centroids, so this is free.
"""
arr = xdata.data.compute()
axes = list(xdata.dims)
Expand All @@ -77,66 +141,212 @@ def _get_centroids_for_labels(xdata: xr.DataArray) -> pd.DataFrame:
coord_sums = np.bincount(flat_inverse, weights=grid.ravel().astype(float))
data[ax] = coord_sums / counts # counts > 0 by construction (unique guarantees this)

return pd.DataFrame(data, index=label_ids)
df = pd.DataFrame(data, index=label_ids)
if return_area:
df["area"] = counts.astype(float)
return df


@get_centroids.register(DataArray)
@get_centroids.register(DataTree)
def _(
e: DataArray | DataTree,
coordinate_system: str = "global",
return_background: bool = False,
) -> DaskDataFrame:
"""Get the centroids of a Labels element (2D or 3D)."""
model = get_model(e)
if model not in [Labels2DModel, Labels3DModel]:
raise ValueError("Expected a `Labels` element. Found an `Image` instead.")
_validate_coordinate_system(e, coordinate_system)

if isinstance(e, DataTree):
assert len(e["scale0"]) == 1
e = next(iter(e["scale0"].values()))

df = _get_centroids_for_labels(e)
if not return_background and 0 in df.index:
df = df.drop(index=0) # drop the background label
t = get_transformation(e, coordinate_system)
centroids = PointsModel.parse(df, transformations={coordinate_system: t})
return transform(centroids, to_coordinate_system=coordinate_system)


@get_centroids.register(GeoDataFrame)
def _(e: GeoDataFrame, coordinate_system: str = "global") -> DaskDataFrame:
"""Get the centroids of a Shapes element (circles or polygons/multipolygons)."""
_validate_coordinate_system(e, coordinate_system)
t = get_transformation(e, coordinate_system)
assert isinstance(t, BaseTransformation)
# separate points from (multi-)polygons
def _get_centroids_for_shapes(e: GeoDataFrame, return_area: bool) -> tuple[pd.DataFrame, np.ndarray | None]:
"""Intrinsic per-shape centroids (``x, y`` columns indexed by the element's index) and optional area."""
first_geometry = e["geometry"].iloc[0]
if isinstance(first_geometry, Point):
xy = e.geometry.get_coordinates().values
# shapely .area is 0 for circles (Point geometry); the radius column carries the size.
area = np.pi * np.asarray(e["radius"], dtype=float) ** 2 if return_area else None
else:
assert isinstance(first_geometry, Polygon | MultiPolygon), (
f"Expected a GeoDataFrame either composed entirely of circles (Points with the `radius` column) or"
f" Polygons/MultiPolygons. Found {type(first_geometry)} instead."
)
xy = e.centroid.get_coordinates().values
area = e.geometry.area.to_numpy() if return_area else None
xy_df = pd.DataFrame(xy, columns=["x", "y"], index=e.index.copy())
points = PointsModel.parse(xy_df, transformations={coordinate_system: t})
return xy_df, area


def _intrinsic_centroid_frame(
element: SpatialElement, return_background: bool, return_area: bool
) -> tuple[pd.DataFrame, np.ndarray | None, SpatialElement]:
"""Per-instance intrinsic centroids (coordinate columns, indexed by instance id), optional area.

Also returns the element the centroids live on (for labels, the ``scale0`` level of a multiscale
raster), which carries the transformation to apply downstream.
"""
model = get_model(element)
if model in (Labels2DModel, Labels3DModel):
raster = next(iter(element["scale0"].values())) if isinstance(element, DataTree) else element
df = _get_centroids_for_labels(raster, return_area=return_area)
if not return_background and 0 in df.index:
df = df.drop(index=0) # drop the background label (its area, if any, goes with it)
area = df.pop("area").to_numpy() if return_area else None
return df, area, raster
if model is ShapesModel:
xy_df, area = _get_centroids_for_shapes(element, return_area)
return xy_df, area, element
if model is PointsModel:
if return_area:
raise ValueError("`return_area` is not supported for points elements (points have no area).")
axes = get_axes_names(element)
assert axes in [("x", "y"), ("x", "y", "z")]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should turn this assert into a if raises

return element[list(axes)].compute(), None, element
raise ValueError(
f"Centroids are not supported for elements modeled by {model.__name__}; expected a Labels, Shapes or Points element."
)


def _points_from_centroids(
df: pd.DataFrame, area: np.ndarray | None, e: SpatialElement, coordinate_system: str
) -> DaskDataFrame:
"""Build a Points element from intrinsic centroids, transformed into ``coordinate_system``."""
out = df.assign(area=np.asarray(area, dtype=float)) if area is not None else df
t = get_transformation(e, coordinate_system)
assert isinstance(t, BaseTransformation)
points = PointsModel.parse(out, transformations={coordinate_system: t})
return transform(points, to_coordinate_system=coordinate_system)


@get_centroids.register(DataArray)
@get_centroids.register(DataTree)
@get_centroids.register(GeoDataFrame)
@get_centroids.register(DaskDataFrame)
def _(e: DaskDataFrame, coordinate_system: str = "global") -> DaskDataFrame:
"""Get the centroids of a Points element."""
def _(
e: SpatialElement,
coordinate_system: str | None = "global",
return_background: bool = False,
return_area: bool = False,
persist_as: PersistAs = "Points",
) -> DaskDataFrame:
"""Get the centroids of a Labels, Shapes or Points element."""
_validate_persist_args(persist_as, coordinate_system, allow_adata=False)
assert coordinate_system is not None # guaranteed by _validate_persist_args (allow_adata=False)
_validate_coordinate_system(e, coordinate_system)
axes = get_axes_names(e)
assert axes in [("x", "y"), ("x", "y", "z")]
coords = e[list(axes)].compute()
t = get_transformation(e, coordinate_system)
assert isinstance(t, BaseTransformation)
centroids = PointsModel.parse(coords, transformations={coordinate_system: t})
return transform(centroids, to_coordinate_system=coordinate_system)
df, area, raster = _intrinsic_centroid_frame(e, return_background, return_area)
return _points_from_centroids(df, area, raster, coordinate_system)


def _resolve_annotating_table(sdata: SpatialData, element_name: str, table_name: str | None) -> str:
"""Resolve the single table that annotates ``element_name`` (where centroids are written)."""
from spatialdata._core.query.relational_query import get_element_annotators

if table_name is not None:
if table_name not in sdata.tables:
raise KeyError(f"Table {table_name!r} not found in `sdata.tables`.")
return table_name
annotators = sorted(get_element_annotators(sdata, element_name))
if not annotators:
raise ValueError(
f"Element {element_name!r} has no annotating table to write centroids into. Use "
f"persist_as='Points' to get the centroids as a Points element instead, or annotate the "
f"element with a table first."
)
if len(annotators) > 1:
raise ValueError(
f"Element {element_name!r} is annotated by multiple tables ({', '.join(annotators)}); "
f"pass `table_name=` to choose one."
)
return annotators[0]


def _write_centroids_into_table(

@LucaMarconato LucaMarconato Jun 23, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function seems to do the job, but I think there is a big design problem. Nothing stops the user in having a table annotating multiple elmements and where:

  • is annotating some elements with xy axes and some elements with yx axes -> the information of axes is lost once the centroids are in obsm
  • the get_centroids function gets called on different coordinate systems for different elements -> the information of which coordinate systems the centroids belong to is lost once centroids are in obsm.

I have some ideas but not a clear solution in mind. This requires some work into definiting how the Spatial AnnData specification works.

@LucaMarconato LucaMarconato Jun 23, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually the first is not a problem since the returned centroids are always x, y[, z] (due to the assertion above)

    axes = get_axes_names(element)
    assert axes in [("x", "y"), ("x", "y", "z")]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The second point requires a design decision.

table: AnnData,
element_name: str,
centroids: pd.DataFrame,
area: np.ndarray | None,
) -> None:
"""Write centroids into ``obsm["spatial"]`` and area into ``obs["area"]`` at the element's rows.

Only the table rows annotating ``element_name`` are touched (a table may annotate several
elements); instances annotated but absent from the element are written as NaN.
"""
if not centroids.index.is_unique:
raise ValueError(f"Cannot persist centroids for {element_name!r}: its instance index has duplicate values.")
_, region_key, instance_key = get_table_keys(table)
mask = (table.obs[region_key].astype(str) == str(element_name)).to_numpy()
if not mask.any():
raise ValueError(f"The resolved table does not annotate element {element_name!r} (no matching rows).")

# Map each annotated instance to its centroid row (-1 where absent -> NaN). A *total* miss means the
# instance ids never align with the element index (mismatched dtype, or no shared instances).
keys = table.obs[instance_key].to_numpy()[mask]
idx = centroids.index.get_indexer(keys)
if (idx == -1).all():
raise ValueError(
f"No instance id annotating {element_name!r} is present in the element; check the table's "
f"`{instance_key}` values and dtype."
)
hit = idx != -1

def _scatter(values: np.ndarray) -> np.ndarray:
"""Gather ``values`` (ordered like ``centroids``) onto the masked rows, NaN where absent."""
out = np.full((len(idx), *values.shape[1:]), np.nan)
out[hit] = values[idx[hit]]
return out

ndim = centroids.shape[1]
spatial = np.full((table.n_obs, ndim), np.nan)
existing = table.obsm.get(_SPATIAL_KEY)
if existing is not None:
existing = np.asarray(existing)
if existing.shape == (table.n_obs, ndim):
spatial = existing.astype(float, copy=True) # preserve other regions' coordinates
elif existing.shape[0] == table.n_obs:
raise ValueError(
f"Existing obsm['{_SPATIAL_KEY}'] {existing.shape} is incompatible with {ndim}-D centroids for "
f"{element_name!r}; refusing to overwrite other regions. Persist with persist_as='Points' instead."
)
spatial[mask] = _scatter(centroids.to_numpy(dtype=float))
table.obsm[_SPATIAL_KEY] = spatial

if area is not None:
col = np.full(table.n_obs, np.nan)
if _AREA_KEY in table.obs:
col = table.obs[_AREA_KEY].to_numpy(dtype=float).copy()
col[mask] = _scatter(np.asarray(area, dtype=float))
table.obs[_AREA_KEY] = col


@get_centroids.register(SpatialData)
def _get_centroids_sdata(
e: SpatialData,
element_name: str,
coordinate_system: str | None = "global",
return_background: bool = False,
return_area: bool = False,
persist_as: PersistAs = "Points",
table_name: str | None = None,
inplace: bool = True,
) -> DaskDataFrame | AnnData | None:
"""Get the centroids of ``element_name``, or (``persist_as="adata"``) write them into its annotating table.

With ``persist_as="adata"`` the centroids go into ``obsm["spatial"]`` (and area into ``obs["area"]``) of the
resolved annotating table (``table_name=`` disambiguates). ``inplace=True`` (default) mutates that table and
returns ``None``; ``inplace=False`` writes into a copy of *only that table* and returns the new ``AnnData``,
leaving ``e`` untouched. ``persist_as="Points"`` behaves like calling :func:`get_centroids` on the element.
"""
_validate_persist_args(persist_as, coordinate_system, allow_adata=True)
element = e[element_name]

if persist_as == "Points":
return get_centroids(
element,
coordinate_system=coordinate_system,
return_background=return_background,
return_area=return_area,
)

# persist_as == "adata": resolve the annotating table and write the centroids into it.
if coordinate_system is not None:
_validate_coordinate_system(element, coordinate_system)
table_name = _resolve_annotating_table(e, element_name, table_name)
df, area, raster = _intrinsic_centroid_frame(element, return_background, return_area)
coord_cols = sorted(df.columns) # canonical x, y[, z] (squidpy obsm["spatial"] order)

@LucaMarconato LucaMarconato Jun 23, 2026

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Important point to put in the docstring (and later in the specs).

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(that the order is always x, y[, z])

coords = _transform_centroid_coords(df[coord_cols].to_numpy(), coord_cols, raster, coordinate_system)
centroids = pd.DataFrame(coords, columns=coord_cols, index=df.index)

table = e.tables[table_name] if inplace else e.tables[table_name].copy()
_write_centroids_into_table(table, element_name, centroids, area)
return None if inplace else table


##
Loading
Loading