Skip to content

Commit fdbd29f

Browse files
FBumannclaude
andcommitted
fix: legend visibility, style persistence, and axis ranges in combined figures
Combined figures (overlay, add_secondary_y) had three issues: 1. Legends disappeared because Plotly Express sets showlegend=False on single-trace figures. Now unnamed traces get names derived from the source figure's y-axis title, and showlegend is fixed per legendgroup. 2. Colors and styles were lost during animation because frame traces carried PX defaults. Now marker, line, opacity and legend properties are propagated from fig.data into all animation frame traces. 3. Axis ranges were computed from fig.data only, so frames with different data ranges went off-screen during animation. Now global min/max is computed across all frames and set explicitly on the layout. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent fbca665 commit fdbd29f

2 files changed

Lines changed: 332 additions & 19 deletions

File tree

tests/test_figures.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,3 +616,142 @@ def test_secondary_not_modified(self) -> None:
616616

617617
# Secondary traces should still use original yaxis
618618
assert secondary.data[0].yaxis == original_yaxis
619+
620+
621+
class TestLegendVisibility:
622+
"""Tests that combined figures preserve legend visibility."""
623+
624+
def test_overlay_single_trace_figures_with_names(self) -> None:
625+
"""Overlay of named single-trace figures shows legend."""
626+
da1 = xr.DataArray([1, 2, 3], dims=["x"], name="a")
627+
da2 = xr.DataArray([4, 5, 6], dims=["x"], name="b")
628+
629+
fig1 = xpx(da1).line()
630+
fig1.update_traces(name="Series A")
631+
fig2 = xpx(da2).line()
632+
fig2.update_traces(name="Series B")
633+
634+
combined = overlay(fig1, fig2)
635+
636+
assert combined.data[0].showlegend is True
637+
assert combined.data[1].showlegend is True
638+
639+
def test_overlay_unnamed_traces_get_yaxis_title(self) -> None:
640+
"""Overlay of unnamed traces derives names from y-axis titles."""
641+
da1 = xr.DataArray([1, 2, 3], dims=["x"], name="Temperature")
642+
da2 = xr.DataArray([4, 5, 6], dims=["x"], name="Pressure")
643+
644+
fig1 = xpx(da1).line()
645+
fig2 = xpx(da2).line()
646+
647+
combined = overlay(fig1, fig2)
648+
649+
# Names derived from y-axis titles (DataArray names)
650+
assert combined.data[0].name == "Temperature"
651+
assert combined.data[1].name == "Pressure"
652+
assert combined.data[0].showlegend is True
653+
assert combined.data[1].showlegend is True
654+
655+
def test_overlay_same_name_disambiguated(self) -> None:
656+
"""Overlay of figures with same y-axis title gets numeric suffix."""
657+
da1 = xr.DataArray([1, 2, 3], dims=["x"], name="value")
658+
da2 = xr.DataArray([4, 5, 6], dims=["x"], name="value")
659+
660+
fig1 = xpx(da1).line()
661+
fig2 = xpx(da2).line()
662+
663+
combined = overlay(fig1, fig2)
664+
665+
assert combined.data[0].name == "value (1)"
666+
assert combined.data[1].name == "value (2)"
667+
668+
def test_overlay_multi_trace_deduplicates_legend(self) -> None:
669+
"""Overlay of multi-trace figures deduplicates shared legendgroups."""
670+
da = xr.DataArray(
671+
np.random.rand(10, 3),
672+
dims=["x", "cat"],
673+
coords={"cat": ["A", "B", "C"]},
674+
)
675+
fig1 = xpx(da).area()
676+
fig2 = xpx(da).line()
677+
678+
combined = overlay(fig1, fig2)
679+
680+
# First occurrence of each legendgroup should show, duplicates hidden
681+
from collections import defaultdict
682+
683+
groups: dict[str, list[bool]] = defaultdict(list)
684+
for trace in combined.data:
685+
lg = trace.legendgroup
686+
groups[lg].append(trace.showlegend is True)
687+
688+
for lg, flags in groups.items():
689+
assert flags.count(True) == 1, f"legendgroup {lg!r} has {flags.count(True)} visible"
690+
691+
def test_add_secondary_y_single_trace_with_names(self) -> None:
692+
"""add_secondary_y of named single-trace figures shows legend."""
693+
da1 = xr.DataArray([1, 2, 3], dims=["x"], name="temp")
694+
da2 = xr.DataArray([100, 200, 300], dims=["x"], name="precip")
695+
696+
fig1 = xpx(da1).line()
697+
fig1.update_traces(name="Temperature")
698+
fig2 = xpx(da2).bar()
699+
fig2.update_traces(name="Precipitation")
700+
701+
combined = add_secondary_y(fig1, fig2)
702+
703+
assert combined.data[0].showlegend is True
704+
assert combined.data[1].showlegend is True
705+
706+
def test_overlay_faceted_legendgroup_dedup(self) -> None:
707+
"""Faceted overlay keeps only one showlegend=True per legendgroup."""
708+
da = xr.DataArray(
709+
np.random.rand(10, 2, 2),
710+
dims=["x", "cat", "facet"],
711+
coords={"cat": ["A", "B"], "facet": ["left", "right"]},
712+
)
713+
fig1 = xpx(da).area(facet_col="facet")
714+
fig2 = xpx(da).line(facet_col="facet")
715+
716+
combined = overlay(fig1, fig2)
717+
718+
# Check each legendgroup has at least one showlegend=True
719+
from collections import defaultdict
720+
721+
groups: dict[str, list[bool]] = defaultdict(list)
722+
for trace in combined.data:
723+
lg = trace.legendgroup or ""
724+
if lg:
725+
groups[lg].append(trace.showlegend is True)
726+
727+
for lg, flags in groups.items():
728+
assert any(flags), f"legendgroup {lg!r} has no showlegend=True trace"
729+
730+
def test_overlay_animation_frames_preserve_style(self) -> None:
731+
"""Animation frame traces keep legend and color from fig.data."""
732+
da = xr.DataArray(
733+
np.random.rand(10, 3),
734+
dims=["x", "time"],
735+
coords={"time": [0, 1, 2]},
736+
name="Population",
737+
)
738+
da_smooth = da.rolling(x=3, center=True).mean()
739+
da_smooth.name = "Smoothed"
740+
741+
fig1 = xpx(da).bar(animation_frame="time")
742+
fig1.update_traces(marker={"color": "steelblue"})
743+
fig2 = xpx(da_smooth).line(animation_frame="time")
744+
fig2.update_traces(line={"color": "red"})
745+
746+
combined = overlay(fig1, fig2)
747+
748+
for frame in combined.frames:
749+
for i, ft in enumerate(frame.data):
750+
src = combined.data[i]
751+
assert ft.name == src.name
752+
assert ft.showlegend == src.showlegend
753+
assert ft.legendgroup == src.legendgroup
754+
# Bar trace should keep steelblue
755+
assert frame.data[0].marker.color == "steelblue"
756+
# Line trace should keep red
757+
assert frame.data[1].line.color == "red"

xarray_plotly/figures.py

Lines changed: 193 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,171 @@
1313
import plotly.graph_objects as go
1414

1515

16+
def _get_yaxis_title(fig: go.Figure) -> str:
17+
"""Extract the primary y-axis title text from a figure.
18+
19+
Args:
20+
fig: A Plotly figure.
21+
22+
Returns:
23+
The y-axis title text, or empty string if not set.
24+
"""
25+
try:
26+
return fig.layout.yaxis.title.text or ""
27+
except AttributeError:
28+
return ""
29+
30+
31+
def _ensure_legend_visibility(
32+
combined: go.Figure,
33+
source_figs: list[go.Figure],
34+
trace_slices: list[slice],
35+
) -> None:
36+
"""Fix legend visibility on a combined figure.
37+
38+
Handles three problems that arise when combining Plotly Express figures:
39+
40+
1. **Unnamed traces** — PX sets ``name=""`` on single-trace (no color)
41+
figures. We derive a name from each source figure's y-axis title.
42+
2. **Hidden named traces** — PX sets ``showlegend=False`` on single-trace
43+
figures. We ensure at least one trace per ``legendgroup`` (or each
44+
ungrouped named trace) has ``showlegend=True``.
45+
3. **Duplicate legend entries** — when two source figures share the same
46+
``legendgroup`` names, we deduplicate so only the first trace per
47+
group shows in the legend.
48+
49+
Args:
50+
combined: The combined Plotly figure (mutated in place).
51+
source_figs: The original source figures, in trace order.
52+
trace_slices: Slices into ``combined.data`` for each source figure.
53+
"""
54+
from collections import defaultdict
55+
56+
# --- Step 1: label unnamed traces from source y-axis titles -----------
57+
labels = [_get_yaxis_title(f) for f in source_figs]
58+
59+
# If all labels are the same, disambiguate
60+
unique_labels = {lb for lb in labels if lb}
61+
if len(unique_labels) == 1:
62+
labels = [f"{labels[0]} ({i + 1})" for i in range(len(labels))]
63+
64+
for label, sl in zip(labels, trace_slices, strict=False):
65+
if not label:
66+
continue
67+
for trace in combined.data[sl]:
68+
if not getattr(trace, "name", None):
69+
trace.name = label
70+
trace.legendgroup = label
71+
72+
# --- Step 2 & 3: fix showlegend per legendgroup -----------------------
73+
grouped: dict[str, list[Any]] = defaultdict(list)
74+
ungrouped: list[Any] = []
75+
76+
for trace in combined.data:
77+
lg = getattr(trace, "legendgroup", None) or ""
78+
if lg:
79+
grouped[lg].append(trace)
80+
else:
81+
ungrouped.append(trace)
82+
83+
for traces in grouped.values():
84+
has_visible = False
85+
for t in traces:
86+
if has_visible:
87+
# Deduplicate: only first keeps showlegend
88+
t.showlegend = False
89+
elif getattr(t, "name", None):
90+
t.showlegend = True
91+
has_visible = True
92+
93+
# Ungrouped traces with a name should show in the legend
94+
for trace in ungrouped:
95+
if getattr(trace, "name", None):
96+
trace.showlegend = True
97+
98+
# --- Step 4: propagate style properties to animation frame traces ------
99+
# When Plotly animates, frame trace data overwrites fig.data properties.
100+
# PX frame traces carry name="", showlegend=False and default colors,
101+
# discarding any styling the user applied via update_traces() before
102+
# combining. Propagate display properties from fig.data into every frame.
103+
_STYLE_ATTRS = ("name", "legendgroup", "showlegend", "marker", "line", "opacity")
104+
for frame in combined.frames or []:
105+
for i, frame_trace in enumerate(frame.data):
106+
if i < len(combined.data):
107+
src = combined.data[i]
108+
for attr in _STYLE_ATTRS:
109+
src_val = getattr(src, attr, None)
110+
if src_val is not None:
111+
setattr(frame_trace, attr, src_val)
112+
113+
114+
def _fix_animation_axis_ranges(fig: go.Figure) -> None:
115+
"""Set axis ranges to encompass data across all animation frames.
116+
117+
Plotly.js computes autorange from ``fig.data`` only and does not
118+
recalculate during animation. When different frames have very different
119+
data ranges (e.g. population of Brazil vs China), values can go off-screen.
120+
This function computes the global min/max for each axis across all frames
121+
and sets explicit ranges on the layout.
122+
123+
Only numeric axes are handled; categorical/date axes are left to autorange.
124+
125+
Args:
126+
fig: A Plotly figure with animation frames (mutated in place).
127+
"""
128+
import numpy as np
129+
130+
if not fig.frames:
131+
return
132+
133+
from collections import defaultdict
134+
135+
# Collect numeric y-values per axis across all traces (fig.data + frames)
136+
y_by_axis: dict[str, list[float]] = defaultdict(list)
137+
x_by_axis: dict[str, list[float]] = defaultdict(list)
138+
139+
for trace in _iter_all_traces(fig):
140+
yaxis = getattr(trace, "yaxis", None) or "y"
141+
xaxis = getattr(trace, "xaxis", None) or "x"
142+
143+
y = getattr(trace, "y", None)
144+
if y is not None:
145+
try:
146+
arr = np.asarray(y, dtype=float)
147+
finite = arr[np.isfinite(arr)]
148+
if len(finite):
149+
y_by_axis[yaxis].extend(finite.tolist())
150+
except (ValueError, TypeError):
151+
pass # Non-numeric (categorical) — skip
152+
153+
x = getattr(trace, "x", None)
154+
if x is not None:
155+
try:
156+
arr = np.asarray(x, dtype=float)
157+
finite = arr[np.isfinite(arr)]
158+
if len(finite):
159+
x_by_axis[xaxis].extend(finite.tolist())
160+
except (ValueError, TypeError):
161+
pass
162+
163+
# Apply ranges to layout
164+
for axis_ref, values in y_by_axis.items():
165+
if not values:
166+
continue
167+
lo, hi = min(values), max(values)
168+
pad = (hi - lo) * 0.05 or 1 # 5% padding
169+
layout_prop = "yaxis" if axis_ref == "y" else f"yaxis{axis_ref[1:]}"
170+
fig.layout[layout_prop].range = [lo - pad, hi + pad]
171+
172+
for axis_ref, values in x_by_axis.items():
173+
if not values:
174+
continue
175+
lo, hi = min(values), max(values)
176+
pad = (hi - lo) * 0.05 or 1
177+
layout_prop = "xaxis" if axis_ref == "x" else f"xaxis{axis_ref[1:]}"
178+
fig.layout[layout_prop].range = [lo - pad, hi + pad]
179+
180+
16181
def _iter_all_traces(fig: go.Figure) -> Iterator[Any]:
17182
"""Iterate over all traces in a figure, including animation frames.
18183
@@ -194,17 +359,11 @@ def overlay(base: go.Figure, *overlays: go.Figure) -> go.Figure:
194359
_validate_compatible_structure(base, overlay)
195360
_validate_animation_compatibility(base, overlay)
196361

197-
# Create new figure with base's layout
198-
combined = go.Figure(layout=copy.deepcopy(base.layout))
199-
200-
# Add all traces from base
201-
for trace in base.data:
202-
combined.add_trace(copy.deepcopy(trace))
203-
204-
# Add all traces from overlays
362+
# Create new figure with base's layout and all traces
363+
all_traces = [copy.deepcopy(t) for t in base.data]
205364
for overlay in overlays:
206-
for trace in overlay.data:
207-
combined.add_trace(copy.deepcopy(trace))
365+
all_traces.extend(copy.deepcopy(t) for t in overlay.data)
366+
combined = go.Figure(data=all_traces, layout=copy.deepcopy(base.layout))
208367

209368
# Handle animation frames
210369
if base.frames:
@@ -213,6 +372,17 @@ def overlay(base: go.Figure, *overlays: go.Figure) -> go.Figure:
213372
merged_frames = _merge_frames(base, list(overlays), base_trace_count, overlay_trace_counts)
214373
combined.frames = merged_frames
215374

375+
# Build trace slices for legend fix
376+
source_figs = [base, *overlays]
377+
slices: list[slice] = []
378+
offset = 0
379+
for fig in source_figs:
380+
n = len(fig.data)
381+
slices.append(slice(offset, offset + n))
382+
offset += n
383+
384+
_ensure_legend_visibility(combined, source_figs, slices)
385+
_fix_animation_axis_ranges(combined)
216386
return combined
217387

218388

@@ -315,19 +485,15 @@ def add_secondary_y(
315485
rightmost_x = max(x_for_y.values(), key=lambda x: int(x[1:]) if x != "x" else 1)
316486
rightmost_primary_y = next(y for y, x in x_for_y.items() if x == rightmost_x)
317487

318-
# Create new figure with base's layout
319-
combined = go.Figure(layout=copy.deepcopy(base.layout))
320-
321-
# Add all traces from base (primary y-axis)
322-
for trace in base.data:
323-
combined.add_trace(copy.deepcopy(trace))
324-
325-
# Add all traces from secondary, remapped to secondary y-axes
488+
# Build all traces: base (primary) + secondary (remapped to secondary y-axes)
489+
all_traces = [copy.deepcopy(t) for t in base.data]
326490
for trace in secondary.data:
327491
trace_copy = copy.deepcopy(trace)
328492
original_yaxis = getattr(trace_copy, "yaxis", None) or "y"
329493
trace_copy.yaxis = y_mapping[original_yaxis]
330-
combined.add_trace(trace_copy)
494+
all_traces.append(trace_copy)
495+
496+
combined = go.Figure(data=all_traces, layout=copy.deepcopy(base.layout))
331497

332498
# Get the rightmost secondary y-axis name for linking
333499
rightmost_secondary_y = y_mapping[rightmost_primary_y]
@@ -368,6 +534,14 @@ def add_secondary_y(
368534
merged_frames = _merge_secondary_y_frames(base, secondary, y_mapping)
369535
combined.frames = merged_frames
370536

537+
base_n = len(base.data)
538+
sec_n = len(secondary.data)
539+
_ensure_legend_visibility(
540+
combined,
541+
[base, secondary],
542+
[slice(0, base_n), slice(base_n, base_n + sec_n)],
543+
)
544+
_fix_animation_axis_ranges(combined)
371545
return combined
372546

373547

0 commit comments

Comments
 (0)