1
1
from __future__ import annotations
2
2
3
+ import contextlib
3
4
import warnings
4
5
from typing import Any , Literal
5
6
@@ -148,19 +149,19 @@ def calculate_niche(
148
149
"If you haven't computed a spatial neighborhood graph yet, use `sq.gr.spatial_neighbors`."
149
150
)
150
151
152
+ result_columns = _get_result_columns (
153
+ flavor = flavor ,
154
+ resolutions = resolutions ,
155
+ library_key = None ,
156
+ libraries = None ,
157
+ )
158
+
151
159
if library_key is not None :
152
160
if library_key not in adata .obs .columns :
153
161
raise KeyError (f"'{ library_key } ' not found in `adata.obs`." )
154
162
155
163
logg .info (f"Stratifying by library_key '{ library_key } '" )
156
164
157
- result_columns = _get_result_columns (
158
- flavor = flavor ,
159
- resolutions = resolutions ,
160
- library_key = None ,
161
- libraries = None ,
162
- )
163
-
164
165
for col in result_columns :
165
166
adata .obs [col ] = "not_a_niche"
166
167
@@ -226,15 +227,25 @@ def calculate_niche(
226
227
spatial_connectivities_key ,
227
228
)
228
229
229
- if inplace :
230
- if isinstance (data , SpatialData ):
231
- data .tables [table_key ] = adata
232
- else :
233
- for col in adata .obs .columns :
234
- if "niche" in col and col not in orig_adata .obs .columns :
235
- orig_adata .obs [col ] = adata .obs [col ]
230
+ if not inplace :
231
+ return adata
232
+ # For SpatialData, update the table directly
233
+ if isinstance (data , SpatialData ):
234
+ data .tables [table_key ] = adata
235
+ else :
236
+ # For AnnData, copy results back to original object
237
+ for col in result_columns :
238
+ if col in orig_adata .obs .columns :
239
+ logg .info (f"Overwriting existing column '{ col } '" )
240
+ with contextlib .suppress (KeyError ):
241
+ del orig_adata .obs [col ]
242
+ if f"{ col } _colors" in orig_adata .uns .keys ():
243
+ with contextlib .suppress (KeyError ):
244
+ del orig_adata .uns [f"{ col } _colors" ]
245
+
246
+ orig_adata .obs [col ] = adata .obs [col ]
236
247
237
- return adata if not inplace else orig_adata
248
+ return None
238
249
239
250
240
251
def _get_result_columns (
@@ -345,17 +356,32 @@ def _get_nhood_profile_niches(
345
356
if n_hop_weights is None :
346
357
n_hop_weights = [1 ] * distance
347
358
# if weights are provided, start with applying weight to the original neighborhood profile
348
- else :
349
- nhood_profile = n_hop_weights [0 ] * nhood_profile
359
+ elif len (n_hop_weights ) < distance :
360
+ # Extend weights if too few provided
361
+ n_hop_weights = n_hop_weights + [n_hop_weights [- 1 ]] * (distance - len (n_hop_weights ))
362
+ logg .debug (f"Extended weights to match distance: { n_hop_weights } " )
363
+
364
+ # Apply first weight to base profile
365
+ weighted_profile = n_hop_weights [0 ] * nhood_profile
366
+
367
+ # Calculate higher-order hop profiles
368
+ n_hop_adjacency_matrix = adata_masked .obsp [spatial_connectivities_key ].copy ()
369
+
350
370
# get n_hop neighbor adjacency matrices by multiplying the original adjacency matrix with itself n times and get corresponding neighborhood profiles.
351
- for n_hop in range (distance - 1 ):
371
+ for n_hop in range (1 , distance ):
372
+ logg .debug (f"Calculating { n_hop + 1 } -hop neighbors" )
373
+ # Multiply adjacency matrix by itself to get n+1 hop adjacency
352
374
n_hop_adjacency_matrix = n_hop_adjacency_matrix @ adata_masked .obsp [spatial_connectivities_key ]
353
375
matrix = n_hop_adjacency_matrix .tocoo ()
354
- nhood_profile += n_hop_weights [n_hop + 1 ] * _calculate_neighborhood_profile (
355
- adata_masked , groups , matrix , abs_nhood
356
- )
376
+
377
+ # Calculate and add weighted profile
378
+ hop_profile = _calculate_neighborhood_profile (adata_masked , groups , matrix , abs_nhood )
379
+ weighted_profile += n_hop_weights [n_hop ] * hop_profile
380
+
357
381
if not abs_nhood :
358
- nhood_profile = nhood_profile / sum (n_hop_weights )
382
+ weighted_profile = weighted_profile / sum (n_hop_weights )
383
+
384
+ nhood_profile = weighted_profile
359
385
360
386
# create AnnData object from neighborhood profile to perform scanpy functions
361
387
adata_neighborhood = ad .AnnData (X = nhood_profile )
@@ -376,25 +402,38 @@ def _get_nhood_profile_niches(
376
402
377
403
# For each resolution, apply leiden on neighborhood profile. Each cluster label equals to a niche label
378
404
for res in resolutions :
405
+ niche_key = f"nhood_niche_res={ res } "
406
+
407
+ if niche_key in adata_masked .obs .columns :
408
+ del adata_masked .obs [niche_key ]
409
+
410
+ if f"{ niche_key } _colors" in adata_masked .uns .keys ():
411
+ del adata_masked .uns [f"{ niche_key } _colors" ]
412
+ # print(adata_masked.obs[niche_key])
413
+
379
414
sc .tl .leiden (
380
415
adata_neighborhood ,
381
416
resolution = res ,
382
- key_added = f"nhood_niche_res= { res } " ,
417
+ key_added = niche_key ,
383
418
)
384
- adata_masked .obs [f"nhood_niche_res={ res } " ] = adata_masked .obs .index .map (
385
- adata_neighborhood .obs [f"nhood_niche_res={ res } " ]
386
- ).fillna ("not_a_niche" )
419
+
420
+ adata_masked .obs [niche_key ] = "not_a_niche"
421
+
422
+ neighborhood_clusters = dict (zip (adata_neighborhood .obs .index , adata_neighborhood .obs [niche_key ], strict = False ))
423
+
424
+ mask_indices = adata_masked .obs .index
425
+ adata_masked .obs .loc [mask_indices , niche_key ] = [
426
+ neighborhood_clusters .get (idx , "not_a_niche" ) for idx in mask_indices
427
+ ]
387
428
388
429
# filter niches with n_cells < min_niche_size
389
430
if min_niche_size is not None :
390
- counts_by_niche = adata_masked .obs [f"nhood_niche_res= { res } " ].value_counts ()
431
+ counts_by_niche = adata_masked .obs [niche_key ].value_counts ()
391
432
to_filter = counts_by_niche [counts_by_niche < min_niche_size ].index
392
- adata_masked .obs [f"nhood_niche_res= { res } " ] = adata_masked .obs [f"nhood_niche_res= { res } " ].apply (
433
+ adata_masked .obs [niche_key ] = adata_masked .obs [niche_key ].apply (
393
434
lambda x , to_filter = to_filter : "not_a_niche" if x in to_filter else x
394
435
)
395
- adata .obs [f"nhood_niche_res={ res } " ] = adata .obs .index .map (
396
- adata_masked .obs [f"nhood_niche_res={ res } " ]
397
- ).fillna ("not_a_niche" )
436
+ adata_masked .obs [niche_key ] = adata_masked .obs .index .map (adata_masked .obs [niche_key ]).fillna ("not_a_niche" )
398
437
399
438
return
400
439
0 commit comments