1212
1313from .. import logging as logg
1414from .._compat import old_positionals
15+ from ..get import _get_obs_rep
1516
1617if TYPE_CHECKING :
1718 from collections .abc import Sequence
1819 from typing import Literal
1920
2021 from anndata import AnnData
21- from numpy .typing import NDArray
22+ from numpy .typing import DTypeLike , NDArray
2223 from scipy .sparse import csc_matrix , csr_matrix
2324
2425 from .._utils import AnyRandom
@@ -143,20 +144,16 @@ def score_genes(
143144 # Basically we need to compare genes against random genes in a matched
144145 # interval of expression.
145146
146- _adata = adata .raw if use_raw else adata
147- _adata_subset = (
148- _adata [:, gene_pool ] if len (gene_pool ) < len (_adata .var_names ) else _adata
149- )
150- # average expression of genes
151- if issparse (_adata_subset .X ):
152- obs_avg = pd .Series (
153- np .array (_sparse_nanmean (_adata_subset .X , axis = 0 )).flatten (),
154- index = gene_pool ,
155- )
156- else :
157- obs_avg = pd .Series (np .nanmean (_adata_subset .X , axis = 0 ), index = gene_pool )
147+ def get_subset (genes : pd .Index [str ]):
148+ x = _get_obs_rep (adata , use_raw = use_raw )
149+ if len (genes ) == len (var_names ):
150+ return x
151+ idx = var_names .get_indexer (genes )
152+ return x [:, idx ]
158153
159- # Sometimes (and I don't know how) missing data may be there, with nansfor
154+ # average expression of genes
155+ obs_avg = pd .Series (_nan_means (get_subset (gene_pool ), axis = 0 ), index = gene_pool )
156+ # Sometimes (and I don’t know how) missing data may be there, with NaNs for missing entries
160157 obs_avg = obs_avg [np .isfinite (obs_avg )]
161158
162159 n_items = int (np .round (len (obs_avg ) / (n_bins - 1 )))
@@ -170,19 +167,11 @@ def score_genes(
170167 r_genes = r_genes .to_series ().sample (ctrl_size ).index
171168 control_genes = control_genes .union (r_genes .difference (gene_list ))
172169
173- X_list = _adata [:, gene_list ].X
174- if issparse (X_list ):
175- X_list = np .array (_sparse_nanmean (X_list , axis = 1 )).flatten ()
176- else :
177- X_list = np .nanmean (X_list , axis = 1 , dtype = "float64" )
178-
179- X_control = _adata [:, control_genes ].X
180- if issparse (X_control ):
181- X_control = np .array (_sparse_nanmean (X_control , axis = 1 )).flatten ()
182- else :
183- X_control = np .nanmean (X_control , axis = 1 , dtype = "float64" )
184-
185- score = X_list - X_control
170+ means_list , means_control = (
171+ _nan_means (get_subset (genes ), axis = 1 , dtype = "float64" )
172+ for genes in (gene_list , control_genes )
173+ )
174+ score = means_list - means_control
186175
187176 adata .obs [score_name ] = pd .Series (
188177 np .array (score ).ravel (), index = adata .obs_names , dtype = "float64"
@@ -200,6 +189,14 @@ def score_genes(
200189 return adata if copy else None
201190
202191
192+ def _nan_means (
193+ x , * , axis : Literal [0 , 1 ], dtype : DTypeLike | None = None
194+ ) -> NDArray [np .float64 ]:
195+ if issparse (x ):
196+ return np .array (_sparse_nanmean (x , axis = axis )).flatten ()
197+ return np .nanmean (x , axis = axis , dtype = dtype )
198+
199+
203200@old_positionals ("s_genes" , "g2m_genes" , "copy" )
204201def score_genes_cell_cycle (
205202 adata : AnnData ,
@@ -253,25 +250,15 @@ def score_genes_cell_cycle(
253250
254251 adata = adata .copy () if copy else adata
255252 ctrl_size = min (len (s_genes ), len (g2m_genes ))
256- # add s-score
257- score_genes (
258- adata , gene_list = s_genes , score_name = "S_score" , ctrl_size = ctrl_size , ** kwargs
259- )
260- # add g2m-score
261- score_genes (
262- adata ,
263- gene_list = g2m_genes ,
264- score_name = "G2M_score" ,
265- ctrl_size = ctrl_size ,
266- ** kwargs ,
267- )
253+ for genes , name in [(s_genes , "S_score" ), (g2m_genes , "G2M_score" )]:
254+ score_genes (adata , genes , score_name = name , ctrl_size = ctrl_size , ** kwargs )
268255 scores = adata .obs [["S_score" , "G2M_score" ]]
269256
270257 # default phase is S
271258 phase = pd .Series ("S" , index = scores .index )
272259
273260 # if G2M is higher than S, it's G2M
274- phase [scores . G2M_score > scores . S_score ] = "G2M"
261+ phase [scores [ " G2M_score" ] > scores [ " S_score" ] ] = "G2M"
275262
276263 # if all scores are negative, it's G1...
277264 phase [np .all (scores < 0 , axis = 1 )] = "G1"
0 commit comments