Skip to content

Commit 8021f8c

Browse files
pre-commit-ci[bot]Reovirus
authored andcommitted
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent cdb443b commit 8021f8c

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

src/scanpy/tools/_score_genes.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44

55
from typing import TYPE_CHECKING
66

7+
import numba
78
import numpy as np
89
import pandas as pd
9-
import numba
1010

1111
from .. import logging as logg
12-
from .._compat import CSBase, old_positionals, njit
12+
from .._compat import CSBase, njit, old_positionals
1313
from .._utils import _check_use_raw, is_backed_type
1414
from ..get import _get_obs_rep
1515

1616
if TYPE_CHECKING:
1717
from collections.abc import Callable, Generator, Sequence
18-
from typing import Literal, Tuple, Any
18+
from typing import Any, Literal
1919

2020
from anndata import AnnData
2121
from numpy.typing import DTypeLike, NDArray
@@ -29,32 +29,38 @@
2929
_GetSubset = Callable[[_StrIdx], np.ndarray | CSBase]
3030

3131

32-
def _get_mean_columns(data: NDArray[Any], indicies: NDArray[np.int32], shape: Tuple) -> NDArray[np.float64]:
32+
def _get_mean_columns(
33+
data: NDArray[Any], indicies: NDArray[np.int32], shape: tuple
34+
) -> NDArray[np.float64]:
3335
sums = np.zeros(shape[1], dtype=np.float64)
3436
counts = np.repeat(float(shape[0]), shape[1])
3537
for data_index in numba.prange(len(data)):
3638
if np.isnan(data[data_index]):
3739
counts[indicies[data_index]] -= 1.0
3840
continue
3941
sums[indicies[data_index]] += data[data_index]
40-
#if we have row column nans return nan (not inf)
42+
# if we have row column nans return nan (not inf)
4143
counts[counts == 0.0] = np.nan
42-
return sums/counts
44+
return sums / counts
45+
4346

44-
4547
@njit
46-
def _get_mean_rows(data: NDArray[Any], indptr: NDArray[np.int32], shape: Tuple) -> NDArray[np.float64]:
48+
def _get_mean_rows(
49+
data: NDArray[Any], indptr: NDArray[np.int32], shape: tuple
50+
) -> NDArray[np.float64]:
4751
sums = np.zeros(shape[0], dtype=np.float64)
4852
counts = np.repeat(float(shape[1]), shape[0])
4953
for cur_row_index in numba.prange(shape[0]):
50-
for data_index in numba.prange(indptr[cur_row_index], indptr[cur_row_index + 1]):
54+
for data_index in numba.prange(
55+
indptr[cur_row_index], indptr[cur_row_index + 1]
56+
):
5157
if np.isnan(data[data_index]):
5258
counts[cur_row_index] -= 1.0
5359
continue
5460
sums[cur_row_index] += data[data_index]
55-
#if we have row from nans return nan (not inf)
61+
# if we have row from nans return nan (not inf)
5662
counts[counts == 0.0] = np.nan
57-
return sums/counts
63+
return sums / counts
5864

5965

6066
@njit
@@ -64,7 +70,7 @@ def _sparse_nanmean(X: CSBase, axis: Literal[0, 1]) -> NDArray[np.float64]:
6470
msg = "X must be a compressed sparse matrix"
6571
raise TypeError(msg)
6672

67-
if axis==1:
73+
if axis == 1:
6874
return _get_mean_rows(X.data, X.indptr, X.shape)
6975
else:
7076
return _get_mean_columns(X.data, X.indices, X.shape)

0 commit comments

Comments
 (0)