33from abc import ABC , abstractmethod
44from typing import TYPE_CHECKING
55
6+ import cupy as cp
7+
68from rapids_singlecell ._utils import parse_device_ids
79
810if TYPE_CHECKING :
@@ -23,6 +25,8 @@ class BaseMetric(ABC):
2325
2426 Parameters
2527 ----------
28+ layer_key
29+ Key in adata.layers for cell data. Mutually exclusive with obsm_key.
2630 obsm_key
2731 Key in adata.obsm for embeddings (default: 'X_pca')
2832
@@ -35,10 +39,34 @@ class BaseMetric(ABC):
3539
3640 supports_multi_gpu : bool = False
3741
38- def __init__ (self , obsm_key : str = "X_pca" ):
39- """Initialize base metric with obsm_key."""
42+ def __init__ (
43+ self ,
44+ layer_key : str | None = None ,
45+ obsm_key : str | None = "X_pca" ,
46+ ):
47+ """Initialize base metric."""
48+ if layer_key is not None and obsm_key is not None :
49+ raise ValueError (
50+ "Cannot use 'layer_key' and 'obsm_key' at the same time. "
51+ "Please provide only one of the two keys."
52+ )
53+ self .layer_key = layer_key
4054 self .obsm_key = obsm_key
4155
56+ def _get_embedding (self , adata : AnnData ) -> cp .ndarray :
57+ """Get embedding from adata using layer_key or obsm_key.
58+
59+ Preserves the input dtype (float32 or float64) for precision control.
60+ """
61+ if self .layer_key is not None :
62+ data = adata .layers [self .layer_key ]
63+ else :
64+ data = adata .obsm [self .obsm_key ]
65+
66+ if isinstance (data , cp .ndarray ):
67+ return data
68+ return cp .asarray (data )
69+
4270 @abstractmethod
4371 def pairwise (
4472 self ,
0 commit comments