@@ -35,8 +35,6 @@ def _jensen_shannon_divergence(p: np.ndarray, q: np.ndarray) -> float:
3535 q = np .clip (q , 0 , None )
3636 if p .sum () == 0 or q .sum () == 0 :
3737 return np .nan
38- p = p / p .sum ()
39- q = q / q .sum ()
4038 return jensenshannon (p , q , base = 10 )
4139
4240
@@ -226,7 +224,11 @@ def evaluate_expression_transfer(
226224
227225 # Helper to compute metrics for a given mask of cells
228226 def compute_metrics (mask ):
229- return np .array ([metric_func (original_x [mask , i ], imputed_x [mask , i ]) for i in range (imputed_x .shape [1 ])])
227+ # Explicitly return as float32 to match DataFrame's dtype
228+ return np .array (
229+ [metric_func (original_x [mask , i ], imputed_x [mask , i ]) for i in range (imputed_x .shape [1 ])],
230+ dtype = np .float32 ,
231+ )
230232
231233 # Compute metrics for all cells
232234 overall_mask = np .ones (original_x .shape [0 ], dtype = bool )
@@ -273,7 +275,9 @@ def _get_aligned_expression_arrays(self, layer_key: str) -> tuple[np.ndarray, np
273275 imputed_x, original_x, shared_genes
274276 """
275277 if self .query_imputed is None :
276- raise ValueError ("Imputed query data not found. Run transfer_expression() first." )
278+ raise ValueError (
279+ "Imputed query data not found. Either run transfer_expression() first or set query_imputed manually."
280+ )
277281 shared_genes = list (self .query_imputed .var_names .intersection (self .query .var_names ))
278282 if len (shared_genes ) == 0 :
279283 raise ValueError ("No shared genes between query_imputed and query." )
@@ -376,8 +380,8 @@ def estimate_presence_score(
376380 groupby
377381 Column in self.query.obs to group query cells by (e.g., cell type, batch). If None, computes a single score for all query cells.
378382 key_added
379- Key to store the presence score: always writes the score across all query cells to self.ref .obs[key_added].
380- If groupby is not None, also writes per-group scores as a DataFrame to self.ref .obsm[key_added].
383+ Key to store the presence score: always writes the score across all query cells to self.reference .obs[key_added].
384+ If groupby is not None, also writes per-group scores as a DataFrame to self.reference .obsm[key_added].
381385 log
382386 Whether to apply log1p transformation to the scores.
383387 percentile
@@ -387,30 +391,30 @@ def estimate_presence_score(
387391 raise ValueError ("Neighbors must be computed before estimating presence scores." )
388392
389393 conn = self .knn .yx .knn_graph_connectivities ()
390- ref_names = self .ref .obs_names
394+ reference_names = self .reference .obs_names
391395
392396 # Always compute and post-process the overall score (all query cells)
393397 scores_all = np .array (conn .sum (axis = 0 )).flatten ()
394- df_all = pd .DataFrame ({"all" : scores_all }, index = ref_names )
398+ df_all = pd .DataFrame ({"all" : scores_all }, index = reference_names )
395399 df_all_processed = process_presence_scores (df_all , log = log , percentile = percentile )
396- self .ref .obs [key_added ] = df_all_processed ["all" ]
397- logger .info ("Presence score across all query cells computed and stored in `ref .obs['%s']`" , key_added )
400+ self .reference .obs [key_added ] = df_all_processed ["all" ]
401+ logger .info ("Presence score across all query cells computed and stored in `reference .obs['%s']`" , key_added )
398402
399403 # If groupby, also compute and post-process per-group scores
400404 if groupby is not None :
401405 group_labels = self .query .obs [groupby ]
402406 groups = group_labels .unique ()
403- score_matrix = np .zeros ((len (ref_names ), len (groups )), dtype = np .float32 )
407+ score_matrix = np .zeros ((len (reference_names ), len (groups )), dtype = np .float32 )
404408 for i , group in enumerate (groups ):
405409 mask = group_labels == group
406410 group_conn = conn [mask .values , :]
407411 score_matrix [:, i ] = np .array (group_conn .sum (axis = 0 )).flatten ()
408- df_groups = pd .DataFrame (score_matrix , index = ref_names , columns = groups )
412+ df_groups = pd .DataFrame (score_matrix , index = reference_names , columns = groups )
409413 df_groups_processed = process_presence_scores (df_groups , log = log , percentile = percentile )
410- self .ref .obsm [key_added ] = df_groups_processed
414+ self .reference .obsm [key_added ] = df_groups_processed
411415
412416 logger .info (
413- "Presence scores per group defined in `query.obs['%s']` computed and stored in `ref .obsm['%s']`" ,
417+ "Presence scores per group defined in `query.obs['%s']` computed and stored in `reference .obsm['%s']`" ,
414418 groupby ,
415419 key_added ,
416420 )
0 commit comments