Skip to content

Commit b8291ca

Browse files
committed
Draft implementation to evaluate epxression transfer
1 parent 5834764 commit b8291ca

File tree

1 file changed

+65
-1
lines changed

1 file changed

+65
-1
lines changed

src/cellmapper/cellmapper.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import pandas as pd
1111
import scanpy as sc
1212
from anndata import AnnData
13-
from scipy.sparse import coo_matrix, csc_matrix, csr_matrix
13+
from scipy.sparse import coo_matrix, csc_matrix, csr_matrix, issparse
14+
from scipy.stats import pearsonr
1415
from sklearn.metrics import (
1516
ConfusionMatrixDisplay,
1617
accuracy_score,
@@ -525,6 +526,69 @@ def plot_confusion_matrix(
525526
if save:
526527
plt.savefig(save, bbox_inches="tight")
527528

529+
def evaluate_expression_transfer(
530+
self,
531+
layer_key: str = "X",
532+
method: str = "pearson",
533+
) -> float:
534+
"""
535+
Evaluate the agreement between imputed and original expression in the query dataset.
536+
537+
Parameters
538+
----------
539+
layer_key
540+
Key in `self.query.layers` to use as the original expression. Use "X" to use `self.query.X`.
541+
method
542+
Method to quantify agreement. Currently supported: "pearson" (average Pearson correlation across shared genes).
543+
544+
Returns
545+
-------
546+
score : float
547+
The average agreement score across all shared genes.
548+
"""
549+
if self.query_imputed is None:
550+
raise ValueError("Imputed query data not found. Run transfer_expression() first.")
551+
552+
# Find shared genes
553+
shared_genes = list(self.query_imputed.var_names.intersection(self.query.var_names))
554+
if len(shared_genes) == 0:
555+
raise ValueError("No shared genes between query_imputed and query.")
556+
557+
# Subset to shared genes using adata[:, shared_genes].X
558+
imputed_x = self.query_imputed[:, shared_genes].X
559+
if layer_key == "X":
560+
original_x = self.query[:, shared_genes].X
561+
else:
562+
original_x = self.query[:, shared_genes].layers[layer_key]
563+
564+
# Convert to dense if sparse
565+
if issparse(imputed_x):
566+
imputed_x = imputed_x.toarray()
567+
if issparse(original_x):
568+
original_x = original_x.toarray()
569+
570+
if method == "pearson":
571+
# Compute Pearson correlation for each gene (column-wise)
572+
corrs = np.full(imputed_x.shape[1], np.nan)
573+
for i in range(imputed_x.shape[1]):
574+
x = np.asarray(original_x[:, i]).ravel()
575+
y = np.asarray(imputed_x[:, i]).ravel()
576+
if np.std(x) == 0 or np.std(y) == 0:
577+
continue # skip constant genes
578+
corrs[i] = pearsonr(x, y)[0]
579+
580+
# Store per-gene correlation in query_imputed.var, only for shared genes
581+
self.query_imputed.var[f"metric_{method}"] = None
582+
self.query_imputed.var.loc[shared_genes, f"metric_{method}"] = corrs
583+
584+
# Return average correlation (ignoring NaNs)
585+
valid_corrs = corrs[~np.isnan(corrs)]
586+
if valid_corrs.size == 0:
587+
raise ValueError("No valid genes for correlation calculation.")
588+
return float(np.mean(valid_corrs))
589+
else:
590+
raise NotImplementedError(f"Method '{method}' is not implemented.")
591+
528592
def fit(
529593
self,
530594
obs_keys: str | list[str] | None = None,

0 commit comments

Comments
 (0)