44
55from typing import TYPE_CHECKING
66
7+ import numba
78import numpy as np
89import pandas as pd
9- import numba
1010
1111from .. import logging as logg
12- from .._compat import CSBase , old_positionals , njit
12+ from .._compat import CSBase , njit , old_positionals
1313from .._utils import _check_use_raw , is_backed_type
1414from ..get import _get_obs_rep
1515
1616if TYPE_CHECKING :
1717 from collections .abc import Callable , Generator , Sequence
18- from typing import Literal , Tuple , Any
18+ from typing import Any , Literal
1919
2020 from anndata import AnnData
2121 from numpy .typing import DTypeLike , NDArray
2929 _GetSubset = Callable [[_StrIdx ], np .ndarray | CSBase ]
3030
3131
32- def _get_mean_columns (data : NDArray [Any ], indicies : NDArray [np .int32 ], shape : Tuple ) -> NDArray [np .float64 ]:
32+ def _get_mean_columns (
33+ data : NDArray [Any ], indicies : NDArray [np .int32 ], shape : tuple
34+ ) -> NDArray [np .float64 ]:
3335 sums = np .zeros (shape [1 ], dtype = np .float64 )
3436 counts = np .repeat (float (shape [0 ]), shape [1 ])
3537 for data_index in numba .prange (len (data )):
3638 if np .isnan (data [data_index ]):
3739 counts [indicies [data_index ]] -= 1.0
3840 continue
3941 sums [indicies [data_index ]] += data [data_index ]
40- #if we have row column nans return nan (not inf)
42+ # if we have row column nans return nan (not inf)
4143 counts [counts == 0.0 ] = np .nan
42- return sums / counts
44+ return sums / counts
45+
4346
44-
4547@njit
46- def _get_mean_rows (data : NDArray [Any ], indptr : NDArray [np .int32 ], shape : Tuple ) -> NDArray [np .float64 ]:
48+ def _get_mean_rows (
49+ data : NDArray [Any ], indptr : NDArray [np .int32 ], shape : tuple
50+ ) -> NDArray [np .float64 ]:
4751 sums = np .zeros (shape [0 ], dtype = np .float64 )
4852 counts = np .repeat (float (shape [1 ]), shape [0 ])
4953 for cur_row_index in numba .prange (shape [0 ]):
50- for data_index in numba .prange (indptr [cur_row_index ], indptr [cur_row_index + 1 ]):
54+ for data_index in numba .prange (
55+ indptr [cur_row_index ], indptr [cur_row_index + 1 ]
56+ ):
5157 if np .isnan (data [data_index ]):
5258 counts [cur_row_index ] -= 1.0
5359 continue
5460 sums [cur_row_index ] += data [data_index ]
55- #if we have row from nans return nan (not inf)
61+ # if we have row from nans return nan (not inf)
5662 counts [counts == 0.0 ] = np .nan
57- return sums / counts
63+ return sums / counts
5864
5965
6066@njit
@@ -64,7 +70,7 @@ def _sparse_nanmean(X: CSBase, axis: Literal[0, 1]) -> NDArray[np.float64]:
6470 msg = "X must be a compressed sparse matrix"
6571 raise TypeError (msg )
6672
67- if axis == 1 :
73+ if axis == 1 :
6874 return _get_mean_rows (X .data , X .indptr , X .shape )
6975 else :
7076 return _get_mean_columns (X .data , X .indices , X .shape )
0 commit comments