Skip to content

Commit 159bb5f

Browse files
authored
Merge branch 'main' into feature/sq-benchmark
2 parents 0df4dba + 32789ae commit 159bb5f

File tree

2 files changed

+82
-122
lines changed

2 files changed

+82
-122
lines changed

src/squidpy/gr/_ppatterns.py

Lines changed: 79 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@
33
from __future__ import annotations
44

55
from collections.abc import Iterable, Sequence
6+
from importlib.util import find_spec
67
from itertools import chain
78
from typing import TYPE_CHECKING, Any, Literal
89

910
import numba.types as nt
1011
import numpy as np
1112
import pandas as pd
1213
from anndata import AnnData
13-
from numba import njit
14+
from numba import njit, prange
1415
from numpy.random import default_rng
1516
from scanpy import logging as logg
1617
from scanpy.get import _get_obs_rep
@@ -266,85 +267,82 @@ def _score_helper(
266267
return score_perms
267268

268269

269-
@njit(
270-
ft[:, :, :](tt(it[:], 2), ft[:, :], it[:], ft[:], bl),
271-
parallel=False,
272-
fastmath=True,
273-
)
270+
@njit(parallel=True, fastmath=True, cache=True)
274271
def _occur_count(
275-
clust: tuple[NDArrayA, NDArrayA],
276-
pw_dist: NDArrayA,
277-
labs_unique: NDArrayA,
278-
interval: NDArrayA,
279-
same_split: bool,
272+
spatial_x: NDArrayA, spatial_y: NDArrayA, thresholds: NDArrayA, label_idx: NDArrayA, n: int, k: int, l_val: int
280273
) -> NDArrayA:
281-
num = labs_unique.shape[0]
282-
out = np.zeros((num, num, interval.shape[0] - 1), dtype=ft)
283-
284-
for idx in range(interval.shape[0] - 1):
285-
co_occur = np.zeros((num, num), dtype=ft)
286-
probs_con = np.zeros((num, num), dtype=ft)
287-
288-
thres_max = interval[idx + 1]
289-
clust_x, clust_y = clust
290-
291-
# Modified to compute co-occurrence probability ratio over increasing radii sizes as opposed to discrete interval bins
292-
# Need pw_dist > 0 to avoid counting a cell with itself as co-occurrence
293-
idx_x, idx_y = np.nonzero((pw_dist <= thres_max) & (pw_dist > 0))
294-
x = clust_x[idx_x]
295-
y = clust_y[idx_y]
296-
# Treat computing co-occurrence using the same split and different splits differently
297-
# Pairwise distance matrix for between the same split is symmetric and therefore only needs to be counted once
298-
for i, j in zip(x, y): # noqa: B905 # cannot use strict=False because of numba
299-
co_occur[i, j] += 1
300-
if not same_split:
301-
co_occur[j, i] += 1
302-
303-
# Prevent divison by zero errors when we have low cell counts/small intervals
304-
probs_matrix = co_occur / np.sum(co_occur) if np.sum(co_occur) != 0 else np.zeros((num, num), dtype=ft)
305-
probs = np.sum(probs_matrix, axis=0)
306-
307-
for c in labs_unique:
308-
probs_conditional = (
309-
co_occur[c] / np.sum(co_occur[c]) if np.sum(co_occur[c]) != 0 else np.zeros(num, dtype=ft)
310-
)
311-
probs_con[c, :] = np.zeros(num, dtype=ft)
312-
for i in range(num):
313-
if probs[i] == 0:
314-
probs_con[c, i] = 0
315-
else:
316-
probs_con[c, i] = probs_conditional[i] / probs[i]
317-
318-
out[:, :, idx] = probs_con
319-
320-
return out
321-
322-
323-
def _co_occurrence_helper(
324-
idx_splits: Iterable[tuple[int, int]],
325-
spatial_splits: Sequence[NDArrayA],
326-
labs_splits: Sequence[NDArrayA],
327-
labs_unique: NDArrayA,
328-
interval: NDArrayA,
329-
queue: SigQueue | None = None,
330-
) -> pd.DataFrame:
331-
out_lst = []
332-
for t in idx_splits:
333-
idx_x, idx_y = t
334-
labs_x = labs_splits[idx_x]
335-
labs_y = labs_splits[idx_y]
336-
dist = pairwise_distances(spatial_splits[idx_x], spatial_splits[idx_y])
274+
# Allocate a 2D array to store a flat local result per point.
275+
k2 = k * k
276+
local_results = np.zeros((n, l_val * k2), dtype=np.int32)
337277

338-
out = _occur_count((labs_x, labs_y), dist, labs_unique, interval, idx_x == idx_y)
339-
out_lst.append(out)
278+
for i in prange(n):
279+
for j in range(n):
280+
if i == j:
281+
continue
282+
dx = spatial_x[i] - spatial_x[j]
283+
dy = spatial_y[i] - spatial_y[j]
284+
d2 = dx * dx + dy * dy
340285

341-
if queue is not None:
342-
queue.put(Signal.UPDATE)
286+
pair = label_idx[i] * k + label_idx[j] # fixed in r–loop
287+
base = pair * l_val # first cell for that pair
343288

344-
if queue is not None:
345-
queue.put(Signal.FINISH)
289+
for r in range(l_val):
290+
if d2 <= thresholds[r]:
291+
local_results[i][base + r] += 1
346292

347-
return out_lst
293+
# reduction and reshape stay the same
294+
result_flat = local_results.sum(axis=0)
295+
result: NDArrayA = result_flat.reshape(k, k, l_val)
296+
297+
return result
298+
299+
300+
@njit(parallel=True, fastmath=True, cache=True)
301+
def _co_occurrence_helper(v_x: NDArrayA, v_y: NDArrayA, v_radium: NDArrayA, labs: NDArrayA) -> NDArrayA:
302+
"""
303+
Fast co-occurrence probability computation using the new numba-accelerated counting.
304+
305+
Parameters
306+
----------
307+
v_x : np.ndarray, float64
308+
x–coordinates.
309+
v_y : np.ndarray, float64
310+
y–coordinates.
311+
v_radium : np.ndarray, float64
312+
Distance thresholds (in ascending order).
313+
labs : np.ndarray
314+
Cluster labels (as integers).
315+
316+
Returns
317+
-------
318+
occ_prob : np.ndarray
319+
A 3D array of shape (k, k, len(v_radium)-1) containing the co-occurrence probabilities.
320+
labs_unique : np.ndarray
321+
Array of unique labels.
322+
"""
323+
n = len(v_x)
324+
labs_unique = np.unique(labs)
325+
k = len(labs_unique)
326+
# l_val is the number of bins; here we assume the thresholds come from v_radium[1:].
327+
l_val = len(v_radium) - 1
328+
# Compute squared thresholds from the interval (skip the first value)
329+
thresholds = (v_radium[1:]) ** 2
330+
331+
# Compute co-occurence counts.
332+
counts = _occur_count(v_x, v_y, thresholds, labs, n, k, l_val)
333+
334+
occ_prob = np.zeros((k, k, l_val), dtype=np.float64)
335+
row_sums = counts.sum(axis=0)
336+
totals = row_sums.sum(axis=0)
337+
338+
for r in prange(l_val):
339+
probs = row_sums[:, r] / totals[r]
340+
for c in range(k):
341+
for i in range(k):
342+
if probs[i] != 0.0 and row_sums[c, r] != 0.0:
343+
occ_prob[i, c, r] = (counts[c, i, r] / row_sums[c, r]) / probs[i]
344+
345+
return occ_prob
348346

349347

350348
@d.dedent
@@ -387,18 +385,16 @@ def co_occurrence(
387385
- :attr:`anndata.AnnData.uns` ``['{cluster_key}_co_occurrence']['interval']`` - the distance thresholds
388386
computed at ``interval``.
389387
"""
388+
390389
if isinstance(adata, SpatialData):
391390
adata = adata.table
392391
_assert_categorical_obs(adata, key=cluster_key)
393392
_assert_spatial_basis(adata, key=spatial_key)
394393

395394
spatial = adata.obsm[spatial_key].astype(fp)
396395
original_clust = adata.obs[cluster_key]
397-
398-
# annotate cluster idx
399396
clust_map = {v: i for i, v in enumerate(original_clust.cat.categories.values)}
400397
labs = np.array([clust_map[c] for c in original_clust], dtype=ip)
401-
labs_unique = np.array(list(clust_map.values()), dtype=ip)
402398

403399
# create intervals thresholds
404400
if isinstance(interval, int):
@@ -409,57 +405,21 @@ def co_occurrence(
409405
if len(interval) <= 1:
410406
raise ValueError(f"Expected interval to be of length `>= 2`, found `{len(interval)}`.")
411407

412-
n_obs = spatial.shape[0]
413-
if n_splits is None:
414-
size_arr = (n_obs**2 * spatial.itemsize) / 1024 / 1024 # calc expected mem usage
415-
n_splits = 1
416-
if size_arr > 2000:
417-
while (n_obs / n_splits) > 2048:
418-
n_splits += 1
419-
logg.warning(
420-
f"`n_splits` was automatically set to `{n_splits}` to "
421-
f"prevent `{n_obs}x{n_obs}` distance matrix from being created"
422-
)
423-
n_splits = int(max(min(n_splits, n_obs), 1))
424-
425-
# split array and labels
426-
spatial_splits = tuple(s for s in np.array_split(spatial, n_splits, axis=0) if len(s))
427-
labs_splits = tuple(s for s in np.array_split(labs, n_splits, axis=0) if len(s))
428-
# create idx array including unique combinations and self-comparison
429-
x, y = np.triu_indices_from(np.empty((n_splits, n_splits)))
430-
idx_splits = list(zip(x, y, strict=False))
408+
spatial_x = spatial[:, 0]
409+
spatial_y = spatial[:, 1]
431410

432-
n_jobs = _get_n_cores(n_jobs)
411+
# Compute co-occurrence probabilities using the fast numba routine.
412+
out = _co_occurrence_helper(spatial_x, spatial_y, interval, labs)
433413
start = logg.info(
434-
f"Calculating co-occurrence probabilities for `{len(interval)}` intervals "
435-
f"`{len(idx_splits)}` split combinations using `{n_jobs}` core(s)"
436-
)
437-
438-
out_lst = parallelize(
439-
_co_occurrence_helper,
440-
collection=idx_splits,
441-
extractor=chain.from_iterable,
442-
n_jobs=n_jobs,
443-
backend=backend,
444-
show_progress_bar=show_progress_bar,
445-
)(
446-
spatial_splits=spatial_splits,
447-
labs_splits=labs_splits,
448-
labs_unique=labs_unique,
449-
interval=interval,
414+
f"Calculating co-occurrence probabilities for `{len(interval)}` intervals using `{n_jobs}` core(s) and `{n_splits}` splits"
450415
)
451-
out = list(out_lst)[0] if len(idx_splits) == 1 else sum(list(out_lst)) / len(idx_splits)
452416

453417
if copy:
454418
logg.info("Finish", time=start)
455419
return out, interval
456420

457421
_save_data(
458-
adata,
459-
attr="uns",
460-
key=Key.uns.co_occurrence(cluster_key),
461-
data={"occ": out, "interval": interval},
462-
time=start,
422+
adata, attr="uns", key=Key.uns.co_occurrence(cluster_key), data={"occ": out, "interval": interval}, time=start
463423
)
464424

465425

src/squidpy/im/_io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _assert_dims_present(dims: tuple[str, ...], include_z: bool = True) -> None:
2525

2626
# modification of `skimage`'s `pil_to_ndarray`:
2727
# https://github.com/scikit-image/scikit-image/blob/main/skimage/io/_plugins/pil_plugin.py#L55
28-
def _infer_shape_dtype(fname: str) -> tuple[tuple[int, ...], np.dtype]: # type: ignore[type-arg]
28+
def _infer_shape_dtype(fname: str) -> tuple[tuple[int, ...], np.dtype]:
2929
def _palette_is_grayscale(pil_image: Image.Image) -> bool:
3030
# get palette as an array with R, G, B columns
3131
palette = np.asarray(pil_image.getpalette()).reshape((256, 3))
@@ -81,7 +81,7 @@ def _palette_is_grayscale(pil_image: Image.Image) -> bool:
8181
raise ValueError(f"Unable to infer image dtype for image mode `{image.mode}`.")
8282

8383

84-
def _get_image_shape_dtype(fname: str) -> tuple[tuple[int, ...], np.dtype]: # type: ignore[type-arg]
84+
def _get_image_shape_dtype(fname: str) -> tuple[tuple[int, ...], np.dtype]:
8585
try:
8686
return _infer_shape_dtype(fname)
8787
except Image.UnidentifiedImageError as e:
@@ -101,7 +101,7 @@ def _get_image_shape_dtype(fname: str) -> tuple[tuple[int, ...], np.dtype]: # t
101101
def _infer_dimensions(
102102
obj: NDArrayA | xr.DataArray | str,
103103
infer_dimensions: InferDimensions | tuple[str, ...] = InferDimensions.DEFAULT,
104-
) -> tuple[tuple[int, ...], tuple[str, ...], np.dtype, tuple[int, ...]]: # type: ignore[type-arg]
104+
) -> tuple[tuple[int, ...], tuple[str, ...], np.dtype, tuple[int, ...]]:
105105
"""
106106
Infer dimension names of an array.
107107

0 commit comments

Comments
 (0)