|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from typing import TYPE_CHECKING, Literal |
| 4 | + |
| 5 | +if TYPE_CHECKING: |
| 6 | + import pandas as pd |
| 7 | + from anndata import AnnData |
| 8 | + |
| 9 | + |
| 10 | +class Distance: |
| 11 | + """ |
| 12 | + GPU-accelerated distance computation between groups of cells. |
| 13 | +
|
| 14 | + This class provides an extensible framework for computing various distance metrics |
| 15 | + between cell groups in single-cell data, with GPU acceleration via CuPy. |
| 16 | +
|
| 17 | + Parameters |
| 18 | + ---------- |
| 19 | + metric : str |
| 20 | + Distance metric to use. Currently supported: |
| 21 | + - 'edistance': Energy distance (GPU-accelerated) |
| 22 | + obsm_key : str |
| 23 | + Key in adata.obsm for embeddings (default: 'X_pca') |
| 24 | +
|
| 25 | + Examples |
| 26 | + -------- |
| 27 | + >>> import rapids_singlecell as rsc |
| 28 | + >>> distance = rsc.ptg.Distance(metric='edistance') |
| 29 | + >>> result = distance.pairwise(adata, groupby='perturbation') |
| 30 | + >>> print(result.distances) |
| 31 | + """ |
| 32 | + |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + metric: Literal["edistance"] = "edistance", |
| 36 | + obsm_key: str = "X_pca", |
| 37 | + ): |
| 38 | + """Initialize Distance calculator with specified metric.""" |
| 39 | + self.metric = metric |
| 40 | + self.obsm_key = obsm_key |
| 41 | + self._metric_impl = None |
| 42 | + self._initialize_metric() |
| 43 | + |
| 44 | + def _initialize_metric(self): |
| 45 | + """Initialize the metric implementation based on the metric type.""" |
| 46 | + if self.metric == "edistance": |
| 47 | + from rapids_singlecell.pertpy_gpu._metrics._edistance_metric import ( |
| 48 | + EDistanceMetric, |
| 49 | + ) |
| 50 | + |
| 51 | + self._metric_impl = EDistanceMetric(obsm_key=self.obsm_key) |
| 52 | + else: |
| 53 | + raise ValueError( |
| 54 | + f"Unknown metric: {self.metric}. Supported metrics: ['edistance']" |
| 55 | + ) |
| 56 | + |
| 57 | + def pairwise( |
| 58 | + self, |
| 59 | + adata: AnnData, |
| 60 | + groupby: str, |
| 61 | + *, |
| 62 | + groups: list[str] | None = None, |
| 63 | + bootstrap: bool = False, |
| 64 | + n_bootstrap: int = 100, |
| 65 | + random_state: int = 0, |
| 66 | + inplace: bool = False, |
| 67 | + ): |
| 68 | + """ |
| 69 | + Compute pairwise distances between all cell groups. |
| 70 | +
|
| 71 | + Parameters |
| 72 | + ---------- |
| 73 | + adata : AnnData |
| 74 | + Annotated data matrix |
| 75 | + groupby : str |
| 76 | + Key in adata.obs for grouping cells |
| 77 | + groups : list[str] | None |
| 78 | + Specific groups to compute (if None, use all) |
| 79 | + bootstrap : bool |
| 80 | + Whether to compute bootstrap variance estimates |
| 81 | + n_bootstrap : int |
| 82 | + Number of bootstrap iterations (if bootstrap=True) |
| 83 | + random_state : int |
| 84 | + Random seed for reproducibility |
| 85 | + inplace : bool |
| 86 | + Whether to store results in adata.uns |
| 87 | +
|
| 88 | + Returns |
| 89 | + ------- |
| 90 | + result |
| 91 | + Result object containing distances and optional variance DataFrames. |
| 92 | + The exact type depends on the metric used. |
| 93 | +
|
| 94 | + Examples |
| 95 | + -------- |
| 96 | + >>> distance = Distance(metric='edistance') |
| 97 | + >>> result = distance.pairwise(adata, groupby='condition') |
| 98 | + >>> print(result.distances) |
| 99 | + """ |
| 100 | + return self._metric_impl.pairwise( |
| 101 | + adata=adata, |
| 102 | + groupby=groupby, |
| 103 | + groups=groups, |
| 104 | + bootstrap=bootstrap, |
| 105 | + n_bootstrap=n_bootstrap, |
| 106 | + random_state=random_state, |
| 107 | + inplace=inplace, |
| 108 | + ) |
| 109 | + |
| 110 | + def onesided_distances( |
| 111 | + self, |
| 112 | + adata: AnnData, |
| 113 | + groupby: str, |
| 114 | + selected_group: str, |
| 115 | + *, |
| 116 | + groups: list[str] | None = None, |
| 117 | + bootstrap: bool = False, |
| 118 | + n_bootstrap: int = 100, |
| 119 | + random_state: int = 0, |
| 120 | + ) -> pd.Series: |
| 121 | + """ |
| 122 | + Compute distances from one selected group to all other groups. |
| 123 | +
|
| 124 | + Parameters |
| 125 | + ---------- |
| 126 | + adata : AnnData |
| 127 | + Annotated data matrix |
| 128 | + groupby : str |
| 129 | + Key in adata.obs for grouping cells |
| 130 | + selected_group : str |
| 131 | + Reference group to compute distances from |
| 132 | + groups : list[str] | None |
| 133 | + Specific groups to compute distances to (if None, use all) |
| 134 | + bootstrap : bool |
| 135 | + Whether to compute bootstrap variance estimates |
| 136 | + n_bootstrap : int |
| 137 | + Number of bootstrap iterations (if bootstrap=True) |
| 138 | + random_state : int |
| 139 | + Random seed for reproducibility |
| 140 | +
|
| 141 | + Returns |
| 142 | + ------- |
| 143 | + distances : pd.Series |
| 144 | + Series containing distances from selected_group to all other groups |
| 145 | +
|
| 146 | + Examples |
| 147 | + -------- |
| 148 | + >>> distance = Distance(metric='edistance') |
| 149 | + >>> distances = distance.onesided_distances( |
| 150 | + ... adata, groupby='condition', selected_group='control' |
| 151 | + ... ) |
| 152 | + >>> print(distances) |
| 153 | + """ |
| 154 | + if not hasattr(self._metric_impl, "onesided_distances"): |
| 155 | + raise NotImplementedError( |
| 156 | + f"Metric '{self.metric}' does not support onesided_distances" |
| 157 | + ) |
| 158 | + return self._metric_impl.onesided_distances( |
| 159 | + adata=adata, |
| 160 | + groupby=groupby, |
| 161 | + selected_group=selected_group, |
| 162 | + groups=groups, |
| 163 | + bootstrap=bootstrap, |
| 164 | + n_bootstrap=n_bootstrap, |
| 165 | + random_state=random_state, |
| 166 | + ) |
| 167 | + |
| 168 | + def bootstrap( |
| 169 | + self, |
| 170 | + adata: AnnData, |
| 171 | + groupby: str, |
| 172 | + group_a: str, |
| 173 | + group_b: str, |
| 174 | + *, |
| 175 | + n_bootstrap: int = 100, |
| 176 | + random_state: int = 0, |
| 177 | + ) -> tuple[float, float]: |
| 178 | + """ |
| 179 | + Compute bootstrap mean and variance for distance between two specific groups. |
| 180 | +
|
| 181 | + Parameters |
| 182 | + ---------- |
| 183 | + adata : AnnData |
| 184 | + Annotated data matrix |
| 185 | + groupby : str |
| 186 | + Key in adata.obs for grouping cells |
| 187 | + group_a : str |
| 188 | + First group name |
| 189 | + group_b : str |
| 190 | + Second group name |
| 191 | + n_bootstrap : int |
| 192 | + Number of bootstrap iterations |
| 193 | + random_state : int |
| 194 | + Random seed for reproducibility |
| 195 | +
|
| 196 | + Returns |
| 197 | + ------- |
| 198 | + mean : float |
| 199 | + Bootstrap mean distance |
| 200 | + variance : float |
| 201 | + Bootstrap variance |
| 202 | +
|
| 203 | + Examples |
| 204 | + -------- |
| 205 | + >>> distance = Distance(metric='edistance') |
| 206 | + >>> mean, var = distance.bootstrap( |
| 207 | + ... adata, groupby='condition', group_a='treated', group_b='control' |
| 208 | + ... ) |
| 209 | + >>> print(f"Distance: {mean:.3f} ± {var**0.5:.3f}") |
| 210 | + """ |
| 211 | + if not hasattr(self._metric_impl, "bootstrap"): |
| 212 | + raise NotImplementedError( |
| 213 | + f"Metric '{self.metric}' does not support bootstrap" |
| 214 | + ) |
| 215 | + return self._metric_impl.bootstrap( |
| 216 | + adata=adata, |
| 217 | + groupby=groupby, |
| 218 | + group_a=group_a, |
| 219 | + group_b=group_b, |
| 220 | + n_bootstrap=n_bootstrap, |
| 221 | + random_state=random_state, |
| 222 | + ) |
| 223 | + |
| 224 | + def __repr__(self) -> str: |
| 225 | + """String representation of Distance object.""" |
| 226 | + return f"Distance(metric='{self.metric}', obsm_key='{self.obsm_key}')" |
0 commit comments