33from typing import TYPE_CHECKING , Literal , NamedTuple
44
55if TYPE_CHECKING :
6+ from collections .abc import Sequence
7+
68 import cupy as cp
79 import numpy as np
810 import pandas as pd
@@ -20,50 +22,33 @@ class Distance:
2022 """
2123 GPU-accelerated distance computation between groups of cells.
2224
23- This class provides an extensible framework for computing various distance metrics
24- between cell groups in single-cell data, with GPU acceleration via CuPy.
25-
26- The API is designed to be compatible with pertpy's Distance class.
25+ API compatible with pertpy's Distance class.
2726
2827 Parameters
2928 ----------
3029 metric : str
31- Distance metric to use. Currently supported:
32- - 'edistance': Energy distance (GPU-accelerated)
30+ Distance metric. Currently supported: 'edistance' (energy distance).
3331 layer_key : str | None
34- Name of the counts layer containing raw counts to calculate distances for.
35- Mutually exclusive with 'obsm_key'. If None, the parameter is ignored.
32+ Key in adata.layers for cell data. Mutually exclusive with 'obsm_key'.
3633 obsm_key : str | None
3734 Key in adata.obsm for embeddings. Mutually exclusive with 'layer_key'.
38- Defaults to 'X_pca' if neither layer_key nor obsm_key is specified.
39- kernel : str
40- Kernel strategy: 'auto' or 'manual'.
41- - 'auto': Dynamically choose optimal blocks_per_pair (default)
42- - 'manual': Use the specified blocks_per_pair directly
43- blocks_per_pair : int
44- Number of blocks per pair (default: 32). For 'auto', this is the maximum.
45- Higher values increase parallelism but add atomic overhead.
35+ Defaults to 'X_pca' if neither is specified.
4636
4737 Examples
4838 --------
4939 >>> import rapids_singlecell as rsc
5040 >>> distance = rsc.ptg.Distance(metric='edistance')
5141 >>> result = distance.pairwise(adata, groupby='perturbation')
52- >>> print(result.distances)
5342
54- # Direct computation on arrays (pertpy-compatible API)
55- >>> X = adata.obsm["X_pca"][adata.obs["group"] == "A"]
56- >>> Y = adata.obsm["X_pca"][adata.obs["group"] == "B"]
57- >>> d = distance(X, Y) # Returns energy distance as float
43+ >>> # Direct computation on arrays
44+ >>> d = distance(X, Y)
5845 """
5946
6047 def __init__ (
6148 self ,
6249 metric : Literal ["edistance" ] = "edistance" ,
6350 layer_key : str | None = None ,
6451 obsm_key : str | None = None ,
65- kernel : Literal ["auto" , "manual" ] = "auto" ,
66- blocks_per_pair : int = 32 ,
6752 ):
6853 """Initialize Distance calculator with specified metric."""
6954 if layer_key and obsm_key :
@@ -77,8 +62,6 @@ def __init__(
7762 self .metric = metric
7863 self .layer_key = layer_key
7964 self .obsm_key = obsm_key
80- self .kernel = kernel
81- self .blocks_per_pair = blocks_per_pair
8265 self ._metric_impl = None
8366 self ._initialize_metric ()
8467
@@ -92,8 +75,6 @@ def _initialize_metric(self):
9275 self ._metric_impl = EDistanceMetric (
9376 layer_key = self .layer_key ,
9477 obsm_key = self .obsm_key ,
95- kernel = self .kernel ,
96- blocks_per_pair = self .blocks_per_pair ,
9778 )
9879 else :
9980 raise ValueError (
@@ -140,11 +121,10 @@ def pairwise(
140121 adata : AnnData ,
141122 groupby : str ,
142123 * ,
143- groups : list [str ] | None = None ,
124+ groups : Sequence [str ] | None = None ,
144125 bootstrap : bool = False ,
145126 n_bootstrap : int = 100 ,
146127 random_state : int = 0 ,
147- inplace : bool = False ,
148128 ):
149129 """
150130 Compute pairwise distances between all cell groups.
@@ -155,28 +135,25 @@ def pairwise(
155135 Annotated data matrix
156136 groupby : str
157137 Key in adata.obs for grouping cells
158- groups : list [str] | None
138+ groups : Sequence [str] | None
159139 Specific groups to compute (if None, use all)
160140 bootstrap : bool
161141 Whether to compute bootstrap variance estimates
162142 n_bootstrap : int
163143 Number of bootstrap iterations (if bootstrap=True)
164144 random_state : int
165145 Random seed for reproducibility
166- inplace : bool
167- Whether to store results in adata.uns
168146
169147 Returns
170148 -------
171149 result
172- Result object containing distances and optional variance DataFrames.
173- The exact type depends on the metric used .
150+ DataFrame with pairwise distances. If bootstrap=True, returns
151+ tuple of (distances, distances_var) DataFrames .
174152
175153 Examples
176154 --------
177155 >>> distance = Distance(metric='edistance')
178156 >>> result = distance.pairwise(adata, groupby='condition')
179- >>> print(result.distances)
180157 """
181158 return self ._metric_impl .pairwise (
182159 adata = adata ,
@@ -185,7 +162,6 @@ def pairwise(
185162 bootstrap = bootstrap ,
186163 n_bootstrap = n_bootstrap ,
187164 random_state = random_state ,
188- inplace = inplace ,
189165 )
190166
191167 def onesided_distances (
@@ -194,11 +170,11 @@ def onesided_distances(
194170 groupby : str ,
195171 selected_group : str ,
196172 * ,
197- groups : list [str ] | None = None ,
173+ groups : Sequence [str ] | None = None ,
198174 bootstrap : bool = False ,
199175 n_bootstrap : int = 100 ,
200176 random_state : int = 0 ,
201- ) -> pd .Series :
177+ ) -> pd .Series | tuple [ pd . Series , pd . Series ] :
202178 """
203179 Compute distances from one selected group to all other groups.
204180
@@ -210,7 +186,7 @@ def onesided_distances(
210186 Key in adata.obs for grouping cells
211187 selected_group : str
212188 Reference group to compute distances from
213- groups : list [str] | None
189+ groups : Sequence [str] | None
214190 Specific groups to compute distances to (if None, use all)
215191 bootstrap : bool
216192 Whether to compute bootstrap variance estimates
@@ -221,16 +197,16 @@ def onesided_distances(
221197
222198 Returns
223199 -------
224- distances : pd.Series
225- Series containing distances from selected_group to all other groups
200+ distances : pd.Series | tuple[pd.Series, pd.Series]
201+ Series containing distances from selected_group to all other groups.
202+ If bootstrap=True, returns tuple of (distances, distances_var).
226203
227204 Examples
228205 --------
229206 >>> distance = Distance(metric='edistance')
230207 >>> distances = distance.onesided_distances(
231208 ... adata, groupby='condition', selected_group='control'
232209 ... )
233- >>> print(distances)
234210 """
235211 if not hasattr (self ._metric_impl , "onesided_distances" ):
236212 raise NotImplementedError (
0 commit comments