diff --git a/docs/api/operations.md b/docs/api/operations.md index 3eb2a5a6c..4e7a902b7 100644 --- a/docs/api/operations.md +++ b/docs/api/operations.md @@ -15,6 +15,7 @@ Operations on `SpatialData` objects. .. autofunction:: match_element_to_table .. autofunction:: match_table_to_element .. autofunction:: match_sdata_to_table +.. autofunction:: filter_by_table_query .. autofunction:: concatenate .. autofunction:: transform .. autofunction:: rasterize diff --git a/docs/conf.py b/docs/conf.py index 6efe4c54a..4c9ff4dce 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -103,6 +103,7 @@ "datatree": ("https://datatree.readthedocs.io/en/latest/", None), "dask": ("https://docs.dask.org/en/latest/", None), "shapely": ("https://shapely.readthedocs.io/en/stable", None), + "annsel": ("https://annsel.readthedocs.io/en/latest/", None), } diff --git a/pyproject.toml b/pyproject.toml index 06f5c95a8..7467c22e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ license = {file = "LICENSE"} readme = "README.md" dependencies = [ "anndata>=0.9.1", + "annsel>=0.1.1", "click", "dask-image", "dask>=2024.4.1,<=2024.11.2", diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 0b68391ad..45ba7058f 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -41,6 +41,7 @@ "match_element_to_table", "match_table_to_element", "match_sdata_to_table", + "filter_by_table_query", "SpatialData", "get_extent", "get_centroids", @@ -71,6 +72,7 @@ from spatialdata._core.operations.vectorize import to_circles, to_polygons from spatialdata._core.query._utils import get_bounding_box_corners from spatialdata._core.query.relational_query import ( + filter_by_table_query, get_element_annotators, get_element_instances, get_values, diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index b84d43c1b..68262960d 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -12,6 +12,7 @@ import numpy as np import pandas as pd from anndata import AnnData +from annsel.core.typing import Predicates from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame from xarray import DataArray, DataTree @@ -823,6 +824,67 @@ def match_sdata_to_table( return SpatialData.init_from_elements(filtered_elements | {table_name: filtered_table}) +def filter_by_table_query( + sdata: SpatialData, + table_name: str, + filter_tables: bool = True, + element_names: list[str] | None = None, + obs_expr: Predicates | None = None, + var_expr: Predicates | None = None, + x_expr: Predicates | None = None, + obs_names_expr: Predicates | None = None, + var_names_expr: Predicates | None = None, + layer: str | None = None, + how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", +) -> SpatialData: + """Filter the SpatialData object based on a set of table queries. (:class:`anndata.AnnData`. + + Parameters + ---------- + sdata: + The SpatialData object to filter. + table_name + The name of the table to filter the SpatialData object by. + filter_tables + If True (default), the table is filtered to only contain rows that are annotating regions + contained within the element_names. + element_names + The names of the elements to filter the SpatialData object by. + obs_expr + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs` by. + var_expr + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var` by. + x_expr + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.X` by. + obs_names_expr + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs_names` by. + var_names_expr + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var_names` by. + layer + The layer of the :class:`anndata.AnnData` to filter the SpatialData object by, only used with `x_expr`. + how + The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right". + + Returns + ------- + The filtered SpatialData object. + + Notes + ----- + You can also use :func:`spatialdata.SpatialData.filter_by_table_query` with the convenience that `sdata` is the + current SpatialData object. + """ + sdata_subset: SpatialData = ( + sdata.subset(element_names=element_names, filter_tables=filter_tables) if element_names else sdata + ) + + filtered_table: AnnData = sdata_subset.tables[table_name].an.filter( + obs=obs_expr, var=var_expr, x=x_expr, obs_names=obs_names_expr, var_names=var_names_expr, layer=layer + ) + + return match_sdata_to_table(sdata=sdata_subset, table_name=table_name, table=filtered_table, how=how) + + @dataclass class _ValueOrigin: origin: str diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 48f6386ca..6c63687ec 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -12,6 +12,7 @@ import pandas as pd import zarr from anndata import AnnData +from annsel.core.typing import Predicates from dask.dataframe import DataFrame as DaskDataFrame from dask.dataframe import read_parquet from dask.delayed import Delayed @@ -2455,6 +2456,69 @@ def attrs(self, value: Mapping[Any, Any]) -> None: else: self._attrs = dict(value) + def filter_by_table_query( + self, + table_name: str, + filter_tables: bool = True, + element_names: list[str] | None = None, + obs_expr: Predicates | None = None, + var_expr: Predicates | None = None, + x_expr: Predicates | None = None, + obs_names_expr: Predicates | None = None, + var_names_expr: Predicates | None = None, + layer: str | None = None, + how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right", + ) -> SpatialData: + """Filter the SpatialData object based on a set of table queries. (:class:`anndata.AnnData`. + + Parameters + ---------- + table_name + The name of the table to filter the SpatialData object by. + filter_tables + If True (default), the table is filtered to only contain rows that are annotating regions + contained within the element_names. + element_names + The names of the elements to filter the SpatialData object by. + obs_expr + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs` by. + var_expr + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var` by. + x_expr + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.X` by. + obs_names_expr + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs_names` by. + var_names_expr + A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var_names` by. + layer + The layer of the :class:`anndata.AnnData` to filter the SpatialData object by, only used with `x_expr`. + how + The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right". + + Returns + ------- + The filtered SpatialData object. + + Notes + ----- + You can also use :func:`query.relational_query.filter_by_table_query`. + """ + from spatialdata._core.query.relational_query import filter_by_table_query + + return filter_by_table_query( + self, + table_name, + filter_tables, + element_names, + obs_expr, + var_expr, + x_expr, + obs_names_expr, + var_names_expr, + layer, + how, + ) + class QueryManager: """Perform queries on SpatialData objects.""" diff --git a/tests/conftest.py b/tests/conftest.py index 211cd312f..3cd7b74ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -494,3 +494,139 @@ def adata_labels() -> AnnData: "tensor_copy": rng.integers(0, blobs.shape[0], size=(n_obs_labels, 2)), } return generate_adata(n_var, obs_labels, obsm_labels, uns_labels) + + +@pytest.fixture() +def complex_sdata() -> SpatialData: + """ + Create a complex SpatialData object with multiple data types for comprehensive testing. + + Contains: + - Images (2D and 3D) + - Labels (2D and 3D) + - Shapes (polygons and circles) + - Points + - Multiple tables with different annotations + - Categorical and numerical values in both obs and var + + Returns + ------- + SpatialData + A complex SpatialData object for testing. + """ + # Get basic components using existing functions + images = _get_images() + labels = _get_labels() + shapes = _get_shapes() + points = _get_points() + + # Create tables with enhanced var data + n_var = 10 + + # Table 1: Basic table annotating labels2d + obs1 = pd.DataFrame( + { + "region": pd.Categorical(["labels2d"] * 50), + "instance_id": range(1, 51), # Skip background (0) + "cell_type": pd.Categorical(RNG.choice(["T cell", "B cell", "Macrophage"], size=50)), + "size": RNG.uniform(10, 100, size=50), + } + ) + + var1 = pd.DataFrame( + { + "feature_type": pd.Categorical(["gene", "protein", "gene", "protein", "gene"] * 2), + "importance": RNG.uniform(0, 10, size=n_var), + "is_marker": RNG.choice([True, False], size=n_var), + }, + index=[f"feature_{i}" for i in range(n_var)], + ) + + X1 = RNG.normal(size=(50, n_var)) + uns1 = { + "spatialdata_attrs": { + "region": "labels2d", + "region_key": "region", + "instance_key": "instance_id", + } + } + + table1 = AnnData(X=X1, obs=obs1, var=var1, uns=uns1) + + # Table 2: Annotating both polygons and circles from shapes + n_polygons = len(shapes["poly"]) + n_circles = len(shapes["circles"]) + total_items = n_polygons + n_circles + + obs2 = pd.DataFrame( + { + "region": pd.Categorical(["poly"] * n_polygons + ["circles"] * n_circles), + "instance_id": np.concatenate([range(n_polygons), range(n_circles)]), + "category": pd.Categorical(RNG.choice(["A", "B", "C"], size=total_items)), + "value": RNG.normal(size=total_items), + "count": RNG.poisson(10, size=total_items), + } + ) + + var2 = pd.DataFrame( + { + "feature_type": pd.Categorical( + ["feature_type1", "feature_type2", "feature_type1", "feature_type2", "feature_type1"] * 2 + ), + "score": RNG.exponential(2, size=n_var), + "detected": RNG.choice([True, False], p=[0.7, 0.3], size=n_var), + }, + index=[f"metric_{i}" for i in range(n_var)], + ) + + X2 = RNG.normal(size=(total_items, n_var)) + uns2 = { + "spatialdata_attrs": { + "region": ["poly", "circles"], + "region_key": "region", + "instance_key": "instance_id", + } + } + + table2 = AnnData(X=X2, obs=obs2, var=var2, uns=uns2) + + # Table 3: Orphan table not annotating any elements + obs3 = pd.DataFrame( + { + "cluster": pd.Categorical(RNG.choice(["cluster_1", "cluster_2", "cluster_3"], size=40)), + "sample": pd.Categorical(["sample_A"] * 20 + ["sample_B"] * 20), + "qc_pass": RNG.choice([True, False], p=[0.8, 0.2], size=40), + } + ) + + var3 = pd.DataFrame( + { + "feature_type": pd.Categorical(["gene", "protein", "gene", "protein", "gene"] * 2), + "mean_expression": RNG.uniform(0, 20, size=n_var), + "variance": RNG.gamma(2, 2, size=n_var), + }, + index=[f"feature_{i}" for i in range(n_var)], + ) + + X3 = RNG.normal(size=(40, n_var)) + table3 = AnnData(X=X3, obs=obs3, var=var3) + + # Create additional coordinate system in one of the shapes for testing + # Modified copy of circles with an additional coordinate system + circles_alt_coords = shapes["circles"].copy() + circles_alt_coords["coordinate_system"] = "alt_system" + + # Add everything to a SpatialData object + sdata = SpatialData( + images=images, + labels=labels, + shapes={**shapes, "circles_alt_coords": circles_alt_coords}, + points=points, + tables={"labels_table": table1, "shapes_table": table2, "orphan_table": table3}, + ) + + # Add layers to tables for testing layer-specific operations + sdata.tables["labels_table"].layers["scaled"] = sdata.tables["labels_table"].X * 2 + sdata.tables["labels_table"].layers["log"] = np.log1p(np.abs(sdata.tables["labels_table"].X)) + + return sdata diff --git a/tests/core/query/test_relational_query.py b/tests/core/query/test_relational_query.py index f0b4da7e0..6c1b81cab 100644 --- a/tests/core/query/test_relational_query.py +++ b/tests/core/query/test_relational_query.py @@ -1,3 +1,4 @@ +import annsel as an import numpy as np import pandas as pd import pytest @@ -7,6 +8,7 @@ from spatialdata._core.query.relational_query import ( _locate_value, _ValueOrigin, + filter_by_table_query, get_element_annotators, join_spatialelement_table, ) @@ -1052,3 +1054,199 @@ def test_get_element_annotators(full_sdata): full_sdata.tables["another_table"] = another_table names = get_element_annotators(full_sdata, "labels2d") assert names == {"another_table", "table"} + + +def test_filter_by_table_query(complex_sdata): + """Test basic filtering functionality of filter_by_table_query.""" + sdata = complex_sdata + + # Test 1: Basic filtering on categorical obs column + result = filter_by_table_query(sdata=sdata, table_name="labels_table", obs_expr=an.col("cell_type") == "T cell") + + # Check that the table was filtered properly + assert all(result["labels_table"].obs["cell_type"] == "T cell") + # Check that result has fewer rows than original + assert result["labels_table"].n_obs < sdata["labels_table"].n_obs + # Check that labels2d element is still present + assert "labels2d" in result.labels + + # Test 2: Filtering on numerical obs column + result = filter_by_table_query(sdata=sdata, table_name="labels_table", obs_expr=an.col("size") > 50) + + # Check that the table was filtered properly + assert all(result["labels_table"].obs["size"] > 50) + # Check that labels2d element is still present + assert "labels2d" in result.labels + + # Test 3: Filtering with var expressions + result = filter_by_table_query( + sdata=sdata, table_name="shapes_table", var_expr=an.col("feature_type") == "feature_type1" + ) + + # Check that the filtered var dataframe only has 'spatial' feature_type + assert all(result["shapes_table"].var["feature_type"] == "feature_type1") + # Check that the filtered var dataframe has fewer rows than the original + assert result["shapes_table"].n_vars < sdata["shapes_table"].n_vars + + # Test 4: Multiple filtering conditions (obs and var) + result = filter_by_table_query( + sdata=sdata, table_name="shapes_table", obs_expr=an.col("category") == "A", var_expr=an.col("score") > 2 + ) + + # Check that both filters were applied + assert all(result["shapes_table"].obs["category"] == "A") + assert all(result["shapes_table"].var["score"] > 2) + + # Test 5: Using X expressions + result = filter_by_table_query(sdata=sdata, table_name="labels_table", x_expr=an.col("feature_1") > 0.5) + + # Check that the filter was applied to X + assert np.all(result["labels_table"][:, "feature_1"].X > 0.5) + + # Test 6: Using different join types + # Test with inner join + result = filter_by_table_query( + sdata=sdata, table_name="shapes_table", obs_expr=an.col("category") == "A", how="inner" + ) + + # The elements should be filtered to only include instances in the table + assert "poly" in result.shapes + assert "circles" in result.shapes + + # Test with left join + result = filter_by_table_query( + sdata=sdata, table_name="shapes_table", obs_expr=an.col("category") == "A", how="left" + ) + + # Elements should be preserved but table should be filtered + assert "poly" in result.shapes + assert "circles" in result.shapes + assert all(result["shapes_table"].obs["category"] == "A") + + # Test 7: Filtering with specific element_names + result = filter_by_table_query( + sdata=sdata, + table_name="shapes_table", + element_names=["poly"], # Only include poly, not circles + obs_expr=an.col("category") == "A", + ) + + # Only specified elements should be in the result + assert "poly" in result.shapes + assert "circles" not in result.shapes + + # Test 8: Testing orphan table handling + # First test with include_orphan_tables=False (default) + result = filter_by_table_query( + sdata=sdata, + table_name="shapes_table", + obs_expr=an.col("category") == "A", + filter_tables=True, + ) + + # Orphan table should not be in the result + assert "orphan_table" not in result.tables + + +def test_filter_by_table_query_with_layers(complex_sdata): + """Test filtering using different layers.""" + sdata = complex_sdata + + # Test filtering using a specific layer + result = filter_by_table_query( + sdata=sdata, + table_name="labels_table", + x_expr=an.col("feature_1") > 1.0, + layer="scaled", # The 'scaled' layer has values 2x the original X + ) + + # Values in the scaled layer's feature_1 column should be > 1.0 + assert np.all(result["labels_table"][:, "feature_1"].layers["scaled"] > 1.0) + + +def test_filter_by_table_query_edge_cases(complex_sdata): + """Test edge cases for filter_by_table_query.""" + sdata = complex_sdata + + # Test 1: Filter by obs_names + result = filter_by_table_query( + sdata=sdata, + table_name="shapes_table", + obs_names_expr=an.obs_names.str.starts_with("0"), # Only rows with index starting with '0' + ) + + # Check that filtered table only has obs names starting with '0' + assert all(str(idx).startswith("0") for idx in result["shapes_table"].obs_names) + + # Test 2: Invalid table name raises KeyError + with pytest.raises(KeyError, match="nonexistent_table"): + filter_by_table_query(sdata=sdata, table_name="nonexistent_table", obs_expr=an.col("category") == "A") + + # Test 3: Invalid column name in expression + with pytest.raises(KeyError): # The exact exception type may vary + filter_by_table_query(sdata=sdata, table_name="shapes_table", obs_expr=an.col("nonexistent_column") == "A") + + # Test 4: Using layer that doesn't exist + with pytest.raises(KeyError): + filter_by_table_query( + sdata=sdata, table_name="labels_table", x_expr=an.col("feature_1") > 0.5, layer="nonexistent_layer" + ) + + # Test 5: Filter by var_names + result = filter_by_table_query( + sdata=sdata, + table_name="labels_table", + var_names_expr=an.var_names.str.contains("feature_[0-4]"), # Only features 0-4 + ) + + # Check that filtered table only has var names matching the pattern + for idx in result["labels_table"].var_names: + var_name = str(idx) + assert var_name.startswith("feature_") and int(var_name.split("_")[1]) < 5 + + # Test 6: Invalid element_names (element doesn't exist) + with pytest.raises(AssertionError, match="elements_dict must not be empty"): + filter_by_table_query( + sdata=sdata, + table_name="shapes_table", + element_names=["nonexistent_element"], + obs_expr=an.col("category") == "A", + ) + + # Test 7: Invalid join type raises ValueError + with pytest.raises(TypeError, match="not a valid type of join."): + filter_by_table_query( + sdata=sdata, + table_name="shapes_table", + how="invalid_join_type", # Invalid join type + obs_expr=an.col("category") == "A", + ) + + +def test_filter_by_table_query_complex_combination(complex_sdata): + """Test complex combinations of filters.""" + sdata = complex_sdata + + # Combine multiple filtering criteria + result = sdata.filter_by_table_query( + table_name="shapes_table", + obs_expr=(an.col("category") == "A", an.col("value") > 0), + var_expr=an.col("feature_type") == "feature_type1", + how="inner", + ) + + # Validate the combined filtering results + assert all(result["shapes_table"].obs["category"] == "A") + assert all(result["shapes_table"].obs["value"] > 0) + assert all(result["shapes_table"].var["feature_type"] == "feature_type1") + + # Both elements should be present but filtered + assert "circles" in result.shapes + + # The filtered shapes should only contain the instances from the filtered table + table_instance_ids = set( + zip(result["shapes_table"].obs["region"], result["shapes_table"].obs["instance_id"], strict=True) + ) + if "circles" in result.shapes: + for idx in result["circles"].index: + assert ("circles", idx) in table_instance_ids