diff --git a/docs/release-notes/3864.fix.md b/docs/release-notes/3864.fix.md new file mode 100644 index 0000000000..d6a7bfabcf --- /dev/null +++ b/docs/release-notes/3864.fix.md @@ -0,0 +1 @@ +Fix {func}`~scanpy.pl.dpt_groups_pseudotime` {smaller}`P Angerer` diff --git a/src/scanpy/plotting/_tools/__init__.py b/src/scanpy/plotting/_tools/__init__.py index 12efc3c2ae..df09481c8d 100644 --- a/src/scanpy/plotting/_tools/__init__.py +++ b/src/scanpy/plotting/_tools/__init__.py @@ -279,7 +279,8 @@ def dpt_groups_pseudotime( show: bool | None = None, save: bool | str | None = None, marker: str | Sequence[str] = ".", -): + return_fig: bool = False, +) -> Figure | None: """Plot groups and pseudotime. Parameters @@ -292,9 +293,9 @@ def dpt_groups_pseudotime( Marker style. See :mod:`~matplotlib.markers` for details. """ - _, (ax_grp, ax_ord) = plt.subplots(2, 1) + fig, (ax_grp, ax_ord) = plt.subplots(2, 1) timeseries_subplot( - adata.obs["dpt_groups"].cat.codes, + adata.obs["dpt_groups"].cat.codes.to_numpy(), time=adata.obs["dpt_order"].values, color=np.asarray(adata.obs["dpt_groups"]), highlights_x=adata.uns["dpt_changepoints"], @@ -321,6 +322,8 @@ def dpt_groups_pseudotime( marker=marker, ) savefig_or_show("dpt_groups_pseudotime", save=save, show=show) + if return_fig: + return fig @old_positionals( diff --git a/src/scanpy/plotting/_utils.py b/src/scanpy/plotting/_utils.py index f09568c2a0..23b7c1835a 100644 --- a/src/scanpy/plotting/_utils.py +++ b/src/scanpy/plotting/_utils.py @@ -2,6 +2,7 @@ import warnings from collections.abc import Callable, Mapping, Sequence +from itertools import cycle, islice from typing import TYPE_CHECKING, Literal, TypedDict, overload import numpy as np @@ -153,41 +154,42 @@ def timeseries_subplot( # noqa: PLR0912, PLR0913 X with n columns, color is of length n. """ - if color is not None: - use_color_map = isinstance(color[0], float | np.floating) + use_color_map = color is not None and isinstance(color[0], float | np.floating) palette = default_palette(palette) x_range = np.arange(X.shape[0]) if time is None else time if X.ndim == 1: X = X[:, None] # noqa: N806 if X.shape[1] > 1: - colors = palette[: X.shape[1]].by_key()["color"] + colors = islice(cycle(palette.by_key()["color"]), X.shape[1]) subsets = [(x_range, X[:, i]) for i in range(X.shape[1])] elif use_color_map: colors = [color] subsets = [(x_range, X[:, 0])] else: levels, _ = np.unique(color, return_inverse=True) - colors = np.array(palette[: len(levels)].by_key()["color"]) + colors = islice(cycle(palette.by_key()["color"]), len(levels)) subsets = [(x_range[color == level], X[color == level, :]) for level in levels] if isinstance(marker, str): marker = [marker] if len(marker) != len(subsets) and len(marker) == 1: - marker = [marker[0] for _ in range(len(subsets))] + marker = [marker[0]] * len(subsets) + if not (has_var_names := (len(var_names) > 0)): + var_names = [""] * len(subsets) if ax is None: ax = plt.subplot() - for i, (x, y) in enumerate(subsets): + for (x, y), m, c, var_name in zip(subsets, marker, colors, var_names, strict=True): ax.scatter( x, y, - marker=marker[i], + marker=m, edgecolor="face", s=rcParams["lines.markersize"], - c=colors[i], - label=var_names[i] if len(var_names) > 0 else "", - cmap=color_map, + c=c, + label=var_name, rasterized=settings._vector_friendly, + **(dict(cmap=color_map) if use_color_map else {}), ) ylim = ax.get_ylim() for h in highlights_x: @@ -199,7 +201,7 @@ def timeseries_subplot( # noqa: PLR0912, PLR0913 ax.set_ylabel(ylabel) if yticks is not None: ax.set_yticks(yticks) - if len(var_names) > 0 and legend: + if has_var_names and legend: ax.legend(frameon=False) diff --git a/src/scanpy/tools/_dpt.py b/src/scanpy/tools/_dpt.py index 9d3e26acf4..1a654a2cd6 100644 --- a/src/scanpy/tools/_dpt.py +++ b/src/scanpy/tools/_dpt.py @@ -450,9 +450,8 @@ def select_segment(self, segs, segs_tips, segs_undecided) -> tuple[int, int]: # # if we did not normalize, there would be a danger of simply # assigning the highest score to the longest segment score = dseg[tips3[2]] / d_seg[tips3[0], tips3[1]] - score = ( - len(seg) if self.choose_largest_segment else score - ) # simply the number of points + # simply the number of points + score = len(seg) if self.choose_largest_segment else score logg.debug( f" group {iseg} score {score} n_points {len(seg)}" f"{' (too small)' if len(seg) < self.min_group_size else ''}" diff --git a/tests/_images/dpt_groups_pseudotime/expected.png b/tests/_images/dpt_groups_pseudotime/expected.png new file mode 100644 index 0000000000..c68a64c73f Binary files /dev/null and b/tests/_images/dpt_groups_pseudotime/expected.png differ diff --git a/tests/_images/dpt_timeseries/expected.png b/tests/_images/dpt_timeseries/expected.png new file mode 100644 index 0000000000..80a3852f9c Binary files /dev/null and b/tests/_images/dpt_timeseries/expected.png differ diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 0095812362..8acc2fa743 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -1339,12 +1339,39 @@ def test_scatter_embedding_add_outline_vmin_vmax_norm_ref(tmp_path, check_same_i ) -def test_timeseries(): +@pytest.fixture(scope="session") +def pbmc_68k_dpt_session() -> AnnData: adata = pbmc68k_reduced() sc.pp.neighbors(adata, n_neighbors=5, method="gauss", knn=False) - sc.tl.diffmap(adata) - sc.tl.dpt(adata, n_branchings=1, n_dcs=10) - sc.pl.dpt_timeseries(adata, as_heatmap=True, show=False) + sc.tl.leiden(adata, resolution=0.5, key_added="leiden_0_5", flavor="leidenalg") + adata.uns["iroot"] = np.flatnonzero(adata.obs["leiden_0_5"] == "0")[0] + sc.tl.diffmap(adata, n_comps=10) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", ".*invalid value encountered in scalar divide" + ) + sc.tl.dpt(adata, n_branchings=3) + return adata + + +@needs.leidenalg +@needs.igraph +@pytest.mark.parametrize( + "func", + [sc.pl.dpt_groups_pseudotime, sc.pl.dpt_timeseries], +) +def test_dpt_plots( + image_comparer, pbmc_68k_dpt_session: AnnData, func: Callable +) -> None: + save_and_compare_images = partial(image_comparer, ROOT, tol=15) + + adata = pbmc_68k_dpt_session.copy() + func( + adata, + show=False, + **(dict(as_heatmap=True) if func is sc.pl.dpt_timeseries else {}), + ) + save_and_compare_images(func.__name__) def test_scatter_raw(tmp_path):