Skip to content

Commit 6e1ed0d

Browse files
committed
Enable spearman correlation
1 parent f50080b commit 6e1ed0d

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

src/cellmapper/cellmapper.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import scanpy as sc
1212
from anndata import AnnData
1313
from scipy.sparse import coo_matrix, csc_matrix, csr_matrix, issparse
14-
from scipy.stats import pearsonr
14+
from scipy.stats import pearsonr, spearmanr
1515
from sklearn.metrics import (
1616
ConfusionMatrixDisplay,
1717
accuracy_score,
@@ -536,7 +536,7 @@ def plot_confusion_matrix(
536536
def evaluate_expression_transfer(
537537
self,
538538
layer_key: str = "X",
539-
method: str = "pearson",
539+
method: Literal["pearson", "spearman"] = "pearson",
540540
) -> None:
541541
"""
542542
Evaluate the agreement between imputed and original expression in the query dataset.
@@ -575,23 +575,25 @@ def evaluate_expression_transfer(
575575
if issparse(original_x):
576576
original_x = original_x.toarray()
577577

578-
if method == "pearson":
579-
# Compute Pearson correlation for each gene (column-wise)
580-
corrs = np.full(imputed_x.shape[1], np.nan)
581-
for i in range(imputed_x.shape[1]):
582-
x = np.asarray(original_x[:, i]).ravel()
583-
y = np.asarray(imputed_x[:, i]).ravel()
584-
if np.std(x) == 0 or np.std(y) == 0:
585-
continue # skip constant genes
586-
corrs[i] = pearsonr(x, y)[0]
578+
if method == "pearson" or method == "spearman":
579+
if method == "pearson":
580+
corr_func = pearsonr
581+
elif method == "spearman":
582+
corr_func = spearmanr
583+
584+
# Compute per-gene correlation
585+
corrs = np.array([corr_func(original_x[:, i], imputed_x[:, i])[0] for i in range(imputed_x.shape[1])])
587586

588587
# Store per-gene correlation in query_imputed.var, only for shared genes
589588
self.query_imputed.var[f"metric_{method}"] = None
590589
self.query_imputed.var.loc[shared_genes, f"metric_{method}"] = corrs
590+
591+
# Compute average correlation, ignoring NaN values
591592
valid_corrs = corrs[~np.isnan(corrs)]
592593
if valid_corrs.size == 0:
593594
raise ValueError("No valid genes for correlation calculation.")
594595
avg_corr = float(np.mean(valid_corrs))
596+
595597
# Store metrics in dict
596598
self.expression_transfer_metrics = {
597599
"method": method,

0 commit comments

Comments
 (0)