diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 3a683c661..74993acb3 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -53,6 +53,7 @@ "deepcopy", "sanitize_table", "sanitize_name", + "subset_sdata_by_table_mask", ] from spatialdata import dataloader, datasets, models, transformations @@ -76,6 +77,7 @@ match_element_to_table, match_sdata_to_table, match_table_to_element, + subset_sdata_by_table_mask, ) from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query from spatialdata._core.spatialdata import SpatialData diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 0803158ca..ca8c53c48 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -9,9 +9,11 @@ import dask.array as da import numpy as np import pandas as pd +import xarray as xr from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame +from numpy.typing import NDArray from xarray import DataArray, DataTree from spatialdata._core.spatialdata import SpatialData @@ -1017,3 +1019,124 @@ def get_values( return df raise ValueError(f"Unknown origin {origin}") + + +def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: + # Use apply_ufunc for efficient processing + # Create a copy to avoid modifying read-only array + result = block.copy() + result[np.isin(result, ids_to_remove)] = 0 + return result + + +def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: + processed = xr.apply_ufunc( + partial(_mask_block, ids_to_remove=ids_to_remove), + image, + input_core_dims=[["y", "x"]], + output_core_dims=[["y", "x"]], + vectorize=True, + dask="parallelized", + output_dtypes=[image.dtype], + dataset_fill_value=0, + dask_gufunc_kwargs={"allow_rechunk": True}, + ) + + # Create a new DataArray to ensure persistence + return xr.DataArray( + data=processed.data, + coords=image.coords, + dims=image.dims, + attrs=image.attrs.copy(), # Preserve all attributes + ) + + +def _get_scale_factors(labels_element: DataTree) -> list[tuple[float, float]]: + scales = list(labels_element.keys()) + + # Calculate relative scale factors between consecutive scales + scale_factors = [] + for i in range(len(scales) - 1): + y_size_current = labels_element[scales[i]].image.shape[0] + x_size_current = labels_element[scales[i]].image.shape[1] + y_size_next = labels_element[scales[i + 1]].image.shape[0] + x_size_next = labels_element[scales[i + 1]].image.shape[1] + y_factor = y_size_current / y_size_next + x_factor = x_size_current / x_size_next + + scale_factors.append((y_factor, x_factor)) + + return scale_factors + + +@singledispatch +def _filter_by_instance_ids(element: Any, ids_to_remove: list[str], instance_key: str) -> Any: + raise NotImplementedError(f"Filtering by instance ids is not implemented for {element}") + + +@_filter_by_instance_ids.register(DataArray) +def _(element: DataArray, ids_to_remove: list[int], instance_key: str) -> DataArray: + del instance_key + return Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(element, ids_to_remove)) + + +@_filter_by_instance_ids.register(DataTree) +def _(element: DataTree, ids_to_remove: list[int], instance_key: str) -> DataTree: + # we extract the info to just reconstruct + # the DataTree after filtering the max scale + del instance_key + max_scale = list(element.keys())[0] + scale_factors_temp = _get_scale_factors(element) + scale_factors = [int(sf[0]) for sf in scale_factors_temp] + + return Labels2DModel.parse( + data=_set_instance_ids_in_labels_to_zero(element[max_scale].image, ids_to_remove), + scale_factors=scale_factors, + ) + + +def subset_sdata_by_table_mask(sdata: SpatialData, table_name: str, mask: NDArray[np.bool_]) -> SpatialData: + """Subset the annotated elements of a SpatialData object by a table and a mask. + + The mask is applied to the table and the annotated elements are subsetted + by the instance ids in the table. + This function returns a new SpatialData object with the subsetted elements. + Elements that are not annotated by the table are not included in the returned SpatialData object. + The element models that are + supported are :class:`spatialdata.models.Labels2DModel`, + :class:`spatialdata.models.PointsModel`, and :class:`spatialdata.models.ShapesModel`. + + Parameters + ---------- + sdata + The SpatialData object to subset. + table_name + The name of the table to apply the mask to. + mask + Boolean mask to apply to the table which is the same length as the number of rows in the table. + + Returns + ------- + The subsetted SpatialData object. + """ + table = sdata.tables.get(table_name) + if table is None: + raise ValueError(f"Table {table_name} not found in SpatialData object.") + + subset_table = table[mask] + sdata.tables[table_name] = subset_table + _, _, instance_key = get_table_keys(subset_table) + annotated_regions = SpatialData.get_annotated_regions(table) + removed_instance_ids = list(np.unique(table.obs[instance_key][~mask])) + + filtered_elements = {} + for reg in annotated_regions: + elem = sdata.get(reg) + model = get_model(elem) + if model is Labels2DModel: + filtered_elements[reg] = _filter_by_instance_ids(elem, removed_instance_ids, instance_key) + elif model in [PointsModel, ShapesModel]: + element_dict, _ = match_element_to_table(sdata, element_name=reg, table_name=table_name) + filtered_elements[reg] = element_dict[reg] + + return SpatialData.init_from_elements(filtered_elements | {table_name: subset_table}) diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py new file mode 100644 index 000000000..629f58bb6 --- /dev/null +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -0,0 +1,80 @@ +import numpy as np +import pytest + +from spatialdata import concatenate, subset_sdata_by_table_mask +from spatialdata._core.query.relational_query import _filter_by_instance_ids +from spatialdata.datasets import blobs_annotating_element + + +def test_filter_labels2dmodel_by_instance_ids() -> None: + sdata = blobs_annotating_element("blobs_labels") + labels_element = sdata["blobs_labels"] + all_instance_ids = sdata.tables["table"].obs["instance_id"].unique() + filtered_labels_element = _filter_by_instance_ids(labels_element, [2, 3], "instance_id") + + # because 0 is the background, we expect the filtered ids to be the instance ids that are not 0 + filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - { + 0, + } + preserved_ids = np.unique(labels_element.data.compute()) + assert filtered_ids == (set(all_instance_ids) - {2, 3}) + # check if there is modification of the original labels + assert set(preserved_ids) == set(all_instance_ids) | {0} + + sdata.tables["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels" + sdata.tables["table"].obs.region = "blobs_multiscale_labels" + labels_element = sdata["blobs_multiscale_labels"] + filtered_labels_element = _filter_by_instance_ids(labels_element, [2, 3], "instance_id") + + for scale in labels_element: + filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - { + 0, + } + preserved_ids = np.unique(labels_element[scale].image.compute()) + assert filtered_ids == (set(all_instance_ids) - {2, 3}) + # check if there is modification of the original labels + assert set(preserved_ids) == set(all_instance_ids) | {0} + + +def test_subset_sdata_by_table_mask() -> None: + sdata = concatenate( + { + "labels": blobs_annotating_element("blobs_labels"), + "shapes": blobs_annotating_element("blobs_circles"), + "points": blobs_annotating_element("blobs_points"), + "multiscale_labels": blobs_annotating_element("blobs_multiscale_labels"), + }, + concatenate_tables=True, + ) + third_elems = sdata.tables["table"].obs["instance_id"] == 3 + subset_sdata = subset_sdata_by_table_mask(sdata, "table", third_elems) + + assert set(subset_sdata.labels.keys()) == {"blobs_labels-labels", "blobs_multiscale_labels-multiscale_labels"} + assert set(subset_sdata.points.keys()) == {"blobs_points-points"} + assert set(subset_sdata.shapes.keys()) == {"blobs_circles-shapes"} + + labels_remaining_ids = set(np.unique(subset_sdata.labels["blobs_labels-labels"].data.compute())) - {0} + assert labels_remaining_ids == {3} + + for scale in subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"]: + ms_labels_remaining_ids = set( + np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute()) + ) - {0} + assert ms_labels_remaining_ids == {3} + + points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"].index)) - {0} + assert points_remaining_ids == {3} + + shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} + assert shapes_remaining_ids == {3} + + +def test_subset_sdata_by_table_mask_with_no_annotated_elements() -> None: + with pytest.raises(ValueError, match="Table table_not_found not found in SpatialData object."): + sdata = blobs_annotating_element("blobs_labels") + _ = subset_sdata_by_table_mask(sdata, "table_not_found", sdata.tables["table"].obs["instance_id"] == 3) + + +def test_filter_by_instance_ids_fails_for_unsupported_element_models() -> None: + with pytest.raises(NotImplementedError, match="Filtering by instance ids is not implemented for"): + _filter_by_instance_ids([1, 1, 1, 2], [1], "instance_id")