Skip to content

Commit d8d0e80

Browse files
committed
no rng warning
1 parent 65fb583 commit d8d0e80

File tree

2 files changed

+13
-33
lines changed

2 files changed

+13
-33
lines changed

src/scanpy/preprocessing/_pca/__init__.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,7 @@
1111
from ..._docs import doc_rng
1212
from ..._settings import settings
1313
from ..._utils import _doc_params, _empty, get_literal_vals, is_backed_type
14-
from ..._utils.random import (
15-
_accepts_legacy_random_state,
16-
_legacy_random_state,
17-
_LegacyRng,
18-
)
14+
from ..._utils.random import _accepts_legacy_random_state, _legacy_random_state
1915
from ...get import _check_mask, _get_obs_rep
2016
from .._docs import doc_mask_var_hvg
2117
from ._compat import _pca_compat_sparse
@@ -207,7 +203,6 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915
207203
"""
208204
logg_start = logg.info("computing PCA")
209205
rng = np.random.default_rng(rng)
210-
rng_is_default = isinstance(rng, _LegacyRng) and rng.arg == 0
211206
if (layer is not None or obsm is not None) and chunked:
212207
# Current chunking implementation relies on pca being called on X
213208
msg = "Cannot use `layer`/`obsm` and `chunked` at the same time."
@@ -247,7 +242,7 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915
247242
raise NotImplementedError(msg)
248243

249244
if chunked:
250-
if not zero_center or not rng_is_default or svd_solver not in {None, "arpack"}:
245+
if not zero_center or svd_solver not in {None, "arpack"}:
251246
logg.debug("Ignoring zero_center, rng, svd_solver")
252247

253248
incremental_pca_kwargs = dict()
@@ -296,14 +291,6 @@ def pca( # noqa: PLR0912, PLR0913, PLR0915
296291
elif isinstance(x._meta, CSBase) or svd_solver == "covariance_eigh":
297292
from ._dask import PCAEighDask
298293

299-
if not rng_is_default:
300-
dbg = (
301-
f"random_state={_legacy_random_state(rng)!r}"
302-
if isinstance(rng, _LegacyRng)
303-
else f"rng={rng!r}"
304-
)
305-
msg = f"Ignoring {dbg} when using a sparse dask array"
306-
warn(msg, UserWarning)
307294
if svd_solver not in {None, "covariance_eigh"}:
308295
msg = f"Ignoring {svd_solver=} when using a sparse dask array"
309296
warn(msg, UserWarning)

tests/test_pca.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import warnings
4-
from contextlib import ExitStack, nullcontext
4+
from contextlib import nullcontext
55
from typing import TYPE_CHECKING, Literal
66

77
import numpy as np
@@ -242,19 +242,17 @@ def test_pca_transform_randomized(array_type):
242242
a_pca_abs = np.abs(A_pca)
243243

244244
if isinstance(adata.X, DaskArray) and isinstance(adata.X._meta, CSBase):
245-
patterns = (
246-
r"Ignoring random_state=14 when using a sparse dask array",
247-
r"Ignoring svd_solver='randomized' when using a sparse dask array",
245+
ctx = pytest.warns(
246+
UserWarning,
247+
match=r"Ignoring svd_solver='randomized' when using a sparse dask array",
248248
)
249249
elif isinstance(adata.X, CSBase):
250-
patterns = [r"Ignoring.*'randomized"]
250+
ctx = pytest.warns(UserWarning, match=r"Ignoring.*'randomized")
251251
else:
252-
patterns = []
252+
ctx = nullcontext()
253253

254254
warnings.filterwarnings("error")
255-
with ExitStack() as stack:
256-
for pat in patterns:
257-
stack.enter_context(pytest.warns(UserWarning, match=pat))
255+
with ctx:
258256
sc.pp.pca(
259257
adata,
260258
n_comps=4,
@@ -339,15 +337,10 @@ def test_pca_reproducible(
339337
pbmc = pbmc3k_normalized()
340338
pbmc.X = array_type(pbmc.X)
341339

342-
with (
343-
pytest.warns(UserWarning, match=rf"Ignoring {rng_arg}=.*sparse dask array")
344-
if isinstance(pbmc.X, DaskArray) and isinstance(pbmc.X._meta, CSBase)
345-
else nullcontext()
346-
):
347-
a, b, c = (
348-
sc.pp.pca(pbmc, copy=True, dtype=np.float64, **{rng_arg: seed})
349-
for seed in (42, 42, 0)
350-
)
340+
a, b, c = (
341+
sc.pp.pca(pbmc, copy=True, dtype=np.float64, **{rng_arg: seed})
342+
for seed in (42, 42, 0)
343+
)
351344

352345
with subtests.test("reproducible"):
353346
assert_equal(a, b)

0 commit comments

Comments
 (0)