Skip to content

Commit 21aecd9

Browse files
authored
Simplify score_genes (#3097)
1 parent 4f40d68 commit 21aecd9

File tree

4 files changed

+43
-58
lines changed

4 files changed

+43
-58
lines changed

scanpy/_utils/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,11 +255,7 @@ def _check_use_raw(adata: AnnData, use_raw: None | bool) -> bool:
255255
"""
256256
if use_raw is not None:
257257
return use_raw
258-
else:
259-
if adata.raw is not None:
260-
return True
261-
else:
262-
return False
258+
return adata.raw is not None
263259

264260

265261
# --------------------------------------------------------------------------------

scanpy/get/get.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -433,22 +433,21 @@ def _get_obs_rep(
433433
is_obsm = obsm is not None
434434
is_obsp = obsp is not None
435435
choices_made = sum((is_layer, is_raw, is_obsm, is_obsp))
436-
assert choices_made <= 1
436+
assert choices_made in {0, 1}
437437
if choices_made == 0:
438438
return adata.X
439-
elif is_layer:
439+
if is_layer:
440440
return adata.layers[layer]
441-
elif use_raw:
441+
if use_raw:
442442
return adata.raw.X
443-
elif is_obsm:
443+
if is_obsm:
444444
return adata.obsm[obsm]
445-
elif is_obsp:
445+
if is_obsp:
446446
return adata.obsp[obsp]
447-
else:
448-
assert False, (
449-
"That was unexpected. Please report this bug at:\n\n\t"
450-
" https://github.com/scverse/scanpy/issues"
451-
)
447+
raise AssertionError(
448+
"That was unexpected. Please report this bug at:\n\n\t"
449+
"https://github.com/scverse/scanpy/issues"
450+
)
452451

453452

454453
def _set_obs_rep(

scanpy/tests/test_score_genes.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
if TYPE_CHECKING:
1616
from typing import Literal
1717

18-
HERE = Path(__file__).parent / Path("_data/")
18+
from numpy.typing import NDArray
1919

2020

21-
def _create_random_gene_names(n_genes, name_length):
21+
HERE = Path(__file__).parent / "_data"
22+
23+
24+
def _create_random_gene_names(n_genes, name_length) -> NDArray[np.str_]:
2225
"""
2326
creates a bunch of random gene names (just CAPS letters)
2427
"""
@@ -68,7 +71,7 @@ def test_score_with_reference():
6871
sc.pp.scale(adata)
6972

7073
sc.tl.score_genes(adata, gene_list=adata.var_names[:100], score_name="Test")
71-
with Path(HERE, "score_genes_reference_paul2015.pkl").open("rb") as file:
74+
with (HERE / "score_genes_reference_paul2015.pkl").open("rb") as file:
7275
reference = pickle.load(file)
7376
# np.testing.assert_allclose(reference, adata.obs["Test"].to_numpy())
7477
np.testing.assert_array_equal(reference, adata.obs["Test"].to_numpy())

scanpy/tools/_score_genes.py

Lines changed: 27 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212

1313
from .. import logging as logg
1414
from .._compat import old_positionals
15+
from ..get import _get_obs_rep
1516

1617
if TYPE_CHECKING:
1718
from collections.abc import Sequence
1819
from typing import Literal
1920

2021
from anndata import AnnData
21-
from numpy.typing import NDArray
22+
from numpy.typing import DTypeLike, NDArray
2223
from scipy.sparse import csc_matrix, csr_matrix
2324

2425
from .._utils import AnyRandom
@@ -143,20 +144,16 @@ def score_genes(
143144
# Basically we need to compare genes against random genes in a matched
144145
# interval of expression.
145146

146-
_adata = adata.raw if use_raw else adata
147-
_adata_subset = (
148-
_adata[:, gene_pool] if len(gene_pool) < len(_adata.var_names) else _adata
149-
)
150-
# average expression of genes
151-
if issparse(_adata_subset.X):
152-
obs_avg = pd.Series(
153-
np.array(_sparse_nanmean(_adata_subset.X, axis=0)).flatten(),
154-
index=gene_pool,
155-
)
156-
else:
157-
obs_avg = pd.Series(np.nanmean(_adata_subset.X, axis=0), index=gene_pool)
147+
def get_subset(genes: pd.Index[str]):
148+
x = _get_obs_rep(adata, use_raw=use_raw)
149+
if len(genes) == len(var_names):
150+
return x
151+
idx = var_names.get_indexer(genes)
152+
return x[:, idx]
158153

159-
# Sometimes (and I don't know how) missing data may be there, with nansfor
154+
# average expression of genes
155+
obs_avg = pd.Series(_nan_means(get_subset(gene_pool), axis=0), index=gene_pool)
156+
# Sometimes (and I don’t know how) missing data may be there, with NaNs for missing entries
160157
obs_avg = obs_avg[np.isfinite(obs_avg)]
161158

162159
n_items = int(np.round(len(obs_avg) / (n_bins - 1)))
@@ -170,19 +167,11 @@ def score_genes(
170167
r_genes = r_genes.to_series().sample(ctrl_size).index
171168
control_genes = control_genes.union(r_genes.difference(gene_list))
172169

173-
X_list = _adata[:, gene_list].X
174-
if issparse(X_list):
175-
X_list = np.array(_sparse_nanmean(X_list, axis=1)).flatten()
176-
else:
177-
X_list = np.nanmean(X_list, axis=1, dtype="float64")
178-
179-
X_control = _adata[:, control_genes].X
180-
if issparse(X_control):
181-
X_control = np.array(_sparse_nanmean(X_control, axis=1)).flatten()
182-
else:
183-
X_control = np.nanmean(X_control, axis=1, dtype="float64")
184-
185-
score = X_list - X_control
170+
means_list, means_control = (
171+
_nan_means(get_subset(genes), axis=1, dtype="float64")
172+
for genes in (gene_list, control_genes)
173+
)
174+
score = means_list - means_control
186175

187176
adata.obs[score_name] = pd.Series(
188177
np.array(score).ravel(), index=adata.obs_names, dtype="float64"
@@ -200,6 +189,14 @@ def score_genes(
200189
return adata if copy else None
201190

202191

192+
def _nan_means(
193+
x, *, axis: Literal[0, 1], dtype: DTypeLike | None = None
194+
) -> NDArray[np.float64]:
195+
if issparse(x):
196+
return np.array(_sparse_nanmean(x, axis=axis)).flatten()
197+
return np.nanmean(x, axis=axis, dtype=dtype)
198+
199+
203200
@old_positionals("s_genes", "g2m_genes", "copy")
204201
def score_genes_cell_cycle(
205202
adata: AnnData,
@@ -253,25 +250,15 @@ def score_genes_cell_cycle(
253250

254251
adata = adata.copy() if copy else adata
255252
ctrl_size = min(len(s_genes), len(g2m_genes))
256-
# add s-score
257-
score_genes(
258-
adata, gene_list=s_genes, score_name="S_score", ctrl_size=ctrl_size, **kwargs
259-
)
260-
# add g2m-score
261-
score_genes(
262-
adata,
263-
gene_list=g2m_genes,
264-
score_name="G2M_score",
265-
ctrl_size=ctrl_size,
266-
**kwargs,
267-
)
253+
for genes, name in [(s_genes, "S_score"), (g2m_genes, "G2M_score")]:
254+
score_genes(adata, genes, score_name=name, ctrl_size=ctrl_size, **kwargs)
268255
scores = adata.obs[["S_score", "G2M_score"]]
269256

270257
# default phase is S
271258
phase = pd.Series("S", index=scores.index)
272259

273260
# if G2M is higher than S, it's G2M
274-
phase[scores.G2M_score > scores.S_score] = "G2M"
261+
phase[scores["G2M_score"] > scores["S_score"]] = "G2M"
275262

276263
# if all scores are negative, it's G1...
277264
phase[np.all(scores < 0, axis=1)] = "G1"

0 commit comments

Comments
 (0)