Skip to content

Commit c9540c9

Browse files
committed
add multi select groups
1 parent bf98467 commit c9540c9

File tree

2 files changed

+150
-35
lines changed

2 files changed

+150
-35
lines changed

src/rapids_singlecell/pertpy_gpu/_distance.py

Lines changed: 55 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ def bootstrap(
355355
def create_contrasts(
356356
adata: AnnData,
357357
groupby: str,
358-
selected_group: str,
358+
selected_group: str | Sequence[str],
359359
*,
360360
groups: Sequence[str] | None = None,
361361
split_by: str | Sequence[str] | None = None,
@@ -382,7 +382,10 @@ def create_contrasts(
382382
Column in ``adata.obs`` whose levels are compared against
383383
``selected_group``
384384
selected_group
385-
The reference (control) value in the ``groupby`` column
385+
The reference (control) value(s) in the ``groupby`` column.
386+
When a sequence is passed, each target is compared against
387+
every reference, producing one row per (target, reference)
388+
combination.
386389
groups
387390
Specific groups to include. If None, all non-reference groups
388391
are included.
@@ -405,6 +408,12 @@ def create_contrasts(
405408
... adata, groupby="target_gene", selected_group="Non_target"
406409
... )
407410
411+
>>> # Multiple references
412+
>>> contrasts = Distance.create_contrasts(
413+
... adata, groupby="target_gene",
414+
... selected_group=["Non_target", "Scramble"],
415+
... )
416+
408417
>>> # Stratified by celltype
409418
>>> contrasts = Distance.create_contrasts(
410419
... adata, groupby="target_gene", selected_group="Non_target",
@@ -425,10 +434,16 @@ def create_contrasts(
425434
"""
426435
import pandas as pd
427436

428-
if selected_group not in adata.obs[groupby].values:
429-
raise ValueError(
430-
f"Reference '{selected_group}' not found in column '{groupby}'"
431-
)
437+
# Normalize to list
438+
if isinstance(selected_group, str):
439+
selected_groups = [selected_group]
440+
else:
441+
selected_groups = list(selected_group)
442+
443+
obs_values = set(adata.obs[groupby].values)
444+
for sg in selected_groups:
445+
if sg not in obs_values:
446+
raise ValueError(f"Reference '{sg}' not found in column '{groupby}'")
432447

433448
if split_by is None:
434449
split_cols: list[str] = []
@@ -438,41 +453,46 @@ def create_contrasts(
438453
split_cols = list(split_by)
439454

440455
allowed_groups = set(groups) if groups is not None else None
456+
selected_set = set(selected_groups)
441457
all_cols = [groupby, *split_cols]
442458

443-
if split_cols:
444-
# Get all existing (groupby, *split) combinations in one pass
445-
existing = adata.obs[all_cols].drop_duplicates().reset_index(drop=True)
459+
parts: list[pd.DataFrame] = []
460+
for sg in selected_groups:
461+
if split_cols:
462+
existing = adata.obs[all_cols].drop_duplicates().reset_index(drop=True)
446463

447-
# Find which splits have the reference
448-
ref_rows = existing[existing[groupby] == selected_group]
449-
if len(ref_rows) == 0:
450-
df = pd.DataFrame(columns=all_cols)
451-
else:
452-
# Inner join: keep only targets in splits that have reference
464+
ref_rows = existing[existing[groupby] == sg]
465+
if len(ref_rows) == 0:
466+
continue
453467
ref_splits = ref_rows[split_cols]
454-
targets = existing[existing[groupby] != selected_group]
468+
targets = existing[~existing[groupby].isin(selected_set)]
455469
if allowed_groups is not None:
456470
targets = targets[targets[groupby].isin(allowed_groups)]
457-
df = targets.merge(ref_splits, on=split_cols, how="inner")
458-
df = (
459-
df[all_cols]
460-
.sort_values([*split_cols, groupby])
461-
.reset_index(drop=True)
462-
)
463-
else:
464-
# No split — just all non-reference levels of groupby
465-
targets = adata.obs[groupby].unique()
466-
targets = [
467-
t
468-
for t in targets
469-
if t != selected_group
470-
and (allowed_groups is None or t in allowed_groups)
471-
]
472-
df = pd.DataFrame({groupby: targets})
473-
474-
# Insert reference column right after groupby
475-
df.insert(1, "reference", selected_group)
471+
matched = targets.merge(ref_splits, on=split_cols, how="inner")
472+
if len(matched) == 0:
473+
continue
474+
matched = matched[all_cols].copy()
475+
else:
476+
target_vals = [
477+
t
478+
for t in adata.obs[groupby].unique()
479+
if t not in selected_set
480+
and (allowed_groups is None or t in allowed_groups)
481+
]
482+
if not target_vals:
483+
continue
484+
matched = pd.DataFrame({groupby: target_vals})
485+
486+
matched.insert(1, "reference", sg)
487+
parts.append(matched)
488+
489+
if not parts:
490+
cols = [groupby, "reference", *split_cols]
491+
return pd.DataFrame(columns=cols)
492+
493+
df = pd.concat(parts, ignore_index=True)
494+
sort_cols = ["reference", *split_cols, groupby]
495+
df = df.sort_values(sort_cols).reset_index(drop=True)
476496

477497
return df
478498

tests/pertpy/test_distances.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,101 @@ def test_contrast_distances_no_split(contrast_adata: AnnData) -> None:
457457
assert np.all(np.isfinite(result["edistance"].values))
458458

459459

460+
def test_contrast_distances_multiple_references() -> None:
461+
"""Test create_contrasts with multiple reference groups."""
462+
rng = np.random.default_rng(42)
463+
n = 10
464+
cpu_emb = rng.normal(size=(n * 6, 5)).astype(np.float32)
465+
obs = pd.DataFrame(
466+
{
467+
"treatment": pd.Categorical(
468+
["ref1"] * n
469+
+ ["ref2"] * n
470+
+ ["drugA"] * n
471+
+ ["drugB"] * n
472+
+ ["ref1"] * n
473+
+ ["drugA"] * n
474+
),
475+
"celltype": pd.Categorical(["T"] * n * 4 + ["B"] * n * 2),
476+
}
477+
)
478+
adata = AnnData(cpu_emb.copy(), obs=obs)
479+
adata.obsm["X_pca"] = cp.asarray(cpu_emb, dtype=cp.float32)
480+
481+
from rapids_singlecell.pertpy_gpu._metrics._edistance import EDistanceMetric
482+
483+
d = EDistanceMetric(obsm_key="X_pca")
484+
distance = Distance(metric="edistance")
485+
486+
# Two references
487+
contrasts = Distance.create_contrasts(
488+
adata,
489+
groupby="treatment",
490+
selected_group=["ref1", "ref2"],
491+
split_by="celltype",
492+
)
493+
494+
# References should not appear as targets
495+
assert "ref1" not in contrasts["treatment"].values
496+
assert "ref2" not in contrasts["treatment"].values
497+
498+
# Both references should appear in the reference column
499+
assert "ref1" in contrasts["reference"].values
500+
assert "ref2" in contrasts["reference"].values
501+
502+
result = distance.contrast_distances(adata, contrasts=contrasts)
503+
assert "edistance" in result.columns
504+
505+
# Verify each row against compute_distance
506+
for _, row in result.iterrows():
507+
mask_target = (adata.obs["treatment"].values == row["treatment"]) & (
508+
adata.obs["celltype"].values == row["celltype"]
509+
)
510+
mask_ref = (adata.obs["treatment"].values == row["reference"]) & (
511+
adata.obs["celltype"].values == row["celltype"]
512+
)
513+
X = adata.obsm["X_pca"][mask_target]
514+
Y = adata.obsm["X_pca"][mask_ref]
515+
516+
if len(X) == 0 or len(Y) == 0:
517+
continue
518+
expected = d.compute_distance(X, Y)
519+
np.testing.assert_allclose(row["edistance"], expected, rtol=1e-5, atol=1e-5)
520+
521+
522+
def test_contrast_distances_multiple_references_no_split() -> None:
523+
"""Test create_contrasts with multiple references and no split_by."""
524+
rng = np.random.default_rng(42)
525+
n = 15
526+
cpu_emb = rng.normal(size=(n * 4, 5)).astype(np.float32)
527+
obs = pd.DataFrame(
528+
{
529+
"treatment": pd.Categorical(
530+
["ref1"] * n + ["ref2"] * n + ["drugA"] * n + ["drugB"] * n
531+
),
532+
}
533+
)
534+
adata = AnnData(cpu_emb.copy(), obs=obs)
535+
adata.obsm["X_pca"] = cp.asarray(cpu_emb, dtype=cp.float32)
536+
537+
distance = Distance(metric="edistance")
538+
539+
contrasts = Distance.create_contrasts(
540+
adata,
541+
groupby="treatment",
542+
selected_group=["ref1", "ref2"],
543+
)
544+
545+
# 2 targets x 2 references = 4 rows
546+
assert len(contrasts) == 4
547+
assert set(contrasts["treatment"].values) == {"drugA", "drugB"}
548+
assert set(contrasts["reference"].values) == {"ref1", "ref2"}
549+
550+
result = distance.contrast_distances(adata, contrasts=contrasts)
551+
assert len(result) == 4
552+
assert np.all(np.isfinite(result["edistance"].values))
553+
554+
460555
def test_contrast_distances_filtered(contrast_adata: AnnData) -> None:
461556
"""Test that filtering a contrasts DataFrame before computing works."""
462557
from rapids_singlecell.pertpy_gpu._metrics._edistance import EDistanceMetric

0 commit comments

Comments
 (0)