3
3
from __future__ import annotations
4
4
5
5
from collections .abc import Iterable , Sequence
6
+ from importlib .util import find_spec
6
7
from itertools import chain
7
8
from typing import TYPE_CHECKING , Any , Literal
8
9
9
10
import numba .types as nt
10
11
import numpy as np
11
12
import pandas as pd
12
13
from anndata import AnnData
13
- from numba import njit
14
+ from numba import njit , prange
14
15
from numpy .random import default_rng
15
16
from scanpy import logging as logg
16
17
from scanpy .get import _get_obs_rep
@@ -266,85 +267,82 @@ def _score_helper(
266
267
return score_perms
267
268
268
269
269
- @njit (
270
- ft [:, :, :](tt (it [:], 2 ), ft [:, :], it [:], ft [:], bl ),
271
- parallel = False ,
272
- fastmath = True ,
273
- )
270
+ @njit (parallel = True , fastmath = True , cache = True )
274
271
def _occur_count (
275
- clust : tuple [NDArrayA , NDArrayA ],
276
- pw_dist : NDArrayA ,
277
- labs_unique : NDArrayA ,
278
- interval : NDArrayA ,
279
- same_split : bool ,
272
+ spatial_x : NDArrayA , spatial_y : NDArrayA , thresholds : NDArrayA , label_idx : NDArrayA , n : int , k : int , l_val : int
280
273
) -> NDArrayA :
281
- num = labs_unique .shape [0 ]
282
- out = np .zeros ((num , num , interval .shape [0 ] - 1 ), dtype = ft )
283
-
284
- for idx in range (interval .shape [0 ] - 1 ):
285
- co_occur = np .zeros ((num , num ), dtype = ft )
286
- probs_con = np .zeros ((num , num ), dtype = ft )
287
-
288
- thres_max = interval [idx + 1 ]
289
- clust_x , clust_y = clust
290
-
291
- # Modified to compute co-occurrence probability ratio over increasing radii sizes as opposed to discrete interval bins
292
- # Need pw_dist > 0 to avoid counting a cell with itself as co-occurrence
293
- idx_x , idx_y = np .nonzero ((pw_dist <= thres_max ) & (pw_dist > 0 ))
294
- x = clust_x [idx_x ]
295
- y = clust_y [idx_y ]
296
- # Treat computing co-occurrence using the same split and different splits differently
297
- # Pairwise distance matrix for between the same split is symmetric and therefore only needs to be counted once
298
- for i , j in zip (x , y ): # noqa: B905 # cannot use strict=False because of numba
299
- co_occur [i , j ] += 1
300
- if not same_split :
301
- co_occur [j , i ] += 1
302
-
303
- # Prevent divison by zero errors when we have low cell counts/small intervals
304
- probs_matrix = co_occur / np .sum (co_occur ) if np .sum (co_occur ) != 0 else np .zeros ((num , num ), dtype = ft )
305
- probs = np .sum (probs_matrix , axis = 0 )
306
-
307
- for c in labs_unique :
308
- probs_conditional = (
309
- co_occur [c ] / np .sum (co_occur [c ]) if np .sum (co_occur [c ]) != 0 else np .zeros (num , dtype = ft )
310
- )
311
- probs_con [c , :] = np .zeros (num , dtype = ft )
312
- for i in range (num ):
313
- if probs [i ] == 0 :
314
- probs_con [c , i ] = 0
315
- else :
316
- probs_con [c , i ] = probs_conditional [i ] / probs [i ]
317
-
318
- out [:, :, idx ] = probs_con
319
-
320
- return out
321
-
322
-
323
- def _co_occurrence_helper (
324
- idx_splits : Iterable [tuple [int , int ]],
325
- spatial_splits : Sequence [NDArrayA ],
326
- labs_splits : Sequence [NDArrayA ],
327
- labs_unique : NDArrayA ,
328
- interval : NDArrayA ,
329
- queue : SigQueue | None = None ,
330
- ) -> pd .DataFrame :
331
- out_lst = []
332
- for t in idx_splits :
333
- idx_x , idx_y = t
334
- labs_x = labs_splits [idx_x ]
335
- labs_y = labs_splits [idx_y ]
336
- dist = pairwise_distances (spatial_splits [idx_x ], spatial_splits [idx_y ])
274
+ # Allocate a 2D array to store a flat local result per point.
275
+ k2 = k * k
276
+ local_results = np .zeros ((n , l_val * k2 ), dtype = np .int32 )
337
277
338
- out = _occur_count ((labs_x , labs_y ), dist , labs_unique , interval , idx_x == idx_y )
339
- out_lst .append (out )
278
+ for i in prange (n ):
279
+ for j in range (n ):
280
+ if i == j :
281
+ continue
282
+ dx = spatial_x [i ] - spatial_x [j ]
283
+ dy = spatial_y [i ] - spatial_y [j ]
284
+ d2 = dx * dx + dy * dy
340
285
341
- if queue is not None :
342
- queue . put ( Signal . UPDATE )
286
+ pair = label_idx [ i ] * k + label_idx [ j ] # fixed in r–loop
287
+ base = pair * l_val # first cell for that pair
343
288
344
- if queue is not None :
345
- queue .put (Signal .FINISH )
289
+ for r in range (l_val ):
290
+ if d2 <= thresholds [r ]:
291
+ local_results [i ][base + r ] += 1
346
292
347
- return out_lst
293
+ # reduction and reshape stay the same
294
+ result_flat = local_results .sum (axis = 0 )
295
+ result : NDArrayA = result_flat .reshape (k , k , l_val )
296
+
297
+ return result
298
+
299
+
300
+ @njit (parallel = True , fastmath = True , cache = True )
301
+ def _co_occurrence_helper (v_x : NDArrayA , v_y : NDArrayA , v_radium : NDArrayA , labs : NDArrayA ) -> NDArrayA :
302
+ """
303
+ Fast co-occurrence probability computation using the new numba-accelerated counting.
304
+
305
+ Parameters
306
+ ----------
307
+ v_x : np.ndarray, float64
308
+ x–coordinates.
309
+ v_y : np.ndarray, float64
310
+ y–coordinates.
311
+ v_radium : np.ndarray, float64
312
+ Distance thresholds (in ascending order).
313
+ labs : np.ndarray
314
+ Cluster labels (as integers).
315
+
316
+ Returns
317
+ -------
318
+ occ_prob : np.ndarray
319
+ A 3D array of shape (k, k, len(v_radium)-1) containing the co-occurrence probabilities.
320
+ labs_unique : np.ndarray
321
+ Array of unique labels.
322
+ """
323
+ n = len (v_x )
324
+ labs_unique = np .unique (labs )
325
+ k = len (labs_unique )
326
+ # l_val is the number of bins; here we assume the thresholds come from v_radium[1:].
327
+ l_val = len (v_radium ) - 1
328
+ # Compute squared thresholds from the interval (skip the first value)
329
+ thresholds = (v_radium [1 :]) ** 2
330
+
331
+ # Compute co-occurence counts.
332
+ counts = _occur_count (v_x , v_y , thresholds , labs , n , k , l_val )
333
+
334
+ occ_prob = np .zeros ((k , k , l_val ), dtype = np .float64 )
335
+ row_sums = counts .sum (axis = 0 )
336
+ totals = row_sums .sum (axis = 0 )
337
+
338
+ for r in prange (l_val ):
339
+ probs = row_sums [:, r ] / totals [r ]
340
+ for c in range (k ):
341
+ for i in range (k ):
342
+ if probs [i ] != 0.0 and row_sums [c , r ] != 0.0 :
343
+ occ_prob [i , c , r ] = (counts [c , i , r ] / row_sums [c , r ]) / probs [i ]
344
+
345
+ return occ_prob
348
346
349
347
350
348
@d .dedent
@@ -387,18 +385,16 @@ def co_occurrence(
387
385
- :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['interval']`` - the distance thresholds
388
386
computed at ``interval``.
389
387
"""
388
+
390
389
if isinstance (adata , SpatialData ):
391
390
adata = adata .table
392
391
_assert_categorical_obs (adata , key = cluster_key )
393
392
_assert_spatial_basis (adata , key = spatial_key )
394
393
395
394
spatial = adata .obsm [spatial_key ].astype (fp )
396
395
original_clust = adata .obs [cluster_key ]
397
-
398
- # annotate cluster idx
399
396
clust_map = {v : i for i , v in enumerate (original_clust .cat .categories .values )}
400
397
labs = np .array ([clust_map [c ] for c in original_clust ], dtype = ip )
401
- labs_unique = np .array (list (clust_map .values ()), dtype = ip )
402
398
403
399
# create intervals thresholds
404
400
if isinstance (interval , int ):
@@ -409,57 +405,21 @@ def co_occurrence(
409
405
if len (interval ) <= 1 :
410
406
raise ValueError (f"Expected interval to be of length `>= 2`, found `{ len (interval )} `." )
411
407
412
- n_obs = spatial .shape [0 ]
413
- if n_splits is None :
414
- size_arr = (n_obs ** 2 * spatial .itemsize ) / 1024 / 1024 # calc expected mem usage
415
- n_splits = 1
416
- if size_arr > 2000 :
417
- while (n_obs / n_splits ) > 2048 :
418
- n_splits += 1
419
- logg .warning (
420
- f"`n_splits` was automatically set to `{ n_splits } ` to "
421
- f"prevent `{ n_obs } x{ n_obs } ` distance matrix from being created"
422
- )
423
- n_splits = int (max (min (n_splits , n_obs ), 1 ))
424
-
425
- # split array and labels
426
- spatial_splits = tuple (s for s in np .array_split (spatial , n_splits , axis = 0 ) if len (s ))
427
- labs_splits = tuple (s for s in np .array_split (labs , n_splits , axis = 0 ) if len (s ))
428
- # create idx array including unique combinations and self-comparison
429
- x , y = np .triu_indices_from (np .empty ((n_splits , n_splits )))
430
- idx_splits = list (zip (x , y , strict = False ))
408
+ spatial_x = spatial [:, 0 ]
409
+ spatial_y = spatial [:, 1 ]
431
410
432
- n_jobs = _get_n_cores (n_jobs )
411
+ # Compute co-occurrence probabilities using the fast numba routine.
412
+ out = _co_occurrence_helper (spatial_x , spatial_y , interval , labs )
433
413
start = logg .info (
434
- f"Calculating co-occurrence probabilities for `{ len (interval )} ` intervals "
435
- f"`{ len (idx_splits )} ` split combinations using `{ n_jobs } ` core(s)"
436
- )
437
-
438
- out_lst = parallelize (
439
- _co_occurrence_helper ,
440
- collection = idx_splits ,
441
- extractor = chain .from_iterable ,
442
- n_jobs = n_jobs ,
443
- backend = backend ,
444
- show_progress_bar = show_progress_bar ,
445
- )(
446
- spatial_splits = spatial_splits ,
447
- labs_splits = labs_splits ,
448
- labs_unique = labs_unique ,
449
- interval = interval ,
414
+ f"Calculating co-occurrence probabilities for `{ len (interval )} ` intervals using `{ n_jobs } ` core(s) and `{ n_splits } ` splits"
450
415
)
451
- out = list (out_lst )[0 ] if len (idx_splits ) == 1 else sum (list (out_lst )) / len (idx_splits )
452
416
453
417
if copy :
454
418
logg .info ("Finish" , time = start )
455
419
return out , interval
456
420
457
421
_save_data (
458
- adata ,
459
- attr = "uns" ,
460
- key = Key .uns .co_occurrence (cluster_key ),
461
- data = {"occ" : out , "interval" : interval },
462
- time = start ,
422
+ adata , attr = "uns" , key = Key .uns .co_occurrence (cluster_key ), data = {"occ" : out , "interval" : interval }, time = start
463
423
)
464
424
465
425
0 commit comments