Skip to content

Commit 4481e5c

Browse files
authored
Merge pull request #1 from quadbio/refactor/evaluation
Move evaluation to mixin class and add method to score expression transfer.
2 parents c8db19c + 8f3681a commit 4481e5c

File tree

5 files changed

+389
-260
lines changed

5 files changed

+389
-260
lines changed

src/cellmapper/cellmapper.py

Lines changed: 3 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,23 @@
11
"""k-NN based mapping of labels, embeddings, and expression values."""
22

33
import gc
4-
from pathlib import Path
54
from typing import Any, Literal
65

76
import anndata as ad
8-
import matplotlib.pyplot as plt
97
import numpy as np
108
import pandas as pd
119
import scanpy as sc
1210
from anndata import AnnData
13-
from scipy.sparse import coo_matrix, csc_matrix, csr_matrix, issparse
14-
from scipy.stats import pearsonr, spearmanr
15-
from sklearn.metrics import (
16-
ConfusionMatrixDisplay,
17-
accuracy_score,
18-
classification_report,
19-
f1_score,
20-
precision_score,
21-
recall_score,
22-
)
11+
from scipy.sparse import coo_matrix, csc_matrix, csr_matrix
2312
from sklearn.preprocessing import OneHotEncoder
2413

14+
from cellmapper.evaluate import CellMapperEvaluationMixin
2515
from cellmapper.logging import logger
2616

2717
from .knn import Neighbors
2818

2919

30-
class CellMapper:
20+
class CellMapper(CellMapperEvaluationMixin):
3121
"""Mapping of labels, embeddings, and expression values between reference and query datasets."""
3222

3323
def __init__(self, ref: AnnData, query: AnnData) -> None:
@@ -431,186 +421,6 @@ def transfer_expression(self, layer_key: str) -> None:
431421
self.query.n_vars,
432422
)
433423

434-
def evaluate_label_transfer(
435-
self,
436-
label_key: str,
437-
confidence_cutoff: float = 0.0,
438-
zero_division: int | Literal["warn"] = 0,
439-
) -> None:
440-
"""
441-
Evaluate label transfer using a k-NN classifier.
442-
443-
Parameters
444-
----------
445-
label_key
446-
Key in .obs storing ground-truth cell type annotations.
447-
confidence_cutoff
448-
Minimum confidence score required to include a cell in the evaluation.
449-
zero_divisions
450-
How to handle zero divisions in sklearn metrics comptuation.
451-
452-
Returns
453-
-------
454-
Nothing, but updates the following attributes:
455-
label_transfer_metrics
456-
Dictionary containing accuracy, precision, recall, F1 scores, and excluded fraction.
457-
"""
458-
# Check if the labels have been transferred
459-
if self.prediction_postfix is None or self.confidence_postfix is None:
460-
raise ValueError("Label transfer has not been performed. Call transfer_labels() first.")
461-
462-
# Extract ground-truth and predicted labels
463-
y_true = self.query.obs[label_key].dropna()
464-
y_pred = self.query.obs.loc[y_true.index, f"{label_key}_{self.prediction_postfix}"]
465-
confidence = self.query.obs.loc[y_true.index, f"{label_key}_{self.confidence_postfix}"]
466-
467-
# Apply confidence cutoff
468-
valid_indices = confidence >= confidence_cutoff
469-
y_true = y_true[valid_indices]
470-
y_pred = y_pred[valid_indices]
471-
excluded_fraction = 1 - valid_indices.mean()
472-
473-
# Compute classification metrics
474-
accuracy = accuracy_score(y_true, y_pred)
475-
precision = precision_score(y_true, y_pred, average="weighted", zero_division=zero_division)
476-
recall = recall_score(y_true, y_pred, average="weighted", zero_division=zero_division)
477-
f1_weighted = f1_score(y_true, y_pred, average="weighted", zero_division=zero_division)
478-
f1_macro = f1_score(y_true, y_pred, average="macro", zero_division=zero_division)
479-
480-
# Log and store results
481-
self.label_transfer_metrics = {
482-
"accuracy": accuracy,
483-
"precision": precision,
484-
"recall": recall,
485-
"f1_weighted": f1_weighted,
486-
"f1_macro": f1_macro,
487-
"excluded_fraction": excluded_fraction,
488-
}
489-
logger.info(
490-
"Accuracy: %.4f, Precision: %.4f, Recall: %.4f, Weighted F1-Score: %.4f, Macro F1-Score: %.4f, Excluded Fraction: %.4f",
491-
accuracy,
492-
precision,
493-
recall,
494-
f1_weighted,
495-
f1_macro,
496-
excluded_fraction,
497-
)
498-
499-
# Optional: Save a detailed classification report
500-
report = classification_report(y_true, y_pred, output_dict=True, zero_division=zero_division)
501-
self.label_transfer_report = pd.DataFrame(report).transpose()
502-
503-
def plot_confusion_matrix(
504-
self, label_key: str, figsize=(10, 8), cmap="viridis", save: str | Path | None = None, **kwargs
505-
) -> None:
506-
"""
507-
Plot the confusion matrix as a heatmap using sklearn's ConfusionMatrixDisplay.
508-
509-
Parameters
510-
----------
511-
figsize
512-
Size of the figure (width, height). Default is (10, 8).
513-
cmap
514-
Colormap to use for the heatmap. Default is "viridis".
515-
label_key
516-
Key in .obs storing ground-truth cell type annotations.
517-
**kwargs
518-
Additional keyword arguments to pass to ConfusionMatrixDisplay.
519-
"""
520-
# Check if the labels have been transferred
521-
if self.prediction_postfix is None or self.confidence_postfix is None:
522-
raise ValueError("Label transfer has not been performed. Call transfer_labels() first.")
523-
524-
# Extract true and predicted labels
525-
y_true = self.query.obs[label_key].dropna()
526-
y_pred = self.query.obs.loc[y_true.index, f"{label_key}_pred"]
527-
528-
# Plot confusion matrix using sklearn's ConfusionMatrixDisplay
529-
_, ax = plt.subplots(1, 1, figsize=figsize)
530-
ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap, xticks_rotation="vertical", ax=ax, **kwargs)
531-
plt.title("Confusion Matrix")
532-
533-
if save:
534-
plt.savefig(save, bbox_inches="tight")
535-
536-
def evaluate_expression_transfer(
537-
self,
538-
layer_key: str = "X",
539-
method: Literal["pearson", "spearman"] = "pearson",
540-
) -> None:
541-
"""
542-
Evaluate the agreement between imputed and original expression in the query dataset.
543-
544-
Parameters
545-
----------
546-
layer_key
547-
Key in `self.query.layers` to use as the original expression. Use "X" to use `self.query.X`.
548-
method
549-
Method to quantify agreement. Currently supported: "pearson" (average Pearson correlation across shared genes).
550-
551-
Returns
552-
-------
553-
Nothing, but updates the following attributes:
554-
expression_transfer_metrics
555-
Dictionary containing the average correlation and number of genes used for the evaluation.
556-
"""
557-
if self.query_imputed is None:
558-
raise ValueError("Imputed query data not found. Run transfer_expression() first.")
559-
560-
# Find shared genes
561-
shared_genes = list(self.query_imputed.var_names.intersection(self.query.var_names))
562-
if len(shared_genes) == 0:
563-
raise ValueError("No shared genes between query_imputed and query.")
564-
565-
# Subset to shared genes using adata[:, shared_genes].X
566-
imputed_x = self.query_imputed[:, shared_genes].X
567-
if layer_key == "X":
568-
original_x = self.query[:, shared_genes].X
569-
else:
570-
original_x = self.query[:, shared_genes].layers[layer_key]
571-
572-
# Convert to dense if sparse
573-
if issparse(imputed_x):
574-
imputed_x = imputed_x.toarray()
575-
if issparse(original_x):
576-
original_x = original_x.toarray()
577-
578-
if method == "pearson" or method == "spearman":
579-
if method == "pearson":
580-
corr_func = pearsonr
581-
elif method == "spearman":
582-
corr_func = spearmanr
583-
584-
# Compute per-gene correlation
585-
corrs = np.array([corr_func(original_x[:, i], imputed_x[:, i])[0] for i in range(imputed_x.shape[1])])
586-
587-
# Store per-gene correlation in query_imputed.var, only for shared genes
588-
self.query_imputed.var[f"metric_{method}"] = None
589-
self.query_imputed.var.loc[shared_genes, f"metric_{method}"] = corrs
590-
591-
# Compute average correlation, ignoring NaN values
592-
valid_corrs = corrs[~np.isnan(corrs)]
593-
if valid_corrs.size == 0:
594-
raise ValueError("No valid genes for correlation calculation.")
595-
avg_corr = float(np.mean(valid_corrs))
596-
597-
# Store metrics in dict
598-
self.expression_transfer_metrics = {
599-
"method": method,
600-
"average_correlation": avg_corr,
601-
"n_genes": len(shared_genes),
602-
"n_valid_genes": int(valid_corrs.size),
603-
}
604-
logger.info(
605-
"Expression transfer evaluation (%s): average correlation = %.4f (n_genes=%d, n_valid_genes=%d)",
606-
method,
607-
avg_corr,
608-
len(shared_genes),
609-
int(valid_corrs.size),
610-
)
611-
else:
612-
raise NotImplementedError(f"Method '{method}' is not implemented.")
613-
614424
def fit(
615425
self,
616426
obs_keys: str | list[str] | None = None,

0 commit comments

Comments
 (0)