diff --git a/docs/release-notes/3764.feature.md b/docs/release-notes/3764.feature.md new file mode 100644 index 0000000000..445454d736 --- /dev/null +++ b/docs/release-notes/3764.feature.md @@ -0,0 +1 @@ +{func}`scanpy.pl.dotplot` now supports a `group_cmaps` parameter for custom per-group coloring. {smaller}`R Baber` diff --git a/src/scanpy/plotting/_dotplot.py b/src/scanpy/plotting/_dotplot.py index 8c15574a1d..eb3da5f260 100644 --- a/src/scanpy/plotting/_dotplot.py +++ b/src/scanpy/plotting/_dotplot.py @@ -162,6 +162,7 @@ def __init__( # noqa: PLR0913 vmax: float | None = None, vcenter: float | None = None, norm: Normalize | None = None, + group_cmaps: Mapping[str, str] | None = None, **kwds, ) -> None: BasePlot.__init__( @@ -188,13 +189,59 @@ def __init__( # noqa: PLR0913 **kwds, ) + # Set default style parameters + self.cmap = self.DEFAULT_COLORMAP + self.dot_max = self.DEFAULT_DOT_MAX + self.dot_min = self.DEFAULT_DOT_MIN + self.smallest_dot = self.DEFAULT_SMALLEST_DOT + self.largest_dot = self.DEFAULT_LARGEST_DOT + self.color_on = self.DEFAULT_COLOR_ON + self.size_exponent = self.DEFAULT_SIZE_EXPONENT + self.grid = False + self.plot_x_padding = self.DEFAULT_PLOT_X_PADDING + self.plot_y_padding = self.DEFAULT_PLOT_Y_PADDING + + self.dot_edge_color = self.DEFAULT_DOT_EDGECOLOR + self.dot_edge_lw = self.DEFAULT_DOT_EDGELW + + # set legend defaults + self.color_legend_title = self.DEFAULT_COLOR_LEGEND_TITLE + self.size_title = self.DEFAULT_SIZE_LEGEND_TITLE + self.legends_width = self.DEFAULT_LEGENDS_WIDTH + self.show_size_legend = True + self.show_colorbar = True + + # Store parameters needed by helper methods and prepare the dot data. + self.standard_scale = standard_scale + self.expression_cutoff = expression_cutoff + self.mean_only_expressed = mean_only_expressed + self.group_cmaps = group_cmaps + + self.dot_color_df, self.dot_size_df = self._prepare_dot_data( + dot_color_df, dot_size_df + ) + + # If group_cmaps is used, validate that all plotted groups have a defined colormap. + if self.group_cmaps is not None: + plotted_groups = set(self.dot_color_df.index) + defined_groups = set(self.group_cmaps.keys()) + missing_groups = plotted_groups - defined_groups + if missing_groups: + msg = ( + "The following groups are in the plot data but are missing from the `group_cmaps` dictionary. " + f"Please define a colormap for them: {sorted(missing_groups)}" + ) + raise ValueError(msg) + + def _prepare_dot_data(self, dot_color_df, dot_size_df): + """Calculate the dataframes for dot size and color.""" # for if category defined by groupby (if any) compute for each var_name # 1. the fraction of cells in the category having a value >expression_cutoff # 2. the mean value over the category # 1. compute fraction of cells having value > expression_cutoff # transform obs_tidy into boolean matrix using the expression_cutoff - obs_bool = self.obs_tidy > expression_cutoff + obs_bool = self.obs_tidy > self.expression_cutoff # compute the sum per group which in the boolean matrix this is the number # of values >expression_cutoff, and divide the result by the total number of @@ -207,7 +254,7 @@ def __init__( # noqa: PLR0913 if dot_color_df is None: # 2. compute mean expression value value - if mean_only_expressed: + if self.mean_only_expressed: dot_color_df = ( self.obs_tidy.mask(~obs_bool) .groupby(level=0, observed=True) @@ -217,24 +264,25 @@ def __init__( # noqa: PLR0913 else: dot_color_df = self.obs_tidy.groupby(level=0, observed=True).mean() - if standard_scale == "group": + if self.standard_scale == "group": dot_color_df = dot_color_df.sub(dot_color_df.min(1), axis=0) dot_color_df = dot_color_df.div(dot_color_df.max(1), axis=0).fillna(0) - elif standard_scale == "var": + elif self.standard_scale == "var": dot_color_df -= dot_color_df.min(0) dot_color_df = (dot_color_df / dot_color_df.max(0)).fillna(0) - elif standard_scale is None: + elif self.standard_scale is None: pass else: logg.warning("Unknown type for standard_scale, ignored") else: # check that both matrices have the same shape if dot_color_df.shape != dot_size_df.shape: - logg.error( - "the given dot_color_df data frame has a different shape than " + msg = ( + "The given dot_color_df data frame has a different shape than " "the data frame used for the dot size. Both data frames need " - "to have the same index and columns" + "to have the same index and columns." ) + raise ValueError(msg) # Because genes (columns) can be duplicated (e.g. when the # same gene is reported as marker gene in two clusters) @@ -255,35 +303,16 @@ def __init__( # noqa: PLR0913 # using the order from the doc_size_df dot_color_df = dot_color_df.loc[dot_size_df.index][dot_size_df.columns] - self.dot_color_df, self.dot_size_df = ( + dot_color_df, dot_size_df = ( df.loc[ - categories_order if categories_order is not None else self.categories + self.categories_order + if self.categories_order is not None + else self.categories ] for df in (dot_color_df, dot_size_df) ) - self.standard_scale = standard_scale - - # Set default style parameters - self.cmap = self.DEFAULT_COLORMAP - self.dot_max = self.DEFAULT_DOT_MAX - self.dot_min = self.DEFAULT_DOT_MIN - self.smallest_dot = self.DEFAULT_SMALLEST_DOT - self.largest_dot = self.DEFAULT_LARGEST_DOT - self.color_on = self.DEFAULT_COLOR_ON - self.size_exponent = self.DEFAULT_SIZE_EXPONENT - self.grid = False - self.plot_x_padding = self.DEFAULT_PLOT_X_PADDING - self.plot_y_padding = self.DEFAULT_PLOT_Y_PADDING - self.dot_edge_color = self.DEFAULT_DOT_EDGECOLOR - self.dot_edge_lw = self.DEFAULT_DOT_EDGELW - - # set legend defaults - self.color_legend_title = self.DEFAULT_COLOR_LEGEND_TITLE - self.size_title = self.DEFAULT_SIZE_LEGEND_TITLE - self.legends_width = self.DEFAULT_LEGENDS_WIDTH - self.show_size_legend = True - self.show_colorbar = True + return dot_color_df, dot_size_df @old_positionals( "cmap", @@ -542,12 +571,28 @@ def _plot_legend(self, legend_ax, return_ax_dict, normalize): # third row: spacer to avoid color and size legend titles to overlap # fourth row: colorbar + # Define base heights for legend components as a fraction of figure height cbar_legend_height = self.min_figure_height * 0.08 size_legend_height = self.min_figure_height * 0.27 spacer_height = self.min_figure_height * 0.3 + # If group_cmaps is used, dynamically calculate the total height needed for all colorbars + if self.group_cmaps is not None: + per_cbar_height = ( + self.min_figure_height * 0.12 + ) # Use a slightly larger height for better spacing + n_cbars = len(self.dot_color_df.index) + cbar_legend_height = per_cbar_height * n_cbars + + # Calculate the height of the top spacer to push content down + top_spacer_height = ( + self.height - size_legend_height - cbar_legend_height - spacer_height + ) + top_spacer_height = max(top_spacer_height, 0) # prevent negative height + + # Create the 4-row GridSpec for the legend area height_ratios = [ - self.height - size_legend_height - cbar_legend_height - spacer_height, + top_spacer_height, size_legend_height, spacer_height, cbar_legend_height, @@ -555,17 +600,74 @@ def _plot_legend(self, legend_ax, return_ax_dict, normalize): fig, legend_gs = make_grid_spec( legend_ax, nrows=4, ncols=1, height_ratios=height_ratios ) + # Hide the frame of the main legend container axis for a cleaner look + legend_ax.set_axis_off() + # Plot size legend into the second row of the grid if self.show_size_legend: size_legend_ax = fig.add_subplot(legend_gs[1]) self._plot_size_legend(size_legend_ax) return_ax_dict["size_legend_ax"] = size_legend_ax + # Plot colorbar(s) into the fourth row of the grid if self.show_colorbar: - color_legend_ax = fig.add_subplot(legend_gs[3]) + if self.group_cmaps is None: + color_legend_ax = fig.add_subplot(legend_gs[3]) + self._plot_colorbar(color_legend_ax, normalize) + return_ax_dict["color_legend_ax"] = color_legend_ax + else: + self._plot_stacked_colorbars(fig, legend_gs[3], normalize) + return_ax_dict["color_legend_ax"] = legend_ax + + def _plot_stacked_colorbars(self, fig, colorbar_area_spec, normalize): + """Plot the stacked colorbars legend when using group_cmaps.""" + import matplotlib as mpl + import matplotlib.colorbar + from matplotlib.cm import ScalarMappable + + plotted_groups = self.dot_color_df.index + groups_to_plot = list(plotted_groups) + n_cbars = len(groups_to_plot) + + # Create a sub-grid just for the colorbars + # Create an empty column to keep colorbars at 3/4 of legend width (1.5 like default with dp.legend_width = 2.0) + colorbar_gs = colorbar_area_spec.subgridspec( + n_cbars, 2, hspace=0.6, width_ratios=[3, 1] + ) + + # Create a dedicated normalizer for the legend + vmin = self.dot_color_df.values.min() + vmax = self.dot_color_df.values.max() + legend_norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) + + for i, group_name in enumerate(groups_to_plot): + ax = fig.add_subplot( + colorbar_gs[i, 0] + ) # Place the colorbar Axes in the first, wider column + cmap = colormaps.get_cmap(self.group_cmaps[group_name]) + mappable = ScalarMappable(norm=legend_norm, cmap=cmap) + + cb = matplotlib.colorbar.Colorbar( + ax, mappable=mappable, orientation="horizontal" + ) + cb.ax.xaxis.set_tick_params(labelsize="small") + + ax.text( + 1.1, + 0.5, + group_name, + ha="left", + va="center", + transform=ax.transAxes, + fontsize="small", + ) + + if i == 0: + cb.ax.set_title(self.color_legend_title, fontsize="small") - self._plot_colorbar(color_legend_ax, normalize) - return_ax_dict["color_legend_ax"] = color_legend_ax + if i < n_cbars - 1: + cb.ax.xaxis.set_ticklabels([]) + cb.ax.xaxis.set_ticks([]) def _mainplot(self, ax: Axes): # work on a copy of the dataframes. This is to avoid changes @@ -591,7 +693,9 @@ def _mainplot(self, ax: Axes): _size_df, _color_df, ax, + are_axes_swapped=self.are_axes_swapped, cmap=self.cmap, + group_cmaps=self.group_cmaps, color_on=self.color_on, dot_max=self.dot_max, dot_min=self.dot_min, @@ -621,6 +725,8 @@ def _dotplot( # noqa: PLR0912, PLR0913, PLR0915 dot_ax: Axes, *, cmap: Colormap | str | None, + group_cmaps: Mapping[str, str] | None, + are_axes_swapped: bool, color_on: Literal["dot", "square"], dot_max: float | None, dot_min: float | None, @@ -736,47 +842,83 @@ def _dotplot( # noqa: PLR0912, PLR0913, PLR0915 size = size * (largest_dot - smallest_dot) + smallest_dot normalize = check_colornorm(vmin, vmax, vcenter, norm) - if color_on == "square": - if edge_color is None: - from seaborn.utils import relative_luminance - - # use either black or white for the edge color - # depending on the luminance of the background - # square color - edge_color = [] - for color_value in cmap(normalize(mean_flat)): - lum = relative_luminance(color_value) - edge_color.append(".15" if lum > 0.408 else "w") - - edge_lw = 1.5 if edge_lw is None else edge_lw - - # first make a heatmap similar to `sc.pl.matrixplot` - # (squares with the asigned colormap). Circles will be plotted - # on top - dot_ax.pcolor(dot_color.values, cmap=cmap, norm=normalize) - for axis in ["top", "bottom", "left", "right"]: - dot_ax.spines[axis].set_linewidth(1.5) - kwds = fix_kwds( - kwds, - s=size, - linewidth=edge_lw, - facecolor="none", - edgecolor=edge_color, - ) - dot_ax.scatter(x, y, **kwds) + if group_cmaps is None: + # Plotting logic for single colormap + if color_on == "square": + if edge_color is None: + from seaborn.utils import relative_luminance + + # use either black or white for the edge color + # depending on the luminance of the background + # square color + edge_color = [] + for color_value in cmap(normalize(mean_flat)): + lum = relative_luminance(color_value) + edge_color.append(".15" if lum > 0.408 else "w") + + edge_lw = 1.5 if edge_lw is None else edge_lw + + # first make a heatmap similar to `sc.pl.matrixplot` + # (squares with the asigned colormap). Circles will be plotted + # on top + dot_ax.pcolor(dot_color.values, cmap=cmap, norm=normalize) + for axis in ["top", "bottom", "left", "right"]: + dot_ax.spines[axis].set_linewidth(1.5) + # Create a temporary kwargs dict for this group's scatter call + # to avoid modifying the original kwds dictionary within the loop. + kwds_scatter = fix_kwds( + kwds, + s=size, + linewidth=edge_lw, + facecolor="none", + edgecolor=edge_color, + ) + dot_ax.scatter(x, y, **kwds_scatter) + else: + edge_color = "none" if edge_color is None else edge_color + edge_lw = 0.0 if edge_lw is None else edge_lw + color = cmap(normalize(mean_flat)) + kwds_scatter = fix_kwds( + kwds, + s=size, + color=color, + linewidth=edge_lw, + edgecolor=edge_color, + ) + dot_ax.scatter(x, y, **kwds_scatter) else: - edge_color = "none" if edge_color is None else edge_color - edge_lw = 0.0 if edge_lw is None else edge_lw - - color = cmap(normalize(mean_flat)) - kwds = fix_kwds( - kwds, - s=size, - color=color, - linewidth=edge_lw, - edgecolor=edge_color, - ) - dot_ax.scatter(x, y, **kwds) + # Plotting logic for group-specific colormaps + groups_iter = dot_color.columns if are_axes_swapped else dot_color.index + n_vars = dot_color.shape[0] if are_axes_swapped else dot_color.shape[1] + n_groups = len(groups_iter) + + # Here we loop through each group and plot it with its own cmap + for group_idx, group_name in enumerate(groups_iter): + group_cmap_name = group_cmaps[group_name] + group_cmap = colormaps.get_cmap(group_cmap_name) + + # Slice the flattened data arrays correctly depending on orientation + if not are_axes_swapped: + # Slicing data for a whole row + indices = slice(group_idx * n_vars, (group_idx + 1) * n_vars) + else: + # Slicing data for a whole column + indices = slice(group_idx, None, n_groups) + + x_group = x[indices] + y_group = y[indices] + size_group = size[indices] + mean_group = mean_flat[indices] + + color = group_cmap(normalize(mean_group)) + kwds_scatter = fix_kwds( + kwds, + s=size_group, + color=color, + linewidth=edge_lw, + edgecolor=edge_color, + ) + dot_ax.scatter(x_group, y_group, **kwds_scatter) y_ticks = np.arange(dot_color.shape[0]) + 0.5 dot_ax.set_yticks(y_ticks) @@ -875,6 +1017,7 @@ def dotplot( # noqa: PLR0913 norm: Normalize | None = None, # Style parameters cmap: Colormap | str | None = DotPlot.DEFAULT_COLORMAP, + group_cmaps: Mapping[str, str] | None = None, dot_max: float | None = DotPlot.DEFAULT_DOT_MAX, dot_min: float | None = DotPlot.DEFAULT_DOT_MIN, smallest_dot: float = DotPlot.DEFAULT_SMALLEST_DOT, @@ -913,6 +1056,11 @@ def dotplot( # noqa: PLR0913 mean_only_expressed If True, gene expression is averaged only over the cells expressing the given genes. + group_cmaps + A mapping of group names to colormap names, e.g. + `{{'T-cell': 'Blues', 'B-cell': 'Reds'}}`. This allows for specifying a + different colormap for each group. If used, all groups in the plot + must have a colormap defined in this mapping. dot_max If ``None``, the maximum dot size is set to the maximum fraction value found (e.g. 0.6). If given, the value should be a number between 0 and 1. @@ -1000,6 +1148,7 @@ def dotplot( # noqa: PLR0913 var_group_rotation=var_group_rotation, layer=layer, dot_color_df=dot_color_df, + group_cmaps=group_cmaps, ax=ax, vmin=vmin, vmax=vmax, @@ -1019,7 +1168,9 @@ def dotplot( # noqa: PLR0913 dot_min=dot_min, smallest_dot=smallest_dot, dot_edge_lw=kwds.pop("linewidth", _empty), - ).legend(colorbar_title=colorbar_title, size_title=size_title) + ).legend( + colorbar_title=colorbar_title, size_title=size_title, width=2.0 + ) # Width 2.0 to avoid size legend circles to overlap if return_fig: return dp diff --git a/tests/_images/dotplot/expected.png b/tests/_images/dotplot/expected.png index 9c4b822369..028c0255b6 100644 Binary files a/tests/_images/dotplot/expected.png and b/tests/_images/dotplot/expected.png differ diff --git a/tests/_images/dotplot2/expected.png b/tests/_images/dotplot2/expected.png index ea85317b98..cb3a29cfac 100644 Binary files a/tests/_images/dotplot2/expected.png and b/tests/_images/dotplot2/expected.png differ diff --git a/tests/_images/dotplot3/expected.png b/tests/_images/dotplot3/expected.png index 93b42e24ce..3927b84aba 100644 Binary files a/tests/_images/dotplot3/expected.png and b/tests/_images/dotplot3/expected.png differ diff --git a/tests/_images/dotplot_dict/expected.png b/tests/_images/dotplot_dict/expected.png index d805ea94db..fb64d418fd 100644 Binary files a/tests/_images/dotplot_dict/expected.png and b/tests/_images/dotplot_dict/expected.png differ diff --git a/tests/_images/dotplot_gene_symbols/expected.png b/tests/_images/dotplot_gene_symbols/expected.png index 1f5c4e0c2f..7cc2f9f78b 100644 Binary files a/tests/_images/dotplot_gene_symbols/expected.png and b/tests/_images/dotplot_gene_symbols/expected.png differ diff --git a/tests/_images/dotplot_group_cmaps/expected.png b/tests/_images/dotplot_group_cmaps/expected.png new file mode 100644 index 0000000000..e26ac7ca52 Binary files /dev/null and b/tests/_images/dotplot_group_cmaps/expected.png differ diff --git a/tests/_images/dotplot_group_cmaps_swap_axes/expected.png b/tests/_images/dotplot_group_cmaps_swap_axes/expected.png new file mode 100644 index 0000000000..042bc0fbf8 Binary files /dev/null and b/tests/_images/dotplot_group_cmaps_swap_axes/expected.png differ diff --git a/tests/_images/dotplot_groupby_index/expected.png b/tests/_images/dotplot_groupby_index/expected.png index 57f5962ee2..85150912f1 100644 Binary files a/tests/_images/dotplot_groupby_index/expected.png and b/tests/_images/dotplot_groupby_index/expected.png differ diff --git a/tests/_images/dotplot_groupby_list_catorder/expected.png b/tests/_images/dotplot_groupby_list_catorder/expected.png index fd6c3453a0..4548e0c55a 100644 Binary files a/tests/_images/dotplot_groupby_list_catorder/expected.png and b/tests/_images/dotplot_groupby_list_catorder/expected.png differ diff --git a/tests/_images/dotplot_std_scale_group/expected.png b/tests/_images/dotplot_std_scale_group/expected.png index 72572a40db..29c8513fdb 100644 Binary files a/tests/_images/dotplot_std_scale_group/expected.png and b/tests/_images/dotplot_std_scale_group/expected.png differ diff --git a/tests/_images/dotplot_std_scale_var/expected.png b/tests/_images/dotplot_std_scale_var/expected.png index 3f00af2704..854de66ad4 100644 Binary files a/tests/_images/dotplot_std_scale_var/expected.png and b/tests/_images/dotplot_std_scale_var/expected.png differ diff --git a/tests/_images/dotplot_totals/expected.png b/tests/_images/dotplot_totals/expected.png index a7d1269c94..1899a26325 100644 Binary files a/tests/_images/dotplot_totals/expected.png and b/tests/_images/dotplot_totals/expected.png differ diff --git a/tests/_images/multiple_plots/expected.png b/tests/_images/multiple_plots/expected.png index f0857c4721..2b514ccc30 100644 Binary files a/tests/_images/multiple_plots/expected.png and b/tests/_images/multiple_plots/expected.png differ diff --git a/tests/_images/ranked_genes_dotplot/expected.png b/tests/_images/ranked_genes_dotplot/expected.png index d6c97610f7..bad71a0858 100644 Binary files a/tests/_images/ranked_genes_dotplot/expected.png and b/tests/_images/ranked_genes_dotplot/expected.png differ diff --git a/tests/test_plotting.py b/tests/test_plotting.py index 3de26b6af3..785c5e57ae 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -1821,3 +1821,56 @@ def test_violin_scale_warning(monkeypatch): monkeypatch.setattr(sc.pl.StackedViolin, "DEFAULT_SCALE", "count", raising=False) with pytest.warns(FutureWarning, match="Don’t set DEFAULT_SCALE"): sc.pl.StackedViolin(adata, adata.var_names[:3], groupby="louvain") + + +params_dotplot_group_cmaps = [ + pytest.param("dotplot_group_cmaps", False, id="default"), + pytest.param("dotplot_group_cmaps_swap_axes", True, id="swap_axes"), +] + + +@pytest.mark.parametrize(("name", "swap_axes"), params_dotplot_group_cmaps) +def test_dotplot_group_cmaps(image_comparer, name, swap_axes): + """Check group_cmaps parameter with custom color maps per group.""" + save_and_compare_images = partial(image_comparer, ROOT, tol=15) + + adata = pbmc68k_reduced() + + markers = ["SERPINB1", "IGFBP7", "GNLY", "IFITM1", "IMP3", "UBALD2", "LTB", "CLPP"] + + group_cmaps = { + "CD14+ Monocyte": "Greys", + "Dendritic": "Purples", + "CD8+ Cytotoxic T": "Reds", + "CD8+/CD45RA+ Naive Cytotoxic": "Greens", + "CD4+/CD45RA+/CD25- Naive T": "Oranges", + "CD4+/CD25 T Reg": "Blues", + "CD4+/CD45RO+ Memory": "hot", + "CD19+ B": "cool", + "CD56+ NK": "winter", + "CD34+": "copper", + } + + sc.pl.dotplot( + adata, + markers, + groupby="bulk_labels", + group_cmaps=group_cmaps, + dendrogram=True, + swap_axes=swap_axes, + show=False, + ) + save_and_compare_images(name) + + +def test_dotplot_group_cmaps_raises_error(): + """Check that a ValueError is raised for missing groups in group_cmaps.""" + adata = pbmc68k_reduced() + markers = ["CD79A"] + # Intentionally incomplete dictionary to trigger the error + group_cmaps = {"CD19+ B": "Blues"} + + with pytest.raises(ValueError, match="missing from the `group_cmaps` dictionary"): + sc.pl.dotplot( + adata, markers, groupby="bulk_labels", group_cmaps=group_cmaps, show=False + )