Skip to content

Commit 7d81afc

Browse files
committed
update kernel and tests
1 parent 8550671 commit 7d81afc

File tree

7 files changed

+1246
-599
lines changed

7 files changed

+1246
-599
lines changed

src/rapids_singlecell/pertpy_gpu/_distance.py

Lines changed: 18 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from typing import TYPE_CHECKING, Literal, NamedTuple
44

55
if TYPE_CHECKING:
6+
from collections.abc import Sequence
7+
68
import cupy as cp
79
import numpy as np
810
import pandas as pd
@@ -20,50 +22,33 @@ class Distance:
2022
"""
2123
GPU-accelerated distance computation between groups of cells.
2224
23-
This class provides an extensible framework for computing various distance metrics
24-
between cell groups in single-cell data, with GPU acceleration via CuPy.
25-
26-
The API is designed to be compatible with pertpy's Distance class.
25+
API compatible with pertpy's Distance class.
2726
2827
Parameters
2928
----------
3029
metric : str
31-
Distance metric to use. Currently supported:
32-
- 'edistance': Energy distance (GPU-accelerated)
30+
Distance metric. Currently supported: 'edistance' (energy distance).
3331
layer_key : str | None
34-
Name of the counts layer containing raw counts to calculate distances for.
35-
Mutually exclusive with 'obsm_key'. If None, the parameter is ignored.
32+
Key in adata.layers for cell data. Mutually exclusive with 'obsm_key'.
3633
obsm_key : str | None
3734
Key in adata.obsm for embeddings. Mutually exclusive with 'layer_key'.
38-
Defaults to 'X_pca' if neither layer_key nor obsm_key is specified.
39-
kernel : str
40-
Kernel strategy: 'auto' or 'manual'.
41-
- 'auto': Dynamically choose optimal blocks_per_pair (default)
42-
- 'manual': Use the specified blocks_per_pair directly
43-
blocks_per_pair : int
44-
Number of blocks per pair (default: 32). For 'auto', this is the maximum.
45-
Higher values increase parallelism but add atomic overhead.
35+
Defaults to 'X_pca' if neither is specified.
4636
4737
Examples
4838
--------
4939
>>> import rapids_singlecell as rsc
5040
>>> distance = rsc.ptg.Distance(metric='edistance')
5141
>>> result = distance.pairwise(adata, groupby='perturbation')
52-
>>> print(result.distances)
5342
54-
# Direct computation on arrays (pertpy-compatible API)
55-
>>> X = adata.obsm["X_pca"][adata.obs["group"] == "A"]
56-
>>> Y = adata.obsm["X_pca"][adata.obs["group"] == "B"]
57-
>>> d = distance(X, Y) # Returns energy distance as float
43+
>>> # Direct computation on arrays
44+
>>> d = distance(X, Y)
5845
"""
5946

6047
def __init__(
6148
self,
6249
metric: Literal["edistance"] = "edistance",
6350
layer_key: str | None = None,
6451
obsm_key: str | None = None,
65-
kernel: Literal["auto", "manual"] = "auto",
66-
blocks_per_pair: int = 32,
6752
):
6853
"""Initialize Distance calculator with specified metric."""
6954
if layer_key and obsm_key:
@@ -77,8 +62,6 @@ def __init__(
7762
self.metric = metric
7863
self.layer_key = layer_key
7964
self.obsm_key = obsm_key
80-
self.kernel = kernel
81-
self.blocks_per_pair = blocks_per_pair
8265
self._metric_impl = None
8366
self._initialize_metric()
8467

@@ -92,8 +75,6 @@ def _initialize_metric(self):
9275
self._metric_impl = EDistanceMetric(
9376
layer_key=self.layer_key,
9477
obsm_key=self.obsm_key,
95-
kernel=self.kernel,
96-
blocks_per_pair=self.blocks_per_pair,
9778
)
9879
else:
9980
raise ValueError(
@@ -140,11 +121,10 @@ def pairwise(
140121
adata: AnnData,
141122
groupby: str,
142123
*,
143-
groups: list[str] | None = None,
124+
groups: Sequence[str] | None = None,
144125
bootstrap: bool = False,
145126
n_bootstrap: int = 100,
146127
random_state: int = 0,
147-
inplace: bool = False,
148128
):
149129
"""
150130
Compute pairwise distances between all cell groups.
@@ -155,28 +135,25 @@ def pairwise(
155135
Annotated data matrix
156136
groupby : str
157137
Key in adata.obs for grouping cells
158-
groups : list[str] | None
138+
groups : Sequence[str] | None
159139
Specific groups to compute (if None, use all)
160140
bootstrap : bool
161141
Whether to compute bootstrap variance estimates
162142
n_bootstrap : int
163143
Number of bootstrap iterations (if bootstrap=True)
164144
random_state : int
165145
Random seed for reproducibility
166-
inplace : bool
167-
Whether to store results in adata.uns
168146
169147
Returns
170148
-------
171149
result
172-
Result object containing distances and optional variance DataFrames.
173-
The exact type depends on the metric used.
150+
DataFrame with pairwise distances. If bootstrap=True, returns
151+
tuple of (distances, distances_var) DataFrames.
174152
175153
Examples
176154
--------
177155
>>> distance = Distance(metric='edistance')
178156
>>> result = distance.pairwise(adata, groupby='condition')
179-
>>> print(result.distances)
180157
"""
181158
return self._metric_impl.pairwise(
182159
adata=adata,
@@ -185,7 +162,6 @@ def pairwise(
185162
bootstrap=bootstrap,
186163
n_bootstrap=n_bootstrap,
187164
random_state=random_state,
188-
inplace=inplace,
189165
)
190166

191167
def onesided_distances(
@@ -194,11 +170,11 @@ def onesided_distances(
194170
groupby: str,
195171
selected_group: str,
196172
*,
197-
groups: list[str] | None = None,
173+
groups: Sequence[str] | None = None,
198174
bootstrap: bool = False,
199175
n_bootstrap: int = 100,
200176
random_state: int = 0,
201-
) -> pd.Series:
177+
) -> pd.Series | tuple[pd.Series, pd.Series]:
202178
"""
203179
Compute distances from one selected group to all other groups.
204180
@@ -210,7 +186,7 @@ def onesided_distances(
210186
Key in adata.obs for grouping cells
211187
selected_group : str
212188
Reference group to compute distances from
213-
groups : list[str] | None
189+
groups : Sequence[str] | None
214190
Specific groups to compute distances to (if None, use all)
215191
bootstrap : bool
216192
Whether to compute bootstrap variance estimates
@@ -221,16 +197,16 @@ def onesided_distances(
221197
222198
Returns
223199
-------
224-
distances : pd.Series
225-
Series containing distances from selected_group to all other groups
200+
distances : pd.Series | tuple[pd.Series, pd.Series]
201+
Series containing distances from selected_group to all other groups.
202+
If bootstrap=True, returns tuple of (distances, distances_var).
226203
227204
Examples
228205
--------
229206
>>> distance = Distance(metric='edistance')
230207
>>> distances = distance.onesided_distances(
231208
... adata, groupby='condition', selected_group='control'
232209
... )
233-
>>> print(distances)
234210
"""
235211
if not hasattr(self._metric_impl, "onesided_distances"):
236212
raise NotImplementedError(

src/rapids_singlecell/pertpy_gpu/_metrics/_base_metric.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from typing import TYPE_CHECKING
55

66
if TYPE_CHECKING:
7+
from collections.abc import Sequence
8+
79
from anndata import AnnData
810

911

@@ -30,58 +32,50 @@ def pairwise(
3032
adata: AnnData,
3133
groupby: str,
3234
*,
33-
groups: list[str] | None = None,
35+
groups: Sequence[str] | None = None,
3436
bootstrap: bool = False,
3537
n_bootstrap: int = 100,
3638
random_state: int = 0,
37-
inplace: bool = False,
3839
):
3940
"""
4041
Compute pairwise distances between all cell groups.
4142
42-
This method must be implemented by all metric subclasses.
43-
4443
Parameters
4544
----------
4645
adata : AnnData
4746
Annotated data matrix
4847
groupby : str
4948
Key in adata.obs for grouping cells
50-
groups : list[str] | None
49+
groups : Sequence[str] | None
5150
Specific groups to compute (if None, use all)
5251
bootstrap : bool
5352
Whether to compute bootstrap variance estimates
5453
n_bootstrap : int
5554
Number of bootstrap iterations (if bootstrap=True)
5655
random_state : int
5756
Random seed for reproducibility
58-
inplace : bool
59-
Whether to store results in adata.uns
6057
6158
Returns
6259
-------
6360
result
6461
Result object containing distances and optional variance information.
65-
The exact type depends on the specific metric implementation.
6662
"""
67-
pass
63+
...
6864

6965
def onesided_distances(
7066
self,
7167
adata: AnnData,
7268
groupby: str,
7369
selected_group: str,
7470
*,
75-
groups: list[str] | None = None,
71+
groups: Sequence[str] | None = None,
7672
bootstrap: bool = False,
7773
n_bootstrap: int = 100,
7874
random_state: int = 0,
7975
):
8076
"""
8177
Compute distances from one selected group to all other groups.
8278
83-
This method is optional and may not be implemented by all metrics.
84-
8579
Parameters
8680
----------
8781
adata : AnnData
@@ -90,7 +84,7 @@ def onesided_distances(
9084
Key in adata.obs for grouping cells
9185
selected_group : str
9286
Reference group to compute distances from
93-
groups : list[str] | None
87+
groups : Sequence[str] | None
9488
Specific groups to compute distances to (if None, use all)
9589
bootstrap : bool
9690
Whether to compute bootstrap variance estimates
@@ -121,8 +115,6 @@ def bootstrap(
121115
"""
122116
Compute bootstrap mean and variance for distance between two specific groups.
123117
124-
This method is optional and may not be implemented by all metrics.
125-
126118
Parameters
127119
----------
128120
adata : AnnData

0 commit comments

Comments
 (0)