6
6
from collections .abc import Iterable # noqa: F401
7
7
from functools import partial
8
8
from itertools import chain
9
+ from typing import Any
9
10
11
+ import geopandas as gpd
10
12
import numpy as np
13
+ import pandas as pd
11
14
from anndata import AnnData
12
15
from anndata .utils import make_index_unique
16
+ from geopandas import GeoDataFrame
13
17
from numba import njit
14
18
from scanpy import logging as logg
15
19
from scipy .sparse import (
20
24
spmatrix ,
21
25
)
22
26
from scipy .spatial import Delaunay
27
+ from shapely import LineString , MultiPolygon , Point , Polygon , distance
23
28
from sklearn .metrics .pairwise import cosine_similarity , euclidean_distances
24
29
from sklearn .neighbors import NearestNeighbors
25
30
from spatialdata import SpatialData
31
+ from spatialdata ._core .centroids import get_centroids
32
+ from spatialdata ._core .query .relational_query import (
33
+ get_element_instances ,
34
+ match_element_to_table ,
35
+ )
36
+ from spatialdata .models import SpatialElement , get_table_keys
37
+ from spatialdata .models .models import (
38
+ Image2DModel ,
39
+ Image3DModel ,
40
+ Labels2DModel ,
41
+ Labels3DModel ,
42
+ PointsModel ,
43
+ RasterSchema ,
44
+ ShapesModel ,
45
+ TableModel ,
46
+ get_axes_names ,
47
+ get_model ,
48
+ )
26
49
27
50
from squidpy ._constants ._constants import CoordType , Transform
28
51
from squidpy ._constants ._pkg_constants import Key
43
66
def spatial_neighbors (
44
67
adata : AnnData | SpatialData ,
45
68
spatial_key : str = Key .obsm .spatial ,
69
+ elements_to_coordinate_systems : dict [str , str ] | None = None ,
70
+ table_key : str | None = None ,
46
71
library_key : str | None = None ,
47
72
coord_type : str | CoordType | None = None ,
48
73
n_neighs : int = 6 ,
@@ -62,6 +87,17 @@ def spatial_neighbors(
62
87
----------
63
88
%(adata)s
64
89
%(spatial_key)s
90
+ If `adata` is a :class:`spatialdata.SpatialData`, the coordinates of the centroids will be stored in the
91
+ `adata` with this key.
92
+ elements_to_coordinate_systems
93
+ A dictionary mapping element names of the SpatialData object to coordinate systems.
94
+ The elements can be either Shapes or Labels. For compatibility, the spatialdata table must annotate
95
+ all regions keys. Must not be `None` if `adata` is a :class:`spatialdata.SpatialData`.
96
+ table_key
97
+ Key in :attr:`spatialdata.SpatialData.tables` where the spatialdata table is stored. Must not be `None` if
98
+ `adata` is a :class:`spatialdata.SpatialData`.
99
+ mask_polygon
100
+ The Polygon or MultiPolygon element.
65
101
%(library_key)s
66
102
coord_type
67
103
Type of coordinate system. Valid options are:
@@ -109,7 +145,54 @@ def spatial_neighbors(
109
145
- :attr:`anndata.AnnData.uns` ``['{{key_added}}']`` - :class:`dict` containing parameters.
110
146
"""
111
147
if isinstance (adata , SpatialData ):
112
- adata = adata .table
148
+ assert (
149
+ elements_to_coordinate_systems is not None
150
+ ), "Since `adata` is a :class:`spatialdata.SpatialData`, `elements_to_coordinate_systems` must not be `None`."
151
+ assert (
152
+ table_key is not None
153
+ ), "Since `adata` is a :class:`spatialdata.SpatialData`, `table_key` must not be `None`."
154
+ elements , table = match_element_to_table (adata , list (elements_to_coordinate_systems ), table_key )
155
+ assert table .obs_names .equals (
156
+ adata .tables [table_key ].obs_names
157
+ ), "The spatialdata table must annotate all elements keys. Some elements are missing, please check the `elements_to_coordinate_systems` dictionary."
158
+ regions , region_key , instance_key = get_table_keys (adata .tables [table_key ])
159
+ regions = [regions ] if isinstance (regions , str ) else regions
160
+ ordered_regions_in_table = adata .tables [table_key ].obs [region_key ].unique ()
161
+
162
+ # TODO: remove this after https://github.com/scverse/spatialdata/issues/614
163
+ remove_centroids = {}
164
+ elem_instances = []
165
+ for e in regions :
166
+ schema = get_model (elements [e ])
167
+ element_instances = get_element_instances (elements [e ]).to_series ()
168
+ if np .isin (0 , element_instances .values ) and (schema in (Labels2DModel , Labels3DModel )):
169
+ element_instances = element_instances .drop (index = 0 )
170
+ remove_centroids [e ] = True
171
+ else :
172
+ remove_centroids [e ] = False
173
+ elem_instances .append (element_instances )
174
+
175
+ element_instances = pd .concat (elem_instances )
176
+ if (not np .all (element_instances .values == adata .tables [table_key ].obs [instance_key ].values )) or (
177
+ not np .all (ordered_regions_in_table == regions )
178
+ ):
179
+ raise ValueError (
180
+ "The spatialdata table must annotate all elements keys. Some elements are missing or not ordered correctly, please check the `elements_to_coordinate_systems` dictionary."
181
+ )
182
+ centroids = []
183
+ for region_ in ordered_regions_in_table :
184
+ cs = elements_to_coordinate_systems [region_ ]
185
+ centroid = get_centroids (adata [region_ ], coordinate_system = cs )[["x" , "y" ]].compute ()
186
+
187
+ # TODO: remove this after https://github.com/scverse/spatialdata/issues/614
188
+ if remove_centroids [region_ ]:
189
+ centroid = centroid [1 :].copy ()
190
+ centroids .append (centroid )
191
+
192
+ adata .tables [table_key ].obsm [spatial_key ] = np .concatenate (centroids )
193
+ adata = adata .tables [table_key ]
194
+ library_key = region_key
195
+
113
196
_assert_positive (n_rings , name = "n_rings" )
114
197
_assert_positive (n_neighs , name = "n_neighs" )
115
198
_assert_spatial_basis (adata , spatial_key )
@@ -167,7 +250,12 @@ def spatial_neighbors(
167
250
neighbors_dict = {
168
251
"connectivities_key" : conns_key ,
169
252
"distances_key" : dists_key ,
170
- "params" : {"n_neighbors" : n_neighs , "coord_type" : coord_type .v , "radius" : radius , "transform" : transform .v },
253
+ "params" : {
254
+ "n_neighbors" : n_neighs ,
255
+ "coord_type" : coord_type .v ,
256
+ "radius" : radius ,
257
+ "transform" : transform .v ,
258
+ },
171
259
}
172
260
173
261
if copy :
@@ -194,10 +282,21 @@ def _spatial_neighbor(
194
282
with warnings .catch_warnings ():
195
283
warnings .simplefilter ("ignore" , SparseEfficiencyWarning )
196
284
if coord_type == CoordType .GRID :
197
- Adj , Dst = _build_grid (coords , n_neighs = n_neighs , n_rings = n_rings , delaunay = delaunay , set_diag = set_diag )
285
+ Adj , Dst = _build_grid (
286
+ coords ,
287
+ n_neighs = n_neighs ,
288
+ n_rings = n_rings ,
289
+ delaunay = delaunay ,
290
+ set_diag = set_diag ,
291
+ )
198
292
elif coord_type == CoordType .GENERIC :
199
293
Adj , Dst = _build_connectivity (
200
- coords , n_neighs = n_neighs , radius = radius , delaunay = delaunay , return_distance = True , set_diag = set_diag
294
+ coords ,
295
+ n_neighs = n_neighs ,
296
+ radius = radius ,
297
+ delaunay = delaunay ,
298
+ return_distance = True ,
299
+ set_diag = set_diag ,
201
300
)
202
301
else :
203
302
raise NotImplementedError (f"Coordinate type `{ coord_type } ` is not yet implemented." )
@@ -233,7 +332,11 @@ def _spatial_neighbor(
233
332
234
333
235
334
def _build_grid (
236
- coords : NDArrayA , n_neighs : int , n_rings : int , delaunay : bool = False , set_diag : bool = False
335
+ coords : NDArrayA ,
336
+ n_neighs : int ,
337
+ n_rings : int ,
338
+ delaunay : bool = False ,
339
+ set_diag : bool = False ,
237
340
) -> tuple [csr_matrix , csr_matrix ]:
238
341
if n_rings > 1 :
239
342
Adj : csr_matrix = _build_connectivity (
@@ -258,7 +361,13 @@ def _build_grid(
258
361
Dst = Adj .copy ()
259
362
Adj .data [:] = 1.0
260
363
else :
261
- Adj = _build_connectivity (coords , n_neighs = n_neighs , neigh_correct = True , delaunay = delaunay , set_diag = set_diag )
364
+ Adj = _build_connectivity (
365
+ coords ,
366
+ n_neighs = n_neighs ,
367
+ neigh_correct = True ,
368
+ delaunay = delaunay ,
369
+ set_diag = set_diag ,
370
+ )
262
371
Dst = Adj .copy ()
263
372
264
373
Dst .setdiag (0.0 )
@@ -302,14 +411,21 @@ def _build_connectivity(
302
411
if neigh_correct :
303
412
dist_cutoff = np .median (dists ) * 1.3 # there's a small amount of sway
304
413
mask = dists < dist_cutoff
305
- row_indices , col_indices , dists = row_indices [mask ], col_indices [mask ], dists [mask ]
414
+ row_indices , col_indices , dists = (
415
+ row_indices [mask ],
416
+ col_indices [mask ],
417
+ dists [mask ],
418
+ )
306
419
else :
307
420
dists , col_indices = tree .radius_neighbors ()
308
421
row_indices = np .repeat (np .arange (N ), [len (x ) for x in col_indices ])
309
422
dists = np .concatenate (dists )
310
423
col_indices = np .concatenate (col_indices )
311
424
312
- Adj = csr_matrix ((np .ones_like (row_indices , dtype = np .float64 ), (row_indices , col_indices )), shape = (N , N ))
425
+ Adj = csr_matrix (
426
+ (np .ones_like (row_indices , dtype = np .float64 ), (row_indices , col_indices )),
427
+ shape = (N , N ),
428
+ )
313
429
if return_distance :
314
430
Dst = csr_matrix ((dists , (row_indices , col_indices )), shape = (N , N ))
315
431
@@ -349,3 +465,121 @@ def _transform_a_spectral(a: spmatrix) -> spmatrix:
349
465
350
466
def _transform_a_cosine (a : spmatrix ) -> spmatrix :
351
467
return cosine_similarity (a , dense_output = False )
468
+
469
+
470
+ @d .dedent
471
+ def mask_graph (
472
+ sdata : SpatialData ,
473
+ table_key : str ,
474
+ polygon_mask : Polygon | MultiPolygon ,
475
+ negative_mask : bool = False ,
476
+ spatial_key : str = Key .obsm .spatial ,
477
+ key_added : str = "mask" ,
478
+ copy : bool = False ,
479
+ ) -> SpatialData :
480
+ """
481
+ Mask the graph based on a polygon mask.
482
+
483
+ Given a spatial graph stored in :attr:`anndata.AnnData.obsp` ``['{{key_added}}_{{spatial_key}}_connectivities']`` and spatial coordinates stored in :attr:`anndata.AnnData.obsp` ``['{{spatial_key}}']``, it maskes the graph so that only edges fully contained in the polygons are kept.
484
+
485
+ Parameters
486
+ ----------
487
+ sdata
488
+ The spatial data object.
489
+ table_key:
490
+ The key of the table containing the spatial data.
491
+ polygon_mask
492
+ The :class:`shapely.Polygon` or :class:`shapely.MultiPolygon` to be used as mask.
493
+ negative_mask
494
+ Whether to keep the edges within the polygon mask or outside.
495
+ Note that when ``negative_mask = True``, only the edges fully contained in the polygon are removed.
496
+ If edges are partially contained in the polygon, they are kept.
497
+ %(spatial_key)s
498
+ key_added
499
+ Key which controls where the results are saved if ``copy = False``.
500
+ %(copy)s
501
+
502
+ Returns
503
+ -------
504
+ If ``copy = True``, returns a :class:`tuple` with the masked spatial connectivities and masked distances matrices.
505
+
506
+ Otherwise, modifies the ``adata`` with the following keys:
507
+
508
+ - :attr:`anndata.AnnData.obsp` ``['{{key_added}}_{{spatial_key}}_connectivities']`` - the spatial connectivities.
509
+ - :attr:`anndata.AnnData.obsp` ``['{{key_added}}_{{spatial_key}}_distances']`` - the spatial distances.
510
+ - :attr:`anndata.AnnData.uns` ``['{{key_added}}_{{spatial_key}}']`` - :class:`dict` containing parameters.
511
+
512
+ Notes
513
+ -----
514
+ The `polygon_mask` must be in the same `coordinate_systems` of the spatial graph, but no check is performed to assess this.
515
+ """
516
+ # we could add this to arg, but I don't see use case for now
517
+ neighs_key = Key .uns .spatial_neighs (spatial_key )
518
+ conns_key = Key .obsp .spatial_conn (spatial_key )
519
+ dists_key = Key .obsp .spatial_dist (spatial_key )
520
+
521
+ # check polygon type
522
+ if not isinstance (polygon_mask , (Polygon , MultiPolygon )):
523
+ raise ValueError (f"`polygon_mask` should be of type `Polygon` or `MultiPolygon`, got { type (polygon_mask )} " )
524
+
525
+ # get elements
526
+ table = sdata .tables [table_key ]
527
+ coords = table .obsm [spatial_key ]
528
+ Adj = table .obsp [conns_key ]
529
+ Dst = table .obsp [dists_key ]
530
+
531
+ # convert edges to lines
532
+ lines_coords , idx_out = _get_lines_coords (Adj .indices , Adj .indptr , coords )
533
+ lines_coords , idx_out = np .array (lines_coords ), np .array (idx_out )
534
+ lines_df = gpd .GeoDataFrame (geometry = list (map (LineString , lines_coords )))
535
+
536
+ # check that lines overlap with the polygon
537
+ filt_lines = lines_df .geometry .within (polygon_mask ).values
538
+
539
+ # ~ within index, and set that to 0
540
+ if not negative_mask :
541
+ # keep only the lines that are within the polygon
542
+ filt_lines = ~ filt_lines
543
+ filt_idx_out = idx_out [filt_lines ]
544
+
545
+ # filter connectivities
546
+ Adj [filt_idx_out [:, 0 ], filt_idx_out [:, 1 ]] = 0
547
+ Adj .eliminate_zeros ()
548
+
549
+ # filter_distances
550
+ Dst [filt_idx_out [:, 0 ], filt_idx_out [:, 1 ]] = 0
551
+ Dst .eliminate_zeros ()
552
+
553
+ mask_conns_key = f"{ key_added } _{ conns_key } "
554
+ mask_dists_key = f"{ key_added } _{ dists_key } "
555
+ mask_neighs_key = f"{ key_added } _{ neighs_key } "
556
+
557
+ neighbors_dict = {
558
+ "connectivities_key" : mask_conns_key ,
559
+ "distances_key" : mask_dists_key ,
560
+ "unfiltered_graph_key" : conns_key ,
561
+ "params" : {
562
+ "negative_mask" : negative_mask ,
563
+ "table_key" : table_key ,
564
+ },
565
+ }
566
+
567
+ if copy :
568
+ return Adj , Dst
569
+
570
+ # save back to spatialdata
571
+ _save_data (table , attr = "obsp" , key = mask_conns_key , data = Adj )
572
+ _save_data (table , attr = "obsp" , key = mask_dists_key , data = Dst , prefix = False )
573
+ _save_data (table , attr = "uns" , key = mask_neighs_key , data = neighbors_dict , prefix = False )
574
+
575
+
576
+ @njit
577
+ def _get_lines_coords (indices : NDArrayA , indptr : NDArrayA , coords : NDArrayA ) -> tuple [list [Any ], list [Any ]]:
578
+ lines = []
579
+ idx_out = []
580
+ for i in range (len (indptr ) - 1 ):
581
+ ixs = indices [indptr [i ] : indptr [i + 1 ]]
582
+ for ix in ixs :
583
+ lines .append ([coords [i ], coords [ix ]])
584
+ idx_out .append ((i , ix ))
585
+ return lines , idx_out
0 commit comments