Skip to content

Commit 39f0954

Browse files
authored
compute graph base on spatialdata regions and add mask graph based on polygon (#842)
* add tests * pin up spatialdata * fix test * add test for spatialdata graph * address comments, pass shapely polygon and not geodataframe to mask_graph * add note * add release note and add mask_graph to API * add shapely to intersphinx
1 parent 4a632d6 commit 39f0954

File tree

8 files changed

+438
-16
lines changed

8 files changed

+438
-16
lines changed

docs/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ Graph
1414
:toctree: api
1515

1616
gr.spatial_neighbors
17+
gr.mask_graph
1718
gr.nhood_enrichment
1819
gr.co_occurrence
1920
gr.centrality_scores

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
omnipath=("https://omnipath.readthedocs.io/en/latest", None),
9595
napari=("https://napari.org/", None),
9696
spatialdata=("https://spatialdata.scverse.org/en/latest", None),
97+
shapely=("https://shapely.readthedocs.io/en/stable", None),
9798
)
9899

99100
# Add any paths that contain templates here, relative to this directory.

docs/release/notes-dev.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,9 @@
11
Squidpy dev (the-future)
22
========================
3+
4+
Features
5+
--------
6+
- Now :func:`squidpy.gr.spatial_graph` can also be used on :class:`spatialdata.SpatialData` objects.
7+
- Add :func:`squidpy.gr.mask_graph` to mask a spatial graph based on :class:`shapely.Polygon` or :class:`shapely.MultiPolygon`
8+
`@giovp <https://github.com/giovp>`__
9+
`#842 <https://github.com/scverse/squidpy/pull/842>`__

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ dependencies = [
7272
"validators>=0.18.2",
7373
"xarray>=0.16.1",
7474
"zarr>=2.6.1",
75-
"spatialdata",
75+
"spatialdata>=0.2.0",
7676
]
7777

7878
[project.optional-dependencies]

src/squidpy/gr/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from squidpy.gr._build import spatial_neighbors
5+
from squidpy.gr._build import mask_graph, spatial_neighbors
66
from squidpy.gr._ligrec import ligrec
77
from squidpy.gr._nhood import centrality_scores, interaction_matrix, nhood_enrichment
88
from squidpy.gr._ppatterns import co_occurrence, spatial_autocorr

src/squidpy/gr/_build.py

Lines changed: 242 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,14 @@
66
from collections.abc import Iterable # noqa: F401
77
from functools import partial
88
from itertools import chain
9+
from typing import Any
910

11+
import geopandas as gpd
1012
import numpy as np
13+
import pandas as pd
1114
from anndata import AnnData
1215
from anndata.utils import make_index_unique
16+
from geopandas import GeoDataFrame
1317
from numba import njit
1418
from scanpy import logging as logg
1519
from scipy.sparse import (
@@ -20,9 +24,28 @@
2024
spmatrix,
2125
)
2226
from scipy.spatial import Delaunay
27+
from shapely import LineString, MultiPolygon, Point, Polygon, distance
2328
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
2429
from sklearn.neighbors import NearestNeighbors
2530
from spatialdata import SpatialData
31+
from spatialdata._core.centroids import get_centroids
32+
from spatialdata._core.query.relational_query import (
33+
get_element_instances,
34+
match_element_to_table,
35+
)
36+
from spatialdata.models import SpatialElement, get_table_keys
37+
from spatialdata.models.models import (
38+
Image2DModel,
39+
Image3DModel,
40+
Labels2DModel,
41+
Labels3DModel,
42+
PointsModel,
43+
RasterSchema,
44+
ShapesModel,
45+
TableModel,
46+
get_axes_names,
47+
get_model,
48+
)
2649

2750
from squidpy._constants._constants import CoordType, Transform
2851
from squidpy._constants._pkg_constants import Key
@@ -43,6 +66,8 @@
4366
def spatial_neighbors(
4467
adata: AnnData | SpatialData,
4568
spatial_key: str = Key.obsm.spatial,
69+
elements_to_coordinate_systems: dict[str, str] | None = None,
70+
table_key: str | None = None,
4671
library_key: str | None = None,
4772
coord_type: str | CoordType | None = None,
4873
n_neighs: int = 6,
@@ -62,6 +87,17 @@ def spatial_neighbors(
6287
----------
6388
%(adata)s
6489
%(spatial_key)s
90+
If `adata` is a :class:`spatialdata.SpatialData`, the coordinates of the centroids will be stored in the
91+
`adata` with this key.
92+
elements_to_coordinate_systems
93+
A dictionary mapping element names of the SpatialData object to coordinate systems.
94+
The elements can be either Shapes or Labels. For compatibility, the spatialdata table must annotate
95+
all regions keys. Must not be `None` if `adata` is a :class:`spatialdata.SpatialData`.
96+
table_key
97+
Key in :attr:`spatialdata.SpatialData.tables` where the spatialdata table is stored. Must not be `None` if
98+
`adata` is a :class:`spatialdata.SpatialData`.
99+
mask_polygon
100+
The Polygon or MultiPolygon element.
65101
%(library_key)s
66102
coord_type
67103
Type of coordinate system. Valid options are:
@@ -109,7 +145,54 @@ def spatial_neighbors(
109145
- :attr:`anndata.AnnData.uns` ``['{{key_added}}']`` - :class:`dict` containing parameters.
110146
"""
111147
if isinstance(adata, SpatialData):
112-
adata = adata.table
148+
assert (
149+
elements_to_coordinate_systems is not None
150+
), "Since `adata` is a :class:`spatialdata.SpatialData`, `elements_to_coordinate_systems` must not be `None`."
151+
assert (
152+
table_key is not None
153+
), "Since `adata` is a :class:`spatialdata.SpatialData`, `table_key` must not be `None`."
154+
elements, table = match_element_to_table(adata, list(elements_to_coordinate_systems), table_key)
155+
assert table.obs_names.equals(
156+
adata.tables[table_key].obs_names
157+
), "The spatialdata table must annotate all elements keys. Some elements are missing, please check the `elements_to_coordinate_systems` dictionary."
158+
regions, region_key, instance_key = get_table_keys(adata.tables[table_key])
159+
regions = [regions] if isinstance(regions, str) else regions
160+
ordered_regions_in_table = adata.tables[table_key].obs[region_key].unique()
161+
162+
# TODO: remove this after https://github.com/scverse/spatialdata/issues/614
163+
remove_centroids = {}
164+
elem_instances = []
165+
for e in regions:
166+
schema = get_model(elements[e])
167+
element_instances = get_element_instances(elements[e]).to_series()
168+
if np.isin(0, element_instances.values) and (schema in (Labels2DModel, Labels3DModel)):
169+
element_instances = element_instances.drop(index=0)
170+
remove_centroids[e] = True
171+
else:
172+
remove_centroids[e] = False
173+
elem_instances.append(element_instances)
174+
175+
element_instances = pd.concat(elem_instances)
176+
if (not np.all(element_instances.values == adata.tables[table_key].obs[instance_key].values)) or (
177+
not np.all(ordered_regions_in_table == regions)
178+
):
179+
raise ValueError(
180+
"The spatialdata table must annotate all elements keys. Some elements are missing or not ordered correctly, please check the `elements_to_coordinate_systems` dictionary."
181+
)
182+
centroids = []
183+
for region_ in ordered_regions_in_table:
184+
cs = elements_to_coordinate_systems[region_]
185+
centroid = get_centroids(adata[region_], coordinate_system=cs)[["x", "y"]].compute()
186+
187+
# TODO: remove this after https://github.com/scverse/spatialdata/issues/614
188+
if remove_centroids[region_]:
189+
centroid = centroid[1:].copy()
190+
centroids.append(centroid)
191+
192+
adata.tables[table_key].obsm[spatial_key] = np.concatenate(centroids)
193+
adata = adata.tables[table_key]
194+
library_key = region_key
195+
113196
_assert_positive(n_rings, name="n_rings")
114197
_assert_positive(n_neighs, name="n_neighs")
115198
_assert_spatial_basis(adata, spatial_key)
@@ -167,7 +250,12 @@ def spatial_neighbors(
167250
neighbors_dict = {
168251
"connectivities_key": conns_key,
169252
"distances_key": dists_key,
170-
"params": {"n_neighbors": n_neighs, "coord_type": coord_type.v, "radius": radius, "transform": transform.v},
253+
"params": {
254+
"n_neighbors": n_neighs,
255+
"coord_type": coord_type.v,
256+
"radius": radius,
257+
"transform": transform.v,
258+
},
171259
}
172260

173261
if copy:
@@ -194,10 +282,21 @@ def _spatial_neighbor(
194282
with warnings.catch_warnings():
195283
warnings.simplefilter("ignore", SparseEfficiencyWarning)
196284
if coord_type == CoordType.GRID:
197-
Adj, Dst = _build_grid(coords, n_neighs=n_neighs, n_rings=n_rings, delaunay=delaunay, set_diag=set_diag)
285+
Adj, Dst = _build_grid(
286+
coords,
287+
n_neighs=n_neighs,
288+
n_rings=n_rings,
289+
delaunay=delaunay,
290+
set_diag=set_diag,
291+
)
198292
elif coord_type == CoordType.GENERIC:
199293
Adj, Dst = _build_connectivity(
200-
coords, n_neighs=n_neighs, radius=radius, delaunay=delaunay, return_distance=True, set_diag=set_diag
294+
coords,
295+
n_neighs=n_neighs,
296+
radius=radius,
297+
delaunay=delaunay,
298+
return_distance=True,
299+
set_diag=set_diag,
201300
)
202301
else:
203302
raise NotImplementedError(f"Coordinate type `{coord_type}` is not yet implemented.")
@@ -233,7 +332,11 @@ def _spatial_neighbor(
233332

234333

235334
def _build_grid(
236-
coords: NDArrayA, n_neighs: int, n_rings: int, delaunay: bool = False, set_diag: bool = False
335+
coords: NDArrayA,
336+
n_neighs: int,
337+
n_rings: int,
338+
delaunay: bool = False,
339+
set_diag: bool = False,
237340
) -> tuple[csr_matrix, csr_matrix]:
238341
if n_rings > 1:
239342
Adj: csr_matrix = _build_connectivity(
@@ -258,7 +361,13 @@ def _build_grid(
258361
Dst = Adj.copy()
259362
Adj.data[:] = 1.0
260363
else:
261-
Adj = _build_connectivity(coords, n_neighs=n_neighs, neigh_correct=True, delaunay=delaunay, set_diag=set_diag)
364+
Adj = _build_connectivity(
365+
coords,
366+
n_neighs=n_neighs,
367+
neigh_correct=True,
368+
delaunay=delaunay,
369+
set_diag=set_diag,
370+
)
262371
Dst = Adj.copy()
263372

264373
Dst.setdiag(0.0)
@@ -302,14 +411,21 @@ def _build_connectivity(
302411
if neigh_correct:
303412
dist_cutoff = np.median(dists) * 1.3 # there's a small amount of sway
304413
mask = dists < dist_cutoff
305-
row_indices, col_indices, dists = row_indices[mask], col_indices[mask], dists[mask]
414+
row_indices, col_indices, dists = (
415+
row_indices[mask],
416+
col_indices[mask],
417+
dists[mask],
418+
)
306419
else:
307420
dists, col_indices = tree.radius_neighbors()
308421
row_indices = np.repeat(np.arange(N), [len(x) for x in col_indices])
309422
dists = np.concatenate(dists)
310423
col_indices = np.concatenate(col_indices)
311424

312-
Adj = csr_matrix((np.ones_like(row_indices, dtype=np.float64), (row_indices, col_indices)), shape=(N, N))
425+
Adj = csr_matrix(
426+
(np.ones_like(row_indices, dtype=np.float64), (row_indices, col_indices)),
427+
shape=(N, N),
428+
)
313429
if return_distance:
314430
Dst = csr_matrix((dists, (row_indices, col_indices)), shape=(N, N))
315431

@@ -349,3 +465,121 @@ def _transform_a_spectral(a: spmatrix) -> spmatrix:
349465

350466
def _transform_a_cosine(a: spmatrix) -> spmatrix:
351467
return cosine_similarity(a, dense_output=False)
468+
469+
470+
@d.dedent
471+
def mask_graph(
472+
sdata: SpatialData,
473+
table_key: str,
474+
polygon_mask: Polygon | MultiPolygon,
475+
negative_mask: bool = False,
476+
spatial_key: str = Key.obsm.spatial,
477+
key_added: str = "mask",
478+
copy: bool = False,
479+
) -> SpatialData:
480+
"""
481+
Mask the graph based on a polygon mask.
482+
483+
Given a spatial graph stored in :attr:`anndata.AnnData.obsp` ``['{{key_added}}_{{spatial_key}}_connectivities']`` and spatial coordinates stored in :attr:`anndata.AnnData.obsp` ``['{{spatial_key}}']``, it maskes the graph so that only edges fully contained in the polygons are kept.
484+
485+
Parameters
486+
----------
487+
sdata
488+
The spatial data object.
489+
table_key:
490+
The key of the table containing the spatial data.
491+
polygon_mask
492+
The :class:`shapely.Polygon` or :class:`shapely.MultiPolygon` to be used as mask.
493+
negative_mask
494+
Whether to keep the edges within the polygon mask or outside.
495+
Note that when ``negative_mask = True``, only the edges fully contained in the polygon are removed.
496+
If edges are partially contained in the polygon, they are kept.
497+
%(spatial_key)s
498+
key_added
499+
Key which controls where the results are saved if ``copy = False``.
500+
%(copy)s
501+
502+
Returns
503+
-------
504+
If ``copy = True``, returns a :class:`tuple` with the masked spatial connectivities and masked distances matrices.
505+
506+
Otherwise, modifies the ``adata`` with the following keys:
507+
508+
- :attr:`anndata.AnnData.obsp` ``['{{key_added}}_{{spatial_key}}_connectivities']`` - the spatial connectivities.
509+
- :attr:`anndata.AnnData.obsp` ``['{{key_added}}_{{spatial_key}}_distances']`` - the spatial distances.
510+
- :attr:`anndata.AnnData.uns` ``['{{key_added}}_{{spatial_key}}']`` - :class:`dict` containing parameters.
511+
512+
Notes
513+
-----
514+
The `polygon_mask` must be in the same `coordinate_systems` of the spatial graph, but no check is performed to assess this.
515+
"""
516+
# we could add this to arg, but I don't see use case for now
517+
neighs_key = Key.uns.spatial_neighs(spatial_key)
518+
conns_key = Key.obsp.spatial_conn(spatial_key)
519+
dists_key = Key.obsp.spatial_dist(spatial_key)
520+
521+
# check polygon type
522+
if not isinstance(polygon_mask, (Polygon, MultiPolygon)):
523+
raise ValueError(f"`polygon_mask` should be of type `Polygon` or `MultiPolygon`, got {type(polygon_mask)}")
524+
525+
# get elements
526+
table = sdata.tables[table_key]
527+
coords = table.obsm[spatial_key]
528+
Adj = table.obsp[conns_key]
529+
Dst = table.obsp[dists_key]
530+
531+
# convert edges to lines
532+
lines_coords, idx_out = _get_lines_coords(Adj.indices, Adj.indptr, coords)
533+
lines_coords, idx_out = np.array(lines_coords), np.array(idx_out)
534+
lines_df = gpd.GeoDataFrame(geometry=list(map(LineString, lines_coords)))
535+
536+
# check that lines overlap with the polygon
537+
filt_lines = lines_df.geometry.within(polygon_mask).values
538+
539+
# ~ within index, and set that to 0
540+
if not negative_mask:
541+
# keep only the lines that are within the polygon
542+
filt_lines = ~filt_lines
543+
filt_idx_out = idx_out[filt_lines]
544+
545+
# filter connectivities
546+
Adj[filt_idx_out[:, 0], filt_idx_out[:, 1]] = 0
547+
Adj.eliminate_zeros()
548+
549+
# filter_distances
550+
Dst[filt_idx_out[:, 0], filt_idx_out[:, 1]] = 0
551+
Dst.eliminate_zeros()
552+
553+
mask_conns_key = f"{key_added}_{conns_key}"
554+
mask_dists_key = f"{key_added}_{dists_key}"
555+
mask_neighs_key = f"{key_added}_{neighs_key}"
556+
557+
neighbors_dict = {
558+
"connectivities_key": mask_conns_key,
559+
"distances_key": mask_dists_key,
560+
"unfiltered_graph_key": conns_key,
561+
"params": {
562+
"negative_mask": negative_mask,
563+
"table_key": table_key,
564+
},
565+
}
566+
567+
if copy:
568+
return Adj, Dst
569+
570+
# save back to spatialdata
571+
_save_data(table, attr="obsp", key=mask_conns_key, data=Adj)
572+
_save_data(table, attr="obsp", key=mask_dists_key, data=Dst, prefix=False)
573+
_save_data(table, attr="uns", key=mask_neighs_key, data=neighbors_dict, prefix=False)
574+
575+
576+
@njit
577+
def _get_lines_coords(indices: NDArrayA, indptr: NDArrayA, coords: NDArrayA) -> tuple[list[Any], list[Any]]:
578+
lines = []
579+
idx_out = []
580+
for i in range(len(indptr) - 1):
581+
ixs = indices[indptr[i] : indptr[i + 1]]
582+
for ix in ixs:
583+
lines.append([coords[i], coords[ix]])
584+
idx_out.append((i, ix))
585+
return lines, idx_out

0 commit comments

Comments
 (0)