|
11 | 11 | import scanpy as sc |
12 | 12 | from anndata import AnnData |
13 | 13 | from scipy.sparse import coo_matrix, csc_matrix, csr_matrix, issparse |
14 | | -from scipy.stats import pearsonr |
| 14 | +from scipy.stats import pearsonr, spearmanr |
15 | 15 | from sklearn.metrics import ( |
16 | 16 | ConfusionMatrixDisplay, |
17 | 17 | accuracy_score, |
@@ -536,7 +536,7 @@ def plot_confusion_matrix( |
536 | 536 | def evaluate_expression_transfer( |
537 | 537 | self, |
538 | 538 | layer_key: str = "X", |
539 | | - method: str = "pearson", |
| 539 | + method: Literal["pearson", "spearman"] = "pearson", |
540 | 540 | ) -> None: |
541 | 541 | """ |
542 | 542 | Evaluate the agreement between imputed and original expression in the query dataset. |
@@ -575,23 +575,25 @@ def evaluate_expression_transfer( |
575 | 575 | if issparse(original_x): |
576 | 576 | original_x = original_x.toarray() |
577 | 577 |
|
578 | | - if method == "pearson": |
579 | | - # Compute Pearson correlation for each gene (column-wise) |
580 | | - corrs = np.full(imputed_x.shape[1], np.nan) |
581 | | - for i in range(imputed_x.shape[1]): |
582 | | - x = np.asarray(original_x[:, i]).ravel() |
583 | | - y = np.asarray(imputed_x[:, i]).ravel() |
584 | | - if np.std(x) == 0 or np.std(y) == 0: |
585 | | - continue # skip constant genes |
586 | | - corrs[i] = pearsonr(x, y)[0] |
| 578 | + if method == "pearson" or method == "spearman": |
| 579 | + if method == "pearson": |
| 580 | + corr_func = pearsonr |
| 581 | + elif method == "spearman": |
| 582 | + corr_func = spearmanr |
| 583 | + |
| 584 | + # Compute per-gene correlation |
| 585 | + corrs = np.array([corr_func(original_x[:, i], imputed_x[:, i])[0] for i in range(imputed_x.shape[1])]) |
587 | 586 |
|
588 | 587 | # Store per-gene correlation in query_imputed.var, only for shared genes |
589 | 588 | self.query_imputed.var[f"metric_{method}"] = None |
590 | 589 | self.query_imputed.var.loc[shared_genes, f"metric_{method}"] = corrs |
| 590 | + |
| 591 | + # Compute average correlation, ignoring NaN values |
591 | 592 | valid_corrs = corrs[~np.isnan(corrs)] |
592 | 593 | if valid_corrs.size == 0: |
593 | 594 | raise ValueError("No valid genes for correlation calculation.") |
594 | 595 | avg_corr = float(np.mean(valid_corrs)) |
| 596 | + |
595 | 597 | # Store metrics in dict |
596 | 598 | self.expression_transfer_metrics = { |
597 | 599 | "method": method, |
|
0 commit comments