Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/3864.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix {func}`~scanpy.pl.dpt_groups_pseudotime` {smaller}`P Angerer`
9 changes: 6 additions & 3 deletions src/scanpy/plotting/_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,8 @@ def dpt_groups_pseudotime(
show: bool | None = None,
save: bool | str | None = None,
marker: str | Sequence[str] = ".",
):
return_fig: bool = False,
) -> Figure | None:
"""Plot groups and pseudotime.

Parameters
Expand All @@ -292,9 +293,9 @@ def dpt_groups_pseudotime(
Marker style. See :mod:`~matplotlib.markers` for details.

"""
_, (ax_grp, ax_ord) = plt.subplots(2, 1)
fig, (ax_grp, ax_ord) = plt.subplots(2, 1)
timeseries_subplot(
adata.obs["dpt_groups"].cat.codes,
adata.obs["dpt_groups"].cat.codes.to_numpy(),
time=adata.obs["dpt_order"].values,
color=np.asarray(adata.obs["dpt_groups"]),
highlights_x=adata.uns["dpt_changepoints"],
Expand All @@ -321,6 +322,8 @@ def dpt_groups_pseudotime(
marker=marker,
)
savefig_or_show("dpt_groups_pseudotime", save=save, show=show)
if return_fig:
return fig


@old_positionals(
Expand Down
24 changes: 13 additions & 11 deletions src/scanpy/plotting/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import warnings
from collections.abc import Callable, Mapping, Sequence
from itertools import cycle, islice
from typing import TYPE_CHECKING, Literal, TypedDict, overload

import numpy as np
Expand Down Expand Up @@ -153,41 +154,42 @@ def timeseries_subplot( # noqa: PLR0912, PLR0913
X with n columns, color is of length n.

"""
if color is not None:
use_color_map = isinstance(color[0], float | np.floating)
use_color_map = color is not None and isinstance(color[0], float | np.floating)
palette = default_palette(palette)
x_range = np.arange(X.shape[0]) if time is None else time
if X.ndim == 1:
X = X[:, None] # noqa: N806
if X.shape[1] > 1:
colors = palette[: X.shape[1]].by_key()["color"]
colors = islice(cycle(palette.by_key()["color"]), X.shape[1])
subsets = [(x_range, X[:, i]) for i in range(X.shape[1])]
elif use_color_map:
colors = [color]
subsets = [(x_range, X[:, 0])]
else:
levels, _ = np.unique(color, return_inverse=True)
colors = np.array(palette[: len(levels)].by_key()["color"])
colors = islice(cycle(palette.by_key()["color"]), len(levels))
subsets = [(x_range[color == level], X[color == level, :]) for level in levels]

if isinstance(marker, str):
marker = [marker]
if len(marker) != len(subsets) and len(marker) == 1:
marker = [marker[0] for _ in range(len(subsets))]
marker = [marker[0]] * len(subsets)
if not (has_var_names := (len(var_names) > 0)):
var_names = [""] * len(subsets)

if ax is None:
ax = plt.subplot()
for i, (x, y) in enumerate(subsets):
for (x, y), m, c, var_name in zip(subsets, marker, colors, var_names, strict=True):
ax.scatter(
x,
y,
marker=marker[i],
marker=m,
edgecolor="face",
s=rcParams["lines.markersize"],
c=colors[i],
label=var_names[i] if len(var_names) > 0 else "",
cmap=color_map,
c=c,
label=var_name,
rasterized=settings._vector_friendly,
**(dict(cmap=color_map) if use_color_map else {}),
)
ylim = ax.get_ylim()
for h in highlights_x:
Expand All @@ -199,7 +201,7 @@ def timeseries_subplot( # noqa: PLR0912, PLR0913
ax.set_ylabel(ylabel)
if yticks is not None:
ax.set_yticks(yticks)
if len(var_names) > 0 and legend:
if has_var_names and legend:
ax.legend(frameon=False)


Expand Down
5 changes: 2 additions & 3 deletions src/scanpy/tools/_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,8 @@ def select_segment(self, segs, segs_tips, segs_undecided) -> tuple[int, int]: #
# if we did not normalize, there would be a danger of simply
# assigning the highest score to the longest segment
score = dseg[tips3[2]] / d_seg[tips3[0], tips3[1]]
score = (
len(seg) if self.choose_largest_segment else score
) # simply the number of points
# simply the number of points
score = len(seg) if self.choose_largest_segment else score
logg.debug(
f" group {iseg} score {score} n_points {len(seg)}"
f"{' (too small)' if len(seg) < self.min_group_size else ''}"
Expand Down
Binary file added tests/_images/dpt_groups_pseudotime/expected.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tests/_images/dpt_timeseries/expected.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
35 changes: 31 additions & 4 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,12 +1339,39 @@ def test_scatter_embedding_add_outline_vmin_vmax_norm_ref(tmp_path, check_same_i
)


def test_timeseries():
@pytest.fixture(scope="session")
def pbmc_68k_dpt_session() -> AnnData:
adata = pbmc68k_reduced()
sc.pp.neighbors(adata, n_neighbors=5, method="gauss", knn=False)
sc.tl.diffmap(adata)
sc.tl.dpt(adata, n_branchings=1, n_dcs=10)
sc.pl.dpt_timeseries(adata, as_heatmap=True, show=False)
sc.tl.leiden(adata, resolution=0.5, key_added="leiden_0_5", flavor="leidenalg")
adata.uns["iroot"] = np.flatnonzero(adata.obs["leiden_0_5"] == "0")[0]
sc.tl.diffmap(adata, n_comps=10)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", ".*invalid value encountered in scalar divide"
)
sc.tl.dpt(adata, n_branchings=3)
return adata


@needs.leidenalg
@needs.igraph
@pytest.mark.parametrize(
"func",
[sc.pl.dpt_groups_pseudotime, sc.pl.dpt_timeseries],
)
def test_dpt_plots(
image_comparer, pbmc_68k_dpt_session: AnnData, func: Callable
) -> None:
save_and_compare_images = partial(image_comparer, ROOT, tol=15)

adata = pbmc_68k_dpt_session.copy()
func(
adata,
show=False,
**(dict(as_heatmap=True) if func is sc.pl.dpt_timeseries else {}),
)
save_and_compare_images(func.__name__)


def test_scatter_raw(tmp_path):
Expand Down
Loading