Skip to content

Commit b509a16

Browse files
committed
return series for single control
1 parent 8325daa commit b509a16

File tree

3 files changed

+38
-25
lines changed

3 files changed

+38
-25
lines changed

src/rapids_singlecell/pertpy_gpu/_distance.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,12 @@ def onesided_distances(
239239
n_bootstrap: int = 100,
240240
random_state: int = 0,
241241
multi_gpu: bool | list[int] | str | None = None,
242-
) -> pd.Series | tuple[pd.Series, pd.Series]:
242+
) -> (
243+
pd.Series
244+
| pd.DataFrame
245+
| tuple[pd.Series, pd.Series]
246+
| tuple[pd.DataFrame, pd.DataFrame]
247+
):
243248
"""
244249
Compute distances from one selected group to all other groups.
245250

src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,12 @@ def onesided_distances(
162162
n_bootstrap: int = 100,
163163
random_state: int = 0,
164164
multi_gpu: bool | list[int] | str | None = None,
165-
) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
165+
) -> (
166+
pd.Series
167+
| pd.DataFrame
168+
| tuple[pd.Series, pd.Series]
169+
| tuple[pd.DataFrame, pd.DataFrame]
170+
):
166171
"""
167172
Compute energy distances from selected reference group(s) to all other groups.
168173
@@ -175,6 +180,8 @@ def onesided_distances(
175180
selected_group
176181
Reference group(s) to compute distances from. Can be a single
177182
group name or a sequence of group names for multiple controls.
183+
When a single string is passed, returns a Series. When a sequence
184+
is passed, returns a DataFrame with one column per control.
178185
groups
179186
Specific groups to compute distances to (if None, use all)
180187
bootstrap
@@ -194,13 +201,14 @@ def onesided_distances(
194201
Returns
195202
-------
196203
distances
197-
DataFrame with groups as index and selected_group(s) as columns.
204+
Series (single control) or DataFrame (multiple controls).
198205
If bootstrap=True, returns tuple of (distances, distances_var).
199206
"""
200207
_assert_categorical_obs(adata, key=groupby)
201208

202-
# Normalize selected_group to a list
203-
if isinstance(selected_group, str):
209+
# Normalize selected_group to a list, track if input was a string
210+
single_control = isinstance(selected_group, str)
211+
if single_control:
204212
selected_groups = [selected_group]
205213
else:
206214
selected_groups = list(selected_group)
@@ -265,6 +273,9 @@ def onesided_distances(
265273
distances = distances.loc[groups_list]
266274
variances = variances.loc[groups_list]
267275

276+
if single_control:
277+
sg = selected_groups[0]
278+
return distances[sg], variances[sg]
268279
return distances, variances
269280

270281
# Non-bootstrap path
@@ -293,6 +304,8 @@ def onesided_distances(
293304
if groups_list != all_groups:
294305
df = df.loc[groups_list]
295306

307+
if single_control:
308+
return df[selected_groups[0]]
296309
return df
297310

298311
def bootstrap(

tests/pertpy/test_distances.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,11 @@ def test_distance_class_onesided_distances(small_adata: AnnData) -> None:
7474
small_adata, groupby="group", selected_group="g0"
7575
)
7676

77-
assert isinstance(result, pd.DataFrame)
77+
assert isinstance(result, pd.Series)
7878
assert len(result) == 3 # 3 groups total
7979
assert "g0" in result.index
8080
assert "g1" in result.index
8181
assert "g2" in result.index
82-
assert list(result.columns) == ["g0"]
8382

8483

8584
def test_distance_class_onesided_matches_pairwise(small_adata: AnnData) -> None:
@@ -96,10 +95,10 @@ def test_distance_class_onesided_matches_pairwise(small_adata: AnnData) -> None:
9695
)
9796
# Should match the row from pairwise matrix
9897
np.testing.assert_allclose(
99-
onesided[group].values, pairwise_df.loc[group].values, atol=1e-5
98+
onesided.values, pairwise_df.loc[group].values, atol=1e-5
10099
)
101100
# Self-distance should be 0
102-
assert onesided.loc[group, group] == pytest.approx(0.0, abs=1e-6)
101+
assert onesided.loc[group] == pytest.approx(0.0, abs=1e-6)
103102

104103

105104
def test_distance_class_onesided_multiple_controls(small_adata: AnnData) -> None:
@@ -149,18 +148,18 @@ def test_distance_class_onesided_bootstrap(small_adata: AnnData) -> None:
149148
assert len(result) == 2
150149
distances, distances_var = result
151150

152-
assert isinstance(distances, pd.DataFrame)
153-
assert isinstance(distances_var, pd.DataFrame)
151+
assert isinstance(distances, pd.Series)
152+
assert isinstance(distances_var, pd.Series)
154153
assert len(distances) == 3
155154
assert len(distances_var) == 3
156155

157156
# Self-distance variance should be 0
158-
assert distances.loc["g0", "g0"] == pytest.approx(0.0, abs=1e-6)
159-
assert distances_var.loc["g0", "g0"] == pytest.approx(0.0, abs=1e-6)
157+
assert distances.loc["g0"] == pytest.approx(0.0, abs=1e-6)
158+
assert distances_var.loc["g0"] == pytest.approx(0.0, abs=1e-6)
160159

161160
# Non-self variances should be positive
162-
assert distances_var.loc["g1", "g0"] > 0
163-
assert distances_var.loc["g2", "g0"] > 0
161+
assert distances_var.loc["g1"] > 0
162+
assert distances_var.loc["g2"] > 0
164163

165164

166165
def test_distance_class_onesided_bootstrap_matches_pairwise(
@@ -189,11 +188,9 @@ def test_distance_class_onesided_bootstrap_matches_pairwise(
189188
)
190189

191190
# Should match the corresponding row from pairwise
191+
np.testing.assert_allclose(onesided.values, pairwise_df.loc["g0"].values, atol=1e-6)
192192
np.testing.assert_allclose(
193-
onesided["g0"].values, pairwise_df.loc["g0"].values, atol=1e-6
194-
)
195-
np.testing.assert_allclose(
196-
onesided_var["g0"].values, pairwise_var_df.loc["g0"].values, atol=1e-6
193+
onesided_var.values, pairwise_var_df.loc["g0"].values, atol=1e-6
197194
)
198195

199196

@@ -329,7 +326,7 @@ def test_onesided_distances_correctness_vs_cpu(small_adata: AnnData) -> None:
329326
else:
330327
expected = _compute_energy_distance_cpu(X, Y)
331328

332-
actual = onesided.loc[target_group, selected_group]
329+
actual = onesided.loc[target_group]
333330
np.testing.assert_allclose(
334331
actual,
335332
expected,
@@ -1106,12 +1103,10 @@ def test_onesided_output_format(small_adata: AnnData) -> None:
11061103
small_adata, groupby="group", selected_group="g0"
11071104
)
11081105

1109-
assert isinstance(result, pd.DataFrame), (
1110-
"onesided_distances should return DataFrame"
1106+
assert isinstance(result, pd.Series), (
1107+
"onesided_distances with single control should return Series"
11111108
)
11121109
assert result.index.name == "group"
1113-
assert result.columns.name == "selected_group"
1114-
assert list(result.columns) == ["g0"]
11151110

11161111

11171112
# ============================================================================
@@ -1414,7 +1409,7 @@ def test_single_cell_mixed_groups() -> None:
14141409
# Test onesided_distances with single-cell selected group
14151410
onesided = distance.onesided_distances(adata, groupby="group", selected_group="g0")
14161411
assert np.all(np.isfinite(onesided.values)), "Onesided distances should be finite"
1417-
assert onesided.loc["g0", "g0"] == 0.0, "Self-distance should be 0"
1412+
assert onesided.loc["g0"] == 0.0, "Self-distance should be 0"
14181413

14191414

14201415
def test_high_dimensional_features() -> None:

0 commit comments

Comments
 (0)