|
1 | 1 | """k-NN based mapping of labels, embeddings, and expression values.""" |
2 | 2 |
|
3 | 3 | import gc |
4 | | -from pathlib import Path |
5 | 4 | from typing import Any, Literal |
6 | 5 |
|
7 | 6 | import anndata as ad |
8 | | -import matplotlib.pyplot as plt |
9 | 7 | import numpy as np |
10 | 8 | import pandas as pd |
11 | 9 | import scanpy as sc |
12 | 10 | from anndata import AnnData |
13 | | -from scipy.sparse import coo_matrix, csc_matrix, csr_matrix, issparse |
14 | | -from scipy.stats import pearsonr, spearmanr |
15 | | -from sklearn.metrics import ( |
16 | | - ConfusionMatrixDisplay, |
17 | | - accuracy_score, |
18 | | - classification_report, |
19 | | - f1_score, |
20 | | - precision_score, |
21 | | - recall_score, |
22 | | -) |
| 11 | +from scipy.sparse import coo_matrix, csc_matrix, csr_matrix |
23 | 12 | from sklearn.preprocessing import OneHotEncoder |
24 | 13 |
|
| 14 | +from cellmapper.evaluate import CellMapperEvaluationMixin |
25 | 15 | from cellmapper.logging import logger |
26 | 16 |
|
27 | 17 | from .knn import Neighbors |
28 | 18 |
|
29 | 19 |
|
30 | | -class CellMapper: |
| 20 | +class CellMapper(CellMapperEvaluationMixin): |
31 | 21 | """Mapping of labels, embeddings, and expression values between reference and query datasets.""" |
32 | 22 |
|
33 | 23 | def __init__(self, ref: AnnData, query: AnnData) -> None: |
@@ -431,186 +421,6 @@ def transfer_expression(self, layer_key: str) -> None: |
431 | 421 | self.query.n_vars, |
432 | 422 | ) |
433 | 423 |
|
434 | | - def evaluate_label_transfer( |
435 | | - self, |
436 | | - label_key: str, |
437 | | - confidence_cutoff: float = 0.0, |
438 | | - zero_division: int | Literal["warn"] = 0, |
439 | | - ) -> None: |
440 | | - """ |
441 | | - Evaluate label transfer using a k-NN classifier. |
442 | | -
|
443 | | - Parameters |
444 | | - ---------- |
445 | | - label_key |
446 | | - Key in .obs storing ground-truth cell type annotations. |
447 | | - confidence_cutoff |
448 | | - Minimum confidence score required to include a cell in the evaluation. |
449 | | - zero_divisions |
450 | | - How to handle zero divisions in sklearn metrics comptuation. |
451 | | -
|
452 | | - Returns |
453 | | - ------- |
454 | | - Nothing, but updates the following attributes: |
455 | | - label_transfer_metrics |
456 | | - Dictionary containing accuracy, precision, recall, F1 scores, and excluded fraction. |
457 | | - """ |
458 | | - # Check if the labels have been transferred |
459 | | - if self.prediction_postfix is None or self.confidence_postfix is None: |
460 | | - raise ValueError("Label transfer has not been performed. Call transfer_labels() first.") |
461 | | - |
462 | | - # Extract ground-truth and predicted labels |
463 | | - y_true = self.query.obs[label_key].dropna() |
464 | | - y_pred = self.query.obs.loc[y_true.index, f"{label_key}_{self.prediction_postfix}"] |
465 | | - confidence = self.query.obs.loc[y_true.index, f"{label_key}_{self.confidence_postfix}"] |
466 | | - |
467 | | - # Apply confidence cutoff |
468 | | - valid_indices = confidence >= confidence_cutoff |
469 | | - y_true = y_true[valid_indices] |
470 | | - y_pred = y_pred[valid_indices] |
471 | | - excluded_fraction = 1 - valid_indices.mean() |
472 | | - |
473 | | - # Compute classification metrics |
474 | | - accuracy = accuracy_score(y_true, y_pred) |
475 | | - precision = precision_score(y_true, y_pred, average="weighted", zero_division=zero_division) |
476 | | - recall = recall_score(y_true, y_pred, average="weighted", zero_division=zero_division) |
477 | | - f1_weighted = f1_score(y_true, y_pred, average="weighted", zero_division=zero_division) |
478 | | - f1_macro = f1_score(y_true, y_pred, average="macro", zero_division=zero_division) |
479 | | - |
480 | | - # Log and store results |
481 | | - self.label_transfer_metrics = { |
482 | | - "accuracy": accuracy, |
483 | | - "precision": precision, |
484 | | - "recall": recall, |
485 | | - "f1_weighted": f1_weighted, |
486 | | - "f1_macro": f1_macro, |
487 | | - "excluded_fraction": excluded_fraction, |
488 | | - } |
489 | | - logger.info( |
490 | | - "Accuracy: %.4f, Precision: %.4f, Recall: %.4f, Weighted F1-Score: %.4f, Macro F1-Score: %.4f, Excluded Fraction: %.4f", |
491 | | - accuracy, |
492 | | - precision, |
493 | | - recall, |
494 | | - f1_weighted, |
495 | | - f1_macro, |
496 | | - excluded_fraction, |
497 | | - ) |
498 | | - |
499 | | - # Optional: Save a detailed classification report |
500 | | - report = classification_report(y_true, y_pred, output_dict=True, zero_division=zero_division) |
501 | | - self.label_transfer_report = pd.DataFrame(report).transpose() |
502 | | - |
503 | | - def plot_confusion_matrix( |
504 | | - self, label_key: str, figsize=(10, 8), cmap="viridis", save: str | Path | None = None, **kwargs |
505 | | - ) -> None: |
506 | | - """ |
507 | | - Plot the confusion matrix as a heatmap using sklearn's ConfusionMatrixDisplay. |
508 | | -
|
509 | | - Parameters |
510 | | - ---------- |
511 | | - figsize |
512 | | - Size of the figure (width, height). Default is (10, 8). |
513 | | - cmap |
514 | | - Colormap to use for the heatmap. Default is "viridis". |
515 | | - label_key |
516 | | - Key in .obs storing ground-truth cell type annotations. |
517 | | - **kwargs |
518 | | - Additional keyword arguments to pass to ConfusionMatrixDisplay. |
519 | | - """ |
520 | | - # Check if the labels have been transferred |
521 | | - if self.prediction_postfix is None or self.confidence_postfix is None: |
522 | | - raise ValueError("Label transfer has not been performed. Call transfer_labels() first.") |
523 | | - |
524 | | - # Extract true and predicted labels |
525 | | - y_true = self.query.obs[label_key].dropna() |
526 | | - y_pred = self.query.obs.loc[y_true.index, f"{label_key}_pred"] |
527 | | - |
528 | | - # Plot confusion matrix using sklearn's ConfusionMatrixDisplay |
529 | | - _, ax = plt.subplots(1, 1, figsize=figsize) |
530 | | - ConfusionMatrixDisplay.from_predictions(y_true, y_pred, cmap=cmap, xticks_rotation="vertical", ax=ax, **kwargs) |
531 | | - plt.title("Confusion Matrix") |
532 | | - |
533 | | - if save: |
534 | | - plt.savefig(save, bbox_inches="tight") |
535 | | - |
536 | | - def evaluate_expression_transfer( |
537 | | - self, |
538 | | - layer_key: str = "X", |
539 | | - method: Literal["pearson", "spearman"] = "pearson", |
540 | | - ) -> None: |
541 | | - """ |
542 | | - Evaluate the agreement between imputed and original expression in the query dataset. |
543 | | -
|
544 | | - Parameters |
545 | | - ---------- |
546 | | - layer_key |
547 | | - Key in `self.query.layers` to use as the original expression. Use "X" to use `self.query.X`. |
548 | | - method |
549 | | - Method to quantify agreement. Currently supported: "pearson" (average Pearson correlation across shared genes). |
550 | | -
|
551 | | - Returns |
552 | | - ------- |
553 | | - Nothing, but updates the following attributes: |
554 | | - expression_transfer_metrics |
555 | | - Dictionary containing the average correlation and number of genes used for the evaluation. |
556 | | - """ |
557 | | - if self.query_imputed is None: |
558 | | - raise ValueError("Imputed query data not found. Run transfer_expression() first.") |
559 | | - |
560 | | - # Find shared genes |
561 | | - shared_genes = list(self.query_imputed.var_names.intersection(self.query.var_names)) |
562 | | - if len(shared_genes) == 0: |
563 | | - raise ValueError("No shared genes between query_imputed and query.") |
564 | | - |
565 | | - # Subset to shared genes using adata[:, shared_genes].X |
566 | | - imputed_x = self.query_imputed[:, shared_genes].X |
567 | | - if layer_key == "X": |
568 | | - original_x = self.query[:, shared_genes].X |
569 | | - else: |
570 | | - original_x = self.query[:, shared_genes].layers[layer_key] |
571 | | - |
572 | | - # Convert to dense if sparse |
573 | | - if issparse(imputed_x): |
574 | | - imputed_x = imputed_x.toarray() |
575 | | - if issparse(original_x): |
576 | | - original_x = original_x.toarray() |
577 | | - |
578 | | - if method == "pearson" or method == "spearman": |
579 | | - if method == "pearson": |
580 | | - corr_func = pearsonr |
581 | | - elif method == "spearman": |
582 | | - corr_func = spearmanr |
583 | | - |
584 | | - # Compute per-gene correlation |
585 | | - corrs = np.array([corr_func(original_x[:, i], imputed_x[:, i])[0] for i in range(imputed_x.shape[1])]) |
586 | | - |
587 | | - # Store per-gene correlation in query_imputed.var, only for shared genes |
588 | | - self.query_imputed.var[f"metric_{method}"] = None |
589 | | - self.query_imputed.var.loc[shared_genes, f"metric_{method}"] = corrs |
590 | | - |
591 | | - # Compute average correlation, ignoring NaN values |
592 | | - valid_corrs = corrs[~np.isnan(corrs)] |
593 | | - if valid_corrs.size == 0: |
594 | | - raise ValueError("No valid genes for correlation calculation.") |
595 | | - avg_corr = float(np.mean(valid_corrs)) |
596 | | - |
597 | | - # Store metrics in dict |
598 | | - self.expression_transfer_metrics = { |
599 | | - "method": method, |
600 | | - "average_correlation": avg_corr, |
601 | | - "n_genes": len(shared_genes), |
602 | | - "n_valid_genes": int(valid_corrs.size), |
603 | | - } |
604 | | - logger.info( |
605 | | - "Expression transfer evaluation (%s): average correlation = %.4f (n_genes=%d, n_valid_genes=%d)", |
606 | | - method, |
607 | | - avg_corr, |
608 | | - len(shared_genes), |
609 | | - int(valid_corrs.size), |
610 | | - ) |
611 | | - else: |
612 | | - raise NotImplementedError(f"Method '{method}' is not implemented.") |
613 | | - |
614 | 424 | def fit( |
615 | 425 | self, |
616 | 426 | obs_keys: str | list[str] | None = None, |
|
0 commit comments