Skip to content

Commit 0aec2ef

Browse files
authored
feat: add obsm parameter to normalize_total and pca (#3863)
1 parent 5766024 commit 0aec2ef

File tree

18 files changed

+211
-158
lines changed

18 files changed

+211
-158
lines changed

docs/release-notes/3863.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Allow operating on :attr:`~anndata.AnnData.obsm` arrays in {func}`scanpy.pp.normalize_total` and {func}`scanpy.pp.pca` {smaller}`P Angerer`

src/scanpy/_utils/__init__.py

Lines changed: 46 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import importlib.util
1010
import inspect
1111
import re
12-
import sys
1312
import warnings
1413
from contextlib import suppress
1514
from enum import Enum
@@ -64,6 +63,37 @@
6463
_ForT = TypeVar("_ForT", bound=Callable | type)
6564

6665

66+
__all__ = [
67+
"AssoResult",
68+
"Empty",
69+
"NeighborsView",
70+
"_choose_graph",
71+
"_doc_params",
72+
"_empty",
73+
"_resolve_axis",
74+
"annotate_doc_types",
75+
"axis_mul_or_truediv",
76+
"axis_nnz",
77+
"check_array_function_arguments",
78+
"check_nonnegative_integers",
79+
"check_presence_download",
80+
"check_use_raw",
81+
"compute_association_matrix_of_groups",
82+
"descend_classes_and_funcs",
83+
"ensure_igraph",
84+
"get_literal_vals",
85+
"indent",
86+
"is_backed_type",
87+
"is_backed_type",
88+
"raise_not_implemented_error_if_backed_type",
89+
"renamed_arg",
90+
"sanitize_anndata",
91+
"select_groups",
92+
"update_params",
93+
"warn_once",
94+
]
95+
96+
6797
LegacyUnionType = type(Union[int, str]) # noqa: UP007
6898

6999

@@ -88,7 +118,7 @@ def ensure_igraph() -> None:
88118
raise ImportError(msg)
89119

90120

91-
def getdoc(c_or_f: Callable | type) -> str | None:
121+
def _getdoc(c_or_f: Callable | type) -> str | None:
92122
if getattr(c_or_f, "__doc__", None) is None:
93123
return None
94124
doc = inspect.getdoc(c_or_f)
@@ -142,7 +172,7 @@ def wrapper(*args, **kwargs):
142172
return decorator
143173

144174

145-
def _import_name(full_name: str) -> Any:
175+
def import_name(full_name: str) -> Any:
146176
from importlib import import_module
147177

148178
parts = full_name.split(".")
@@ -197,7 +227,7 @@ def descend_classes_and_funcs(mod: ModuleType, root: str, encountered=None):
197227
def annotate_doc_types(mod: ModuleType, root: str):
198228
for c_or_f in descend_classes_and_funcs(mod, root):
199229
with suppress(AttributeError):
200-
c_or_f.getdoc = partial(getdoc, c_or_f)
230+
c_or_f.getdoc = partial(_getdoc, c_or_f)
201231

202232

203233
_leading_whitespace_re = re.compile("(^[ ]*)(?:[^ \n])", re.MULTILINE)
@@ -227,7 +257,7 @@ def dec(obj: _ForT) -> _ForT:
227257
return dec
228258

229259

230-
def _check_array_function_arguments(**kwargs):
260+
def check_array_function_arguments(**kwargs):
231261
"""Check for invalid arguments when an array is passed.
232262
233263
Helper for functions that work on either AnnData objects or array-likes.
@@ -239,7 +269,7 @@ def _check_array_function_arguments(**kwargs):
239269
raise TypeError(msg)
240270

241271

242-
def _check_use_raw(
272+
def check_use_raw(
243273
adata: AnnData,
244274
use_raw: None | bool, # noqa: FBT001
245275
*,
@@ -540,14 +570,14 @@ def get_literal_vals(typ: UnionType | Any) -> KeysView[Any]:
540570
Scaling_T = TypeVar("Scaling_T", DaskArray, np.ndarray)
541571

542572

543-
def broadcast_axis(divisor: Scaling_T, axis: Literal[0, 1]) -> Scaling_T:
573+
def _broadcast_axis(divisor: Scaling_T, axis: Literal[0, 1]) -> Scaling_T:
544574
divisor = np.ravel(divisor)
545575
if axis:
546576
return divisor[None, :]
547577
return divisor[:, None]
548578

549579

550-
def check_op(op):
580+
def _check_op(op) -> None:
551581
if op not in {truediv, mul}:
552582
msg = f"{op} not one of truediv or mul"
553583
raise ValueError(msg)
@@ -564,8 +594,8 @@ def axis_mul_or_truediv(
564594
allow_divide_by_zero: bool = True,
565595
out: ArrayLike | None = None,
566596
) -> np.ndarray:
567-
check_op(op)
568-
scaling_array = broadcast_axis(scaling_array, axis)
597+
_check_op(op)
598+
scaling_array = _broadcast_axis(scaling_array, axis)
569599
if op is mul:
570600
return np.multiply(x, scaling_array, out=out)
571601
if not allow_divide_by_zero:
@@ -584,7 +614,7 @@ def _(
584614
allow_divide_by_zero: bool = True,
585615
out: CSBase | None = None,
586616
) -> CSBase:
587-
check_op(op)
617+
_check_op(op)
588618
if out is not None and x.data is not out.data:
589619
msg = "`out` argument provided but not equal to X. This behavior is not supported for sparse matrix scaling."
590620
raise ValueError(msg)
@@ -621,7 +651,7 @@ def new_data_op(x):
621651
).T
622652

623653

624-
def make_axis_chunks(
654+
def _make_axis_chunks(
625655
x: DaskArray, axis: Literal[0, 1]
626656
) -> tuple[tuple[int], tuple[int]]:
627657
if axis == 0:
@@ -640,14 +670,14 @@ def _(
640670
allow_divide_by_zero: bool = True,
641671
out: None = None,
642672
) -> DaskArray:
643-
check_op(op)
673+
_check_op(op)
644674
if out is not None:
645675
msg = "`out` is not `None`. Do not do in-place modifications on dask arrays."
646676
raise TypeError(msg)
647677

648678
import dask.array as da
649679

650-
scaling_array = broadcast_axis(scaling_array, axis)
680+
scaling_array = _broadcast_axis(scaling_array, axis)
651681
row_scale = axis == 0
652682
column_scale = axis == 1
653683

@@ -668,11 +698,11 @@ def _(
668698
warnings.warn(
669699
"Rechunking scaling_array in user operation", UserWarning, stacklevel=3
670700
)
671-
scaling_array = scaling_array.rechunk(make_axis_chunks(x, axis))
701+
scaling_array = scaling_array.rechunk(_make_axis_chunks(x, axis))
672702
else:
673703
scaling_array = da.from_array(
674704
scaling_array,
675-
chunks=make_axis_chunks(x, axis),
705+
chunks=_make_axis_chunks(x, axis),
676706
)
677707
return da.map_blocks(
678708
axis_mul_or_truediv,
@@ -802,27 +832,6 @@ def select_groups(
802832
return groups_order_subset, groups_masks_obs
803833

804834

805-
def warn_with_traceback( # noqa: PLR0917
806-
message, category, filename, lineno, file=None, line=None
807-
) -> None:
808-
"""Get full tracebacks when warning is raised by setting.
809-
810-
warnings.showwarning = warn_with_traceback
811-
812-
See Also
813-
--------
814-
https://stackoverflow.com/questions/22373927/get-traceback-of-warnings
815-
816-
"""
817-
import traceback
818-
819-
traceback.print_stack()
820-
log = ( # noqa: F841 # TODO Does this need fixing?
821-
file if hasattr(file, "write") else sys.stderr
822-
)
823-
settings.write(warnings.formatwarning(message, category, filename, lineno, line))
824-
825-
826835
def warn_once(msg: str, category: type[Warning], stacklevel: int = 0) -> None:
827836
warnings.warn(msg, category, stacklevel=stacklevel + 1)
828837
# You'd think `'once'` works, but it doesn't at the repl and in notebooks
@@ -837,19 +846,6 @@ def check_presence_download(filename: Path, backup_url):
837846
_download(backup_url, filename)
838847

839848

840-
def lazy_import(full_name):
841-
"""Import a module in a way that it’s only executed on member access."""
842-
try:
843-
return sys.modules[full_name]
844-
except KeyError:
845-
spec = importlib.util.find_spec(full_name)
846-
module = importlib.util.module_from_spec(spec)
847-
loader = importlib.util.LazyLoader(spec.loader)
848-
# Make module with proper locking and get it inserted into sys.modules.
849-
loader.exec_module(module)
850-
return module
851-
852-
853849
# --------------------------------------------------------------------------------
854850
# Neighbors
855851
# --------------------------------------------------------------------------------

src/scanpy/experimental/pp/_normalization.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def normalize_pearson_residuals(
9696
clip: float | None = None,
9797
check_values: bool = True,
9898
layer: str | None = None,
99+
obsm: str | None = None,
99100
inplace: bool = True,
100101
copy: bool = False,
101102
) -> AnnData | dict[str, np.ndarray] | None:
@@ -138,8 +139,8 @@ def normalize_pearson_residuals(
138139
adata = adata.copy()
139140

140141
view_to_actual(adata)
141-
x = _get_obs_rep(adata, layer=layer)
142-
computed_on = layer if layer else "adata.X"
142+
x = _get_obs_rep(adata, layer=layer, obsm=obsm)
143+
computed_on = layer or obsm or "adata.X"
143144

144145
msg = f"computing analytic Pearson residuals on {computed_on}"
145146
start = logg.info(msg)
@@ -148,7 +149,7 @@ def normalize_pearson_residuals(
148149
settings_dict = dict(theta=theta, clip=clip, computed_on=computed_on)
149150

150151
if inplace:
151-
_set_obs_rep(adata, residuals, layer=layer)
152+
_set_obs_rep(adata, residuals, layer=layer, obsm=obsm)
152153
adata.uns["pearson_residuals_normalization"] = settings_dict
153154
else:
154155
results_dict = dict(X=residuals, **settings_dict)

src/scanpy/get/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
from .get import (
77
_check_mask,
88
_get_obs_rep,
9+
_ObsRep,
910
_set_obs_rep,
1011
obs_df,
1112
rank_genes_groups_df,
1213
var_df,
1314
)
1415

1516
__all__ = [
17+
"_ObsRep",
1618
"_check_mask",
1719
"_get_obs_rep",
1820
"_set_obs_rep",

src/scanpy/get/get.py

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from __future__ import annotations
44

5-
from typing import TYPE_CHECKING, TypeVar
5+
from typing import TYPE_CHECKING, TypedDict, TypeVar
66

77
import numpy as np
88
import pandas as pd
@@ -13,7 +13,7 @@
1313

1414
if TYPE_CHECKING:
1515
from collections.abc import Collection, Iterable
16-
from typing import Any, Literal
16+
from typing import Any, Literal, Unpack
1717

1818
from anndata._core.sparse_dataset import BaseCompressedSparseDataset
1919
from anndata._core.views import ArrayView
@@ -399,43 +399,45 @@ def var_df(
399399
return df
400400

401401

402+
class _ObsRep(TypedDict, total=False):
403+
use_raw: bool
404+
layer: str | None
405+
obsm: str | None
406+
obsp: str | None
407+
408+
402409
def _get_obs_rep(
403-
adata: AnnData,
404-
*,
405-
use_raw: bool = False,
406-
layer: str | None = None,
407-
obsm: str | None = None,
408-
obsp: str | None = None,
410+
adata: AnnData, **choices: Unpack[_ObsRep]
409411
) -> (
410412
np.ndarray | CSBase | pd.DataFrame | ArrayView | BaseCompressedSparseDataset | None
411413
):
412414
"""Choose array aligned with obs annotation."""
413415
# https://github.com/scverse/scanpy/issues/1546
414-
if not isinstance(use_raw, bool):
416+
if not isinstance(use_raw := choices.get("use_raw", False), bool):
415417
msg = f"use_raw expected to be bool, was {type(use_raw)}."
416418
raise TypeError(msg)
417-
418-
is_layer = layer is not None
419-
is_raw = use_raw is not False
420-
is_obsm = obsm is not None
421-
is_obsp = obsp is not None
422-
choices_made = sum((is_layer, is_raw, is_obsm, is_obsp))
423-
assert choices_made in {0, 1}
424-
if choices_made == 0:
425-
return adata.X
426-
if is_layer:
427-
return adata.layers[layer]
428-
if use_raw:
429-
return adata.raw.X
430-
if is_obsm:
431-
return adata.obsm[obsm]
432-
if is_obsp:
433-
return adata.obsp[obsp]
434-
msg = (
435-
"That was unexpected. Please report this bug at:\n\n\t"
436-
"https://github.com/scverse/scanpy/issues"
437-
)
438-
raise AssertionError(msg)
419+
assert choices.keys() <= {"layer", "use_raw", "obsm", "obsp"}
420+
421+
# we do this here so the `case _` branch knows which ones are valid for the
422+
# respective calling function. E.g. `_get_obs_rep(adata, layer="a", obsm="b")`
423+
# will say that “Only one of `layer` or `obsm` can be specified.”
424+
match [(k, v) for k, v in choices.items() if v not in {None, False}]:
425+
case []:
426+
return adata.X
427+
# can’t use {"key": v} as match expression, since they allow additional entries
428+
case [("layer", layer)]:
429+
return adata.layers[layer]
430+
case [("use_raw", True)]:
431+
return adata.raw.X
432+
case [("obsm", obsm)]:
433+
return adata.obsm[obsm]
434+
case [("obsp", obsp)]:
435+
return adata.obsp[obsp]
436+
case _:
437+
valid = [f"`{k}`" for k in choices]
438+
valid[-1] = f"or {valid[-1]}"
439+
msg = f"Only one of {', '.join(valid)} can be specified."
440+
raise ValueError(msg)
439441

440442

441443
def _set_obs_rep(

0 commit comments

Comments
 (0)