Skip to content

Commit a36b33c

Browse files
committed
rewrite logics with numba (for scipy <1.15.0)
1 parent 2882948 commit a36b33c

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

src/scanpy/tools/_score_genes.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66

77
import numpy as np
88
import pandas as pd
9+
import numba
910

1011
from .. import logging as logg
11-
from .._compat import CSBase, old_positionals
12+
from .._compat import CSBase, old_positionals, njit
1213
from .._utils import _check_use_raw, is_backed_type
1314
from ..get import _get_obs_rep
1415

1516
if TYPE_CHECKING:
1617
from collections.abc import Callable, Generator, Sequence
17-
from typing import Literal
18+
from typing import Literal, Tuple
1819

1920
from anndata import AnnData
2021
from numpy.typing import DTypeLike, NDArray
@@ -28,6 +29,35 @@
2829
_GetSubset = Callable[[_StrIdx], np.ndarray | CSBase]
2930

3031

32+
def _get_mean_columns(data, indicies: NDArray[np.int32], shape: Tuple) -> NDArray[np.float64]:
33+
sums = np.zeros(shape[1], dtype=np.float64)
34+
counts = np.repeat(float(shape[0]), shape[1])
35+
for data_index in numba.prange(len(data)):
36+
if np.isnan(data[data_index]):
37+
counts[indicies[data_index]] -= 1.0
38+
continue
39+
sums[indicies[data_index]] += data[data_index]
40+
#if we have row column nans return nan (not inf)
41+
counts[counts == 0.0] = np.nan
42+
return sums/counts
43+
44+
45+
@njit
46+
def _get_mean_rows(data, indptr: NDArray[np.int32], shape: Tuple) -> NDArray[np.float64]:
47+
sums = np.zeros(shape[0], dtype=np.float64)
48+
counts = np.repeat(float(shape[1]), shape[0])
49+
for cur_row_index in numba.prange(shape[0]):
50+
for data_index in numba.prange(indptr[cur_row_index], indptr[cur_row_index + 1]):
51+
if np.isnan(data[data_index]):
52+
counts[cur_row_index] -= 1.0
53+
continue
54+
sums[cur_row_index] += data[data_index]
55+
#if we have row from nans return nan (not inf)
56+
counts[counts == 0.0] = np.nan
57+
return sums/counts
58+
59+
60+
@njit
3161
def _sparse_nanmean(X: CSBase, axis: Literal[0, 1]) -> NDArray[np.float64]:
3262
"""np.nanmean equivalent for sparse matrices."""
3363
if not isinstance(X, CSBase):

0 commit comments

Comments
 (0)