1515
1616if TYPE_CHECKING :
1717 from collections .abc import Callable , Generator , Sequence
18- from typing import Literal , Tuple
18+ from typing import Literal , Tuple , Any
1919
2020 from anndata import AnnData
2121 from numpy .typing import DTypeLike , NDArray
2929 _GetSubset = Callable [[_StrIdx ], np .ndarray | CSBase ]
3030
3131
32- def _get_mean_columns (data , indicies : NDArray [np .int32 ], shape : Tuple ) -> NDArray [np .float64 ]:
32+ def _get_mean_columns (data : NDArray [ Any ] , indicies : NDArray [np .int32 ], shape : Tuple ) -> NDArray [np .float64 ]:
3333 sums = np .zeros (shape [1 ], dtype = np .float64 )
3434 counts = np .repeat (float (shape [0 ]), shape [1 ])
3535 for data_index in numba .prange (len (data )):
@@ -43,7 +43,7 @@ def _get_mean_columns(data, indicies: NDArray[np.int32], shape: Tuple) -> NDArra
4343
4444
4545@njit
46- def _get_mean_rows (data , indptr : NDArray [np .int32 ], shape : Tuple ) -> NDArray [np .float64 ]:
46+ def _get_mean_rows (data : NDArray [ Any ] , indptr : NDArray [np .int32 ], shape : Tuple ) -> NDArray [np .float64 ]:
4747 sums = np .zeros (shape [0 ], dtype = np .float64 )
4848 counts = np .repeat (float (shape [1 ]), shape [0 ])
4949 for cur_row_index in numba .prange (shape [0 ]):
@@ -64,24 +64,10 @@ def _sparse_nanmean(X: CSBase, axis: Literal[0, 1]) -> NDArray[np.float64]:
6464 msg = "X must be a compressed sparse matrix"
6565 raise TypeError (msg )
6666
67- Z = X .copy ()
68-
69- # count the number of nonzero elements (include nans) per row/column (dep. on axis)
70- nonzeros_and_nones = Z .count_nonzero (axis = axis )
71-
72- # just sum the data withput nan
73- Z .data [np .isnan (Z .data )] = 0
74- Z .eliminate_zeros ()
75- s = Z .sum (axis , dtype = "float64" )
76-
77- # Z.count_nonzero(axis=axis) is now non-zero not-nan elements in X
78- # diff between nonzeros_and_nones and curr nonzero is nans
79- n_elements = (
80- Z .shape [axis ] - (nonzeros_and_nones - Z .count_nonzero (axis = axis ))
81- ).reshape (s .shape , copy = False )
82- m = s / n_elements
83-
84- return m
67+ if axis == 1 :
68+ return _get_mean_rows (X .data , X .indptr , X .shape )
69+ else :
70+ return _get_mean_columns (X .data , X .indices , X .shape )
8571
8672
8773@old_positionals (
0 commit comments