33from __future__ import annotations
44
55from collections .abc import Iterable , Sequence
6+ from importlib .util import find_spec
67from itertools import chain
78from typing import TYPE_CHECKING , Any , Literal
89
910import numba .types as nt
1011import numpy as np
1112import pandas as pd
1213from anndata import AnnData
13- from numba import njit
14+ from numba import njit , prange
1415from numpy .random import default_rng
1516from scanpy import logging as logg
1617from scanpy .get import _get_obs_rep
@@ -266,85 +267,82 @@ def _score_helper(
266267 return score_perms
267268
268269
269- @njit (
270- ft [:, :, :](tt (it [:], 2 ), ft [:, :], it [:], ft [:], bl ),
271- parallel = False ,
272- fastmath = True ,
273- )
270+ @njit (parallel = True , fastmath = True , cache = True )
274271def _occur_count (
275- clust : tuple [NDArrayA , NDArrayA ],
276- pw_dist : NDArrayA ,
277- labs_unique : NDArrayA ,
278- interval : NDArrayA ,
279- same_split : bool ,
272+ spatial_x : NDArrayA , spatial_y : NDArrayA , thresholds : NDArrayA , label_idx : NDArrayA , n : int , k : int , l_val : int
280273) -> NDArrayA :
281- num = labs_unique .shape [0 ]
282- out = np .zeros ((num , num , interval .shape [0 ] - 1 ), dtype = ft )
283-
284- for idx in range (interval .shape [0 ] - 1 ):
285- co_occur = np .zeros ((num , num ), dtype = ft )
286- probs_con = np .zeros ((num , num ), dtype = ft )
287-
288- thres_max = interval [idx + 1 ]
289- clust_x , clust_y = clust
290-
291- # Modified to compute co-occurrence probability ratio over increasing radii sizes as opposed to discrete interval bins
292- # Need pw_dist > 0 to avoid counting a cell with itself as co-occurrence
293- idx_x , idx_y = np .nonzero ((pw_dist <= thres_max ) & (pw_dist > 0 ))
294- x = clust_x [idx_x ]
295- y = clust_y [idx_y ]
296- # Treat computing co-occurrence using the same split and different splits differently
297- # Pairwise distance matrix for between the same split is symmetric and therefore only needs to be counted once
298- for i , j in zip (x , y ): # noqa: B905 # cannot use strict=False because of numba
299- co_occur [i , j ] += 1
300- if not same_split :
301- co_occur [j , i ] += 1
302-
303- # Prevent divison by zero errors when we have low cell counts/small intervals
304- probs_matrix = co_occur / np .sum (co_occur ) if np .sum (co_occur ) != 0 else np .zeros ((num , num ), dtype = ft )
305- probs = np .sum (probs_matrix , axis = 0 )
306-
307- for c in labs_unique :
308- probs_conditional = (
309- co_occur [c ] / np .sum (co_occur [c ]) if np .sum (co_occur [c ]) != 0 else np .zeros (num , dtype = ft )
310- )
311- probs_con [c , :] = np .zeros (num , dtype = ft )
312- for i in range (num ):
313- if probs [i ] == 0 :
314- probs_con [c , i ] = 0
315- else :
316- probs_con [c , i ] = probs_conditional [i ] / probs [i ]
317-
318- out [:, :, idx ] = probs_con
319-
320- return out
321-
322-
323- def _co_occurrence_helper (
324- idx_splits : Iterable [tuple [int , int ]],
325- spatial_splits : Sequence [NDArrayA ],
326- labs_splits : Sequence [NDArrayA ],
327- labs_unique : NDArrayA ,
328- interval : NDArrayA ,
329- queue : SigQueue | None = None ,
330- ) -> pd .DataFrame :
331- out_lst = []
332- for t in idx_splits :
333- idx_x , idx_y = t
334- labs_x = labs_splits [idx_x ]
335- labs_y = labs_splits [idx_y ]
336- dist = pairwise_distances (spatial_splits [idx_x ], spatial_splits [idx_y ])
274+ # Allocate a 2D array to store a flat local result per point.
275+ k2 = k * k
276+ local_results = np .zeros ((n , l_val * k2 ), dtype = np .int32 )
337277
338- out = _occur_count ((labs_x , labs_y ), dist , labs_unique , interval , idx_x == idx_y )
339- out_lst .append (out )
278+ for i in prange (n ):
279+ for j in range (n ):
280+ if i == j :
281+ continue
282+ dx = spatial_x [i ] - spatial_x [j ]
283+ dy = spatial_y [i ] - spatial_y [j ]
284+ d2 = dx * dx + dy * dy
340285
341- if queue is not None :
342- queue . put ( Signal . UPDATE )
286+ pair = label_idx [ i ] * k + label_idx [ j ] # fixed in r–loop
287+ base = pair * l_val # first cell for that pair
343288
344- if queue is not None :
345- queue .put (Signal .FINISH )
289+ for r in range (l_val ):
290+ if d2 <= thresholds [r ]:
291+ local_results [i ][base + r ] += 1
346292
347- return out_lst
293+ # reduction and reshape stay the same
294+ result_flat = local_results .sum (axis = 0 )
295+ result : NDArrayA = result_flat .reshape (k , k , l_val )
296+
297+ return result
298+
299+
300+ @njit (parallel = True , fastmath = True , cache = True )
301+ def _co_occurrence_helper (v_x : NDArrayA , v_y : NDArrayA , v_radium : NDArrayA , labs : NDArrayA ) -> NDArrayA :
302+ """
303+ Fast co-occurrence probability computation using the new numba-accelerated counting.
304+
305+ Parameters
306+ ----------
307+ v_x : np.ndarray, float64
308+ x–coordinates.
309+ v_y : np.ndarray, float64
310+ y–coordinates.
311+ v_radium : np.ndarray, float64
312+ Distance thresholds (in ascending order).
313+ labs : np.ndarray
314+ Cluster labels (as integers).
315+
316+ Returns
317+ -------
318+ occ_prob : np.ndarray
319+ A 3D array of shape (k, k, len(v_radium)-1) containing the co-occurrence probabilities.
320+ labs_unique : np.ndarray
321+ Array of unique labels.
322+ """
323+ n = len (v_x )
324+ labs_unique = np .unique (labs )
325+ k = len (labs_unique )
326+ # l_val is the number of bins; here we assume the thresholds come from v_radium[1:].
327+ l_val = len (v_radium ) - 1
328+ # Compute squared thresholds from the interval (skip the first value)
329+ thresholds = (v_radium [1 :]) ** 2
330+
331+ # Compute co-occurence counts.
332+ counts = _occur_count (v_x , v_y , thresholds , labs , n , k , l_val )
333+
334+ occ_prob = np .zeros ((k , k , l_val ), dtype = np .float64 )
335+ row_sums = counts .sum (axis = 0 )
336+ totals = row_sums .sum (axis = 0 )
337+
338+ for r in prange (l_val ):
339+ probs = row_sums [:, r ] / totals [r ]
340+ for c in range (k ):
341+ for i in range (k ):
342+ if probs [i ] != 0.0 and row_sums [c , r ] != 0.0 :
343+ occ_prob [i , c , r ] = (counts [c , i , r ] / row_sums [c , r ]) / probs [i ]
344+
345+ return occ_prob
348346
349347
350348@d .dedent
@@ -387,18 +385,16 @@ def co_occurrence(
387385 - :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['interval']`` - the distance thresholds
388386 computed at ``interval``.
389387 """
388+
390389 if isinstance (adata , SpatialData ):
391390 adata = adata .table
392391 _assert_categorical_obs (adata , key = cluster_key )
393392 _assert_spatial_basis (adata , key = spatial_key )
394393
395394 spatial = adata .obsm [spatial_key ].astype (fp )
396395 original_clust = adata .obs [cluster_key ]
397-
398- # annotate cluster idx
399396 clust_map = {v : i for i , v in enumerate (original_clust .cat .categories .values )}
400397 labs = np .array ([clust_map [c ] for c in original_clust ], dtype = ip )
401- labs_unique = np .array (list (clust_map .values ()), dtype = ip )
402398
403399 # create intervals thresholds
404400 if isinstance (interval , int ):
@@ -409,57 +405,21 @@ def co_occurrence(
409405 if len (interval ) <= 1 :
410406 raise ValueError (f"Expected interval to be of length `>= 2`, found `{ len (interval )} `." )
411407
412- n_obs = spatial .shape [0 ]
413- if n_splits is None :
414- size_arr = (n_obs ** 2 * spatial .itemsize ) / 1024 / 1024 # calc expected mem usage
415- n_splits = 1
416- if size_arr > 2000 :
417- while (n_obs / n_splits ) > 2048 :
418- n_splits += 1
419- logg .warning (
420- f"`n_splits` was automatically set to `{ n_splits } ` to "
421- f"prevent `{ n_obs } x{ n_obs } ` distance matrix from being created"
422- )
423- n_splits = int (max (min (n_splits , n_obs ), 1 ))
424-
425- # split array and labels
426- spatial_splits = tuple (s for s in np .array_split (spatial , n_splits , axis = 0 ) if len (s ))
427- labs_splits = tuple (s for s in np .array_split (labs , n_splits , axis = 0 ) if len (s ))
428- # create idx array including unique combinations and self-comparison
429- x , y = np .triu_indices_from (np .empty ((n_splits , n_splits )))
430- idx_splits = list (zip (x , y , strict = False ))
408+ spatial_x = spatial [:, 0 ]
409+ spatial_y = spatial [:, 1 ]
431410
432- n_jobs = _get_n_cores (n_jobs )
411+ # Compute co-occurrence probabilities using the fast numba routine.
412+ out = _co_occurrence_helper (spatial_x , spatial_y , interval , labs )
433413 start = logg .info (
434- f"Calculating co-occurrence probabilities for `{ len (interval )} ` intervals "
435- f"`{ len (idx_splits )} ` split combinations using `{ n_jobs } ` core(s)"
436- )
437-
438- out_lst = parallelize (
439- _co_occurrence_helper ,
440- collection = idx_splits ,
441- extractor = chain .from_iterable ,
442- n_jobs = n_jobs ,
443- backend = backend ,
444- show_progress_bar = show_progress_bar ,
445- )(
446- spatial_splits = spatial_splits ,
447- labs_splits = labs_splits ,
448- labs_unique = labs_unique ,
449- interval = interval ,
414+ f"Calculating co-occurrence probabilities for `{ len (interval )} ` intervals using `{ n_jobs } ` core(s) and `{ n_splits } ` splits"
450415 )
451- out = list (out_lst )[0 ] if len (idx_splits ) == 1 else sum (list (out_lst )) / len (idx_splits )
452416
453417 if copy :
454418 logg .info ("Finish" , time = start )
455419 return out , interval
456420
457421 _save_data (
458- adata ,
459- attr = "uns" ,
460- key = Key .uns .co_occurrence (cluster_key ),
461- data = {"occ" : out , "interval" : interval },
462- time = start ,
422+ adata , attr = "uns" , key = Key .uns .co_occurrence (cluster_key ), data = {"occ" : out , "interval" : interval }, time = start
463423 )
464424
465425
0 commit comments