Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/operations.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Operations on `SpatialData` objects.
.. autofunction:: match_element_to_table
.. autofunction:: match_table_to_element
.. autofunction:: match_sdata_to_table
.. autofunction:: filter_by_table_query
.. autofunction:: concatenate
.. autofunction:: transform
.. autofunction:: rasterize
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
"datatree": ("https://datatree.readthedocs.io/en/latest/", None),
"dask": ("https://docs.dask.org/en/latest/", None),
"shapely": ("https://shapely.readthedocs.io/en/stable", None),
"annsel": ("https://annsel.readthedocs.io/en/latest/", None),
}


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ license = {file = "LICENSE"}
readme = "README.md"
dependencies = [
"anndata>=0.9.1",
"annsel>=0.1.1",
"click",
"dask-image",
"dask>=2024.4.1,<=2024.11.2",
Expand Down
2 changes: 2 additions & 0 deletions src/spatialdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"match_element_to_table",
"match_table_to_element",
"match_sdata_to_table",
"filter_by_table_query",
"SpatialData",
"get_extent",
"get_centroids",
Expand Down Expand Up @@ -71,6 +72,7 @@
from spatialdata._core.operations.vectorize import to_circles, to_polygons
from spatialdata._core.query._utils import get_bounding_box_corners
from spatialdata._core.query.relational_query import (
filter_by_table_query,
get_element_annotators,
get_element_instances,
get_values,
Expand Down
62 changes: 62 additions & 0 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import pandas as pd
from anndata import AnnData
from annsel.core.typing import Predicates
from dask.dataframe import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
from xarray import DataArray, DataTree
Expand Down Expand Up @@ -823,6 +824,67 @@ def match_sdata_to_table(
return SpatialData.init_from_elements(filtered_elements | {table_name: filtered_table})


def filter_by_table_query(
sdata: SpatialData,
table_name: str,
filter_tables: bool = True,
element_names: list[str] | None = None,
obs_expr: Predicates | None = None,
var_expr: Predicates | None = None,
x_expr: Predicates | None = None,
obs_names_expr: Predicates | None = None,
var_names_expr: Predicates | None = None,
layer: str | None = None,
how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right",
) -> SpatialData:
"""Filter the SpatialData object based on a set of table queries. (:class:`anndata.AnnData`.

Parameters
----------
sdata:
The SpatialData object to filter.
table_name
The name of the table to filter the SpatialData object by.
filter_tables
If True (default), the table is filtered to only contain rows that are annotating regions
contained within the element_names.
element_names
The names of the elements to filter the SpatialData object by.
obs_expr
A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs` by.
var_expr
A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var` by.
x_expr
A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.X` by.
obs_names_expr
A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs_names` by.
var_names_expr
A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var_names` by.
layer
The layer of the :class:`anndata.AnnData` to filter the SpatialData object by, only used with `x_expr`.
how
The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right".

Returns
-------
The filtered SpatialData object.

Notes
-----
You can also use :func:`spatialdata.SpatialData.filter_by_table_query` with the convenience that `sdata` is the
current SpatialData object.
"""
sdata_subset: SpatialData = (
sdata.subset(element_names=element_names, filter_tables=filter_tables) if element_names else sdata
)

filtered_table: AnnData = sdata_subset.tables[table_name].an.filter(
obs=obs_expr, var=var_expr, x=x_expr, obs_names=obs_names_expr, var_names=var_names_expr, layer=layer
)

return match_sdata_to_table(sdata=sdata_subset, table_name=table_name, table=filtered_table, how=how)


@dataclass
class _ValueOrigin:
origin: str
Expand Down
64 changes: 64 additions & 0 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pandas as pd
import zarr
from anndata import AnnData
from annsel.core.typing import Predicates
from dask.dataframe import DataFrame as DaskDataFrame
from dask.dataframe import read_parquet
from dask.delayed import Delayed
Expand Down Expand Up @@ -2455,6 +2456,69 @@ def attrs(self, value: Mapping[Any, Any]) -> None:
else:
self._attrs = dict(value)

def filter_by_table_query(
self,
table_name: str,
filter_tables: bool = True,
element_names: list[str] | None = None,
obs_expr: Predicates | None = None,
var_expr: Predicates | None = None,
x_expr: Predicates | None = None,
obs_names_expr: Predicates | None = None,
var_names_expr: Predicates | None = None,
layer: str | None = None,
how: Literal["left", "left_exclusive", "inner", "right", "right_exclusive"] = "right",
) -> SpatialData:
"""Filter the SpatialData object based on a set of table queries. (:class:`anndata.AnnData`.

Parameters
----------
table_name
The name of the table to filter the SpatialData object by.
filter_tables
If True (default), the table is filtered to only contain rows that are annotating regions
contained within the element_names.
element_names
The names of the elements to filter the SpatialData object by.
obs_expr
A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs` by.
var_expr
A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var` by.
x_expr
A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.X` by.
obs_names_expr
A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.obs_names` by.
var_names_expr
A Predicate or an iterable of Predicates to filter :attr:`anndata.AnnData.var_names` by.
layer
The layer of the :class:`anndata.AnnData` to filter the SpatialData object by, only used with `x_expr`.
how
The type of join to perform. See :func:`spatialdata.join_spatialelement_table`. Default is "right".

Returns
-------
The filtered SpatialData object.

Notes
-----
You can also use :func:`query.relational_query.filter_by_table_query`.
"""
from spatialdata._core.query.relational_query import filter_by_table_query

return filter_by_table_query(
self,
table_name,
filter_tables,
element_names,
obs_expr,
var_expr,
x_expr,
obs_names_expr,
var_names_expr,
layer,
how,
)


class QueryManager:
"""Perform queries on SpatialData objects."""
Expand Down
136 changes: 136 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,3 +494,139 @@ def adata_labels() -> AnnData:
"tensor_copy": rng.integers(0, blobs.shape[0], size=(n_obs_labels, 2)),
}
return generate_adata(n_var, obs_labels, obsm_labels, uns_labels)


@pytest.fixture()
def complex_sdata() -> SpatialData:
"""
Create a complex SpatialData object with multiple data types for comprehensive testing.

Contains:
- Images (2D and 3D)
- Labels (2D and 3D)
- Shapes (polygons and circles)
- Points
- Multiple tables with different annotations
- Categorical and numerical values in both obs and var

Returns
-------
SpatialData
A complex SpatialData object for testing.
"""
# Get basic components using existing functions
images = _get_images()
labels = _get_labels()
shapes = _get_shapes()
points = _get_points()

# Create tables with enhanced var data
n_var = 10

# Table 1: Basic table annotating labels2d
obs1 = pd.DataFrame(
{
"region": pd.Categorical(["labels2d"] * 50),
"instance_id": range(1, 51), # Skip background (0)
"cell_type": pd.Categorical(RNG.choice(["T cell", "B cell", "Macrophage"], size=50)),
"size": RNG.uniform(10, 100, size=50),
}
)

var1 = pd.DataFrame(
{
"feature_type": pd.Categorical(["gene", "protein", "gene", "protein", "gene"] * 2),
"importance": RNG.uniform(0, 10, size=n_var),
"is_marker": RNG.choice([True, False], size=n_var),
},
index=[f"feature_{i}" for i in range(n_var)],
)

X1 = RNG.normal(size=(50, n_var))
uns1 = {
"spatialdata_attrs": {
"region": "labels2d",
"region_key": "region",
"instance_key": "instance_id",
}
}

table1 = AnnData(X=X1, obs=obs1, var=var1, uns=uns1)

# Table 2: Annotating both polygons and circles from shapes
n_polygons = len(shapes["poly"])
n_circles = len(shapes["circles"])
total_items = n_polygons + n_circles

obs2 = pd.DataFrame(
{
"region": pd.Categorical(["poly"] * n_polygons + ["circles"] * n_circles),
"instance_id": np.concatenate([range(n_polygons), range(n_circles)]),
"category": pd.Categorical(RNG.choice(["A", "B", "C"], size=total_items)),
"value": RNG.normal(size=total_items),
"count": RNG.poisson(10, size=total_items),
}
)

var2 = pd.DataFrame(
{
"feature_type": pd.Categorical(
["feature_type1", "feature_type2", "feature_type1", "feature_type2", "feature_type1"] * 2
),
"score": RNG.exponential(2, size=n_var),
"detected": RNG.choice([True, False], p=[0.7, 0.3], size=n_var),
},
index=[f"metric_{i}" for i in range(n_var)],
)

X2 = RNG.normal(size=(total_items, n_var))
uns2 = {
"spatialdata_attrs": {
"region": ["poly", "circles"],
"region_key": "region",
"instance_key": "instance_id",
}
}

table2 = AnnData(X=X2, obs=obs2, var=var2, uns=uns2)

# Table 3: Orphan table not annotating any elements
obs3 = pd.DataFrame(
{
"cluster": pd.Categorical(RNG.choice(["cluster_1", "cluster_2", "cluster_3"], size=40)),
"sample": pd.Categorical(["sample_A"] * 20 + ["sample_B"] * 20),
"qc_pass": RNG.choice([True, False], p=[0.8, 0.2], size=40),
}
)

var3 = pd.DataFrame(
{
"feature_type": pd.Categorical(["gene", "protein", "gene", "protein", "gene"] * 2),
"mean_expression": RNG.uniform(0, 20, size=n_var),
"variance": RNG.gamma(2, 2, size=n_var),
},
index=[f"feature_{i}" for i in range(n_var)],
)

X3 = RNG.normal(size=(40, n_var))
table3 = AnnData(X=X3, obs=obs3, var=var3)

# Create additional coordinate system in one of the shapes for testing
# Modified copy of circles with an additional coordinate system
circles_alt_coords = shapes["circles"].copy()
circles_alt_coords["coordinate_system"] = "alt_system"

# Add everything to a SpatialData object
sdata = SpatialData(
images=images,
labels=labels,
shapes={**shapes, "circles_alt_coords": circles_alt_coords},
points=points,
tables={"labels_table": table1, "shapes_table": table2, "orphan_table": table3},
)

# Add layers to tables for testing layer-specific operations
sdata.tables["labels_table"].layers["scaled"] = sdata.tables["labels_table"].X * 2
sdata.tables["labels_table"].layers["log"] = np.log1p(np.abs(sdata.tables["labels_table"].X))

return sdata
Loading
Loading