-
Notifications
You must be signed in to change notification settings - Fork 71
Filter Operations on Label2DModel and Shape #946
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
a28c7c9
225d593
e549b4b
2aad72b
ef74057
d6e22cb
46c41db
80d95a2
4438605
4c927ee
e9e0da2
7534c91
b4901cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,9 +11,11 @@ | |
import dask.array as da | ||
import numpy as np | ||
import pandas as pd | ||
import xarray as xr | ||
from anndata import AnnData | ||
from dask.dataframe import DataFrame as DaskDataFrame | ||
from geopandas import GeoDataFrame | ||
from numpy.typing import NDArray | ||
from xarray import DataArray, DataTree | ||
|
||
from spatialdata._core.spatialdata import SpatialData | ||
|
@@ -1019,3 +1021,132 @@ def get_values( | |
return df | ||
|
||
raise ValueError(f"Unknown origin {origin}") | ||
|
||
|
||
def _mask_block(block: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: | ||
# Use apply_ufunc for efficient processing | ||
# Create a copy to avoid modifying read-only array | ||
result = block.copy() | ||
result[np.isin(result, ids_to_remove)] = 0 | ||
return result | ||
|
||
|
||
def _set_instance_ids_in_labels_to_zero(image: xr.DataArray, ids_to_remove: list[int]) -> xr.DataArray: | ||
processed = xr.apply_ufunc( | ||
partial(_mask_block, ids_to_remove=ids_to_remove), | ||
image, | ||
input_core_dims=[["y", "x"]], | ||
output_core_dims=[["y", "x"]], | ||
vectorize=True, | ||
dask="parallelized", | ||
output_dtypes=[image.dtype], | ||
dataset_fill_value=0, | ||
dask_gufunc_kwargs={"allow_rechunk": True}, | ||
) | ||
|
||
# Force computation to ensure the changes are materialized | ||
computed_result = processed.compute() | ||
|
||
# Create a new DataArray to ensure persistence | ||
return xr.DataArray( | ||
data=computed_result.data, | ||
coords=image.coords, | ||
dims=image.dims, | ||
attrs=image.attrs.copy(), # Preserve all attributes | ||
) | ||
|
||
|
||
def _get_scale_factors(labels_element: DataTree) -> list[tuple[float, float]]: | ||
scales = list(labels_element.keys()) | ||
|
||
# Calculate relative scale factors between consecutive scales | ||
scale_factors = [] | ||
for i in range(len(scales) - 1): | ||
y_size_current = labels_element[scales[i]].image.shape[0] | ||
x_size_current = labels_element[scales[i]].image.shape[1] | ||
y_size_next = labels_element[scales[i + 1]].image.shape[0] | ||
x_size_next = labels_element[scales[i + 1]].image.shape[1] | ||
y_factor = y_size_current / y_size_next | ||
x_factor = x_size_current / x_size_next | ||
|
||
scale_factors.append((y_factor, x_factor)) | ||
|
||
return scale_factors | ||
|
||
|
||
@singledispatch | ||
def _filter_by_instance_ids(element: Any, ids_to_remove: list[str], instance_key: str) -> Any: | ||
raise NotImplementedError(f"Filtering by instance ids is not implemented for {element}") | ||
|
||
|
||
@_filter_by_instance_ids.register(GeoDataFrame) | ||
def _(element: GeoDataFrame, ids_to_remove: list[str], instance_key: str) -> GeoDataFrame: | ||
return element[~element.index.isin(ids_to_remove)] | ||
|
||
|
||
@_filter_by_instance_ids.register(DaskDataFrame) | ||
def _(element: DaskDataFrame, ids_to_remove: list[str], instance_key: str) -> DaskDataFrame: | ||
return element[~element[instance_key].isin(ids_to_remove)] | ||
|
||
|
||
@_filter_by_instance_ids.register(DataArray) | ||
def _(element: DataArray, ids_to_remove: list[int], instance_key: str) -> DataArray: | ||
del instance_key | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you mean why del instance_key? It is to explicitly clarify that I won't be using it for this dispatch |
||
return Labels2DModel.parse(_set_instance_ids_in_labels_to_zero(element, ids_to_remove)) | ||
|
||
|
||
@_filter_by_instance_ids.register(DataTree) | ||
def _(element: DataArray | DataTree, ids_to_remove: list[int], instance_key: str) -> xr.DataArray | xr.DataTree: | ||
Zethson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
# we extract the info to just reconstruct | ||
# the DataTree after filtering the max scale | ||
max_scale = list(element.keys())[0] | ||
scale_factors_temp = _get_scale_factors(element) | ||
scale_factors = [int(sf[0]) for sf in scale_factors_temp] | ||
|
||
return Labels2DModel.parse( | ||
data=_set_instance_ids_in_labels_to_zero(element[max_scale].image, ids_to_remove), | ||
scale_factors=scale_factors, | ||
) | ||
|
||
|
||
def subset_sdata_by_table_mask(sdata: SpatialData, table_name: str, mask: NDArray[np.bool_]) -> SpatialData: | ||
"""Subset the annotated elements of a SpatialData object by a table and a mask. | ||
|
||
The mask is applied to the table and the annotated elements are subsetted | ||
by the instance ids in the table. | ||
This function returns a new SpatialData object with the subsetted elements. | ||
Elements that are not annotated by the table are not included in the returned SpatialData object. | ||
The element models that are | ||
supported are :class:`spatialdata.models.Labels2DModel`, | ||
:class:`spatialdata.models.PointsModel`, and :class:`spatialdata.models.ShapesModel`. | ||
|
||
Parameters | ||
---------- | ||
sdata | ||
The SpatialData object to subset. | ||
table_name | ||
The name of the table to apply the mask to. | ||
mask | ||
Boolean mask to apply to the table which is the same length as the number of rows in the table. | ||
|
||
Returns | ||
------- | ||
The subsetted SpatialData object. | ||
""" | ||
table = sdata.tables.get(table_name) | ||
if table is None: | ||
raise ValueError(f"Table {table_name} not found in SpatialData object.") | ||
|
||
subset_table = table[mask] | ||
_, _, instance_key = get_table_keys(subset_table) | ||
annotated_regions = SpatialData.get_annotated_regions(table) | ||
removed_instance_ids = list(np.unique(table.obs[instance_key][~mask])) | ||
|
||
filtered_elements = {} | ||
for reg in annotated_regions: | ||
elem = sdata.get(reg) | ||
model = get_model(elem) | ||
if model in [Labels2DModel, PointsModel, ShapesModel]: | ||
filtered_elements[reg] = _filter_by_instance_ids(elem, removed_instance_ids, instance_key) | ||
|
||
return SpatialData.init_from_elements(filtered_elements | {table_name: subset_table}) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
from spatialdata import concatenate, subset_sdata_by_table_mask | ||
from spatialdata._core.query.relational_query import _filter_by_instance_ids | ||
from spatialdata.datasets import blobs_annotating_element | ||
|
||
|
||
def test_filter_labels2dmodel_by_instance_ids(): | ||
|
||
sdata = blobs_annotating_element("blobs_labels") | ||
labels_element = sdata["blobs_labels"] | ||
all_instance_ids = sdata.tables["table"].obs["instance_id"].unique() | ||
filtered_labels_element = _filter_by_instance_ids(labels_element, [2, 3], "instance_id") | ||
|
||
# because 0 is the background, we expect the filtered ids to be the instance ids that are not 0 | ||
filtered_ids = set(np.unique(filtered_labels_element.data.compute())) - { | ||
0, | ||
} | ||
preserved_ids = np.unique(labels_element.data.compute()) | ||
assert filtered_ids == (set(all_instance_ids) - {2, 3}) | ||
# check if there is modification of the original labels | ||
assert set(preserved_ids) == set(all_instance_ids) | {0} | ||
|
||
sdata.tables["table"].uns["spatialdata_attrs"]["region"] = "blobs_multiscale_labels" | ||
sdata.tables["table"].obs.region = "blobs_multiscale_labels" | ||
labels_element = sdata["blobs_multiscale_labels"] | ||
filtered_labels_element = _filter_by_instance_ids(labels_element, [2, 3], "instance_id") | ||
|
||
for scale in labels_element: | ||
filtered_ids = set(np.unique(filtered_labels_element[scale].image.compute())) - { | ||
0, | ||
} | ||
preserved_ids = np.unique(labels_element[scale].image.compute()) | ||
assert filtered_ids == (set(all_instance_ids) - {2, 3}) | ||
# check if there is modification of the original labels | ||
assert set(preserved_ids) == set(all_instance_ids) | {0} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use np.testing instead There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but these are all simple set comparisons. would have to order then compare the matrices which seems convoluted. But if you think the same still I can do it |
||
|
||
|
||
def test_subset_sdata_by_table_mask(): | ||
sdata = concatenate( | ||
{ | ||
"labels": blobs_annotating_element("blobs_labels"), | ||
"shapes": blobs_annotating_element("blobs_circles"), | ||
"points": blobs_annotating_element("blobs_points"), | ||
"multiscale_labels": blobs_annotating_element("blobs_multiscale_labels"), | ||
}, | ||
concatenate_tables=True, | ||
) | ||
third_elems = sdata.tables["table"].obs["instance_id"] == 3 | ||
subset_sdata = subset_sdata_by_table_mask(sdata, "table", third_elems) | ||
|
||
assert set(subset_sdata.labels.keys()) == {"blobs_labels-labels", "blobs_multiscale_labels-multiscale_labels"} | ||
assert set(subset_sdata.points.keys()) == {"blobs_points-points"} | ||
assert set(subset_sdata.shapes.keys()) == {"blobs_circles-shapes"} | ||
|
||
labels_remaining_ids = set(np.unique(subset_sdata.labels["blobs_labels-labels"].data.compute())) - {0} | ||
assert labels_remaining_ids == {3} | ||
|
||
for scale in subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"]: | ||
ms_labels_remaining_ids = set( | ||
np.unique(subset_sdata.labels["blobs_multiscale_labels-multiscale_labels"][scale].image.compute()) | ||
) - {0} | ||
assert ms_labels_remaining_ids == {3} | ||
|
||
points_remaining_ids = set(np.unique(subset_sdata.points["blobs_points-points"]["instance_id"].compute())) - {0} | ||
assert points_remaining_ids == {3} | ||
|
||
shapes_remaining_ids = set(np.unique(subset_sdata.shapes["blobs_circles-shapes"].index)) - {0} | ||
assert shapes_remaining_ids == {3} | ||
|
||
|
||
def test_subset_sdata_by_table_mask_with_no_annotated_elements(): | ||
with pytest.raises(ValueError, match="Table table_not_found not found in SpatialData object."): | ||
sdata = blobs_annotating_element("blobs_labels") | ||
_ = subset_sdata_by_table_mask(sdata, "table_not_found", sdata.tables["table"].obs["instance_id"] == 3) | ||
|
||
|
||
def test_filter_by_instance_ids_fails_for_unsupported_element_models(): | ||
with pytest.raises(NotImplementedError, match="Filtering by instance ids is not implemented for"): | ||
_filter_by_instance_ids([1, 1, 1, 2], [1], "instance_id") |
Uh oh!
There was an error while loading. Please reload this page.