From a28c7c905e40bd748edd6b6f772988c72f8154dd Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 30 Jun 2025 15:17:25 +0200 Subject: [PATCH 01/11] init --- src/spatialdata/_core/query/masking.py | 113 +++++++++++++++++++++++++ tests/core/query/test_masking.py | 39 +++++++++ 2 files changed, 152 insertions(+) create mode 100644 src/spatialdata/_core/query/masking.py create mode 100644 tests/core/query/test_masking.py diff --git a/src/spatialdata/_core/query/masking.py b/src/spatialdata/_core/query/masking.py new file mode 100644 index 000000000..451e7d4ed --- /dev/null +++ b/src/spatialdata/_core/query/masking.py @@ -0,0 +1,113 @@ +import numpy as np +import xarray as xr +from functools import partial + +from spatialdata.models import Labels2DModel, ShapesModel +from spatialdata.models.models import DataTree + + + +def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: + # Use apply_ufunc for efficient processing + # Create a copy to avoid modifying read-only array + result = block.copy() + result[np.isin(result, ids_to_remove)] = 0 + return result + + +def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: + processed = xr.apply_ufunc( + partial(_mask_block, ids_to_remove=ids_to_remove), + image, + input_core_dims=[["y", "x"]], + output_core_dims=[["y", "x"]], + vectorize=True, + dask="parallelized", + output_dtypes=[image.dtype], + dataset_fill_value=0, + dask_gufunc_kwargs={"allow_rechunk": True}, + ) + + # Force computation to ensure the changes are materialized + computed_result = processed.compute() + + # Create a new DataArray to ensure persistence + result = xr.DataArray( + data=computed_result.data, + coords=image.coords, + dims=image.dims, + attrs=image.attrs.copy(), # Preserve all attributes + ) + + return result + + +def _get_scale_factors(labels_element: Labels2DModel) -> list[tuple[float, float]]: + scales = list(labels_element.keys()) + + # Calculate relative scale factors between consecutive scales + scale_factors = [] + for i in range(len(scales) - 1): + y_size_current = labels_element[scales[i]].image.shape[0] + x_size_current = labels_element[scales[i]].image.shape[1] + y_size_next = labels_element[scales[i + 1]].image.shape[0] + x_size_next = labels_element[scales[i + 1]].image.shape[1] + y_factor = y_size_current / y_size_next + x_factor = x_size_current / x_size_next + + scale_factors.append((y_factor, x_factor)) + + return scale_factors + + +def filter_shapesmodel_by_instance_ids(element: ShapesModel, ids_to_remove: list[str]) -> ShapesModel: + """ + Filter a ShapesModel by instance ids. + + Parameters + ---------- + element + The ShapesModel to filter. + ids_to_remove + The instance ids to remove. + + Returns + ------- + The filtered ShapesModel. + """ + return element[~element.index.isin(ids_to_remove)] + + +def filter_labels2dmodel_by_instance_ids(element: Labels2DModel, ids_to_remove: list[int]) -> Labels2DModel: + """ + Filter a Labels2DModel by instance ids. + + This function works for both DataArray and DataTree and sets the + instance ids to zero. + + Parameters + ---------- + element + The Labels2DModel to filter. + ids_to_remove + The instance ids to remove. + + Returns + ------- + The filtered Labels2DModel. + """ + if isinstance(element, xr.DataArray): + return Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(element, ids_to_remove)) + + if isinstance(element, DataTree): + # we extract the info to just reconstruct + # the DataTree after filtering the max scale + max_scale = list(element.keys())[0] + scale_factors = _get_scale_factors(element) + scale_factors = [int(sf[0]) for sf in scale_factors] + + return Labels2DModel.parse( + data=_set_instance_ids_in_labels_to_zero(element[max_scale].image, ids_to_remove), + scale_factors=scale_factors, + ) + raise ValueError(f"Unknown element type: {type(element)}") diff --git a/tests/core/query/test_masking.py b/tests/core/query/test_masking.py new file mode 100644 index 000000000..89f4db526 --- /dev/null +++ b/tests/core/query/test_masking.py @@ -0,0 +1,39 @@ +import numpy as np +import anndata as ad + +from spatialdata._core.query.masking import filter_labels2dmodel_by_instance_ids, filter_shapesmodel_by_instance_ids +from spatialdata.datasets import blobs_annotating_element + + +def test_filter_labels2dmodel_by_instance_ids(): + sdata = blobs_annotating_element("blobs_labels") + labels_element = sdata["blobs_labels"] + all_instance_ids = sdata.tables["table"].obs["instance_id"].unique() + filtered_labels_element = filter_labels2dmodel_by_instance_ids(labels_element, [2, 3]) + + # because 0 is the background, we expect the filtered ids to be the instance ids that are not 0 + filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - {0,} + preserved_ids = np.unique(labels_element.data.compute()) + assert filtered_ids == (set(all_instance_ids) - {2, 3}) + # check if there is modification of the original labels + assert set(preserved_ids) == set(all_instance_ids) | {0} + + sdata.tables["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels" + sdata.tables["table"].obs.region = "blobs_multiscale_labels" + labels_element = sdata["blobs_multiscale_labels"] + filtered_labels_element = filter_labels2dmodel_by_instance_ids(labels_element, [2, 3]) + + for scale in labels_element: + filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - {0,} + preserved_ids = np.unique(labels_element[scale].image.compute()) + assert filtered_ids == (set(all_instance_ids) - {2, 3}) + # check if there is modification of the original labels + assert set(preserved_ids) == set(all_instance_ids) | {0} + + +def test_filter_shapesmodel_by_instance_ids(): + sdata = blobs_annotating_element("blobs_circles") + shapes_element = sdata["blobs_circles"] + filtered_shapes_element = filter_shapesmodel_by_instance_ids(shapes_element, [2, 3]) + + assert set(filtered_shapes_element.index.tolist()) == {0, 1, 4} From 225d593146efd81be25e505231f6450c47bd254c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 30 Jun 2025 13:22:41 +0000 Subject: [PATCH 02/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spatialdata/_core/query/masking.py | 4 ++-- tests/core/query/test_masking.py | 9 ++++++--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/spatialdata/_core/query/masking.py b/src/spatialdata/_core/query/masking.py index 451e7d4ed..32510b03a 100644 --- a/src/spatialdata/_core/query/masking.py +++ b/src/spatialdata/_core/query/masking.py @@ -1,12 +1,12 @@ +from functools import partial + import numpy as np import xarray as xr -from functools import partial from spatialdata.models import Labels2DModel, ShapesModel from spatialdata.models.models import DataTree - def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: # Use apply_ufunc for efficient processing # Create a copy to avoid modifying read-only array diff --git a/tests/core/query/test_masking.py b/tests/core/query/test_masking.py index 89f4db526..370c153cc 100644 --- a/tests/core/query/test_masking.py +++ b/tests/core/query/test_masking.py @@ -1,5 +1,4 @@ import numpy as np -import anndata as ad from spatialdata._core.query.masking import filter_labels2dmodel_by_instance_ids, filter_shapesmodel_by_instance_ids from spatialdata.datasets import blobs_annotating_element @@ -12,7 +11,9 @@ def test_filter_labels2dmodel_by_instance_ids(): filtered_labels_element = filter_labels2dmodel_by_instance_ids(labels_element, [2, 3]) # because 0 is the background, we expect the filtered ids to be the instance ids that are not 0 - filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - {0,} + filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - { + 0, + } preserved_ids = np.unique(labels_element.data.compute()) assert filtered_ids == (set(all_instance_ids) - {2, 3}) # check if there is modification of the original labels @@ -24,7 +25,9 @@ def test_filter_labels2dmodel_by_instance_ids(): filtered_labels_element = filter_labels2dmodel_by_instance_ids(labels_element, [2, 3]) for scale in labels_element: - filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - {0,} + filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - { + 0, + } preserved_ids = np.unique(labels_element[scale].image.compute()) assert filtered_ids == (set(all_instance_ids) - {2, 3}) # check if there is modification of the original labels From e549b4b4725f9ab73eaead533e27ca320491b9c4 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 30 Jun 2025 15:45:47 +0200 Subject: [PATCH 03/11] fix mypy linterrors --- src/spatialdata/_core/query/masking.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/spatialdata/_core/query/masking.py b/src/spatialdata/_core/query/masking.py index 32510b03a..0605ad2e7 100644 --- a/src/spatialdata/_core/query/masking.py +++ b/src/spatialdata/_core/query/masking.py @@ -2,9 +2,11 @@ import numpy as np import xarray as xr +from geopandas import GeoDataFrame +from xarray.core.dataarray import DataArray +from xarray.core.datatree import DataTree from spatialdata.models import Labels2DModel, ShapesModel -from spatialdata.models.models import DataTree def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: @@ -32,15 +34,13 @@ def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list computed_result = processed.compute() # Create a new DataArray to ensure persistence - result = xr.DataArray( + return xr.DataArray( data=computed_result.data, coords=image.coords, dims=image.dims, attrs=image.attrs.copy(), # Preserve all attributes ) - return result - def _get_scale_factors(labels_element: Labels2DModel) -> list[tuple[float, float]]: scales = list(labels_element.keys()) @@ -60,7 +60,7 @@ def _get_scale_factors(labels_element: Labels2DModel) -> list[tuple[float, float return scale_factors -def filter_shapesmodel_by_instance_ids(element: ShapesModel, ids_to_remove: list[str]) -> ShapesModel: +def filter_shapesmodel_by_instance_ids(element: ShapesModel, ids_to_remove: list[str]) -> GeoDataFrame: """ Filter a ShapesModel by instance ids. @@ -75,10 +75,11 @@ def filter_shapesmodel_by_instance_ids(element: ShapesModel, ids_to_remove: list ------- The filtered ShapesModel. """ - return element[~element.index.isin(ids_to_remove)] + element2: GeoDataFrame = element[~element.index.isin(ids_to_remove)] # type: ignore[index, attr-defined] + return ShapesModel.parse(element2) -def filter_labels2dmodel_by_instance_ids(element: Labels2DModel, ids_to_remove: list[int]) -> Labels2DModel: +def filter_labels2dmodel_by_instance_ids(element: Labels2DModel, ids_to_remove: list[int]) -> DataArray | DataTree: """ Filter a Labels2DModel by instance ids. @@ -103,8 +104,8 @@ def filter_labels2dmodel_by_instance_ids(element: Labels2DModel, ids_to_remove: # we extract the info to just reconstruct # the DataTree after filtering the max scale max_scale = list(element.keys())[0] - scale_factors = _get_scale_factors(element) - scale_factors = [int(sf[0]) for sf in scale_factors] + scale_factors_temp = _get_scale_factors(element) + scale_factors = [int(sf[0]) for sf in scale_factors_temp] return Labels2DModel.parse( data=_set_instance_ids_in_labels_to_zero(element[max_scale].image, ids_to_remove), From 2aad72b23c3ccbd4bcc852b28bdf866032d8d1d1 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 3 Jul 2025 16:33:21 +0200 Subject: [PATCH 04/11] update the location and the design --- src/spatialdata/__init__.py | 2 + src/spatialdata/_core/query/masking.py | 114 ---------------- .../_core/query/relational_query.py | 124 ++++++++++++++++++ tests/core/query/test_masking.py | 42 ------ ...tional_query_subset_sdata_by_table_mask.py | 67 ++++++++++ 5 files changed, 193 insertions(+), 156 deletions(-) delete mode 100644 src/spatialdata/_core/query/masking.py delete mode 100644 tests/core/query/test_masking.py create mode 100644 tests/core/query/test_relational_query_subset_sdata_by_table_mask.py diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 0b68391ad..724975bf6 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -55,6 +55,7 @@ "deepcopy", "sanitize_table", "sanitize_name", + "subset_sdata_by_table_mask", ] from spatialdata import dataloader, datasets, models, transformations @@ -78,6 +79,7 @@ match_element_to_table, match_sdata_to_table, match_table_to_element, + subset_sdata_by_table_mask, ) from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query from spatialdata._core.spatialdata import SpatialData diff --git a/src/spatialdata/_core/query/masking.py b/src/spatialdata/_core/query/masking.py deleted file mode 100644 index 0605ad2e7..000000000 --- a/src/spatialdata/_core/query/masking.py +++ /dev/null @@ -1,114 +0,0 @@ -from functools import partial - -import numpy as np -import xarray as xr -from geopandas import GeoDataFrame -from xarray.core.dataarray import DataArray -from xarray.core.datatree import DataTree - -from spatialdata.models import Labels2DModel, ShapesModel - - -def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: - # Use apply_ufunc for efficient processing - # Create a copy to avoid modifying read-only array - result = block.copy() - result[np.isin(result, ids_to_remove)] = 0 - return result - - -def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: - processed = xr.apply_ufunc( - partial(_mask_block, ids_to_remove=ids_to_remove), - image, - input_core_dims=[["y", "x"]], - output_core_dims=[["y", "x"]], - vectorize=True, - dask="parallelized", - output_dtypes=[image.dtype], - dataset_fill_value=0, - dask_gufunc_kwargs={"allow_rechunk": True}, - ) - - # Force computation to ensure the changes are materialized - computed_result = processed.compute() - - # Create a new DataArray to ensure persistence - return xr.DataArray( - data=computed_result.data, - coords=image.coords, - dims=image.dims, - attrs=image.attrs.copy(), # Preserve all attributes - ) - - -def _get_scale_factors(labels_element: Labels2DModel) -> list[tuple[float, float]]: - scales = list(labels_element.keys()) - - # Calculate relative scale factors between consecutive scales - scale_factors = [] - for i in range(len(scales) - 1): - y_size_current = labels_element[scales[i]].image.shape[0] - x_size_current = labels_element[scales[i]].image.shape[1] - y_size_next = labels_element[scales[i + 1]].image.shape[0] - x_size_next = labels_element[scales[i + 1]].image.shape[1] - y_factor = y_size_current / y_size_next - x_factor = x_size_current / x_size_next - - scale_factors.append((y_factor, x_factor)) - - return scale_factors - - -def filter_shapesmodel_by_instance_ids(element: ShapesModel, ids_to_remove: list[str]) -> GeoDataFrame: - """ - Filter a ShapesModel by instance ids. - - Parameters - ---------- - element - The ShapesModel to filter. - ids_to_remove - The instance ids to remove. - - Returns - ------- - The filtered ShapesModel. - """ - element2: GeoDataFrame = element[~element.index.isin(ids_to_remove)] # type: ignore[index, attr-defined] - return ShapesModel.parse(element2) - - -def filter_labels2dmodel_by_instance_ids(element: Labels2DModel, ids_to_remove: list[int]) -> DataArray | DataTree: - """ - Filter a Labels2DModel by instance ids. - - This function works for both DataArray and DataTree and sets the - instance ids to zero. - - Parameters - ---------- - element - The Labels2DModel to filter. - ids_to_remove - The instance ids to remove. - - Returns - ------- - The filtered Labels2DModel. - """ - if isinstance(element, xr.DataArray): - return Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(element, ids_to_remove)) - - if isinstance(element, DataTree): - # we extract the info to just reconstruct - # the DataTree after filtering the max scale - max_scale = list(element.keys())[0] - scale_factors_temp = _get_scale_factors(element) - scale_factors = [int(sf[0]) for sf in scale_factors_temp] - - return Labels2DModel.parse( - data=_set_instance_ids_in_labels_to_zero(element[max_scale].image, ids_to_remove), - scale_factors=scale_factors, - ) - raise ValueError(f"Unknown element type: {type(element)}") diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index b84d43c1b..f82572d38 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -11,9 +11,11 @@ import dask.array as da import numpy as np import pandas as pd +import xarray as xr from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame +from numpy.typing import NDArray from xarray import DataArray, DataTree from spatialdata._core.spatialdata import SpatialData @@ -1019,3 +1021,125 @@ def get_values( return df raise ValueError(f"Unknown origin {origin}") + + +def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: + # Use apply_ufunc for efficient processing + # Create a copy to avoid modifying read-only array + result = block.copy() + result[np.isin(result, ids_to_remove)] = 0 + return result + + +def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: + processed = xr.apply_ufunc( + partial(_mask_block, ids_to_remove=ids_to_remove), + image, + input_core_dims=[["y", "x"]], + output_core_dims=[["y", "x"]], + vectorize=True, + dask="parallelized", + output_dtypes=[image.dtype], + dataset_fill_value=0, + dask_gufunc_kwargs={"allow_rechunk": True}, + ) + + # Force computation to ensure the changes are materialized + computed_result = processed.compute() + + # Create a new DataArray to ensure persistence + return xr.DataArray( + data=computed_result.data, + coords=image.coords, + dims=image.dims, + attrs=image.attrs.copy(), # Preserve all attributes + ) + + +def _get_scale_factors(labels_element: DataTree) -> list[tuple[float, float]]: + scales = list(labels_element.keys()) + + # Calculate relative scale factors between consecutive scales + scale_factors = [] + for i in range(len(scales) - 1): + y_size_current = labels_element[scales[i]].image.shape[0] + x_size_current = labels_element[scales[i]].image.shape[1] + y_size_next = labels_element[scales[i + 1]].image.shape[0] + x_size_next = labels_element[scales[i + 1]].image.shape[1] + y_factor = y_size_current / y_size_next + x_factor = x_size_current / x_size_next + + scale_factors.append((y_factor, x_factor)) + + return scale_factors + + +@singledispatch +def _filter_by_instance_ids(element: Any, ids_to_remove: list[str], instance_key: str) -> Any: + raise NotImplementedError(f"Filtering by instance ids is not implemented for {element}") + + +@_filter_by_instance_ids.register(GeoDataFrame) +def _(element: GeoDataFrame, ids_to_remove: list[str], instance_key: str) -> GeoDataFrame: + return element[~element.index.isin(ids_to_remove)] + + +@_filter_by_instance_ids.register(DaskDataFrame) +def _(element: DaskDataFrame, ids_to_remove: list[str], instance_key: str) -> DaskDataFrame: + return element[~element[instance_key].isin(ids_to_remove)] + + +@_filter_by_instance_ids.register(DataArray) +def _(element: DataArray, ids_to_remove: list[int], instance_key: str) -> DataArray: + del instance_key + return Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(element, ids_to_remove)) + + +@_filter_by_instance_ids.register(DataTree) +def _(element: DataArray | DataTree, ids_to_remove: list[int], instance_key: str) -> xr.DataArray | xr.DataTree: + # we extract the info to just reconstruct + # the DataTree after filtering the max scale + max_scale = list(element.keys())[0] + scale_factors_temp = _get_scale_factors(element) + scale_factors = [int(sf[0]) for sf in scale_factors_temp] + + return Labels2DModel.parse( + data=_set_instance_ids_in_labels_to_zero(element[max_scale].image, ids_to_remove), + scale_factors=scale_factors, + ) + + +def subset_sdata_by_table_mask(sdata: SpatialData, table_name: str, mask: NDArray[np.bool_]) -> SpatialData: + """ + Subset a SpatialData object by a table and a mask. + + Parameters + ---------- + sdata + The SpatialData object to subset. + table_name + The name of the table to apply the mask to. + mask + Boolean mask to apply to the table. + + Returns + ------- + The subsetted SpatialData object. + """ + table = sdata.tables.get(table_name) + if table is None: + raise ValueError(f"Table {table_name} not found in SpatialData object.") + + subset_table = table[mask] + _, _, instance_key = get_table_keys(subset_table) + annotated_regions = SpatialData.get_annotated_regions(table) + removed_instance_ids = list(np.unique(table.obs[instance_key][~mask])) + + filtered_elements = {} + for reg in annotated_regions: + elem = sdata.get(reg) + model = get_model(elem) + if model in [Labels2DModel, PointsModel, ShapesModel]: + filtered_elements[reg] = _filter_by_instance_ids(elem, removed_instance_ids, instance_key) + + return SpatialData.init_from_elements(filtered_elements | {table_name: subset_table}) diff --git a/tests/core/query/test_masking.py b/tests/core/query/test_masking.py deleted file mode 100644 index 370c153cc..000000000 --- a/tests/core/query/test_masking.py +++ /dev/null @@ -1,42 +0,0 @@ -import numpy as np - -from spatialdata._core.query.masking import filter_labels2dmodel_by_instance_ids, filter_shapesmodel_by_instance_ids -from spatialdata.datasets import blobs_annotating_element - - -def test_filter_labels2dmodel_by_instance_ids(): - sdata = blobs_annotating_element("blobs_labels") - labels_element = sdata["blobs_labels"] - all_instance_ids = sdata.tables["table"].obs["instance_id"].unique() - filtered_labels_element = filter_labels2dmodel_by_instance_ids(labels_element, [2, 3]) - - # because 0 is the background, we expect the filtered ids to be the instance ids that are not 0 - filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - { - 0, - } - preserved_ids = np.unique(labels_element.data.compute()) - assert filtered_ids == (set(all_instance_ids) - {2, 3}) - # check if there is modification of the original labels - assert set(preserved_ids) == set(all_instance_ids) | {0} - - sdata.tables["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels" - sdata.tables["table"].obs.region = "blobs_multiscale_labels" - labels_element = sdata["blobs_multiscale_labels"] - filtered_labels_element = filter_labels2dmodel_by_instance_ids(labels_element, [2, 3]) - - for scale in labels_element: - filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - { - 0, - } - preserved_ids = np.unique(labels_element[scale].image.compute()) - assert filtered_ids == (set(all_instance_ids) - {2, 3}) - # check if there is modification of the original labels - assert set(preserved_ids) == set(all_instance_ids) | {0} - - -def test_filter_shapesmodel_by_instance_ids(): - sdata = blobs_annotating_element("blobs_circles") - shapes_element = sdata["blobs_circles"] - filtered_shapes_element = filter_shapesmodel_by_instance_ids(shapes_element, [2, 3]) - - assert set(filtered_shapes_element.index.tolist()) == {0, 1, 4} diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py new file mode 100644 index 000000000..bc7e0967e --- /dev/null +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -0,0 +1,67 @@ +import numpy as np + +from spatialdata._core.query.relational_query import _filter_by_instance_ids +from spatialdata.datasets import blobs_annotating_element +from spatialdata import concatenate, subset_sdata_by_table_mask + + +def test_filter_labels2dmodel_by_instance_ids(): + sdata = blobs_annotating_element("blobs_labels") + labels_element = sdata["blobs_labels"] + all_instance_ids = sdata.tables["table"].obs["instance_id"].unique() + filtered_labels_element = _filter_by_instance_ids(labels_element, [2, 3], "instance_id") + + # because 0 is the background, we expect the filtered ids to be the instance ids that are not 0 + filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - { + 0, + } + preserved_ids = np.unique(labels_element.data.compute()) + assert filtered_ids == (set(all_instance_ids) - {2, 3}) + # check if there is modification of the original labels + assert set(preserved_ids) == set(all_instance_ids) | {0} + + sdata.tables["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels" + sdata.tables["table"].obs.region = "blobs_multiscale_labels" + labels_element = sdata["blobs_multiscale_labels"] + filtered_labels_element = _filter_by_instance_ids(labels_element, [2, 3], "instance_id") + + for scale in labels_element: + filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - { + 0, + } + preserved_ids = np.unique(labels_element[scale].image.compute()) + assert filtered_ids == (set(all_instance_ids) - {2, 3}) + # check if there is modification of the original labels + assert set(preserved_ids) == set(all_instance_ids) | {0} + + +def test_subset_sdata_by_table_mask(): + sdata = concatenate( + { + "labels": blobs_annotating_element("blobs_labels"), + "shapes": blobs_annotating_element("blobs_circles"), + "points": blobs_annotating_element("blobs_points"), + "multiscale_labels": blobs_annotating_element("blobs_multiscale_labels"), + }, + concatenate_tables=True, + ) + third_elems = sdata.tables["table"].obs["instance_id"] == 3 + subset_sdata = subset_sdata_by_table_mask(sdata, "table", third_elems) + + assert set(subset_sdata.labels.keys()) == {"blobs_labels-labels", "blobs_multiscale_labels-multiscale_labels"} + assert set(subset_sdata.points.keys()) == {"blobs_points-points"} + assert set(subset_sdata.shapes.keys()) == {"blobs_circles-shapes"} + + labels_remaining_ids = set(np.unique(subset_sdata.labels["blobs_labels-labels"].data.compute())) - {0} + assert labels_remaining_ids == {3} + + for scale in subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"]: + ms_labels_remaining_ids = set(np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute())) - {0} + assert ms_labels_remaining_ids == {3} + + points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"]['instance_id'].compute())) - {0} + assert points_remaining_ids == {3} + + shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} + assert shapes_remaining_ids == {3} + From ef7405790cca44e7b95f479661290ce13d9b62ff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Jul 2025 14:33:39 +0000 Subject: [PATCH 05/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../test_relational_query_subset_sdata_by_table_mask.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py index bc7e0967e..810e2f737 100644 --- a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -1,8 +1,8 @@ import numpy as np +from spatialdata import concatenate, subset_sdata_by_table_mask from spatialdata._core.query.relational_query import _filter_by_instance_ids from spatialdata.datasets import blobs_annotating_element -from spatialdata import concatenate, subset_sdata_by_table_mask def test_filter_labels2dmodel_by_instance_ids(): @@ -56,12 +56,13 @@ def test_subset_sdata_by_table_mask(): assert labels_remaining_ids == {3} for scale in subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"]: - ms_labels_remaining_ids = set(np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute())) - {0} + ms_labels_remaining_ids = set( + np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute()) + ) - {0} assert ms_labels_remaining_ids == {3} - points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"]['instance_id'].compute())) - {0} + points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"]["instance_id"].compute())) - {0} assert points_remaining_ids == {3} shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} assert shapes_remaining_ids == {3} - From d6e22cbe9ce1eff0f752b744c5be814b4b724b57 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 3 Jul 2025 16:41:30 +0200 Subject: [PATCH 06/11] update docs --- src/spatialdata/_core/query/relational_query.py | 13 ++++++++++--- ...t_relational_query_subset_sdata_by_table_mask.py | 9 +++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index f82572d38..8914d1daf 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -1110,8 +1110,15 @@ def _(element: DataArray | DataTree, ids_to_remove: list[int], instance_key: str def subset_sdata_by_table_mask(sdata: SpatialData, table_name: str, mask: NDArray[np.bool_]) -> SpatialData: - """ - Subset a SpatialData object by a table and a mask. + """Subset the annotated elements of a SpatialData object by a table and a mask. + + The mask is applied to the table and the annotated elements are subsetted + by the instance ids in the table. + This function returns a new SpatialData object with the subsetted elements. + Elements that are not annotated by the table are not included in the returned SpatialData object. + The element models that are + supported are :class:`spatialdata.models.Labels2DModel`, + :class:`spatialdata.models.PointsModel`, and :class:`spatialdata.models.ShapesModel`. Parameters ---------- @@ -1120,7 +1127,7 @@ def subset_sdata_by_table_mask(sdata: SpatialData, table_name: str, mask: NDArra table_name The name of the table to apply the mask to. mask - Boolean mask to apply to the table. + Boolean mask to apply to the table which is the same length as the number of rows in the table. Returns ------- diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py index bc7e0967e..810e2f737 100644 --- a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -1,8 +1,8 @@ import numpy as np +from spatialdata import concatenate, subset_sdata_by_table_mask from spatialdata._core.query.relational_query import _filter_by_instance_ids from spatialdata.datasets import blobs_annotating_element -from spatialdata import concatenate, subset_sdata_by_table_mask def test_filter_labels2dmodel_by_instance_ids(): @@ -56,12 +56,13 @@ def test_subset_sdata_by_table_mask(): assert labels_remaining_ids == {3} for scale in subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"]: - ms_labels_remaining_ids = set(np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute())) - {0} + ms_labels_remaining_ids = set( + np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute()) + ) - {0} assert ms_labels_remaining_ids == {3} - points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"]['instance_id'].compute())) - {0} + points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"]["instance_id"].compute())) - {0} assert points_remaining_ids == {3} shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} assert shapes_remaining_ids == {3} - From 80d95a2710b7bdd5e6c020048864e64802810b0b Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 3 Jul 2025 17:06:35 +0200 Subject: [PATCH 07/11] make coverage 100/100 because why not --- ...st_relational_query_subset_sdata_by_table_mask.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py index 810e2f737..46d76eaef 100644 --- a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -1,4 +1,5 @@ import numpy as np +import pytest from spatialdata import concatenate, subset_sdata_by_table_mask from spatialdata._core.query.relational_query import _filter_by_instance_ids @@ -66,3 +67,14 @@ def test_subset_sdata_by_table_mask(): shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} assert shapes_remaining_ids == {3} + + +def test_subset_sdata_by_table_mask_with_no_annotated_elements(): + with pytest.raises(ValueError, match="Table table_not_found not found in SpatialData object."): + sdata = blobs_annotating_element("blobs_labels") + _ = subset_sdata_by_table_mask(sdata, "table_not_found", sdata.tables["table"].obs["instance_id"] == 3) + + +def test_filter_by_instance_ids_fails_for_unsupported_element_models(): + with pytest.raises(NotImplementedError, match="Filtering by instance ids is not implemented for"): + _filter_by_instance_ids([1, 1, 1, 2], [1], "instance_id") \ No newline at end of file From 44386052a31ec1d612db8b5ada3b3188f155947e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Jul 2025 15:06:53 +0000 Subject: [PATCH 08/11] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../query/test_relational_query_subset_sdata_by_table_mask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py index 46d76eaef..d634f645d 100644 --- a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -77,4 +77,4 @@ def test_subset_sdata_by_table_mask_with_no_annotated_elements(): def test_filter_by_instance_ids_fails_for_unsupported_element_models(): with pytest.raises(NotImplementedError, match="Filtering by instance ids is not implemented for"): - _filter_by_instance_ids([1, 1, 1, 2], [1], "instance_id") \ No newline at end of file + _filter_by_instance_ids([1, 1, 1, 2], [1], "instance_id") From 4c927ee86ca7e60e86262e9050d57ac40575401f Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 10 Jul 2025 17:48:36 +0200 Subject: [PATCH 09/11] fixed type annotation --- src/spatialdata/_core/query/relational_query.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 8914d1daf..503329559 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -1096,9 +1096,10 @@ def _(element: DataArray, ids_to_remove: list[int], instance_key: str) -> DataAr @_filter_by_instance_ids.register(DataTree) -def _(element: DataArray | DataTree, ids_to_remove: list[int], instance_key: str) -> xr.DataArray | xr.DataTree: +def _(element: DataTree, ids_to_remove: list[int], instance_key: str) -> DataTree: # we extract the info to just reconstruct # the DataTree after filtering the max scale + del instance_key max_scale = list(element.keys())[0] scale_factors_temp = _get_scale_factors(element) scale_factors = [int(sf[0]) for sf in scale_factors_temp] From e9e0da2c62bad3e1dd59688765427aa03595bca9 Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Thu, 10 Jul 2025 17:54:56 +0200 Subject: [PATCH 10/11] dont compute eagerly use. delete other instance key for consistency --- src/spatialdata/_core/query/relational_query.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 503329559..ba6c1eb2f 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -1044,12 +1044,9 @@ def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list dask_gufunc_kwargs={"allow_rechunk": True}, ) - # Force computation to ensure the changes are materialized - computed_result = processed.compute() - # Create a new DataArray to ensure persistence return xr.DataArray( - data=computed_result.data, + data=processed.data, coords=image.coords, dims=image.dims, attrs=image.attrs.copy(), # Preserve all attributes @@ -1081,6 +1078,7 @@ def _filter_by_instance_ids(element: Any, ids_to_remove: list[str], instance_key @_filter_by_instance_ids.register(GeoDataFrame) def _(element: GeoDataFrame, ids_to_remove: list[str], instance_key: str) -> GeoDataFrame: + del instance_key return element[~element.index.isin(ids_to_remove)] From 7534c91b4181f6c5e0feb8f0cf3c1dfa3ed5bd1f Mon Sep 17 00:00:00 2001 From: selmanozleyen Date: Mon, 14 Jul 2025 15:40:11 +0200 Subject: [PATCH 11/11] update the tests and make sure we use match_element_to_table --- src/spatialdata/_core/query/relational_query.py | 17 +++++------------ ...lational_query_subset_sdata_by_table_mask.py | 10 +++++----- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index ba6c1eb2f..e570caec1 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -1076,17 +1076,6 @@ def _filter_by_instance_ids(element: Any, ids_to_remove: list[str], instance_key raise NotImplementedError(f"Filtering by instance ids is not implemented for {element}") -@_filter_by_instance_ids.register(GeoDataFrame) -def _(element: GeoDataFrame, ids_to_remove: list[str], instance_key: str) -> GeoDataFrame: - del instance_key - return element[~element.index.isin(ids_to_remove)] - - -@_filter_by_instance_ids.register(DaskDataFrame) -def _(element: DaskDataFrame, ids_to_remove: list[str], instance_key: str) -> DaskDataFrame: - return element[~element[instance_key].isin(ids_to_remove)] - - @_filter_by_instance_ids.register(DataArray) def _(element: DataArray, ids_to_remove: list[int], instance_key: str) -> DataArray: del instance_key @@ -1137,6 +1126,7 @@ def subset_sdata_by_table_mask(sdata: SpatialData, table_name: str, mask: NDArra raise ValueError(f"Table {table_name} not found in SpatialData object.") subset_table = table[mask] + sdata.tables[table_name] = subset_table _, _, instance_key = get_table_keys(subset_table) annotated_regions = SpatialData.get_annotated_regions(table) removed_instance_ids = list(np.unique(table.obs[instance_key][~mask])) @@ -1145,7 +1135,10 @@ def subset_sdata_by_table_mask(sdata: SpatialData, table_name: str, mask: NDArra for reg in annotated_regions: elem = sdata.get(reg) model = get_model(elem) - if model in [Labels2DModel, PointsModel, ShapesModel]: + if model is Labels2DModel: filtered_elements[reg] = _filter_by_instance_ids(elem, removed_instance_ids, instance_key) + elif model in [PointsModel, ShapesModel]: + element_dict, _ = match_element_to_table(sdata, element_name=reg, table_name=table_name) + filtered_elements[reg] = element_dict[reg] return SpatialData.init_from_elements(filtered_elements | {table_name: subset_table}) diff --git a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py index d634f645d..629f58bb6 100644 --- a/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py +++ b/tests/core/query/test_relational_query_subset_sdata_by_table_mask.py @@ -6,7 +6,7 @@ from spatialdata.datasets import blobs_annotating_element -def test_filter_labels2dmodel_by_instance_ids(): +def test_filter_labels2dmodel_by_instance_ids() -> None: sdata = blobs_annotating_element("blobs_labels") labels_element = sdata["blobs_labels"] all_instance_ids = sdata.tables["table"].obs["instance_id"].unique() @@ -36,7 +36,7 @@ def test_filter_labels2dmodel_by_instance_ids(): assert set(preserved_ids) == set(all_instance_ids) | {0} -def test_subset_sdata_by_table_mask(): +def test_subset_sdata_by_table_mask() -> None: sdata = concatenate( { "labels": blobs_annotating_element("blobs_labels"), @@ -62,19 +62,19 @@ def test_subset_sdata_by_table_mask(): ) - {0} assert ms_labels_remaining_ids == {3} - points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"]["instance_id"].compute())) - {0} + points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"].index)) - {0} assert points_remaining_ids == {3} shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} assert shapes_remaining_ids == {3} -def test_subset_sdata_by_table_mask_with_no_annotated_elements(): +def test_subset_sdata_by_table_mask_with_no_annotated_elements() -> None: with pytest.raises(ValueError, match="Table table_not_found not found in SpatialData object."): sdata = blobs_annotating_element("blobs_labels") _ = subset_sdata_by_table_mask(sdata, "table_not_found", sdata.tables["table"].obs["instance_id"] == 3) -def test_filter_by_instance_ids_fails_for_unsupported_element_models(): +def test_filter_by_instance_ids_fails_for_unsupported_element_models() -> None: with pytest.raises(NotImplementedError, match="Filtering by instance ids is not implemented for"): _filter_by_instance_ids([1, 1, 1, 2], [1], "instance_id")