Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/0.14.1.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* Adds `base` parameter to `pp.log1p` and `exclude_highly_expressed`/`max_fraction` parameters to `pp.normalize_total`, matching scanpy's API {pr}`565` {smaller}`S Dicks`
* Adds Dask support and categorical covariate support to `pp.regress_out` {pr}`567` {smaller}`S Dicks`
* Adds `wilcoxon_binned` method to `tl.rank_genes_groups` — a histogram-based approximate Wilcoxon rank-sum test that supports Dask arrays {pr}`568` {smaller}`S Dicks`
* Refactors `tl.rank_genes_groups` internals to use categorical integer codes instead of boolean mask matrices {pr}`570` {smaller}`S Dicks`

```{rubric} Performance
```
Expand Down
77 changes: 32 additions & 45 deletions src/rapids_singlecell/tools/_rank_genes_groups/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ._utils import EPS, _select_groups, _select_top_n

if TYPE_CHECKING:
from collections.abc import Generator, Iterable
from collections.abc import Iterable

from anndata import AnnData
from numpy.typing import NDArray
Expand All @@ -42,40 +42,31 @@ def __init__(
) -> None:
# Handle groups parameter
if groups == "all" or groups is None:
groups_order: Literal["all"] | list[str] = "all"
selected: list | None = None
elif isinstance(groups, str | int):
msg = "Specify a sequence of groups"
raise ValueError(msg)
else:
groups_order = list(groups)
if isinstance(groups_order[0], int):
groups_order = [str(n) for n in groups_order]
if reference != "rest" and reference not in set(groups_order):
groups_order += [reference]
selected = list(groups)
if len(selected) > 0 and isinstance(selected[0], int):
selected = [str(n) for n in selected]
if reference != "rest" and reference not in set(selected):
selected.append(reference)

self.labels = pd.Series(adata.obs[groupby]).reset_index(drop=True)
all_categories = self.labels.cat.categories

if reference != "rest" and reference not in set(self.labels.cat.categories):
cats = self.labels.cat.categories.tolist()
if reference != "rest" and str(reference) not in {
str(c) for c in all_categories
}:
cats = all_categories.tolist()
msg = f"reference = {reference} needs to be one of groupby = {cats}."
raise ValueError(msg)

self.groups_order, self.groups_masks_obs = _select_groups(
self.labels, groups_order
self.groups_order, self.group_codes, self.group_sizes = _select_groups(
self.labels, selected
)

# Validate singlet groups
invalid_groups = set()
for name, mask in zip(self.groups_order, self.groups_masks_obs, strict=True):
if np.count_nonzero(mask) < 2:
invalid_groups.add(str(name))
if invalid_groups:
msg = (
f"Could not calculate statistics for groups {', '.join(invalid_groups)} "
"since they only contain one sample."
)
raise ValueError(msg)

# Get data matrix
if layer is not None:
if use_raw is True:
Expand Down Expand Up @@ -105,7 +96,7 @@ def __init__(

self.ireference = None
if reference != "rest":
self.ireference = np.where(self.groups_order == reference)[0][0]
self.ireference = int(np.where(self.groups_order == str(reference))[0][0])

# Set up expm1 function based on log base
self.is_log1p = "log1p" in adata.uns
Expand All @@ -115,10 +106,6 @@ def __init__(
else:
self.expm1_func = np.expm1

# For logreg
self.grouping_mask = self.labels.isin(pd.Series(self.groups_order))
self.grouping = self.labels.loc[self.grouping_mask]

# For basic stats
self.comp_pts = comp_pts
self.means: np.ndarray | None = None
Expand Down Expand Up @@ -190,8 +177,8 @@ def _basic_stats(self) -> None:
else:
result = agg.count_mean_var_dense(dof=1)

# Map results to selected groups order
cat_names = list(self.labels.cat.categories)
# Map Aggregate results (indexed by original category) to selected groups
cat_names = [str(c) for c in self.labels.cat.categories]

means = np.zeros((n_groups, n_genes), dtype=np.float64)
vars_ = np.zeros((n_groups, n_genes), dtype=np.float64)
Expand All @@ -204,8 +191,10 @@ def _basic_stats(self) -> None:
means[idx] = cp.asnumpy(result["mean"][cat_idx])
vars_[idx] = cp.asnumpy(result["var"][cat_idx])
if self.comp_pts:
n_cells = self.groups_masks_obs[idx].sum()
pts[idx] = cp.asnumpy(result["count_nonzero"][cat_idx]) / n_cells
pts[idx] = (
cp.asnumpy(result["count_nonzero"][cat_idx])
/ self.group_sizes[idx]
)

self.means = means
# Clip tiny negative variances to 0 (floating-point precision artifacts)
Expand All @@ -222,7 +211,7 @@ def _basic_stats(self) -> None:
else None
)

n_cells_per_group = np.array([mask.sum() for mask in self.groups_masks_obs])
n_cells_per_group = self.group_sizes
total_sum = self.means * n_cells_per_group[:, None]
total_sum_all = total_sum.sum(axis=0)

Expand Down Expand Up @@ -367,15 +356,15 @@ def _accumulate_chunk_stats_with_ref(

def t_test(
self, method: Literal["t-test", "t-test_overestim_var"]
) -> Generator[tuple[int, NDArray, NDArray], None, None]:
) -> list[tuple[int, NDArray, NDArray]]:
"""Compute t-test statistics using Welch's t-test."""
from ._ttest import t_test

return t_test(self, method)

def wilcoxon(
self, *, tie_correct: bool, chunk_size: int | None = None
) -> Generator[tuple[int, NDArray, NDArray], None, None]:
) -> list[tuple[int, NDArray, NDArray]]:
"""Compute Wilcoxon rank-sum test statistics."""
from ._wilcoxon import wilcoxon

Expand All @@ -387,15 +376,15 @@ def wilcoxon_binned(
n_bins: int | None = None,
chunk_size: int | None = None,
bin_range: Literal["log1p", "auto"] | None = None,
) -> Generator[tuple[int, NDArray, NDArray], None, None]:
"""Histogram-based approximate Wilcoxon rank-sum test (one-vs-rest)."""
) -> list[tuple[int, NDArray, NDArray]]:
"""Histogram-based approximate Wilcoxon rank-sum test."""
from ._wilcoxon_binned import wilcoxon_binned

return wilcoxon_binned(
self, n_bins=n_bins, chunk_size=chunk_size, bin_range=bin_range
)

def logreg(self, **kwds) -> Generator[tuple[int, NDArray, None], None, None]:
def logreg(self, **kwds) -> list[tuple[int, NDArray, None]]:
"""Compute logistic regression scores."""
from ._logreg import logreg

Expand All @@ -422,20 +411,18 @@ def compute_statistics(
}:
self.X = X_to_GPU(self.X)
if method in {"t-test", "t-test_overestim_var"}:
generate_test_results = self.t_test(method)
test_results = self.t_test(method)
elif method == "wilcoxon":
if isinstance(self.X, DaskArray):
msg = "Wilcoxon test is not supported for Dask arrays. Please convert your data to CuPy arrays."
raise ValueError(msg)
generate_test_results = self.wilcoxon(
tie_correct=tie_correct, chunk_size=chunk_size
)
test_results = self.wilcoxon(tie_correct=tie_correct, chunk_size=chunk_size)
elif method == "wilcoxon_binned":
generate_test_results = self.wilcoxon_binned(
test_results = self.wilcoxon_binned(
n_bins=n_bins, chunk_size=chunk_size, bin_range=bin_range
)
elif method == "logreg":
generate_test_results = self.logreg(**kwds)
test_results = self.logreg(**kwds)
else:
assert_never(method)

Expand All @@ -444,7 +431,7 @@ def compute_statistics(
# Collect all stats data first to avoid DataFrame fragmentation
stats_data: dict[tuple[str, str], np.ndarray] = {}

for group_index, scores, pvals in generate_test_results:
for group_index, scores, pvals in test_results:
group_name = str(self.groups_order[group_index])

if n_genes_user is not None:
Expand Down
28 changes: 13 additions & 15 deletions src/rapids_singlecell/tools/_rank_genes_groups/_logreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,25 @@
from typing import TYPE_CHECKING

import cupy as cp
import numpy as np

from rapids_singlecell._compat import DaskArray, _meta_dense

if TYPE_CHECKING:
from collections.abc import Generator

from numpy.typing import NDArray

from ._core import _RankGenes


def logreg(rg: _RankGenes, **kwds) -> Generator[tuple[int, NDArray, None], None, None]:
def logreg(rg: _RankGenes, **kwds) -> list[tuple[int, NDArray, None]]:
"""Compute logistic regression scores."""
if len(rg.groups_order) == 1:
msg = "Cannot perform logistic regression on a single cluster."
raise ValueError(msg)

X = rg.X[rg.grouping_mask.values, :]

grouping_logreg = rg.grouping.cat.codes.to_numpy().astype(X.dtype)
uniques = np.unique(grouping_logreg)
for idx, cat in enumerate(uniques):
grouping_logreg[np.where(grouping_logreg == cat)] = idx
n_groups = len(rg.groups_order)
selected = rg.group_codes < n_groups
X = rg.X[selected, :]
grouping_logreg = rg.group_codes[selected].astype(X.dtype)

if isinstance(X, DaskArray):
import dask.array as da
Expand All @@ -43,16 +38,19 @@ def logreg(rg: _RankGenes, **kwds) -> Generator[tuple[int, NDArray, None], None,
clf = LogisticRegression(**kwds)
clf.fit(X, grouping_logreg)
scores_all = cp.array(clf.coef_)
if len(rg.groups_order) == scores_all.shape[1]:
if n_groups == scores_all.shape[1]:
scores_all = scores_all.T

for igroup, _group in enumerate(rg.groups_order):
if len(rg.groups_order) <= 2:
results: list[tuple[int, NDArray, None]] = []
for igroup in range(n_groups):
if n_groups <= 2:
scores = scores_all[0].get()
else:
scores = scores_all[igroup].get()

yield igroup, scores, None
results.append((igroup, scores, None))

if len(rg.groups_order) <= 2:
if n_groups <= 2:
break

return results
20 changes: 11 additions & 9 deletions src/rapids_singlecell/tools/_rank_genes_groups/_ttest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,33 +5,33 @@
import numpy as np

if TYPE_CHECKING:
from collections.abc import Generator

from numpy.typing import NDArray

from ._core import _RankGenes


def t_test(
rg: _RankGenes, method: Literal["t-test", "t-test_overestim_var"]
) -> Generator[tuple[int, NDArray, NDArray], None, None]:
) -> list[tuple[int, NDArray, NDArray]]:
"""Compute t-test statistics using Welch's t-test."""
from scipy import stats

rg._basic_stats()

for group_index, (mask_obs, mean_group, var_group) in enumerate(
zip(rg.groups_masks_obs, rg.means, rg.vars, strict=True)
):
results: list[tuple[int, NDArray, NDArray]] = []

for group_index in range(len(rg.groups_order)):
if rg.ireference is not None and group_index == rg.ireference:
continue

ns_group = np.count_nonzero(mask_obs)
mean_group = rg.means[group_index]
var_group = rg.vars[group_index]
ns_group = int(rg.group_sizes[group_index])

if rg.ireference is not None:
mean_rest = rg.means[rg.ireference]
var_rest = rg.vars[rg.ireference]
ns_other = np.count_nonzero(rg.groups_masks_obs[rg.ireference])
ns_other = int(rg.group_sizes[rg.ireference])
else:
mean_rest = rg.means_rest[group_index]
var_rest = rg.vars_rest[group_index]
Expand Down Expand Up @@ -62,4 +62,6 @@ def t_test(
scores[np.isnan(scores)] = 0
pvals[np.isnan(pvals)] = 1

yield group_index, scores, pvals
results.append((group_index, scores, pvals))

return results
Loading
Loading