Skip to content

Commit 3d1623f

Browse files
committed
Backport PR #3864: fix: various DPT plot fixes
1 parent 3a9a0ec commit 3d1623f

File tree

7 files changed

+54
-21
lines changed

7 files changed

+54
-21
lines changed

docs/release-notes/3864.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix {func}`~scanpy.pl.dpt_groups_pseudotime` {smaller}`P Angerer`

src/scanpy/plotting/_tools/__init__.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,8 @@ def dpt_groups_pseudotime(
279279
show: bool | None = None,
280280
save: bool | str | None = None,
281281
marker: str | Sequence[str] = ".",
282-
):
282+
return_fig: bool = False,
283+
) -> Figure | None:
283284
"""Plot groups and pseudotime.
284285
285286
Parameters
@@ -292,9 +293,9 @@ def dpt_groups_pseudotime(
292293
Marker style. See :mod:`~matplotlib.markers` for details.
293294
294295
"""
295-
_, (ax_grp, ax_ord) = plt.subplots(2, 1)
296+
fig, (ax_grp, ax_ord) = plt.subplots(2, 1)
296297
timeseries_subplot(
297-
adata.obs["dpt_groups"].cat.codes,
298+
adata.obs["dpt_groups"].cat.codes.to_numpy(),
298299
time=adata.obs["dpt_order"].values,
299300
color=np.asarray(adata.obs["dpt_groups"]),
300301
highlights_x=adata.uns["dpt_changepoints"],
@@ -321,6 +322,8 @@ def dpt_groups_pseudotime(
321322
marker=marker,
322323
)
323324
savefig_or_show("dpt_groups_pseudotime", save=save, show=show)
325+
if return_fig:
326+
return fig
324327

325328

326329
@old_positionals(

src/scanpy/plotting/_utils.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import warnings
44
from collections.abc import Callable, Mapping, Sequence
5+
from itertools import cycle, islice
56
from typing import TYPE_CHECKING, Literal, TypedDict, overload
67

78
import numpy as np
@@ -153,41 +154,42 @@ def timeseries_subplot( # noqa: PLR0912, PLR0913
153154
X with n columns, color is of length n.
154155
155156
"""
156-
if color is not None:
157-
use_color_map = isinstance(color[0], float | np.floating)
157+
use_color_map = color is not None and isinstance(color[0], float | np.floating)
158158
palette = default_palette(palette)
159159
x_range = np.arange(X.shape[0]) if time is None else time
160160
if X.ndim == 1:
161161
X = X[:, None] # noqa: N806
162162
if X.shape[1] > 1:
163-
colors = palette[: X.shape[1]].by_key()["color"]
163+
colors = islice(cycle(palette.by_key()["color"]), X.shape[1])
164164
subsets = [(x_range, X[:, i]) for i in range(X.shape[1])]
165165
elif use_color_map:
166166
colors = [color]
167167
subsets = [(x_range, X[:, 0])]
168168
else:
169169
levels, _ = np.unique(color, return_inverse=True)
170-
colors = np.array(palette[: len(levels)].by_key()["color"])
170+
colors = islice(cycle(palette.by_key()["color"]), len(levels))
171171
subsets = [(x_range[color == level], X[color == level, :]) for level in levels]
172172

173173
if isinstance(marker, str):
174174
marker = [marker]
175175
if len(marker) != len(subsets) and len(marker) == 1:
176-
marker = [marker[0] for _ in range(len(subsets))]
176+
marker = [marker[0]] * len(subsets)
177+
if not (has_var_names := (len(var_names) > 0)):
178+
var_names = [""] * len(subsets)
177179

178180
if ax is None:
179181
ax = plt.subplot()
180-
for i, (x, y) in enumerate(subsets):
182+
for (x, y), m, c, var_name in zip(subsets, marker, colors, var_names, strict=True):
181183
ax.scatter(
182184
x,
183185
y,
184-
marker=marker[i],
186+
marker=m,
185187
edgecolor="face",
186188
s=rcParams["lines.markersize"],
187-
c=colors[i],
188-
label=var_names[i] if len(var_names) > 0 else "",
189-
cmap=color_map,
189+
c=c,
190+
label=var_name,
190191
rasterized=settings._vector_friendly,
192+
**(dict(cmap=color_map) if use_color_map else {}),
191193
)
192194
ylim = ax.get_ylim()
193195
for h in highlights_x:
@@ -199,7 +201,7 @@ def timeseries_subplot( # noqa: PLR0912, PLR0913
199201
ax.set_ylabel(ylabel)
200202
if yticks is not None:
201203
ax.set_yticks(yticks)
202-
if len(var_names) > 0 and legend:
204+
if has_var_names and legend:
203205
ax.legend(frameon=False)
204206

205207

src/scanpy/tools/_dpt.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,9 +450,8 @@ def select_segment(self, segs, segs_tips, segs_undecided) -> tuple[int, int]: #
450450
# if we did not normalize, there would be a danger of simply
451451
# assigning the highest score to the longest segment
452452
score = dseg[tips3[2]] / d_seg[tips3[0], tips3[1]]
453-
score = (
454-
len(seg) if self.choose_largest_segment else score
455-
) # simply the number of points
453+
# simply the number of points
454+
score = len(seg) if self.choose_largest_segment else score
456455
logg.debug(
457456
f" group {iseg} score {score} n_points {len(seg)}"
458457
f"{' (too small)' if len(seg) < self.min_group_size else ''}"
10.1 KB
Loading
68.6 KB
Loading

tests/test_plotting.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import warnings
34
from functools import partial
45
from itertools import chain, combinations, repeat
56
from pathlib import Path
@@ -1320,12 +1321,39 @@ def test_scatter_embedding_add_outline_vmin_vmax_norm_ref(tmp_path, check_same_i
13201321
)
13211322

13221323

1323-
def test_timeseries():
1324+
@pytest.fixture(scope="session")
1325+
def pbmc_68k_dpt_session() -> AnnData:
13241326
adata = pbmc68k_reduced()
13251327
sc.pp.neighbors(adata, n_neighbors=5, method="gauss", knn=False)
1326-
sc.tl.diffmap(adata)
1327-
sc.tl.dpt(adata, n_branchings=1, n_dcs=10)
1328-
sc.pl.dpt_timeseries(adata, as_heatmap=True)
1328+
sc.tl.leiden(adata, resolution=0.5, key_added="leiden_0_5", flavor="leidenalg")
1329+
adata.uns["iroot"] = np.flatnonzero(adata.obs["leiden_0_5"] == "0")[0]
1330+
sc.tl.diffmap(adata, n_comps=10)
1331+
with warnings.catch_warnings():
1332+
warnings.filterwarnings(
1333+
"ignore", ".*invalid value encountered in scalar divide"
1334+
)
1335+
sc.tl.dpt(adata, n_branchings=3)
1336+
return adata
1337+
1338+
1339+
@needs.leidenalg
1340+
@needs.igraph
1341+
@pytest.mark.parametrize(
1342+
"func",
1343+
[sc.pl.dpt_groups_pseudotime, sc.pl.dpt_timeseries],
1344+
)
1345+
def test_dpt_plots(
1346+
image_comparer, pbmc_68k_dpt_session: AnnData, func: Callable
1347+
) -> None:
1348+
save_and_compare_images = partial(image_comparer, ROOT, tol=15)
1349+
1350+
adata = pbmc_68k_dpt_session.copy()
1351+
func(
1352+
adata,
1353+
show=False,
1354+
**(dict(as_heatmap=True) if func is sc.pl.dpt_timeseries else {}),
1355+
)
1356+
save_and_compare_images(func.__name__)
13291357

13301358

13311359
def test_scatter_raw(tmp_path):

0 commit comments

Comments
 (0)