@@ -175,9 +175,10 @@ def evaluate_expression_transfer(
175175 self ,
176176 layer_key : str = "X" ,
177177 method : Literal ["pearson" , "spearman" , "js" , "rmse" ] = "pearson" ,
178+ groupby : str | None = None ,
178179 ) -> None :
179180 """
180- Evaluate the agreement between imputed and original expression in the query dataset.
181+ Evaluate the agreement between imputed and original expression in the query dataset, optionally per group .
181182
182183 These metrics are inspired by Li et al., Nature Methods 2022 (https://www.nature.com/articles/s41592-022-01480-9).
183184
@@ -186,59 +187,69 @@ def evaluate_expression_transfer(
186187 layer_key
187188 Key in `self.query.layers` to use as the original expression. Use "X" to use `self.query.X`.
188189 method
189- Method to quantify agreement. Currently supported: "pearson" (average Pearson correlation across shared genes).
190+ Method to quantify agreement. Supported: "pearson", "spearman", "js", "rmse".
191+ groupby
192+ Column in self.query.obs to group query cells by (e.g., cell type, batch). If None, computes a single score for all query cells.
190193
191194 Returns
192195 -------
193196 Nothing, but updates the following attributes:
194197 expression_transfer_metrics
195- Dictionary containing the average correlation and number of genes used for the evaluation.
198+ Dictionary containing the average metric and number of genes used for the evaluation.
199+ query.var[f"metric_{method}"]
200+ Per-gene metric values (overall, across all cells).
201+ query.varm[f"metric_{method}"]
202+ Per-gene, per-group metric values (if groupby is provided).
196203 """
197204 imputed_x , original_x , shared_genes = self ._get_aligned_expression_arrays (layer_key )
198- log_msg = "Expression transfer evaluation (%s): average value = %.4f (n_genes=%d, n_valid_genes=%d)"
199- if method == "pearson" or method == "spearman" :
200- if method == "pearson" :
201- corr_func = pearsonr
202- elif method == "spearman" :
203- corr_func = spearmanr
204-
205- # Compute per-gene correlation
206- corrs = np .array ([corr_func (original_x [:, i ], imputed_x [:, i ])[0 ] for i in range (imputed_x .shape [1 ])])
207-
208- # Store per-gene correlation in query_imputed.var, only for shared genes
209- self ._store_expression_metric (
210- shared_genes ,
211- corrs ,
212- f"metric_{ method } " ,
213- "average_correlation" ,
214- method ,
215- log_msg ,
216- )
205+
206+ # Select metric function
207+ if method == "pearson" :
208+ metric_func = lambda a , b : pearsonr (a , b )[0 ]
209+ elif method == "spearman" :
210+ metric_func = lambda a , b : spearmanr (a , b )[0 ]
217211 elif method in ("js" , "jensen-shannon" ):
218- jsds = np .array (
219- [_jensen_shannon_divergence (imputed_x [:, i ], original_x [:, i ]) for i in range (imputed_x .shape [1 ])]
220- )
221- self ._store_expression_metric (
222- shared_genes ,
223- jsds ,
224- "metric_js" ,
225- "average_jsd" ,
226- "js" ,
227- log_msg ,
228- )
212+ metric_func = _jensen_shannon_divergence
229213 elif method == "rmse" :
230- rmses = np .array ([_rmse_zscore (imputed_x [:, i ], original_x [:, i ]) for i in range (imputed_x .shape [1 ])])
231- self ._store_expression_metric (
232- shared_genes ,
233- rmses ,
234- "metric_rmse" ,
235- "average_rmse" ,
236- "rmse" ,
237- log_msg ,
238- )
214+ metric_func = _rmse_zscore
239215 else :
240216 raise NotImplementedError (f"Method '{ method } ' is not implemented." )
241217
218+ # Helper to compute metrics for a given mask of cells
219+ def compute_metrics (mask ):
220+ return np .array ([metric_func (original_x [mask , i ], imputed_x [mask , i ]) for i in range (imputed_x .shape [1 ])])
221+
222+ # Compute metrics for all cells
223+ overall_mask = np .ones (original_x .shape [0 ], dtype = bool )
224+ overall_metrics = compute_metrics (overall_mask )
225+ self ._store_expression_metric (
226+ shared_genes ,
227+ overall_metrics ,
228+ method ,
229+ )
230+
231+ if groupby is not None :
232+ # Prepare DataFrame to store per-group metrics
233+ group_labels = self .query .obs [groupby ]
234+ groups = group_labels .unique ()
235+ metrics_df = pd .DataFrame (
236+ np .full ((self .query .n_vars , len (groups )), np .nan , dtype = np .float32 ),
237+ index = self .query .var_names ,
238+ columns = groups ,
239+ )
240+
241+ # Compute and store metrics for each group
242+ for group in groups :
243+ mask = group_labels == group
244+ metrics_df .loc [shared_genes , group ] = compute_metrics (mask .values )
245+ self .query .varm [f"metric_{ method } " ] = metrics_df
246+
247+ logger .info (
248+ "Metrics per group defined in `query.obs['%s']` computed and stored in `query.varm['%s']`" ,
249+ groupby ,
250+ f"metric_{ method } " ,
251+ )
252+
242253 def _get_aligned_expression_arrays (self , layer_key : str ) -> tuple [np .ndarray , np .ndarray , list [str ]]:
243254 """
244255 Extract and align imputed and original expression arrays for shared genes between query_imputed and query.
@@ -271,53 +282,136 @@ def _store_expression_metric(
271282 self ,
272283 shared_genes : list [str ],
273284 values : np .ndarray ,
274- metric_name : str ,
275- summary_key : str ,
276- method_label : str ,
277- log_msg : str ,
285+ method : str ,
278286 ) -> None :
279287 """
280- Store per-gene and summary expression transfer metrics in the query_imputed AnnData object and log the results.
288+ Store per-gene and summary expression transfer metrics in the query AnnData object and log the results.
281289
282290 Parameters
283291 ----------
284292 shared_genes
285293 List of shared gene names.
286294 values
287- Array of per-gene metric values (e.g., correlation, JSD).
288- metric_name
289- Name of the column to store per-gene values in query_imputed.var.
290- summary_key
291- Key for the average metric in self.expression_transfer_metrics.
292- method_label
295+ Array of per-gene metric values (e.g., correlation, JSD) or 2D array (genes x groups).
296+ method
293297 Name of the method/metric (for logging and summary dict).
294- log_msg
295- Logging message format string, should accept (avg_value, n_genes, n_valid_genes).
296-
297- Returns
298- -------
299- Nothing, but updates the following attributes:
300- query_imputed.var[metric_name]
301- DataFrame column with per-gene metric values.
302- expression_transfer_metrics
303- Dictionary containing the average metric value and number of genes used for the evaluation.
304298 """
305- self .query_imputed .var [metric_name ] = None
306- self .query_imputed .var .loc [shared_genes , metric_name ] = values
299+ # Store overall metric in .var
300+ self .query .var [f"metric_{ method } " ] = np .nan
301+ self .query .var .loc [shared_genes , f"metric_{ method } " ] = values
307302 valid_values = values [~ np .isnan (values )]
303+
308304 if valid_values .size == 0 :
309- raise ValueError (f"No valid genes for { method_label } calculation." )
305+ raise ValueError (f"No valid genes for { method } calculation." )
310306 avg_value = float (np .mean (valid_values ))
311307 self .expression_transfer_metrics = {
312- "method" : method_label ,
313- summary_key : avg_value ,
308+ "method" : method ,
309+ "average" : avg_value ,
314310 "n_genes" : len (shared_genes ),
315311 "n_valid_genes" : int (valid_values .size ),
316312 }
313+
317314 logger .info (
318- log_msg ,
319- method_label ,
315+ "Expression transfer evaluation (%s): average value = %.4f (n_genes=%d, n_valid_genes=%d)" ,
316+ method ,
320317 avg_value ,
321318 len (shared_genes ),
322319 int (valid_values .size ),
323320 )
321+
322+ def estimate_presence_score (
323+ self ,
324+ groupby : str | None = None ,
325+ key_added : str = "presence_score" ,
326+ log : bool = False ,
327+ percentile : tuple [float , float ] = (1 , 99 ),
328+ ):
329+ """
330+ Estimate raw presence scores for each reference cell based on query-to-reference connectivities.
331+
332+ Adapted from the HNOCA-tools package: https://github.com/devsystemslab/HNOCA-tools
333+
334+ Parameters
335+ ----------
336+ groupby
337+ Column in self.query.obs to group query cells by (e.g., cell type, batch). If None, computes a single score for all query cells.
338+ key_added
339+ Key to store the presence score: always writes the score across all query cells to self.ref.obs[key_added].
340+ If groupby is not None, also writes per-group scores as a DataFrame to self.ref.obsm[key_added].
341+ log
342+ Whether to apply log1p transformation to the scores.
343+ percentile
344+ Tuple of (low, high) percentiles for clipping scores before normalization.
345+ """
346+ if self .knn is None or self .knn .yx is None :
347+ raise ValueError ("Neighbors must be computed before estimating presence scores." )
348+
349+ conn = self .knn .yx .knn_graph_connectivities ()
350+ ref_names = self .ref .obs_names
351+
352+ # Always compute and post-process the overall score (all query cells)
353+ scores_all = np .array (conn .sum (axis = 0 )).flatten ()
354+ df_all = pd .DataFrame ({"all" : scores_all }, index = ref_names )
355+ df_all_processed = process_presence_scores (df_all , log = log , percentile = percentile )
356+ self .ref .obs [key_added ] = df_all_processed ["all" ]
357+ logger .info ("Presence score across all query cells computed and stored in `ref.obs['%s']`" , key_added )
358+
359+ # If groupby, also compute and post-process per-group scores
360+ if groupby is not None :
361+ group_labels = self .query .obs [groupby ]
362+ groups = group_labels .unique ()
363+ score_matrix = np .zeros ((len (ref_names ), len (groups )), dtype = np .float32 )
364+ for i , group in enumerate (groups ):
365+ mask = group_labels == group
366+ group_conn = conn [mask .values , :]
367+ score_matrix [:, i ] = np .array (group_conn .sum (axis = 0 )).flatten ()
368+ df_groups = pd .DataFrame (score_matrix , index = ref_names , columns = groups )
369+ df_groups_processed = process_presence_scores (df_groups , log = log , percentile = percentile )
370+ self .ref .obsm [key_added ] = df_groups_processed
371+
372+ logger .info (
373+ "Presence scores per group defined in `query.obs['%s']` computed and stored in `ref.obsm['%s']`" ,
374+ groupby ,
375+ key_added ,
376+ )
377+
378+
379+ def process_presence_scores (
380+ scores : pd .DataFrame ,
381+ log : bool = False ,
382+ percentile : tuple [float , float ] = (1 , 99 ),
383+ ) -> pd .DataFrame :
384+ """
385+ Post-process presence scores with log1p, percentile clipping, and min-max normalization.
386+
387+ Parameters
388+ ----------
389+ scores
390+ DataFrame of raw presence scores (rows: reference cells, columns: groups or 'all').
391+ log
392+ Whether to apply log1p transformation to the scores.
393+ percentile
394+ Tuple of (low, high) percentiles for clipping scores before normalization.
395+
396+ Returns
397+ -------
398+ pd.DataFrame
399+ Post-processed presence scores, same shape as input.
400+ """
401+ # Log1p transformation (optional)
402+ if log :
403+ scores = np .log1p (scores )
404+
405+ # Percentile clipping (optional)
406+ if percentile != (0 , 100 ):
407+ low , high = percentile
408+ scores = scores .apply (lambda x : np .clip (x , np .percentile (x , low ), np .percentile (x , high )), axis = 0 )
409+
410+ # Min-max normalization (always)
411+ def minmax (x ):
412+ min_val , max_val = np .min (x ), np .max (x )
413+ return (x - min_val ) / (max_val - min_val ) if max_val > min_val else np .zeros_like (x )
414+
415+ scores = scores .apply (minmax , axis = 0 )
416+
417+ return scores
0 commit comments