|
6 | 6 |
|
7 | 7 | import numpy as np |
8 | 8 | import pandas as pd |
| 9 | +import numba |
9 | 10 |
|
10 | 11 | from .. import logging as logg |
11 | | -from .._compat import CSBase, old_positionals |
| 12 | +from .._compat import CSBase, old_positionals, njit |
12 | 13 | from .._utils import _check_use_raw, is_backed_type |
13 | 14 | from ..get import _get_obs_rep |
14 | 15 |
|
15 | 16 | if TYPE_CHECKING: |
16 | 17 | from collections.abc import Callable, Generator, Sequence |
17 | | - from typing import Literal |
| 18 | + from typing import Literal, Tuple |
18 | 19 |
|
19 | 20 | from anndata import AnnData |
20 | 21 | from numpy.typing import DTypeLike, NDArray |
|
28 | 29 | _GetSubset = Callable[[_StrIdx], np.ndarray | CSBase] |
29 | 30 |
|
30 | 31 |
|
| 32 | +def _get_mean_columns(data, indicies: NDArray[np.int32], shape: Tuple) -> NDArray[np.float64]: |
| 33 | + sums = np.zeros(shape[1], dtype=np.float64) |
| 34 | + counts = np.repeat(float(shape[0]), shape[1]) |
| 35 | + for data_index in numba.prange(len(data)): |
| 36 | + if np.isnan(data[data_index]): |
| 37 | + counts[indicies[data_index]] -= 1.0 |
| 38 | + continue |
| 39 | + sums[indicies[data_index]] += data[data_index] |
| 40 | + #if we have row column nans return nan (not inf) |
| 41 | + counts[counts == 0.0] = np.nan |
| 42 | + return sums/counts |
| 43 | + |
| 44 | + |
| 45 | +@njit |
| 46 | +def _get_mean_rows(data, indptr: NDArray[np.int32], shape: Tuple) -> NDArray[np.float64]: |
| 47 | + sums = np.zeros(shape[0], dtype=np.float64) |
| 48 | + counts = np.repeat(float(shape[1]), shape[0]) |
| 49 | + for cur_row_index in numba.prange(shape[0]): |
| 50 | + for data_index in numba.prange(indptr[cur_row_index], indptr[cur_row_index + 1]): |
| 51 | + if np.isnan(data[data_index]): |
| 52 | + counts[cur_row_index] -= 1.0 |
| 53 | + continue |
| 54 | + sums[cur_row_index] += data[data_index] |
| 55 | + #if we have row from nans return nan (not inf) |
| 56 | + counts[counts == 0.0] = np.nan |
| 57 | + return sums/counts |
| 58 | + |
| 59 | + |
| 60 | +@njit |
31 | 61 | def _sparse_nanmean(X: CSBase, axis: Literal[0, 1]) -> NDArray[np.float64]: |
32 | 62 | """np.nanmean equivalent for sparse matrices.""" |
33 | 63 | if not isinstance(X, CSBase): |
|
0 commit comments