Skip to content

Commit ff3f823

Browse files
committed
add docs and fix test with smaller fucntions
1 parent 7d81afc commit ff3f823

File tree

5 files changed

+94
-21
lines changed

5 files changed

+94
-21
lines changed

docs/api/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import rapids_singlecell as rsc
1212
1313
scanpy_gpu
1414
squidpy_gpu
15+
pertpy_gpu
1516
decoupler_gpu
1617
get
1718
```

docs/api/pertpy_gpu.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# pertpy-GPU: `ptg`
2+
3+
{mod}`pertpy` contains tools for perturbation analysis. {mod}`rapids_singlecell.ptg` accelerates some of these methods.
4+
5+
## Distance
6+
7+
```{eval-rst}
8+
.. module:: rapids_singlecell.ptg
9+
.. currentmodule:: rapids_singlecell
10+
11+
.. autosummary::
12+
:toctree: generated
13+
14+
ptg.Distance
15+
16+
.. autoclass:: rapids_singlecell.ptg.Distance
17+
:members:
18+
:undoc-members:
19+
:show-inheritance:
20+
```

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from pathlib import Path, PurePosixPath
99
from typing import TYPE_CHECKING
1010
import anndata # noqa
11+
import fast_array_utils # noqa
1112

1213
if TYPE_CHECKING:
1314
from sphinx.application import Sphinx

docs/release-notes/0.13.6.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
### 0.13.6 {small}`the-future`
2+
3+
```{rubric} Features
4+
```
5+
* Adds `ptg.Distance` for GPU-accelerated distance computation between groups of cells, compatible with pertpy's Distance API {pr}`454` {smaller}`S Dicks & S Özleyen`

tests/pertpy/test_distances.py

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -711,14 +711,30 @@ def test_distance_default_obsm_key() -> None:
711711
# ============================================================================
712712

713713

714-
def test_float64_matches_float32_results(small_adata: AnnData) -> None:
714+
def test_float64_matches_float32_results() -> None:
715715
"""Test that float64 and float32 produce similar results (within float32 precision)."""
716-
adata_f64 = small_adata.copy()
717-
adata_f64.obsm["X_pca"] = small_adata.obsm["X_pca"].astype(cp.float64)
716+
# Use small dataset to avoid GPU resource exhaustion with float64
717+
rng = np.random.default_rng(0)
718+
n_groups = 3
719+
cells_per_group = 4
720+
n_features = 5
721+
total_cells = n_groups * cells_per_group
722+
723+
cpu_embedding = rng.normal(size=(total_cells, n_features)).astype(np.float32)
724+
groups = [f"g{idx}" for idx in range(n_groups) for _ in range(cells_per_group)]
725+
obs = pd.DataFrame(
726+
{"group": pd.Categorical(groups, categories=[f"g{i}" for i in range(n_groups)])}
727+
)
728+
729+
adata_f32 = AnnData(cpu_embedding, obs=obs)
730+
adata_f32.obsm["X_pca"] = cp.asarray(cpu_embedding, dtype=cp.float32)
731+
732+
adata_f64 = AnnData(cpu_embedding, obs=obs.copy())
733+
adata_f64.obsm["X_pca"] = cp.asarray(cpu_embedding, dtype=cp.float64)
718734

719735
distance = Distance(metric="edistance")
720736

721-
result_f32 = distance.pairwise(small_adata, groupby="group")
737+
result_f32 = distance.pairwise(adata_f32, groupby="group")
722738
result_f64 = distance.pairwise(adata_f64, groupby="group")
723739

724740
np.testing.assert_allclose(
@@ -772,9 +788,10 @@ def test_bootstrap_different_feature_counts(n_features: int) -> None:
772788
def test_pairwise_correctness_parametrized(dtype) -> None:
773789
"""Parametrized test for pairwise correctness across dtypes."""
774790
rng = np.random.default_rng(42)
775-
n_groups = 4
776-
cells_per_group = 15
777-
n_features = 20
791+
# Use smaller sizes for float64 to avoid GPU resource exhaustion
792+
n_groups = 3
793+
cells_per_group = 4 if dtype == np.float64 else 15
794+
n_features = 5 if dtype == np.float64 else 20
778795
total_cells = n_groups * cells_per_group
779796

780797
cpu_embedding = rng.normal(size=(total_cells, n_features)).astype(dtype)
@@ -793,7 +810,7 @@ def test_pairwise_correctness_parametrized(dtype) -> None:
793810
rtol = 1e-10 if dtype == np.float64 else 1e-5
794811
atol = 1e-12 if dtype == np.float64 else 1e-6
795812

796-
for g1, g2 in [("g0", "g1"), ("g1", "g3"), ("g2", "g3")]:
813+
for g1, g2 in [("g0", "g1"), ("g0", "g2"), ("g1", "g2")]:
797814
X = cpu_embedding[np.array(groups) == g1]
798815
Y = cpu_embedding[np.array(groups) == g2]
799816
expected = _compute_energy_distance_cpu(X, Y)
@@ -807,30 +824,59 @@ def test_pairwise_correctness_parametrized(dtype) -> None:
807824
)
808825

809826

810-
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
811827
@pytest.mark.parametrize("n_features", [50, 400])
812-
def test_correctness_dtype_and_features(dtype, n_features) -> None:
813-
"""Test correctness across dtypes and feature counts."""
828+
def test_correctness_float32_features(n_features) -> None:
829+
"""Test correctness across feature counts with float32."""
814830
rng = np.random.default_rng(42)
815831
n_groups = 3
816832
cells_per_group = 10
817833
total_cells = n_groups * cells_per_group
818834

819-
cpu_embedding = rng.normal(size=(total_cells, n_features)).astype(dtype)
835+
cpu_embedding = rng.normal(size=(total_cells, n_features)).astype(np.float32)
820836
groups = [f"g{idx}" for idx in range(n_groups) for _ in range(cells_per_group)]
821837
obs = pd.DataFrame(
822838
{"group": pd.Categorical(groups, categories=[f"g{i}" for i in range(n_groups)])}
823839
)
824840

825-
adata = AnnData(cpu_embedding.copy().astype(np.float32), obs=obs)
826-
adata.obsm["X_pca"] = cp.asarray(cpu_embedding, dtype=dtype)
841+
adata = AnnData(cpu_embedding.copy(), obs=obs)
842+
adata.obsm["X_pca"] = cp.asarray(cpu_embedding, dtype=np.float32)
827843

828844
distance = Distance(metric="edistance")
829845
result_df = distance.pairwise(adata, groupby="group")
830846

831-
# Check correctness
832-
rtol = 1e-10 if dtype == np.float64 else 1e-4
833-
atol = 1e-12 if dtype == np.float64 else 1e-5
847+
X = cpu_embedding[np.array(groups) == "g0"]
848+
Y = cpu_embedding[np.array(groups) == "g1"]
849+
expected = _compute_energy_distance_cpu(X, Y)
850+
actual = result_df.loc["g0", "g1"]
851+
852+
np.testing.assert_allclose(
853+
actual,
854+
expected,
855+
rtol=1e-4,
856+
atol=1e-5,
857+
err_msg=f"n_features={n_features} mismatch",
858+
)
859+
860+
861+
def test_correctness_float64_small() -> None:
862+
"""Test correctness with float64 using small data to avoid GPU resource exhaustion."""
863+
rng = np.random.default_rng(42)
864+
n_groups = 3
865+
cells_per_group = 4
866+
n_features = 5
867+
total_cells = n_groups * cells_per_group
868+
869+
cpu_embedding = rng.normal(size=(total_cells, n_features)).astype(np.float64)
870+
groups = [f"g{idx}" for idx in range(n_groups) for _ in range(cells_per_group)]
871+
obs = pd.DataFrame(
872+
{"group": pd.Categorical(groups, categories=[f"g{i}" for i in range(n_groups)])}
873+
)
874+
875+
adata = AnnData(cpu_embedding.copy().astype(np.float32), obs=obs)
876+
adata.obsm["X_pca"] = cp.asarray(cpu_embedding, dtype=np.float64)
877+
878+
distance = Distance(metric="edistance")
879+
result_df = distance.pairwise(adata, groupby="group")
834880

835881
X = cpu_embedding[np.array(groups) == "g0"]
836882
Y = cpu_embedding[np.array(groups) == "g1"]
@@ -840,9 +886,9 @@ def test_correctness_dtype_and_features(dtype, n_features) -> None:
840886
np.testing.assert_allclose(
841887
actual,
842888
expected,
843-
rtol=rtol,
844-
atol=atol,
845-
err_msg=f"dtype={dtype.__name__}, n_features={n_features} mismatch",
889+
rtol=1e-10,
890+
atol=1e-12,
891+
err_msg="float64 mismatch",
846892
)
847893

848894

@@ -928,7 +974,7 @@ def pertpy_adata() -> AnnData:
928974
929975
This uses the same data pertpy uses in their tests.
930976
"""
931-
import pertpy as pt
977+
pt = pytest.importorskip("pertpy")
932978
import scanpy as sc
933979

934980
adata = pt.dt.distance_example()

0 commit comments

Comments
 (0)