Skip to content

Commit e4b61ac

Browse files
authored
Feat/plotly support (#2985)
1 parent 892c16d commit e4b61ac

File tree

9 files changed

+1366
-226
lines changed

9 files changed

+1366
-226
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1111

1212
**Improved**
1313

14+
- Added `TimeSeries.plotly()` method for interactive time series visualization using Plotly backend. [#2977](https://github.com/unit8co/darts/pull/2977) by [Dustin Brunner](https://github.com/brunnedu).
15+
- Provides interactive plotting with zoom, pan, hover tooltips, and legend interactions
16+
- Maintains API consistency with the existing `plot()` method for easy adoption
17+
- Supports deterministic and stochastic, univariate and multivariate series
18+
- Allows overlaying multiple series on the same figure via the `fig` parameter
19+
- Customizable trace styling via `**kwargs`
20+
- Includes automatic downsampling for large series (configurable via `downsample_threshold` parameter) to avoid crashes when plotting large series
21+
- Integrates seamlessly with `plotting.use_darts_style` which now affects both `TimeSeries.plot()` and `TimeSeries.plotly()`
22+
- Plotly remains an optional dependency and can be installed with `pip install plotly`
23+
1424
**Fixed**
1525

1626
- Fixed bug in `StaticCovariatesTransformer` where one-hot encoded column names were incorrectly assigned when the order of columns specified in `cols_cat` differed from the actual data column order. This caused silent data corruption where column names combined wrong feature names with wrong category values (e.g., `City_US` instead of `Country_US`). [#2989](https://github.com/unit8co/darts/pull/2989) by [Dustin Brunner](https://github.com/brunnedu).

darts/config.py

Lines changed: 113 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
**Plotting Options**
2323
2424
- ``plotting.use_darts_style`` : bool (default: False)
25-
Whether to apply Darts' custom matplotlib plotting style. When True, Darts will configure
26-
matplotlib with a custom style optimized for time series visualization. When False, matplotlib's
27-
default or user-configured style will be used. Changes to this option take effect immediately.
25+
Whether to apply Darts' custom plotting style to both matplotlib and plotly. When True, Darts will
26+
configure both backends with a custom style optimized for time series visualization. When False,
27+
the default or user-configured styles will be used. Changes to this option take effect immediately.
2828
2929
Examples
3030
========
@@ -50,6 +50,17 @@
5050

5151
logger = get_logger(__name__)
5252

53+
# Darts color palette used for both matplotlib and plotly plotting
54+
_DARTS_COLORS = [
55+
"#000000",
56+
"#003dfd",
57+
"#b512b8",
58+
"#11a9ba",
59+
"#0d780f",
60+
"#f77f07",
61+
"#ba0f0f",
62+
]
63+
5364

5465
class _Option:
5566
"""Internal class representing a single configuration option."""
@@ -117,10 +128,9 @@ def __init__(self):
117128
plotting_use_darts_style = _Option(
118129
key="plotting.use_darts_style",
119130
default_value=False,
120-
description="Whether to apply Darts' custom matplotlib plotting style. "
121-
"When True, Darts will configure matplotlib with a custom style optimized for "
122-
"time series visualization. When False, matplotlib's default or user-configured "
123-
"style will be used.",
131+
description="Whether to apply Darts' custom plotting style to both matplotlib and plotly. "
132+
"When True, Darts will configure both backends with a custom style optimized for "
133+
"time series visualization. When False, the default or user-configured styles will be used.",
124134
validator=self._validate_bool,
125135
callback=self._on_plotting_style_change,
126136
)
@@ -134,7 +144,10 @@ def __init__(self):
134144
]
135145
}
136146
# remember if user applied Darts style
137-
self._mpl_style_applied = False
147+
self._darts_plotting_style_applied = False
148+
# store original templates to restore later
149+
self._original_plotly_template = None
150+
self._original_mpl_params = None
138151

139152
@staticmethod
140153
def _validate_positive_int(value: Any):
@@ -150,43 +163,101 @@ def _validate_bool(value: Any):
150163

151164
def _on_plotting_style_change(self, value: bool) -> None:
152165
"""Callback for when plotting.use_darts_style changes."""
166+
# matplotlib
153167
import matplotlib as mpl
154168
from matplotlib import cycler
155169

156-
if not value and self._mpl_style_applied:
157-
# restore default options
158-
mpl.rcParams.update(mpl.rcParamsDefault)
159-
self._mpl_style_applied = False
160-
return
161-
162-
# apply Darts plotting style
163-
colors = cycler(
164-
color=["black", "003DFD", "b512b8", "11a9ba", "0d780f", "f77f07", "ba0f0f"]
165-
)
166-
167-
u8plots_mplstyle = {
168-
"font.family": "sans serif",
169-
"axes.edgecolor": "black",
170-
"axes.grid": True,
171-
"axes.labelcolor": "#333333",
172-
"axes.labelweight": 600,
173-
"axes.linewidth": 1,
174-
"axes.prop_cycle": colors,
175-
"axes.spines.top": False,
176-
"axes.spines.right": False,
177-
"axes.spines.bottom": False,
178-
"axes.spines.left": False,
179-
"grid.color": "#dedede",
180-
"legend.frameon": False,
181-
"lines.linewidth": 1.3,
182-
"xtick.color": "#333333",
183-
"xtick.labelsize": "small",
184-
"ytick.color": "#333333",
185-
"ytick.labelsize": "small",
186-
"xtick.bottom": False,
187-
}
188-
mpl.rcParams.update(u8plots_mplstyle)
189-
self._mpl_style_applied = True
170+
if value:
171+
# store current matplotlib params before applying darts style
172+
if not self._darts_plotting_style_applied:
173+
self._original_mpl_params = mpl.rcParams.copy()
174+
175+
# apply Darts plotting style to matplotlib
176+
colors = cycler(color=_DARTS_COLORS)
177+
u8plots_mplstyle = {
178+
"font.family": "sans serif",
179+
"axes.edgecolor": "black",
180+
"axes.grid": True,
181+
"axes.labelcolor": "#333333",
182+
"axes.labelweight": 600,
183+
"axes.linewidth": 1,
184+
"axes.prop_cycle": colors,
185+
"axes.spines.top": False,
186+
"axes.spines.right": False,
187+
"axes.spines.bottom": False,
188+
"axes.spines.left": False,
189+
"grid.color": "#dedede",
190+
"legend.frameon": False,
191+
"lines.linewidth": 1.3,
192+
"xtick.color": "#333333",
193+
"xtick.labelsize": "small",
194+
"ytick.color": "#333333",
195+
"ytick.labelsize": "small",
196+
"xtick.bottom": False,
197+
}
198+
mpl.rcParams.update(u8plots_mplstyle)
199+
else:
200+
# restore previous matplotlib options
201+
if self._original_mpl_params is not None:
202+
mpl.rcParams.update(self._original_mpl_params)
203+
self._original_mpl_params = None
204+
205+
# plotly
206+
try:
207+
import plotly.graph_objects as go
208+
import plotly.io as pio
209+
210+
if value:
211+
# store existing default to restore later
212+
if not self._darts_plotting_style_applied:
213+
self._original_plotly_template = pio.templates.default
214+
215+
# assign the darts plotly template directly
216+
pio.templates.default = go.layout.Template(
217+
layout=go.Layout(
218+
font=dict(family="Arial, sans-serif", size=14, color="black"),
219+
paper_bgcolor="white",
220+
plot_bgcolor="white",
221+
colorway=_DARTS_COLORS,
222+
showlegend=True,
223+
legend=dict(
224+
bgcolor="rgba(255, 255, 255, 0.8)",
225+
x=1,
226+
y=1,
227+
yanchor="top",
228+
xanchor="right",
229+
font=dict(size=14),
230+
borderwidth=0,
231+
),
232+
xaxis=dict(
233+
showline=True,
234+
linecolor="#dedede",
235+
showgrid=False,
236+
title=dict(font=dict(size=16, color="black")),
237+
),
238+
yaxis=dict(
239+
showline=False,
240+
showgrid=True,
241+
gridcolor="#dedede",
242+
gridwidth=1,
243+
zeroline=True,
244+
zerolinecolor="#dedede",
245+
),
246+
margin=dict(l=50, r=50, t=50, b=50),
247+
),
248+
data=dict(scatter=[go.Scatter(line=dict(width=3))]),
249+
)
250+
else:
251+
# restore the previous default
252+
if self._original_plotly_template is not None:
253+
pio.templates.default = self._original_plotly_template
254+
self._original_plotly_template = None
255+
except ImportError:
256+
# plotly not available, skip plotly configuration
257+
pass
258+
259+
# update the state tracker
260+
self._darts_plotting_style_applied = value
190261

191262
def _find_option(self, pattern: str, check_unique: bool = False) -> list[_Option]:
192263
"""Find options matching a pattern (supports both exact match and prefix match)."""

darts/tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,14 @@
103103
logger.warning("Polars not installed - Some tests will be skipped.")
104104
POLARS_AVAILABLE = False
105105

106+
try:
107+
import plotly # noqa: F401
108+
109+
PLOTLY_AVAILABLE = True
110+
except ImportError:
111+
logger.warning("Plotly not installed - Some tests will be skipped.")
112+
PLOTLY_AVAILABLE = False
113+
106114
tfm_kwargs: dict[str, Any] = {
107115
"pl_trainer_kwargs": {
108116
"accelerator": "cpu",

darts/tests/test_config.py

Lines changed: 47 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
reset_option,
1313
set_option,
1414
)
15+
from darts.tests.conftest import PLOTLY_AVAILABLE
16+
17+
if PLOTLY_AVAILABLE:
18+
import plotly.io as pio
1519

1620

1721
@pytest.fixture(scope="function", autouse=True)
@@ -233,8 +237,48 @@ def test_plotting_no_user_style_override_on_reset(self):
233237
reset_option("plotting.use_darts_style")
234238
assert mpl.rcParams == user_style
235239

236-
# Setting Darts style will override this behavior
240+
# Setting Darts style and resetting should restore user's style
237241
set_option("plotting.use_darts_style", True)
238242
reset_option("plotting.use_darts_style")
239-
assert mpl.rcParams != user_style
240-
assert mpl.rcParams == mpl.rcParamsDefault
243+
assert mpl.rcParams == user_style
244+
245+
@pytest.mark.skipif(not PLOTLY_AVAILABLE, reason="requires plotly")
246+
def test_plotting_style_callback_plotly(self):
247+
"""Test that changing plotting.use_darts_style updates plotly template."""
248+
# Store original template
249+
original_template = pio.templates.default
250+
251+
# Apply Darts style
252+
set_option("plotting.use_darts_style", True)
253+
assert pio.templates.default != original_template
254+
darts_template = pio.templates.default
255+
256+
# Remove Darts style
257+
set_option("plotting.use_darts_style", False)
258+
assert pio.templates.default == original_template
259+
260+
# Re-apply Darts style
261+
set_option("plotting.use_darts_style", True)
262+
assert pio.templates.default == darts_template
263+
264+
# Reset style
265+
reset_option("plotting.use_darts_style")
266+
assert pio.templates.default == original_template
267+
268+
@pytest.mark.skipif(not PLOTLY_AVAILABLE, reason="requires plotly")
269+
def test_plotting_no_user_template_override_on_reset_plotly(self):
270+
"""Test that resetting plotting.use_darts_style preserves user's plotly template."""
271+
# Set a custom template
272+
pio.templates.default = "plotly_dark"
273+
user_template = pio.templates.default
274+
assert user_template == "plotly_dark"
275+
276+
# Resetting style should not override user's template
277+
reset_option("plotting.use_darts_style")
278+
assert pio.templates.default == user_template
279+
280+
# Setting Darts style and resetting should restore user's template
281+
set_option("plotting.use_darts_style", True)
282+
assert pio.templates.default != user_template
283+
reset_option("plotting.use_darts_style")
284+
assert pio.templates.default == user_template

darts/tests/test_timeseries_plot.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -211,17 +211,6 @@ def test_plot_stochastic_params(self, mock_show, config):
211211
plt.show()
212212
plt.close()
213213

214-
@patch("matplotlib.pyplot.show")
215-
@pytest.mark.parametrize("config", ["dt", "ri"])
216-
def test_plot_multiple_series(self, mock_show, config):
217-
index_type = config
218-
series1 = getattr(self, f"series_{index_type}_d")
219-
series2 = getattr(self, f"series_{index_type}_p")
220-
series1.plot()
221-
series2.plot()
222-
plt.show()
223-
plt.close()
224-
225214
@patch("matplotlib.pyplot.show")
226215
@pytest.mark.parametrize("config", ["dt", "ri"])
227216
def test_plot_deterministic_and_stochastic(self, mock_show, config):

0 commit comments

Comments
 (0)