Skip to content

Commit cdb443b

Browse files
committed
Add types
1 parent a36b33c commit cdb443b

File tree

1 file changed

+7
-21
lines changed

1 file changed

+7
-21
lines changed

src/scanpy/tools/_score_genes.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
if TYPE_CHECKING:
1717
from collections.abc import Callable, Generator, Sequence
18-
from typing import Literal, Tuple
18+
from typing import Literal, Tuple, Any
1919

2020
from anndata import AnnData
2121
from numpy.typing import DTypeLike, NDArray
@@ -29,7 +29,7 @@
2929
_GetSubset = Callable[[_StrIdx], np.ndarray | CSBase]
3030

3131

32-
def _get_mean_columns(data, indicies: NDArray[np.int32], shape: Tuple) -> NDArray[np.float64]:
32+
def _get_mean_columns(data: NDArray[Any], indicies: NDArray[np.int32], shape: Tuple) -> NDArray[np.float64]:
3333
sums = np.zeros(shape[1], dtype=np.float64)
3434
counts = np.repeat(float(shape[0]), shape[1])
3535
for data_index in numba.prange(len(data)):
@@ -43,7 +43,7 @@ def _get_mean_columns(data, indicies: NDArray[np.int32], shape: Tuple) -> NDArra
4343

4444

4545
@njit
46-
def _get_mean_rows(data, indptr: NDArray[np.int32], shape: Tuple) -> NDArray[np.float64]:
46+
def _get_mean_rows(data: NDArray[Any], indptr: NDArray[np.int32], shape: Tuple) -> NDArray[np.float64]:
4747
sums = np.zeros(shape[0], dtype=np.float64)
4848
counts = np.repeat(float(shape[1]), shape[0])
4949
for cur_row_index in numba.prange(shape[0]):
@@ -64,24 +64,10 @@ def _sparse_nanmean(X: CSBase, axis: Literal[0, 1]) -> NDArray[np.float64]:
6464
msg = "X must be a compressed sparse matrix"
6565
raise TypeError(msg)
6666

67-
Z = X.copy()
68-
69-
# count the number of nonzero elements (include nans) per row/column (dep. on axis)
70-
nonzeros_and_nones = Z.count_nonzero(axis=axis)
71-
72-
# just sum the data withput nan
73-
Z.data[np.isnan(Z.data)] = 0
74-
Z.eliminate_zeros()
75-
s = Z.sum(axis, dtype="float64")
76-
77-
# Z.count_nonzero(axis=axis) is now non-zero not-nan elements in X
78-
# diff between nonzeros_and_nones and curr nonzero is nans
79-
n_elements = (
80-
Z.shape[axis] - (nonzeros_and_nones - Z.count_nonzero(axis=axis))
81-
).reshape(s.shape, copy=False)
82-
m = s / n_elements
83-
84-
return m
67+
if axis==1:
68+
return _get_mean_rows(X.data, X.indptr, X.shape)
69+
else:
70+
return _get_mean_columns(X.data, X.indices, X.shape)
8571

8672

8773
@old_positionals(

0 commit comments

Comments
 (0)