diff --git a/src/scanpy/plotting/_baseplot_class.py b/src/scanpy/plotting/_baseplot_class.py index b643510509..68ceecbee0 100644 --- a/src/scanpy/plotting/_baseplot_class.py +++ b/src/scanpy/plotting/_baseplot_class.py @@ -2,17 +2,21 @@ from __future__ import annotations +import warnings from collections.abc import Mapping -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, NamedTuple, overload from warnings import warn import numpy as np +import pandas as pd +from anndata import AnnData from matplotlib import colormaps, gridspec from matplotlib import pyplot as plt from .. import logging as logg from .._compat import old_positionals from .._utils import _empty +from ..get._aggregated import aggregate from ._anndata import ( VarGroups, _plot_dendrogram, @@ -23,15 +27,14 @@ from ._utils import check_colornorm, make_grid_spec if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Iterable, Sequence from typing import Literal, Self - import pandas as pd - from anndata import AnnData from matplotlib.axes import Axes from matplotlib.colors import Colormap, Normalize from .._utils import Empty + from ..get._aggregated import AggType from ._utils import ColorLike, _AxesSubplot _VarNames = str | Sequence[str] @@ -145,7 +148,8 @@ def __init__( # noqa: PLR0913 self.var_group_rotation = var_group_rotation self.width, self.height = figsize if figsize is not None else (None, None) - self.categories, self.obs_tidy = _prepare_dataframe( + # still need this as pandas handles this procedure more optimally + self.categories, obs_tidy = _prepare_dataframe( adata, self.var_names, groupby, @@ -155,6 +159,16 @@ def __init__( # noqa: PLR0913 layer=layer, gene_symbols=gene_symbols, ) + # we are going to save a view of adata as we still need it for filtering in dotplot by expression_cutoff and mean_only_expressed + # also AnnData is a little lighter than DataFrame + # and we can replace self.adata as it is used elsewhere + self._group_key = obs_tidy.index.name + self._view = AnnData( + X=obs_tidy.values, + obs={self._group_key: obs_tidy.index}, + var=pd.DataFrame(index=var_names), + ) + if len(self.categories) > self.MAX_NUM_CATEGORIES: warn( f"Over {self.MAX_NUM_CATEGORIES} categories found. " @@ -164,16 +178,16 @@ def __init__( # noqa: PLR0913 ) if categories_order is not None and ( - set(self.obs_tidy.index.categories) != set(categories_order) + set(self.categories) != set(categories_order) ): logg.error( "Please check that the categories given by " "the `order` parameter match the categories that " "want to be reordered.\n\n" "Mismatch: " - f"{set(self.obs_tidy.index.categories).difference(categories_order)}\n\n" + f"{set(self.categories).difference(categories_order)}\n\n" f"Given order categories: {categories_order}\n\n" - f"{groupby} categories: {list(self.obs_tidy.index.categories)}\n" + f"{groupby} categories: {list(self.categories)}\n" ) return @@ -397,10 +411,12 @@ def add_totals( _sort = sort is not None _ascending = sort == "ascending" - counts_df = self.obs_tidy.index.value_counts(sort=_sort, ascending=_ascending) + counts_df = self._view.obs[self._group_key].value_counts( + sort=_sort, ascending=_ascending + ) if _sort: - self.categories_order = counts_df.index + self.categories_order = list(counts_df.index) self.plot_group_extra = { "kind": "group_totals", @@ -411,6 +427,67 @@ def add_totals( } return self + @overload + def _agg_df( + self, func: AggType, mask: np.ndarray | None = None + ) -> pd.DataFrame: ... + + @overload + def _agg_df( + self, func: Iterable[AggType], mask: np.ndarray | None = None + ) -> dict[str, pd.DataFrame]: ... + + def _agg_df( + self, func: AggType | Iterable[AggType], mask: np.ndarray | None = None + ) -> pd.DataFrame | dict[str, pd.DataFrame]: + """Aggregate `self._view` by `self._group_key`. + + Run `func` on X and eturn a DataFrame (or dict of DataFrames) with `index=self.categories`, `columns=self.var_names`. + If `mask` is provided, it should be shape `(n_groups, n_vars)` and will + overwrite view.X before aggregating (useful for dot-cutoff logic). + """ + # make a fresh copy so we never mutate the master view + view = self._view.copy() + if mask is not None: + view.X = mask.astype(view.X.dtype) + + ag = aggregate( + view, + by=self._group_key, + func=func, + axis="obs", + ) + # if single func, return one DataFrame + if isinstance(func, str): + arr = ag.layers[func] + return pd.DataFrame(arr, index=self.categories, columns=self.var_names) + # if multiple, return a dict of DataFrames + out = {} + for f in func: + arr = ag.layers[f] + out[f] = pd.DataFrame(arr, index=self.categories, columns=self.var_names) + return out + + def _scale_df( + self, df: pd.DataFrame, standard_scale: Literal["var", "group", None] = None + ) -> pd.DataFrame: + """Scale `df` based on `standard_scale` parameter.""" + if standard_scale == "obs": + standard_scale = "group" + msg = "`standard_scale='obs'` is deprecated, use `standard_scale='group'` instead" + warnings.warn(msg, FutureWarning, stacklevel=2) + if standard_scale == "group": + df = df.sub(df.min(1), axis=0) + df = df.div(df.max(1), axis=0).fillna(0) + elif standard_scale == "var": + df -= df.min(0) + df = (df / df.max(0)).fillna(0) + elif standard_scale is None: + pass + else: + logg.warning("Unknown type for standard_scale, ignored") + return df + @old_positionals("cmap") def style(self, *, cmap: Colormap | str | None | Empty = _empty) -> Self: r"""Set visual style parameters. diff --git a/src/scanpy/plotting/_dotplot.py b/src/scanpy/plotting/_dotplot.py index 8c15574a1d..4849da0ce8 100644 --- a/src/scanpy/plotting/_dotplot.py +++ b/src/scanpy/plotting/_dotplot.py @@ -11,13 +11,7 @@ from .._utils import _doc_params, _empty from ._baseplot_class import BasePlot, doc_common_groupby_plot_args from ._docs import doc_common_plot_args, doc_show_save_ax, doc_vboundnorm -from ._utils import ( - _dk, - check_colornorm, - fix_kwds, - make_grid_spec, - savefig_or_show, -) +from ._utils import _dk, check_colornorm, fix_kwds, make_grid_spec, savefig_or_show if TYPE_CHECKING: from collections.abc import Mapping, Sequence @@ -187,46 +181,31 @@ def __init__( # noqa: PLR0913 norm=norm, **kwds, ) - - # for if category defined by groupby (if any) compute for each var_name - # 1. the fraction of cells in the category having a value >expression_cutoff - # 2. the mean value over the category - - # 1. compute fraction of cells having value > expression_cutoff - # transform obs_tidy into boolean matrix using the expression_cutoff - obs_bool = self.obs_tidy > expression_cutoff - - # compute the sum per group which in the boolean matrix this is the number - # of values >expression_cutoff, and divide the result by the total number of - # values in the group (given by `count()`) if dot_size_df is None: - dot_size_df = ( - obs_bool.groupby(level=0, observed=True).sum() - / obs_bool.groupby(level=0, observed=True).count() - ) + if expression_cutoff > 0: + mask = (expression_cutoff < self._view.X).astype(self._view.X.dtype) + dot_size_df = self._agg_df("mean", mask=mask) + else: + df_all = self._agg_df("count_nonzero") + # count_nonzero → raw counts, divide by group sizes + group_sizes = ( + self._view.obs[self._group_key] + .value_counts() + .loc[self.categories] + .values + ) + dot_size_df = df_all.div(group_sizes, axis=0) if dot_color_df is None: - # 2. compute mean expression value value - if mean_only_expressed: - dot_color_df = ( - self.obs_tidy.mask(~obs_bool) - .groupby(level=0, observed=True) - .mean() - .fillna(0) - ) - else: - dot_color_df = self.obs_tidy.groupby(level=0, observed=True).mean() - - if standard_scale == "group": - dot_color_df = dot_color_df.sub(dot_color_df.min(1), axis=0) - dot_color_df = dot_color_df.div(dot_color_df.max(1), axis=0).fillna(0) - elif standard_scale == "var": - dot_color_df -= dot_color_df.min(0) - dot_color_df = (dot_color_df / dot_color_df.max(0)).fillna(0) - elif standard_scale is None: - pass + if mean_only_expressed and expression_cutoff > 0: + mask = expression_cutoff < self._view.X + df_sum = self._agg_df("sum", mask=mask) + expr_counts = dot_size_df.values * group_sizes[:, None] + dot_color_df = df_sum.div(expr_counts).fillna(0) else: - logg.warning("Unknown type for standard_scale, ignored") + dot_color_df = self._agg_df("mean") + + dot_color_df = self._scale_df(dot_color_df, standard_scale) else: # check that both matrices have the same shape if dot_color_df.shape != dot_size_df.shape: @@ -255,12 +234,14 @@ def __init__( # noqa: PLR0913 # using the order from the doc_size_df dot_color_df = dot_color_df.loc[dot_size_df.index][dot_size_df.columns] - self.dot_color_df, self.dot_size_df = ( - df.loc[ - categories_order if categories_order is not None else self.categories - ] - for df in (dot_color_df, dot_size_df) - ) + # reorder rows + self.dot_size_df = dot_size_df.loc[ + categories_order if categories_order is not None else self.categories + ] + self.dot_color_df = dot_color_df.loc[ + categories_order if categories_order is not None else self.categories + ] + self.standard_scale = standard_scale # Set default style parameters diff --git a/src/scanpy/plotting/_matrixplot.py b/src/scanpy/plotting/_matrixplot.py index e767e224cc..b6f30c3089 100644 --- a/src/scanpy/plotting/_matrixplot.py +++ b/src/scanpy/plotting/_matrixplot.py @@ -5,16 +5,11 @@ import numpy as np from matplotlib import colormaps, rcParams -from .. import logging as logg from .._compat import old_positionals from .._settings import settings from .._utils import _doc_params, _empty from ._baseplot_class import BasePlot, doc_common_groupby_plot_args -from ._docs import ( - doc_common_plot_args, - doc_show_save_ax, - doc_vboundnorm, -) +from ._docs import doc_common_plot_args, doc_show_save_ax, doc_vboundnorm from ._utils import _dk, check_colornorm, fix_kwds, savefig_or_show if TYPE_CHECKING: @@ -167,29 +162,13 @@ def __init__( # noqa: PLR0913 ) if values_df is None: - # compute mean value - values_df = ( - self.obs_tidy.groupby(level=0, observed=True) - .mean() - .loc[ - self.categories_order - if self.categories_order is not None - else self.categories - ] - ) + values_df = self._agg_df("mean") + + values_df = self._scale_df(values_df, standard_scale) - if standard_scale == "group": - values_df = values_df.sub(values_df.min(1), axis=0) - values_df = values_df.div(values_df.max(1), axis=0).fillna(0) - elif standard_scale == "var": - values_df -= values_df.min(0) - values_df = (values_df / values_df.max(0)).fillna(0) - elif standard_scale is None: - pass - else: - logg.warning("Unknown type for standard_scale, ignored") - - self.values_df = values_df + self.values_df = values_df.loc[ + categories_order if categories_order is not None else self.categories + ] self.cmap = self.DEFAULT_COLORMAP self.edge_color = self.DEFAULT_EDGE_COLOR diff --git a/src/scanpy/plotting/_stacked_violin.py b/src/scanpy/plotting/_stacked_violin.py index 404605b5a2..bb41b3a2bb 100644 --- a/src/scanpy/plotting/_stacked_violin.py +++ b/src/scanpy/plotting/_stacked_violin.py @@ -9,7 +9,6 @@ from matplotlib.colors import is_color_like from packaging.version import Version -from .. import logging as logg from .._compat import old_positionals from .._settings import settings from .._utils import _doc_params, _empty @@ -225,22 +224,11 @@ def __init__( # noqa: PLR0913 norm=norm, **kwds, ) - - if standard_scale == "obs": - standard_scale = "group" - msg = "`standard_scale='obs'` is deprecated, use `standard_scale='group'` instead" - warnings.warn(msg, FutureWarning, stacklevel=2) - if standard_scale == "group": - self.obs_tidy = self.obs_tidy.sub(self.obs_tidy.min(1), axis=0) - self.obs_tidy = self.obs_tidy.div(self.obs_tidy.max(1), axis=0).fillna(0) - elif standard_scale == "var": - self.obs_tidy -= self.obs_tidy.min(0) - self.obs_tidy = (self.obs_tidy / self.obs_tidy.max(0)).fillna(0) - elif standard_scale is None: - pass - else: - logg.warning("Unknown type for standard_scale, ignored") - + # scale before aggregation + X = self._view.X.astype(float) + X = self._scale_df(X, standard_scale) + # replace view.X with the scaled values (NaNs => 0) + self._view.X = np.nan_to_num(X) # Set default style parameters self.cmap = self.DEFAULT_COLORMAP self.row_palette = self.DEFAULT_ROW_PALETTE @@ -386,22 +374,21 @@ def _mainplot(self, ax: Axes): # work on a copy of the dataframes. This is to avoid changes # on the original data frames after repetitive calls to the # StackedViolin object, for example once with swap_axes and other without - _matrix = self.obs_tidy.copy() + _matrix = pd.DataFrame( + self._view.X, index=self._view.obs[self._group_key], columns=self.var_names + ) if self.var_names_idx_order is not None: _matrix = _matrix.iloc[:, self.var_names_idx_order] # get mean values for color and transform to color values # using colormap - _color_df = ( - _matrix.groupby(level=0, observed=True) - .median() - .loc[ - self.categories_order - if self.categories_order is not None - else self.categories - ] - ) + _color_df = self._agg_df("median").loc[ + self.categories_order + if self.categories_order is not None + else self.categories + ] + if self.are_axes_swapped: _color_df = _color_df.T