Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 75 additions & 54 deletions src/scanpy/plotting/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import math
from collections import OrderedDict
from collections.abc import Collection, Mapping, Sequence
from itertools import pairwise, product
Expand Down Expand Up @@ -51,7 +52,6 @@
from matplotlib.axes import Axes
from matplotlib.colors import Colormap, ListedColormap, Normalize
from numpy.typing import NDArray
from seaborn import FacetGrid
from seaborn.matrix import ClusterGrid

from .._utils import Empty
Expand Down Expand Up @@ -753,13 +753,14 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915
xlabel: str = "",
ylabel: str | Sequence[str] | None = None,
rotation: float | None = None,
ncols: int | None = None,
show: bool | None = None,
ax: Axes | None = None,
# deprecated
save: bool | str | None = None,
scale: DensityNorm | Empty = _empty,
**kwds,
) -> Axes | FacetGrid | None:
) -> Axes | Sequence[Axes] | None:
"""Violin plot.

Wraps :func:`seaborn.violinplot` for :class:`~anndata.AnnData`.
Expand Down Expand Up @@ -802,17 +803,21 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915
Label of the x axis. Defaults to `groupby` if `rotation` is `None`,
otherwise, no label is shown.
ylabel
Label of the y axis. If `None` and `groupby` is `None`, defaults
to `'value'`. If `None` and `groubpy` is not `None`, defaults to `keys`.
Label of the y axis.
rotation
Rotation of xtick labels.
ncols
Number of columns for arranging multiple plots.
If `None`, all panels are placed in a single row.
{show_save_ax}
**kwds
Are passed to :func:`~seaborn.violinplot`.

Returns
-------
A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`.
Axes or list of Axes
If `show=False`, returns the `matplotlib` Axes object(s) used for
plotting. If `show=True`, returns `None`.

Examples
--------
Expand Down Expand Up @@ -873,13 +878,19 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915
del scale

if isinstance(ylabel, str | NoneType):
ylabel = [ylabel] * (1 if groupby is None else len(keys))
ylabel = "" if ylabel is None else ylabel
if groupby is None and multi_panel:
ylabel = [ylabel] * len(keys)
else:
ylabel = [ylabel] * (1 if groupby is None else len(keys))

if groupby is None:
if len(ylabel) != 1:
msg = f"Expected number of y-labels to be `1`, found `{len(ylabel)}`."
expected = len(keys) if multi_panel else 1
if len(ylabel) != expected:
msg = f"Expected {expected} y-labels, got {len(ylabel)}."
raise ValueError(msg)
elif len(ylabel) != len(keys):
msg = f"Expected number of y-labels to be `{len(keys)}`, found `{len(ylabel)}`."
msg = f"Expected {len(keys)} y-labels, got {len(ylabel)}."
raise ValueError(msg)

if groupby is not None:
Expand Down Expand Up @@ -911,56 +922,62 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915
x = groupby
ys = keys

if multi_panel and groupby is None and len(ys) == 1:
# This is a quick and dirty way for adapting scales across several
# keys if groupby is None.
y = ys[0]

g: sns.axisgrid.FacetGrid = sns.catplot(
y=y,
data=obs_tidy,
kind="violin",
density_norm=density_norm,
col=x,
col_order=keys,
sharey=False,
cut=0,
inner=None,
**kwds,
)
# set default violin parameters
kwds.setdefault("cut", 0)
kwds.setdefault("inner")

if stripplot:
grouped_df = obs_tidy.groupby(x, observed=True)
for ax_id, key in zip(range(g.axes.shape[1]), keys, strict=True):
sns.stripplot(
y=y,
data=grouped_df.get_group(key),
jitter=jitter,
size=size,
color="black",
ax=g.axes[0, ax_id],
)
if log:
g.set(yscale="log")
g.set_titles(col_template="{col_name}").set_xlabels("")
if rotation is not None:
for ax_base in g.axes[0]:
ax_base.tick_params(axis="x", labelrotation=rotation)
else:
# set by default the violin plot cut=0 to limit the extend
# of the violin plot (see stacked_violin code) for more info.
kwds.setdefault("cut", 0)
kwds.setdefault("inner")
if ax is None:
panels = keys if multi_panel else ["x"] if groupby is None else keys

if ax is None:
if ncols is not None and len(panels) > 1:
n_panels = len(panels)
n_rows = math.ceil(n_panels / ncols)
_fig, axs = plt.subplots(n_rows, ncols)
axs = axs.flatten()[:n_panels]
else:
axs, _, _, _ = setup_axes(
ax,
panels=["x"] if groupby is None else keys,
panels=panels,
show_ticks=True,
right_margin=0.3,
)
else:
axs = [ax]
else:
axs = [ax]

if len(axs) > 1:
axs[0].figure.subplots_adjust(hspace=0.5, wspace=0.4)

if groupby is None and multi_panel:
for ax_base, key, ylab in zip(axs, keys, ylabel, strict=True):
sns.violinplot(
y=key,
data=obs_df,
orient="vertical",
density_norm=density_norm,
ax=ax_base,
**kwds,
)

if stripplot:
sns.stripplot(
y=key,
data=obs_df,
jitter=jitter,
color="black",
size=size,
ax=ax_base,
)

ax_base.set_xlabel("")
ax_base.set_title(str(key).replace("_", " "))
if ylab is not None:
ax_base.set_ylabel(ylab)
if log:
ax_base.set_yscale("log")
if rotation is not None:
ax_base.tick_params(axis="x", labelrotation=rotation)

else:
for ax_base, y, ylab in zip(axs, ys, ylabel, strict=True):
sns.violinplot(
x=x,
Expand All @@ -972,6 +989,7 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915
ax=ax_base,
**kwds,
)

if stripplot:
sns.stripplot(
x=x,
Expand All @@ -983,6 +1001,11 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915
size=size,
ax=ax_base,
)

if multi_panel or groupby is not None:
ax_base.set_title(str(y).replace("_", " "))
else:
ax_base.set_title("")
if xlabel == "" and groupby is not None and rotation is None:
xlabel = groupby.replace("_", " ")
ax_base.set_xlabel(xlabel)
Expand All @@ -996,8 +1019,6 @@ def violin( # noqa: PLR0912, PLR0913, PLR0915
_utils.savefig_or_show("violin", show=show, save=save)
if show:
return None
if multi_panel and groupby is None and len(ys) == 1:
return g
if len(axs) == 1:
return axs[0]
return axs
Expand Down
Loading