Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions cicd_utils/ridgeplot_examples/_lincoln_weather.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ def main(

fig = ridgeplot(
samples=samples,
labels=[["Min Temperature [F]", "Max Temperature [F]"]] * len(months),
labels=[["Min Temperatures", "Max Temperatures"]] * len(months),
row_labels=months,
legendgroup=True,
colorscale="Inferno",
color_discrete_map=color_discrete_map,
bandwidth=4,
Expand All @@ -46,7 +47,12 @@ def main(
xaxis_gridwidth=2,
yaxis_title="Month",
xaxis_title="Temperature [F]",
showlegend=False,
legend=dict(
yanchor="top",
y=0.99,
xanchor="right",
x=0.99,
),
)

return fig
Expand Down
4 changes: 2 additions & 2 deletions cicd_utils/ridgeplot_examples/_lincoln_weather_red_blue.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
def main() -> go.Figure:
fig = lincoln_weather(
color_discrete_map={
"Min Temperature [F]": "deepskyblue",
"Max Temperature [F]": "orangered",
"Min Temperatures": "deepskyblue",
"Max Temperatures": "orangered",
}
)
return fig
Expand Down
Binary file modified docs/_static/charts/lincoln_weather.webp
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/_static/charts/lincoln_weather_red_blue.webp
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
25 changes: 18 additions & 7 deletions docs/getting_started/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ samples = [
# And finish by styling it up to your liking!
fig = ridgeplot(
samples=samples,
labels=[["Min Temperature [F]", "Max Temperature [F]"]] * len(months),
labels=[["Min Temperatures", "Max Temperatures"]] * len(months),
row_labels=months,
legendgroup=True,
colorscale="Inferno",
bandwidth=4,
kde_points=np.linspace(-40, 110, 400),
Expand All @@ -152,7 +153,12 @@ fig.update_layout(
xaxis_gridwidth=2,
yaxis_title="Month",
xaxis_title="Temperature [F]",
showlegend=False,
legend=dict(
yanchor="top",
y=0.99,
xanchor="right",
x=0.99,
),
)
fig.show()
```
Expand Down Expand Up @@ -229,14 +235,14 @@ Finally, we can pass the {py:paramref}`~ridgeplot.ridgeplot.samples` list to the
```python
fig = ridgeplot(
samples=samples,
labels=[["Min Temperature [F]", "Max Temperature [F]"]] * len(months),
labels=[["Min Temperatures", "Max Temperatures"]] * len(months),
row_labels=months,
legendgroup=True,
colorscale="Inferno",
bandwidth=4,
kde_points=np.linspace(-40, 110, 400),
spacing=0.3,
)

fig.update_layout(
title="Minimum and maximum daily temperatures in Lincoln, NE (2016)",
height=600,
Expand All @@ -248,7 +254,12 @@ fig.update_layout(
xaxis_gridwidth=2,
yaxis_title="Month",
xaxis_title="Temperature [F]",
showlegend=False,
legend=dict(
yanchor="top",
y=0.99,
xanchor="right",
x=0.99,
),
)
fig.show()
```
Expand Down Expand Up @@ -277,8 +288,8 @@ fig = ridgeplot(
# addition of `color_discrete_map`
# ...
color_discrete_map={
"Min Temperature [F]": "deepskyblue",
"Max Temperature [F]": "orangered",
"Min Temperatures": "deepskyblue",
"Max Temperatures": "orangered",
}
# ...
)
Expand Down
6 changes: 5 additions & 1 deletion src/ridgeplot/_figure_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
SolidColormode,
compute_solid_colors,
)
from ridgeplot._obj.legendcontext import LegendContextManager
from ridgeplot._obj.traces import get_trace_cls
from ridgeplot._obj.traces.base import ColoringContext
from ridgeplot._types import (
Expand Down Expand Up @@ -123,6 +124,7 @@ def create_ridgeplot(
trace_labels: LabelsArray | ShallowLabelsArray | None,
trace_types: TraceTypesArray | ShallowTraceTypesArray | TraceType,
row_labels: Collection[str] | None | Literal[False],
legendgroup: bool,
colorscale: ColorScale | Collection[Color] | str | None,
colormode: Literal["fillgradient"] | SolidColormode,
color_discrete_map: dict[str, str] | None,
Expand Down Expand Up @@ -176,6 +178,8 @@ def create_ridgeplot(
# --- Build the figure
# ==============================================================

legend_ctx_manager = LegendContextManager(legendgroup=legendgroup)

interpolation_ctx = InterpolationContext(
densities=densities,
n_rows=n_rows,
Expand Down Expand Up @@ -209,7 +213,7 @@ def create_ridgeplot(
):
trace_drawer = get_trace_cls(trace_type)(
trace=trace,
label=label,
legend_ctx=legend_ctx_manager.get_legend_ctx(label=label),
solid_color=color,
zorder=ith_trace,
y_base=y_base,
Expand Down
60 changes: 60 additions & 0 deletions src/ridgeplot/_obj/legendcontext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

import dataclasses
from typing import Any, TypedDict


class Font(TypedDict, total=False):
# plotly/graph_objs/scatter/legendgrouptitle/_font.py
color: str | None
family: str | None
lineposition: str | None
shadow: str | None
size: int | float | None
style: str | None
textcase: str | None
variant: str | None
weight: int | None


class Legendgrouptitle(TypedDict, total=False):
# plotly/graph_objs/scatter/_legendgrouptitle.py
text: str | None
font: Font | None


@dataclasses.dataclass
class LegendContext:
name: str
showlegend: bool
legendgroup: str | int | float | None = None
legendgrouptitle: Legendgrouptitle | None = None

@property
def trace_kwargs(self) -> dict[str, Any]:
return dataclasses.asdict(self)


class LegendContextManager:
def __init__(self, legendgroup: bool) -> None:
super().__init__()
self.legendgroup = legendgroup
self._seen_labels: set[str] = set()

def get_legend_ctx(self, label: str) -> LegendContext:
if not self.legendgroup:
return LegendContext(name=label, showlegend=True)
if label not in self._seen_labels:
self._seen_labels.add(label)
return LegendContext(
name=label,
showlegend=True,
legendgroup=label,
# FIXME: This doesn't seem to work as expected
# legendgrouptitle=Legendgrouptitle(text=label),
Comment on lines +53 to +54
Copy link

Copilot AI Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The commented FIXME indicates that legendgrouptitle functionality is not working as expected. This suggests there may be an issue with the Plotly legend group title feature that should be investigated or documented.

Copilot uses AI. Check for mistakes.

)
return LegendContext(
name=label,
showlegend=False,
legendgroup=label,
)
10 changes: 3 additions & 7 deletions src/ridgeplot/_obj/traces/area.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ridgeplot._color.interpolation import slice_colorscale
from ridgeplot._color.utils import apply_alpha
from ridgeplot._obj.traces.base import DEFAULT_HOVERTEMPLATE, ColoringContext, RidgeplotTrace
from ridgeplot._obj.traces.base import ColoringContext, RidgeplotTrace
from ridgeplot._utils import normalise_min_max


Expand Down Expand Up @@ -68,22 +68,18 @@ def draw(self, fig: go.Figure, coloring_ctx: ColoringContext) -> go.Figure:
hoverinfo="skip",
# z-order (higher z-order means the trace is drawn on top)
zorder=self.zorder,
legendgroup=self.legend_ctx.legendgroup,
)
)
fig.add_trace(
go.Scatter(
x=self.x,
y=[y_i + self.y_base for y_i in self.y],
name=self.label,
fill="tonexty",
mode="lines",
line_width=self.line_width,
**self._common_trace_kwargs,
**self._get_coloring_kwargs(ctx=coloring_ctx),
# Hover information
customdata=[[y_i] for y_i in self.y],
hovertemplate=DEFAULT_HOVERTEMPLATE,
# z-order (higher z-order means the trace is drawn on top)
zorder=self.zorder,
),
)
return fig
9 changes: 2 additions & 7 deletions src/ridgeplot/_obj/traces/bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing_extensions import Any, override

from ridgeplot._color.interpolation import interpolate_color
from ridgeplot._obj.traces.base import DEFAULT_HOVERTEMPLATE, ColoringContext, RidgeplotTrace
from ridgeplot._obj.traces.base import ColoringContext, RidgeplotTrace
from ridgeplot._utils import normalise_min_max


Expand Down Expand Up @@ -42,16 +42,11 @@ def draw(self, fig: go.Figure, coloring_ctx: ColoringContext) -> go.Figure:
go.Bar(
x=self.x,
y=self.y,
name=self.label,
base=self.y_base,
marker_line_width=self.line_width,
width=None, # Plotly automatically picks the right width
**self._get_coloring_kwargs(ctx=coloring_ctx),
# Hover information
customdata=[[y_i] for y_i in self.y],
hovertemplate=DEFAULT_HOVERTEMPLATE,
# z-order (higher z-order means the trace is drawn on top)
zorder=self.zorder,
**self._common_trace_kwargs,
),
)
return fig
22 changes: 18 additions & 4 deletions src/ridgeplot/_obj/traces/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar

from typing_extensions import Literal

if TYPE_CHECKING:
from plotly import graph_objects as go

from ridgeplot._color.interpolation import InterpolationContext
from ridgeplot._obj.legendcontext import LegendContext
from ridgeplot._types import Color, ColorScale, DensityTrace


Expand All @@ -24,7 +25,7 @@
string below, but it's not quite the same... (see '.7~r' as well)
"""

DEFAULT_HOVERTEMPLATE = (
_DEFAULT_HOVERTEMPLATE = (
f"(%{{x:{_D3HF}}}, %{{customdata[0]:{_D3HF}}})"
"<br>"
"<extra>%{fullData.name}</extra>"
Expand Down Expand Up @@ -56,7 +57,7 @@ def __init__(
self,
*, # kw only
trace: DensityTrace,
label: str,
legend_ctx: LegendContext,
solid_color: str,
zorder: int,
# Constant over the trace's row
Expand All @@ -67,13 +68,26 @@ def __init__(
):
super().__init__()
self.x, self.y = zip(*trace, strict=True)
self.label = label
self.legend_ctx = legend_ctx
self.solid_color = solid_color
self.zorder = zorder
self.y_base = y_base
self.line_color: Color = self.solid_color if line_color == "fill-color" else line_color
self.line_width: float = line_width if line_width is not None else self._DEFAULT_LINE_WIDTH

@property
def _common_trace_kwargs(self) -> dict[str, Any]:
"""Return common trace kwargs."""
return dict(
# Legend information
**self.legend_ctx.trace_kwargs,
# Hover information
customdata=[[y_i] for y_i in self.y],
hovertemplate=_DEFAULT_HOVERTEMPLATE,
# z-order (higher z-order means the trace is drawn on top)
zorder=self.zorder,
)

@abstractmethod
def draw(self, fig: go.Figure, coloring_ctx: ColoringContext) -> go.Figure:
raise NotImplementedError
2 changes: 2 additions & 0 deletions src/ridgeplot/_ridgeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def ridgeplot(
trace_type: TraceTypesArray | ShallowTraceTypesArray | TraceType | None = None,
labels: LabelsArray | ShallowLabelsArray | None = None,
row_labels: Collection[str] | None | Literal[False] = None,
legendgroup: bool = False, # TODO: document and rename to smth better!
Copy link

Copilot AI Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TODO comment suggests this parameter needs documentation and a better name. The current name 'legendgroup' and the boolean type don't clearly convey that this enables grouping of legend entries by label name.

Copilot uses AI. Check for mistakes.

# KDE parameters
kernel: str = "gau",
bandwidth: KDEBandwidth = "normal_reference",
Expand Down Expand Up @@ -503,6 +504,7 @@ def ridgeplot(
trace_labels=labels,
trace_types=trace_type,
row_labels=row_labels,
legendgroup=legendgroup,
colorscale=colorscale,
colormode=colormode,
color_discrete_map=color_discrete_map,
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/artifacts/basic.json

Large diffs are not rendered by default.

Loading
Loading