Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/release-notes/0.15.0.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
### 0.15.0 {small}`the-future`

```{rubric} Features
```
* Allow multiple control groups in ``onesided_distances`` for computing energy distances against several references in a single kernel launch {pr}`601` {smaller}`S Dicks`

```{rubric} Bug fixes
```
* Fix multi-GPU ``cudaErrorLaunchFailure`` during cross-device result aggregation when using RMM without pool allocation for very large datasets {pr}`594` {smaller}`S Dicks`
2 changes: 1 addition & 1 deletion src/rapids_singlecell/pertpy_gpu/_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def onesided_distances(
self,
adata: AnnData,
groupby: str,
selected_group: str,
selected_group: Sequence[str] | str,
*,
groups: Sequence[str] | None = None,
bootstrap: bool = False,
Expand Down
10 changes: 6 additions & 4 deletions src/rapids_singlecell/pertpy_gpu/_metrics/_base_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def onesided_distances(
self,
adata: AnnData,
groupby: str,
selected_group: str,
selected_group: str | Sequence[str],
*,
groups: Sequence[str] | None = None,
bootstrap: bool = False,
Expand All @@ -96,7 +96,7 @@ def onesided_distances(
multi_gpu: bool | list[int] | str | None = None,
):
"""
Compute distances from one selected group to all other groups.
Compute distances from selected reference group(s) to all other groups.

Parameters
----------
Expand All @@ -105,7 +105,8 @@ def onesided_distances(
groupby
Key in adata.obs for grouping cells
selected_group
Reference group to compute distances from
Reference group(s) to compute distances from. Can be a single
group name or a sequence of group names.
groups
Specific groups to compute distances to (if None, use all)
bootstrap
Expand All @@ -125,7 +126,8 @@ def onesided_distances(
Returns
-------
distances
Distance values from selected_group to other groups
DataFrame with distances from selected_group(s) to other groups.
If bootstrap=True, returns tuple of (distances, distances_var).
"""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement onesided_distances"
Expand Down
138 changes: 81 additions & 57 deletions src/rapids_singlecell/pertpy_gpu/_metrics/_edistance.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,16 +155,16 @@ def onesided_distances(
self,
adata: AnnData,
groupby: str,
selected_group: str,
selected_group: str | Sequence[str],
*,
groups: Sequence[str] | None = None,
bootstrap: bool = False,
n_bootstrap: int = 100,
random_state: int = 0,
multi_gpu: bool | list[int] | str | None = None,
) -> pd.Series | tuple[pd.Series, pd.Series]:
) -> pd.DataFrame | tuple[pd.DataFrame, pd.DataFrame]:
"""
Compute energy distances from one selected group to all other groups.
Compute energy distances from selected reference group(s) to all other groups.

Parameters
----------
Expand All @@ -173,7 +173,8 @@ def onesided_distances(
groupby
Key in adata.obs for grouping cells
selected_group
Reference group to compute distances from
Reference group(s) to compute distances from. Can be a single
group name or a sequence of group names for multiple controls.
groups
Specific groups to compute distances to (if None, use all)
bootstrap
Expand All @@ -193,27 +194,34 @@ def onesided_distances(
Returns
-------
distances
Series containing distances from selected_group to all other groups.
DataFrame with groups as index and selected_group(s) as columns.
If bootstrap=True, returns tuple of (distances, distances_var).
"""
_assert_categorical_obs(adata, key=groupby)

# Normalize selected_group to a list
if isinstance(selected_group, str):
selected_groups = [selected_group]
else:
selected_groups = list(selected_group)

embedding = self._get_embedding(adata)
original_groups = adata.obs[groupby]
group_map = {v: i for i, v in enumerate(original_groups.cat.categories.values)}

if selected_group not in group_map:
raise ValueError(
f"Selected group '{selected_group}' not found in groupby '{groupby}'"
)
for sg in selected_groups:
if sg not in group_map:
raise ValueError(
f"Selected group '{sg}' not found in groupby '{groupby}'"
)

group_labels = cp.array([group_map[c] for c in original_groups], dtype=cp.int32)
k = len(group_map)
cat_offsets, cell_indices = _create_category_index_mapping(group_labels, k)

all_groups = list(original_groups.cat.categories.values)
groups_list = all_groups if groups is None else list(groups)
selected_idx = group_map[selected_group]
selected_indices = [group_map[sg] for sg in selected_groups]

device_ids = parse_device_ids(multi_gpu=multi_gpu)

Expand All @@ -223,63 +231,69 @@ def onesided_distances(
cat_offsets=cat_offsets,
cell_indices=cell_indices,
k=k,
selected_idx=selected_idx,
selected_indices=selected_indices,
n_bootstrap=n_bootstrap,
random_state=random_state,
device_ids=device_ids,
)

# Compute energy distances: e[s,b] = 2*d[s,b] - d[s,s] - d[b,b]
diag_means = cp.diag(onesided_means)
row = onesided_means[selected_idx, :]
edistances = 2 * row - diag_means[selected_idx] - diag_means
edistances[selected_idx] = 0.0

# Variance: var[s,b] = 4*var[s,b] + var[s,s] + var[b,b]
diag_vars = cp.diag(onesided_vars)
ed_vars = (
4 * onesided_vars[selected_idx, :] + diag_vars[selected_idx] + diag_vars
)
ed_vars[selected_idx] = 0.0

series_name = f"edistance to {selected_group}"
distances = pd.Series(edistances.get(), index=all_groups, name=series_name)
# Compute energy distances for each control:
# e[s,b] = 2*d[s,b] - d[s,s] - d[b,b]
ed_cols = {}
var_cols = {}
for sg, si in zip(selected_groups, selected_indices):
ed_row = 2 * onesided_means[si, :] - diag_means[si] - diag_means
ed_row[si] = 0.0
ed_cols[sg] = ed_row.get()

var_row = 4 * onesided_vars[si, :] + diag_vars[si] + diag_vars
var_row[si] = 0.0
var_cols[sg] = var_row.get()

distances = pd.DataFrame(ed_cols, index=all_groups)
distances.index.name = groupby
variances = pd.Series(ed_vars.get(), index=all_groups, name=series_name)
distances.columns.name = "selected_group"

variances = pd.DataFrame(var_cols, index=all_groups)
variances.index.name = groupby
variances.columns.name = "selected_group"

if groups_list != all_groups:
distances = distances.loc[groups_list]
variances = variances.loc[groups_list]

return distances, variances

# Non-bootstrap path: compute onesided means directly
# Non-bootstrap path
onesided_means = self._onesided_means(
embedding,
cat_offsets,
cell_indices,
k,
selected_idx=selected_idx,
selected_indices=selected_indices,
device_ids=device_ids,
)

# Compute energy distances: e[s,b] = 2*d[s,b] - d[s,s] - d[b,b]
# Compute energy distances for each control:
# e[s,b] = 2*d[s,b] - d[s,s] - d[b,b]
diag = cp.diag(onesided_means)
edistances = 2 * onesided_means[selected_idx, :] - diag[selected_idx] - diag
edistances[selected_idx] = 0.0 # Self-distance is 0
ed_cols = {}
for sg, si in zip(selected_groups, selected_indices):
ed_row = 2 * onesided_means[si, :] - diag[si] - diag
ed_row[si] = 0.0
ed_cols[sg] = ed_row.get()

# Create Series with pertpy-compatible name
series = pd.Series(
edistances.get(), index=all_groups, name=f"edistance to {selected_group}"
)
series.index.name = groupby
df = pd.DataFrame(ed_cols, index=all_groups)
df.index.name = groupby
df.columns.name = "selected_group"

# Filter to requested groups if needed
if groups_list != all_groups:
series = series.loc[groups_list]
df = df.loc[groups_list]

return series
return df

def bootstrap(
self,
Expand Down Expand Up @@ -681,16 +695,16 @@ def _onesided_means(
cell_indices: cp.ndarray,
k: int,
*,
selected_idx: int,
selected_indices: list[int],
device_ids: list[int],
) -> cp.ndarray:
"""Compute mean distances from selected group to all groups.
"""Compute mean distances from selected group(s) to all groups.

Splits pairs across specified GPUs and aggregates results on GPU 0.

Computes:
- d[selected_idx, i] for all i (cross-distances)
- d[i, i] for all i (self-distances needed for energy distance formula)
- d[s, i] for each s in selected_indices, for all i (cross-distances)
- d[i, i] for non-selected groups (self-distances for energy distance)

Parameters
----------
Expand All @@ -702,8 +716,8 @@ def _onesided_means(
Cell indices on GPU 0
k
Number of groups
selected_idx
Index of the selected group
selected_indices
Indices of the selected (control) groups
device_ids
List of GPU device IDs to use

Expand All @@ -718,18 +732,28 @@ def _onesided_means(
# Get group sizes
group_sizes = cp.diff(cat_offsets).astype(cp.int64)

# Build pairs for onesided computation
all_indices = cp.arange(k, dtype=cp.int32)
cross_left = cp.full(k, selected_idx, dtype=cp.int32)
cross_right = all_indices
# Build pairs for onesided computation.
# The kernel symmetrizes: for pair (a,b) it writes to both
# sums[a,b] and sums[b,a]. So we must avoid having both (i,j)
# and (j,i) in the pair list to prevent double-counting.
selected_set = set(selected_indices)

# Collect unique pairs as a set of (min, max) tuples
pair_set: set[tuple[int, int]] = set()

# Cross pairs: (s, i) for each selected s and all i
for si in selected_indices:
for i in range(k):
pair_set.add((min(si, i), max(si, i)))

# Diagonal pairs for other groups with >= 2 cells
mask = (all_indices != selected_idx) & (group_sizes >= 2)
other_diag = all_indices[mask]
# Diagonal pairs for non-selected groups with >= 2 cells
for i in range(k):
if i not in selected_set and int(group_sizes[i]) >= 2:
pair_set.add((i, i))

# Combine all pairs
pair_left = cp.concatenate([cross_left, other_diag])
pair_right = cp.concatenate([cross_right, other_diag])
pairs = sorted(pair_set)
pair_left = cp.array([p[0] for p in pairs], dtype=cp.int32)
pair_right = cp.array([p[1] for p in pairs], dtype=cp.int32)
num_pairs = len(pair_left)

if num_pairs == 0:
Expand Down Expand Up @@ -913,7 +937,7 @@ def _onesided_means_bootstrap(
cat_offsets: cp.ndarray,
cell_indices: cp.ndarray,
k: int,
selected_idx: int,
selected_indices: list[int],
n_bootstrap: int,
random_state: int,
device_ids: list[int],
Expand All @@ -932,8 +956,8 @@ def _onesided_means_bootstrap(
Cell indices on GPU 0
k
Number of groups
selected_idx
Index of the selected group
selected_indices
Indices of the selected (control) groups
n_bootstrap
Number of bootstrap iterations
random_state
Expand Down Expand Up @@ -966,7 +990,7 @@ def _onesided_means_bootstrap(
cat_offsets=boot_cat_offsets,
cell_indices=boot_cell_indices,
k=k,
selected_idx=selected_idx,
selected_indices=selected_indices,
device_ids=device_ids,
)
all_results.append(onesided_means.get())
Expand Down
Loading