Skip to content

Commit 8ac80db

Browse files
committed
update
1 parent 1177340 commit 8ac80db

File tree

15 files changed

+1159
-2150
lines changed

15 files changed

+1159
-2150
lines changed
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from __future__ import annotations
22

3-
from ._edistance import pertpy_edistance
3+
from ._distance import Distance
4+
from ._metrics._edistance_metric import EDistanceResult
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Literal
4+
5+
if TYPE_CHECKING:
6+
import pandas as pd
7+
from anndata import AnnData
8+
9+
10+
class Distance:
11+
"""
12+
GPU-accelerated distance computation between groups of cells.
13+
14+
This class provides an extensible framework for computing various distance metrics
15+
between cell groups in single-cell data, with GPU acceleration via CuPy.
16+
17+
Parameters
18+
----------
19+
metric : str
20+
Distance metric to use. Currently supported:
21+
- 'edistance': Energy distance (GPU-accelerated)
22+
obsm_key : str
23+
Key in adata.obsm for embeddings (default: 'X_pca')
24+
25+
Examples
26+
--------
27+
>>> import rapids_singlecell as rsc
28+
>>> distance = rsc.ptg.Distance(metric='edistance')
29+
>>> result = distance.pairwise(adata, groupby='perturbation')
30+
>>> print(result.distances)
31+
"""
32+
33+
def __init__(
34+
self,
35+
metric: Literal["edistance"] = "edistance",
36+
obsm_key: str = "X_pca",
37+
):
38+
"""Initialize Distance calculator with specified metric."""
39+
self.metric = metric
40+
self.obsm_key = obsm_key
41+
self._metric_impl = None
42+
self._initialize_metric()
43+
44+
def _initialize_metric(self):
45+
"""Initialize the metric implementation based on the metric type."""
46+
if self.metric == "edistance":
47+
from rapids_singlecell.pertpy_gpu._metrics._edistance_metric import (
48+
EDistanceMetric,
49+
)
50+
51+
self._metric_impl = EDistanceMetric(obsm_key=self.obsm_key)
52+
else:
53+
raise ValueError(
54+
f"Unknown metric: {self.metric}. Supported metrics: ['edistance']"
55+
)
56+
57+
def pairwise(
58+
self,
59+
adata: AnnData,
60+
groupby: str,
61+
*,
62+
groups: list[str] | None = None,
63+
bootstrap: bool = False,
64+
n_bootstrap: int = 100,
65+
random_state: int = 0,
66+
inplace: bool = False,
67+
):
68+
"""
69+
Compute pairwise distances between all cell groups.
70+
71+
Parameters
72+
----------
73+
adata : AnnData
74+
Annotated data matrix
75+
groupby : str
76+
Key in adata.obs for grouping cells
77+
groups : list[str] | None
78+
Specific groups to compute (if None, use all)
79+
bootstrap : bool
80+
Whether to compute bootstrap variance estimates
81+
n_bootstrap : int
82+
Number of bootstrap iterations (if bootstrap=True)
83+
random_state : int
84+
Random seed for reproducibility
85+
inplace : bool
86+
Whether to store results in adata.uns
87+
88+
Returns
89+
-------
90+
result
91+
Result object containing distances and optional variance DataFrames.
92+
The exact type depends on the metric used.
93+
94+
Examples
95+
--------
96+
>>> distance = Distance(metric='edistance')
97+
>>> result = distance.pairwise(adata, groupby='condition')
98+
>>> print(result.distances)
99+
"""
100+
return self._metric_impl.pairwise(
101+
adata=adata,
102+
groupby=groupby,
103+
groups=groups,
104+
bootstrap=bootstrap,
105+
n_bootstrap=n_bootstrap,
106+
random_state=random_state,
107+
inplace=inplace,
108+
)
109+
110+
def onesided_distances(
111+
self,
112+
adata: AnnData,
113+
groupby: str,
114+
selected_group: str,
115+
*,
116+
groups: list[str] | None = None,
117+
bootstrap: bool = False,
118+
n_bootstrap: int = 100,
119+
random_state: int = 0,
120+
) -> pd.Series:
121+
"""
122+
Compute distances from one selected group to all other groups.
123+
124+
Parameters
125+
----------
126+
adata : AnnData
127+
Annotated data matrix
128+
groupby : str
129+
Key in adata.obs for grouping cells
130+
selected_group : str
131+
Reference group to compute distances from
132+
groups : list[str] | None
133+
Specific groups to compute distances to (if None, use all)
134+
bootstrap : bool
135+
Whether to compute bootstrap variance estimates
136+
n_bootstrap : int
137+
Number of bootstrap iterations (if bootstrap=True)
138+
random_state : int
139+
Random seed for reproducibility
140+
141+
Returns
142+
-------
143+
distances : pd.Series
144+
Series containing distances from selected_group to all other groups
145+
146+
Examples
147+
--------
148+
>>> distance = Distance(metric='edistance')
149+
>>> distances = distance.onesided_distances(
150+
... adata, groupby='condition', selected_group='control'
151+
... )
152+
>>> print(distances)
153+
"""
154+
if not hasattr(self._metric_impl, "onesided_distances"):
155+
raise NotImplementedError(
156+
f"Metric '{self.metric}' does not support onesided_distances"
157+
)
158+
return self._metric_impl.onesided_distances(
159+
adata=adata,
160+
groupby=groupby,
161+
selected_group=selected_group,
162+
groups=groups,
163+
bootstrap=bootstrap,
164+
n_bootstrap=n_bootstrap,
165+
random_state=random_state,
166+
)
167+
168+
def bootstrap(
169+
self,
170+
adata: AnnData,
171+
groupby: str,
172+
group_a: str,
173+
group_b: str,
174+
*,
175+
n_bootstrap: int = 100,
176+
random_state: int = 0,
177+
) -> tuple[float, float]:
178+
"""
179+
Compute bootstrap mean and variance for distance between two specific groups.
180+
181+
Parameters
182+
----------
183+
adata : AnnData
184+
Annotated data matrix
185+
groupby : str
186+
Key in adata.obs for grouping cells
187+
group_a : str
188+
First group name
189+
group_b : str
190+
Second group name
191+
n_bootstrap : int
192+
Number of bootstrap iterations
193+
random_state : int
194+
Random seed for reproducibility
195+
196+
Returns
197+
-------
198+
mean : float
199+
Bootstrap mean distance
200+
variance : float
201+
Bootstrap variance
202+
203+
Examples
204+
--------
205+
>>> distance = Distance(metric='edistance')
206+
>>> mean, var = distance.bootstrap(
207+
... adata, groupby='condition', group_a='treated', group_b='control'
208+
... )
209+
>>> print(f"Distance: {mean:.3f} ± {var**0.5:.3f}")
210+
"""
211+
if not hasattr(self._metric_impl, "bootstrap"):
212+
raise NotImplementedError(
213+
f"Metric '{self.metric}' does not support bootstrap"
214+
)
215+
return self._metric_impl.bootstrap(
216+
adata=adata,
217+
groupby=groupby,
218+
group_a=group_a,
219+
group_b=group_b,
220+
n_bootstrap=n_bootstrap,
221+
random_state=random_state,
222+
)
223+
224+
def __repr__(self) -> str:
225+
"""String representation of Distance object."""
226+
return f"Distance(metric='{self.metric}', obsm_key='{self.obsm_key}')"

0 commit comments

Comments
 (0)