Skip to content

Commit 7044d72

Browse files
committed
fix subsetting
1 parent c9540c9 commit 7044d72

File tree

2 files changed

+142
-104
lines changed

2 files changed

+142
-104
lines changed

src/rapids_singlecell/pertpy_gpu/_metrics/_base_metric.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from abc import ABC, abstractmethod
44
from typing import TYPE_CHECKING
55

6+
import cupy as cp
7+
68
from rapids_singlecell._utils import parse_device_ids
79

810
if TYPE_CHECKING:
@@ -23,6 +25,8 @@ class BaseMetric(ABC):
2325
2426
Parameters
2527
----------
28+
layer_key
29+
Key in adata.layers for cell data. Mutually exclusive with obsm_key.
2630
obsm_key
2731
Key in adata.obsm for embeddings (default: 'X_pca')
2832
@@ -35,10 +39,34 @@ class BaseMetric(ABC):
3539

3640
supports_multi_gpu: bool = False
3741

38-
def __init__(self, obsm_key: str = "X_pca"):
39-
"""Initialize base metric with obsm_key."""
42+
def __init__(
43+
self,
44+
layer_key: str | None = None,
45+
obsm_key: str | None = "X_pca",
46+
):
47+
"""Initialize base metric."""
48+
if layer_key is not None and obsm_key is not None:
49+
raise ValueError(
50+
"Cannot use 'layer_key' and 'obsm_key' at the same time. "
51+
"Please provide only one of the two keys."
52+
)
53+
self.layer_key = layer_key
4054
self.obsm_key = obsm_key
4155

56+
def _get_embedding(self, adata: AnnData) -> cp.ndarray:
57+
"""Get embedding from adata using layer_key or obsm_key.
58+
59+
Preserves the input dtype (float32 or float64) for precision control.
60+
"""
61+
if self.layer_key is not None:
62+
data = adata.layers[self.layer_key]
63+
else:
64+
data = adata.obsm[self.obsm_key]
65+
66+
if isinstance(data, cp.ndarray):
67+
return data
68+
return cp.asarray(data)
69+
4270
@abstractmethod
4371
def pairwise(
4472
self,

0 commit comments

Comments
 (0)