Skip to content

Commit f50080b

Browse files
committed
Make metric computation more flexible
1 parent b8291ca commit f50080b

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

src/cellmapper/cellmapper.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(self, ref: AnnData, query: AnnData) -> None:
5454
self.confidence_postfix: str | None = None
5555
self.only_yx: bool | None = None
5656
self.query_imputed: AnnData | None = None
57+
self.expression_transfer_metrics: dict[str, Any] | None = None
5758

5859
def __repr__(self):
5960
"""Return a concise string representation of the CellMapper object."""
@@ -447,6 +448,12 @@ def evaluate_label_transfer(
447448
Minimum confidence score required to include a cell in the evaluation.
448449
zero_divisions
449450
How to handle zero divisions in sklearn metrics comptuation.
451+
452+
Returns
453+
-------
454+
Nothing, but updates the following attributes:
455+
label_transfer_metrics
456+
Dictionary containing accuracy, precision, recall, F1 scores, and excluded fraction.
450457
"""
451458
# Check if the labels have been transferred
452459
if self.prediction_postfix is None or self.confidence_postfix is None:
@@ -530,7 +537,7 @@ def evaluate_expression_transfer(
530537
self,
531538
layer_key: str = "X",
532539
method: str = "pearson",
533-
) -> float:
540+
) -> None:
534541
"""
535542
Evaluate the agreement between imputed and original expression in the query dataset.
536543
@@ -543,8 +550,9 @@ def evaluate_expression_transfer(
543550
544551
Returns
545552
-------
546-
score : float
547-
The average agreement score across all shared genes.
553+
Nothing, but updates the following attributes:
554+
expression_transfer_metrics
555+
Dictionary containing the average correlation and number of genes used for the evaluation.
548556
"""
549557
if self.query_imputed is None:
550558
raise ValueError("Imputed query data not found. Run transfer_expression() first.")
@@ -580,12 +588,24 @@ def evaluate_expression_transfer(
580588
# Store per-gene correlation in query_imputed.var, only for shared genes
581589
self.query_imputed.var[f"metric_{method}"] = None
582590
self.query_imputed.var.loc[shared_genes, f"metric_{method}"] = corrs
583-
584-
# Return average correlation (ignoring NaNs)
585591
valid_corrs = corrs[~np.isnan(corrs)]
586592
if valid_corrs.size == 0:
587593
raise ValueError("No valid genes for correlation calculation.")
588-
return float(np.mean(valid_corrs))
594+
avg_corr = float(np.mean(valid_corrs))
595+
# Store metrics in dict
596+
self.expression_transfer_metrics = {
597+
"method": method,
598+
"average_correlation": avg_corr,
599+
"n_genes": len(shared_genes),
600+
"n_valid_genes": int(valid_corrs.size),
601+
}
602+
logger.info(
603+
"Expression transfer evaluation (%s): average correlation = %.4f (n_genes=%d, n_valid_genes=%d)",
604+
method,
605+
avg_corr,
606+
len(shared_genes),
607+
int(valid_corrs.size),
608+
)
589609
else:
590610
raise NotImplementedError(f"Method '{method}' is not implemented.")
591611

0 commit comments

Comments
 (0)