Skip to content

Commit bf98467

Browse files
committed
clean up tests
1 parent 38c4142 commit bf98467

File tree

3 files changed

+176
-52
lines changed

3 files changed

+176
-52
lines changed

src/rapids_singlecell/pertpy_gpu/_distance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,6 @@ def contrast_distances(
620620

621621
def __repr__(self) -> str:
622622
"""String representation of Distance object."""
623-
if self.layer_key:
623+
if self.layer_key is not None:
624624
return f"Distance(metric='{self.metric}', layer_key='{self.layer_key}')"
625625
return f"Distance(metric='{self.metric}', obsm_key='{self.obsm_key}')"

src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ def _get_embedding(self, adata: AnnData) -> cp.ndarray:
525525
526526
Preserves the input dtype (float32 or float64) for precision control.
527527
"""
528-
if self.layer_key:
528+
if self.layer_key is not None:
529529
data = adata.layers[self.layer_key]
530530
else:
531531
data = adata.obsm[self.obsm_key]

tests/pertpy/test_distances.py

Lines changed: 174 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -368,84 +368,208 @@ def test_contrast_distances_matches_compute_distance(
368368

369369
d = EDistanceMetric(obsm_key="X_pca")
370370

371-
contrasts = {
372-
"drugA_vs_ctrl_T": (
373-
{"treatment": "drugA", "celltype": "T"},
374-
{"treatment": "ctrl", "celltype": "T"},
375-
),
376-
"drugA_vs_ctrl_B": (
377-
{"treatment": "drugA", "celltype": "B"},
378-
{"treatment": "ctrl", "celltype": "B"},
379-
),
380-
}
371+
contrasts = Distance.create_contrasts(
372+
contrast_adata,
373+
groupby="treatment",
374+
selected_group="ctrl",
375+
split_by="celltype",
376+
)
381377

382378
result = d.contrast_distances(contrast_adata, contrasts=contrasts)
383379

384-
assert isinstance(result, pd.Series)
385-
assert result.name == "edistance"
386-
assert len(result) == 2
380+
assert isinstance(result, pd.DataFrame)
381+
assert "edistance" in result.columns
382+
assert len(result) == len(contrasts)
387383

388384
# Verify each contrast against compute_distance
389-
for name, (cond_a, cond_b) in contrasts.items():
390-
mask_a = np.ones(len(contrast_adata), dtype=bool)
391-
mask_b = np.ones(len(contrast_adata), dtype=bool)
392-
for col, val in cond_a.items():
393-
mask_a &= contrast_adata.obs[col].values == val
394-
for col, val in cond_b.items():
395-
mask_b &= contrast_adata.obs[col].values == val
396-
397-
X = contrast_adata.obsm["X_pca"][mask_a]
398-
Y = contrast_adata.obsm["X_pca"][mask_b]
385+
for _, row in result.iterrows():
386+
mask_target = (contrast_adata.obs["treatment"].values == row["treatment"]) & (
387+
contrast_adata.obs["celltype"].values == row["celltype"]
388+
)
389+
mask_ref = (contrast_adata.obs["treatment"].values == row["reference"]) & (
390+
contrast_adata.obs["celltype"].values == row["celltype"]
391+
)
392+
393+
X = contrast_adata.obsm["X_pca"][mask_target]
394+
Y = contrast_adata.obsm["X_pca"][mask_ref]
399395
expected = d.compute_distance(X, Y)
400396

401397
np.testing.assert_allclose(
402-
result[name],
398+
row["edistance"],
403399
expected,
404400
atol=1e-6,
405-
err_msg=f"Contrast {name} mismatch",
401+
err_msg=f"Contrast {row['treatment']} vs {row['reference']} "
402+
f"in {row['celltype']} mismatch",
406403
)
407404

408405

409406
def test_contrast_distances_shared_condition(contrast_adata: AnnData) -> None:
410407
"""Test that contrasts sharing a condition (e.g. same control) work."""
408+
distance = Distance(metric="edistance")
409+
410+
contrasts = Distance.create_contrasts(
411+
contrast_adata,
412+
groupby="treatment",
413+
selected_group="ctrl",
414+
split_by="celltype",
415+
)
416+
417+
result = distance.contrast_distances(contrast_adata, contrasts=contrasts)
418+
419+
assert isinstance(result, pd.DataFrame)
420+
assert "edistance" in result.columns
421+
# All distances should be finite
422+
assert np.all(np.isfinite(result["edistance"].values))
423+
424+
425+
def test_contrast_distances_self_distance_zero(contrast_adata: AnnData) -> None:
426+
"""Test that self-distance (same group vs itself) is zero."""
427+
distance = Distance(metric="edistance")
428+
429+
# Manually create a contrast where target == reference
430+
contrasts = pd.DataFrame(
431+
{
432+
"treatment": ["ctrl"],
433+
"reference": ["ctrl"],
434+
"celltype": ["T"],
435+
}
436+
)
437+
438+
result = distance.contrast_distances(contrast_adata, contrasts=contrasts)
439+
assert result["edistance"].iloc[0] == pytest.approx(0.0, abs=1e-7)
440+
441+
442+
def test_contrast_distances_no_split(contrast_adata: AnnData) -> None:
443+
"""Test contrast_distances without split_by columns."""
444+
distance = Distance(metric="edistance")
445+
446+
contrasts = Distance.create_contrasts(
447+
contrast_adata,
448+
groupby="treatment",
449+
selected_group="ctrl",
450+
)
451+
452+
result = distance.contrast_distances(contrast_adata, contrasts=contrasts)
453+
454+
assert isinstance(result, pd.DataFrame)
455+
assert "edistance" in result.columns
456+
assert len(result) == 1 # only drugA vs ctrl
457+
assert np.all(np.isfinite(result["edistance"].values))
458+
459+
460+
def test_contrast_distances_filtered(contrast_adata: AnnData) -> None:
461+
"""Test that filtering a contrasts DataFrame before computing works."""
411462
from rapids_singlecell.pertpy_gpu._metrics._edistance import EDistanceMetric
412463

413464
d = EDistanceMetric(obsm_key="X_pca")
465+
distance = Distance(metric="edistance")
414466

415-
# Both contrasts share the ctrl_T condition
416-
contrasts = {
417-
"drugA_vs_ctrl_T": (
418-
{"treatment": "drugA", "celltype": "T"},
419-
{"treatment": "ctrl", "celltype": "T"},
420-
),
421-
"ctrl_T_self": (
422-
{"treatment": "ctrl", "celltype": "T"},
423-
{"treatment": "ctrl", "celltype": "T"},
424-
),
425-
}
467+
# Create full contrasts, then drop one celltype
468+
contrasts = Distance.create_contrasts(
469+
contrast_adata,
470+
groupby="treatment",
471+
selected_group="ctrl",
472+
split_by="celltype",
473+
)
474+
assert len(contrasts) == 2 # drugA-T, drugA-B
426475

427-
result = d.contrast_distances(contrast_adata, contrasts=contrasts)
476+
# Keep only celltype == "T"
477+
filtered = contrasts[contrasts["celltype"] == "T"].reset_index(drop=True)
478+
assert len(filtered) == 1
479+
480+
result = distance.contrast_distances(contrast_adata, contrasts=filtered)
481+
482+
assert isinstance(result, pd.DataFrame)
483+
assert len(result) == 1
484+
assert result["celltype"].iloc[0] == "T"
485+
486+
# Verify the distance matches compute_distance
487+
mask_target = (contrast_adata.obs["treatment"].values == "drugA") & (
488+
contrast_adata.obs["celltype"].values == "T"
489+
)
490+
mask_ref = (contrast_adata.obs["treatment"].values == "ctrl") & (
491+
contrast_adata.obs["celltype"].values == "T"
492+
)
493+
expected = d.compute_distance(
494+
contrast_adata.obsm["X_pca"][mask_target],
495+
contrast_adata.obsm["X_pca"][mask_ref],
496+
)
497+
np.testing.assert_allclose(result["edistance"].iloc[0], expected, atol=1e-6)
428498

429-
# Self-distance should be 0
430-
assert result["ctrl_T_self"] == pytest.approx(0.0, abs=1e-7)
499+
# Also verify it differs from the full (unfiltered) result
500+
full_result = distance.contrast_distances(contrast_adata, contrasts=contrasts)
501+
assert len(full_result) == 2
431502

503+
# The T-cell row should match between filtered and full
504+
full_t = full_result[full_result["celltype"] == "T"]["edistance"].iloc[0]
505+
np.testing.assert_allclose(result["edistance"].iloc[0], full_t, atol=1e-10)
506+
507+
508+
def test_contrast_distances_two_split_by() -> None:
509+
"""Test contrast_distances with two split_by columns."""
510+
rng = np.random.default_rng(42)
511+
n = 10
512+
cpu_emb = rng.normal(size=(n * 6, 5)).astype(np.float32)
513+
obs = pd.DataFrame(
514+
{
515+
"treatment": pd.Categorical(
516+
["ctrl"] * n
517+
+ ["drugA"] * n
518+
+ ["ctrl"] * n
519+
+ ["drugA"] * n
520+
+ ["ctrl"] * n
521+
+ ["drugA"] * n
522+
),
523+
"celltype": pd.Categorical(["T"] * n * 2 + ["B"] * n * 2 + ["T"] * n * 2),
524+
"batch": pd.Categorical(["b1"] * n * 4 + ["b2"] * n * 2),
525+
}
526+
)
527+
adata = AnnData(cpu_emb.copy(), obs=obs)
528+
adata.obsm["X_pca"] = cp.asarray(cpu_emb, dtype=cp.float32)
432529

433-
def test_contrast_distances_empty_condition(contrast_adata: AnnData) -> None:
434-
"""Test that a condition matching no cells is handled."""
435530
from rapids_singlecell.pertpy_gpu._metrics._edistance import EDistanceMetric
436531

437532
d = EDistanceMetric(obsm_key="X_pca")
533+
distance = Distance(metric="edistance")
438534

439-
contrasts = {
440-
"nonexistent": (
441-
{"treatment": "drugX", "celltype": "T"},
442-
{"treatment": "ctrl", "celltype": "T"},
443-
),
444-
}
535+
contrasts = Distance.create_contrasts(
536+
adata,
537+
groupby="treatment",
538+
selected_group="ctrl",
539+
split_by=["celltype", "batch"],
540+
)
445541

446-
# Should not crash — group will have 0 cells
447-
result = d.contrast_distances(contrast_adata, contrasts=contrasts)
448-
assert isinstance(result, pd.Series)
542+
assert "celltype" in contrasts.columns
543+
assert "batch" in contrasts.columns
544+
545+
result = distance.contrast_distances(adata, contrasts=contrasts)
546+
547+
assert isinstance(result, pd.DataFrame)
548+
assert "edistance" in result.columns
549+
550+
# Verify each contrast against compute_distance
551+
for _, row in result.iterrows():
552+
mask_target = (
553+
(adata.obs["treatment"].values == row["treatment"])
554+
& (adata.obs["celltype"].values == row["celltype"])
555+
& (adata.obs["batch"].values == row["batch"])
556+
)
557+
mask_ref = (
558+
(adata.obs["treatment"].values == row["reference"])
559+
& (adata.obs["celltype"].values == row["celltype"])
560+
& (adata.obs["batch"].values == row["batch"])
561+
)
562+
563+
X = adata.obsm["X_pca"][mask_target]
564+
Y = adata.obsm["X_pca"][mask_ref]
565+
expected = d.compute_distance(X, Y)
566+
567+
np.testing.assert_allclose(
568+
row["edistance"],
569+
expected,
570+
rtol=1e-5,
571+
atol=1e-5,
572+
)
449573

450574

451575
# ============================================================================

0 commit comments

Comments
 (0)