1313 import plotly .graph_objects as go
1414
1515
16+ def _get_yaxis_title (fig : go .Figure ) -> str :
17+ """Extract the primary y-axis title text from a figure.
18+
19+ Args:
20+ fig: A Plotly figure.
21+
22+ Returns:
23+ The y-axis title text, or empty string if not set.
24+ """
25+ try :
26+ return fig .layout .yaxis .title .text or ""
27+ except AttributeError :
28+ return ""
29+
30+
31+ def _ensure_legend_visibility (
32+ combined : go .Figure ,
33+ source_figs : list [go .Figure ],
34+ trace_slices : list [slice ],
35+ ) -> None :
36+ """Fix legend visibility on a combined figure.
37+
38+ Handles three problems that arise when combining Plotly Express figures:
39+
40+ 1. **Unnamed traces** — PX sets ``name=""`` on single-trace (no color)
41+ figures. We derive a name from each source figure's y-axis title.
42+ 2. **Hidden named traces** — PX sets ``showlegend=False`` on single-trace
43+ figures. We ensure at least one trace per ``legendgroup`` (or each
44+ ungrouped named trace) has ``showlegend=True``.
45+ 3. **Duplicate legend entries** — when two source figures share the same
46+ ``legendgroup`` names, we deduplicate so only the first trace per
47+ group shows in the legend.
48+
49+ Args:
50+ combined: The combined Plotly figure (mutated in place).
51+ source_figs: The original source figures, in trace order.
52+ trace_slices: Slices into ``combined.data`` for each source figure.
53+ """
54+ from collections import defaultdict
55+
56+ # --- Step 1: label unnamed traces from source y-axis titles -----------
57+ labels = [_get_yaxis_title (f ) for f in source_figs ]
58+
59+ # If all labels are the same, disambiguate
60+ unique_labels = {lb for lb in labels if lb }
61+ if len (unique_labels ) == 1 :
62+ labels = [f"{ labels [0 ]} ({ i + 1 } )" for i in range (len (labels ))]
63+
64+ for label , sl in zip (labels , trace_slices , strict = False ):
65+ if not label :
66+ continue
67+ for trace in combined .data [sl ]:
68+ if not getattr (trace , "name" , None ):
69+ trace .name = label
70+ trace .legendgroup = label
71+
72+ # --- Step 2 & 3: fix showlegend per legendgroup -----------------------
73+ grouped : dict [str , list [Any ]] = defaultdict (list )
74+ ungrouped : list [Any ] = []
75+
76+ for trace in combined .data :
77+ lg = getattr (trace , "legendgroup" , None ) or ""
78+ if lg :
79+ grouped [lg ].append (trace )
80+ else :
81+ ungrouped .append (trace )
82+
83+ for traces in grouped .values ():
84+ has_visible = False
85+ for t in traces :
86+ if has_visible :
87+ # Deduplicate: only first keeps showlegend
88+ t .showlegend = False
89+ elif getattr (t , "name" , None ):
90+ t .showlegend = True
91+ has_visible = True
92+
93+ # Ungrouped traces with a name should show in the legend
94+ for trace in ungrouped :
95+ if getattr (trace , "name" , None ):
96+ trace .showlegend = True
97+
98+ # --- Step 4: propagate style properties to animation frame traces ------
99+ # When Plotly animates, frame trace data overwrites fig.data properties.
100+ # PX frame traces carry name="", showlegend=False and default colors,
101+ # discarding any styling the user applied via update_traces() before
102+ # combining. Propagate display properties from fig.data into every frame.
103+ _STYLE_ATTRS = ("name" , "legendgroup" , "showlegend" , "marker" , "line" , "opacity" )
104+ for frame in combined .frames or []:
105+ for i , frame_trace in enumerate (frame .data ):
106+ if i < len (combined .data ):
107+ src = combined .data [i ]
108+ for attr in _STYLE_ATTRS :
109+ src_val = getattr (src , attr , None )
110+ if src_val is not None :
111+ setattr (frame_trace , attr , src_val )
112+
113+
114+ def _fix_animation_axis_ranges (fig : go .Figure ) -> None :
115+ """Set axis ranges to encompass data across all animation frames.
116+
117+ Plotly.js computes autorange from ``fig.data`` only and does not
118+ recalculate during animation. When different frames have very different
119+ data ranges (e.g. population of Brazil vs China), values can go off-screen.
120+ This function computes the global min/max for each axis across all frames
121+ and sets explicit ranges on the layout.
122+
123+ Only numeric axes are handled; categorical/date axes are left to autorange.
124+
125+ Args:
126+ fig: A Plotly figure with animation frames (mutated in place).
127+ """
128+ import numpy as np
129+
130+ if not fig .frames :
131+ return
132+
133+ from collections import defaultdict
134+
135+ # Collect numeric y-values per axis across all traces (fig.data + frames)
136+ y_by_axis : dict [str , list [float ]] = defaultdict (list )
137+ x_by_axis : dict [str , list [float ]] = defaultdict (list )
138+
139+ for trace in _iter_all_traces (fig ):
140+ yaxis = getattr (trace , "yaxis" , None ) or "y"
141+ xaxis = getattr (trace , "xaxis" , None ) or "x"
142+
143+ y = getattr (trace , "y" , None )
144+ if y is not None :
145+ try :
146+ arr = np .asarray (y , dtype = float )
147+ finite = arr [np .isfinite (arr )]
148+ if len (finite ):
149+ y_by_axis [yaxis ].extend (finite .tolist ())
150+ except (ValueError , TypeError ):
151+ pass # Non-numeric (categorical) — skip
152+
153+ x = getattr (trace , "x" , None )
154+ if x is not None :
155+ try :
156+ arr = np .asarray (x , dtype = float )
157+ finite = arr [np .isfinite (arr )]
158+ if len (finite ):
159+ x_by_axis [xaxis ].extend (finite .tolist ())
160+ except (ValueError , TypeError ):
161+ pass
162+
163+ # Apply ranges to layout
164+ for axis_ref , values in y_by_axis .items ():
165+ if not values :
166+ continue
167+ lo , hi = min (values ), max (values )
168+ pad = (hi - lo ) * 0.05 or 1 # 5% padding
169+ layout_prop = "yaxis" if axis_ref == "y" else f"yaxis{ axis_ref [1 :]} "
170+ fig .layout [layout_prop ].range = [lo - pad , hi + pad ]
171+
172+ for axis_ref , values in x_by_axis .items ():
173+ if not values :
174+ continue
175+ lo , hi = min (values ), max (values )
176+ pad = (hi - lo ) * 0.05 or 1
177+ layout_prop = "xaxis" if axis_ref == "x" else f"xaxis{ axis_ref [1 :]} "
178+ fig .layout [layout_prop ].range = [lo - pad , hi + pad ]
179+
180+
16181def _iter_all_traces (fig : go .Figure ) -> Iterator [Any ]:
17182 """Iterate over all traces in a figure, including animation frames.
18183
@@ -194,17 +359,11 @@ def overlay(base: go.Figure, *overlays: go.Figure) -> go.Figure:
194359 _validate_compatible_structure (base , overlay )
195360 _validate_animation_compatibility (base , overlay )
196361
197- # Create new figure with base's layout
198- combined = go .Figure (layout = copy .deepcopy (base .layout ))
199-
200- # Add all traces from base
201- for trace in base .data :
202- combined .add_trace (copy .deepcopy (trace ))
203-
204- # Add all traces from overlays
362+ # Create new figure with base's layout and all traces
363+ all_traces = [copy .deepcopy (t ) for t in base .data ]
205364 for overlay in overlays :
206- for trace in overlay .data :
207- combined . add_trace ( copy .deepcopy (trace ))
365+ all_traces . extend ( copy . deepcopy ( t ) for t in overlay .data )
366+ combined = go . Figure ( data = all_traces , layout = copy .deepcopy (base . layout ))
208367
209368 # Handle animation frames
210369 if base .frames :
@@ -213,6 +372,17 @@ def overlay(base: go.Figure, *overlays: go.Figure) -> go.Figure:
213372 merged_frames = _merge_frames (base , list (overlays ), base_trace_count , overlay_trace_counts )
214373 combined .frames = merged_frames
215374
375+ # Build trace slices for legend fix
376+ source_figs = [base , * overlays ]
377+ slices : list [slice ] = []
378+ offset = 0
379+ for fig in source_figs :
380+ n = len (fig .data )
381+ slices .append (slice (offset , offset + n ))
382+ offset += n
383+
384+ _ensure_legend_visibility (combined , source_figs , slices )
385+ _fix_animation_axis_ranges (combined )
216386 return combined
217387
218388
@@ -315,19 +485,15 @@ def add_secondary_y(
315485 rightmost_x = max (x_for_y .values (), key = lambda x : int (x [1 :]) if x != "x" else 1 )
316486 rightmost_primary_y = next (y for y , x in x_for_y .items () if x == rightmost_x )
317487
318- # Create new figure with base's layout
319- combined = go .Figure (layout = copy .deepcopy (base .layout ))
320-
321- # Add all traces from base (primary y-axis)
322- for trace in base .data :
323- combined .add_trace (copy .deepcopy (trace ))
324-
325- # Add all traces from secondary, remapped to secondary y-axes
488+ # Build all traces: base (primary) + secondary (remapped to secondary y-axes)
489+ all_traces = [copy .deepcopy (t ) for t in base .data ]
326490 for trace in secondary .data :
327491 trace_copy = copy .deepcopy (trace )
328492 original_yaxis = getattr (trace_copy , "yaxis" , None ) or "y"
329493 trace_copy .yaxis = y_mapping [original_yaxis ]
330- combined .add_trace (trace_copy )
494+ all_traces .append (trace_copy )
495+
496+ combined = go .Figure (data = all_traces , layout = copy .deepcopy (base .layout ))
331497
332498 # Get the rightmost secondary y-axis name for linking
333499 rightmost_secondary_y = y_mapping [rightmost_primary_y ]
@@ -368,6 +534,14 @@ def add_secondary_y(
368534 merged_frames = _merge_secondary_y_frames (base , secondary , y_mapping )
369535 combined .frames = merged_frames
370536
537+ base_n = len (base .data )
538+ sec_n = len (secondary .data )
539+ _ensure_legend_visibility (
540+ combined ,
541+ [base , secondary ],
542+ [slice (0 , base_n ), slice (base_n , base_n + sec_n )],
543+ )
544+ _fix_animation_axis_ranges (combined )
371545 return combined
372546
373547
0 commit comments