Skip to content

Commit 14b3592

Browse files
tibisabaujakubchlapekdennisbader
authored
handle static covariates timeseries (#2996)
* handle static covariates timeseries * update changelog * Update timeseries.py * Update darts/timeseries.py Co-authored-by: Jakub Chłapek <147340544+jakubchlapek@users.noreply.github.com> * minor updates --------- Co-authored-by: Jakub Chłapek <147340544+jakubchlapek@users.noreply.github.com> Co-authored-by: dennisbader <dennis.bader@gmx.ch>
1 parent bc4d747 commit 14b3592

File tree

3 files changed

+194
-8
lines changed

3 files changed

+194
-8
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
1111

1212
**Improved**
1313

14+
- `TimeSeries.to_json()` and `from_json()` now support serialization and deserialization of static covariates, metadata, and hierarchy. The optional parameters in `from_json()` can still be used to override or provide these values if they are not present in the JSON string. [#2996](https://github.com/unit8co/darts/pull/2996) by [Tiberiu Sabau](https://github.com/tibisabau).
1415
- Added new time aggregated metric `autc()` (Area Under Tolerance Curve): The tolerance curve gives the fraction of predicted target values within tolerance bands of the actual target values across a range of tolerances (defined as % of target range). The AUTC is the normalized area under this tolerance curve and given as a score between [0, 1]. Higher scores are better. [#2994](https://github.com/unit8co/darts/pull/2994) by [Jakub Chłapek](https://github.com/jakubchlapek)
1516
- Added new plotting function `darts.utils.statistics.plot_tolerance_curve()` to plot the tolerance curve described above. [#2994](https://github.com/unit8co/darts/pull/2994) by [Jakub Chłapek](https://github.com/jakubchlapek)
1617
- Added `TimeSeries.plotly()` method for interactive time series visualization using Plotly backend. [#2977](https://github.com/unit8co/darts/pull/2977) by [Dustin Brunner](https://github.com/brunnedu).

darts/tests/test_timeseries_static_covariates.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,10 @@ def test_ts_from_x(self, tag, tmpdir_module):
130130
tag, ts, TimeSeries.from_csv(f_csv, time_col="time", **kwargs)
131131
)
132132
self.helper_test_transfer(tag, ts, TimeSeries.from_pickle(f_pkl))
133+
# Test with kwargs (backward compatibility)
133134
self.helper_test_transfer(tag, ts, TimeSeries.from_json(ts_json, **kwargs))
135+
# Test without kwargs (new automatic serialization)
136+
self.helper_test_transfer(tag, ts, TimeSeries.from_json(ts_json))
134137

135138
def test_invalid_metadata(self):
136139
ts = linear_timeseries(length=10)
@@ -965,3 +968,146 @@ def helper_test_transfer_values(tag, ts, ts_new):
965968
)
966969
else: # metadata
967970
assert ts_new.metadata == ts.metadata
971+
972+
973+
class TestTimeSeriesJSONSerialization:
974+
"""Test JSON serialization with static_covariates, metadata, and hierarchy."""
975+
976+
def test_json_with_static_covariates(self):
977+
"""Test that static covariates are preserved in JSON serialization."""
978+
ts = linear_timeseries(length=10)
979+
static_cov = pd.Series([0.0, 1.0], index=["st1", "st2"])
980+
ts = ts.with_static_covariates(static_cov)
981+
982+
# Serialize and deserialize
983+
json_str = ts.to_json()
984+
ts_restored = TimeSeries.from_json(json_str)
985+
986+
# Check that static covariates are preserved
987+
assert ts_restored.static_covariates is not None
988+
assert ts_restored.static_covariates.equals(ts.static_covariates)
989+
990+
def test_json_with_metadata(self):
991+
"""Test that metadata is preserved in JSON serialization."""
992+
ts = linear_timeseries(length=10)
993+
metadata = {"key1": "value1", "key2": 42, "key3": [1, 2, 3]}
994+
ts = ts.with_metadata(metadata)
995+
996+
# Serialize and deserialize
997+
json_str = ts.to_json()
998+
ts_restored = TimeSeries.from_json(json_str)
999+
1000+
# Check that metadata is preserved
1001+
assert ts_restored.metadata is not None
1002+
assert ts_restored.metadata == ts.metadata
1003+
1004+
def test_json_with_hierarchy(self):
1005+
"""Test that hierarchy is preserved in JSON serialization."""
1006+
components = ["total", "a", "b", "ax", "bx"]
1007+
hierarchy = {
1008+
"ax": ["a"],
1009+
"bx": ["b"],
1010+
"a": ["total"],
1011+
"b": ["total"],
1012+
}
1013+
ts = TimeSeries.from_values(
1014+
values=np.random.rand(10, len(components)),
1015+
columns=components,
1016+
hierarchy=hierarchy,
1017+
)
1018+
1019+
# Serialize and deserialize
1020+
json_str = ts.to_json()
1021+
ts_restored = TimeSeries.from_json(json_str)
1022+
1023+
# Check that hierarchy is preserved
1024+
assert ts_restored.hierarchy is not None
1025+
assert ts_restored.hierarchy == ts.hierarchy
1026+
assert ts_restored.top_level_component == ts.top_level_component
1027+
assert ts_restored.bottom_level_components == ts.bottom_level_components
1028+
1029+
def test_json_with_all_attributes(self):
1030+
"""Test JSON serialization with all attributes (static_covariates, metadata, hierarchy)."""
1031+
components = ["total", "a", "b"]
1032+
hierarchy = {"a": ["total"], "b": ["total"]}
1033+
static_cov = pd.DataFrame(
1034+
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
1035+
index=components,
1036+
columns=["sc1", "sc2"],
1037+
)
1038+
metadata = {"description": "test series", "version": 1}
1039+
1040+
ts = TimeSeries.from_values(
1041+
values=np.random.rand(10, len(components)),
1042+
columns=components,
1043+
hierarchy=hierarchy,
1044+
static_covariates=static_cov,
1045+
metadata=metadata,
1046+
)
1047+
1048+
# Serialize and deserialize
1049+
json_str = ts.to_json()
1050+
ts_restored = TimeSeries.from_json(json_str)
1051+
1052+
# Check all attributes are preserved
1053+
assert ts_restored.static_covariates is not None
1054+
assert ts_restored.static_covariates.equals(ts.static_covariates)
1055+
assert ts_restored.metadata == ts.metadata
1056+
assert ts_restored.hierarchy == ts.hierarchy
1057+
assert ts_restored.top_level_component == ts.top_level_component
1058+
assert ts_restored.bottom_level_components == ts.bottom_level_components
1059+
1060+
def test_json_override_attributes(self):
1061+
"""Test that from_json parameters can override JSON-embedded attributes."""
1062+
components = ["total", "a", "b"]
1063+
hierarchy = {"a": ["total"], "b": ["total"]}
1064+
static_cov = pd.DataFrame(
1065+
[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
1066+
index=components,
1067+
columns=["sc1", "sc2"],
1068+
)
1069+
metadata = {"description": "test series", "version": 1}
1070+
1071+
ts = TimeSeries.from_values(
1072+
values=np.random.rand(10, len(components)),
1073+
columns=components,
1074+
hierarchy=hierarchy,
1075+
static_covariates=static_cov,
1076+
metadata=metadata,
1077+
)
1078+
json_str = ts.to_json()
1079+
1080+
# Override with different values
1081+
new_static_cov = pd.DataFrame(
1082+
[[10.0, 11.0], [12.0, 13.0], [14.0, 5.0]],
1083+
index=components,
1084+
columns=["sc_new1", "sc_new2"],
1085+
)
1086+
new_metadata = {"key2": "value2"}
1087+
new_hierarchy = {"total": ["b"], "b": ["a"]}
1088+
ts_restored = TimeSeries.from_json(
1089+
json_str,
1090+
static_covariates=new_static_cov,
1091+
metadata=new_metadata,
1092+
hierarchy=new_hierarchy,
1093+
)
1094+
1095+
# Check that overrides worked
1096+
# When a Series is passed, it becomes a single-row DataFrame with the series index as columns
1097+
assert ts_restored.static_covariates is not None
1098+
assert ts_restored.static_covariates.equals(new_static_cov)
1099+
assert ts_restored.metadata == new_metadata
1100+
assert ts_restored.hierarchy == new_hierarchy
1101+
1102+
def test_json_without_attributes(self):
1103+
"""Test JSON serialization for series without optional attributes."""
1104+
ts = linear_timeseries(length=10)
1105+
1106+
# Serialize and deserialize
1107+
json_str = ts.to_json()
1108+
ts_restored = TimeSeries.from_json(json_str)
1109+
1110+
# Check that optional attributes are None
1111+
assert ts_restored.static_covariates is None
1112+
assert ts_restored.metadata is None
1113+
assert ts_restored.hierarchy is None

darts/timeseries.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
"""
4141

4242
import itertools
43+
import json
4344
import math
4445
import pickle
4546
import re
@@ -1450,6 +1451,10 @@ def from_json(
14501451
14511452
At the moment this only supports deterministic time series (i.e., made of 1 sample).
14521453
1454+
If the JSON string contains static covariates, hierarchy, or metadata, they will be automatically
1455+
loaded. The optional parameters `static_covariates`, `hierarchy`, and `metadata` can be used to
1456+
override or provide these values if they are not present in the JSON string.
1457+
14531458
Parameters
14541459
----------
14551460
json_str
@@ -1462,6 +1467,7 @@ def from_json(
14621467
are globally 'applied' to all components of the TimeSeries. If a multi-row DataFrame, the number of
14631468
rows must match the number of components of the TimeSeries (in this case, the number of columns in
14641469
``value_cols``). This adds control for component-specific static covariates.
1470+
If the JSON string already contains static covariates, this parameter will override them.
14651471
hierarchy
14661472
Optionally, a dictionary describing the grouping(s) of the time series. The keys are component names, and
14671473
for a given component name `c`, the value is a list of component names that `c` "belongs" to. For instance,
@@ -1487,8 +1493,10 @@ def from_json(
14871493
The hierarchy can be used to reconcile forecasts (so that the sums of the forecasts at
14881494
different levels are consistent), see `hierarchical reconciliation
14891495
<https://unit8co.github.io/darts/generated_api/darts.dataprocessing.transformers.reconciliation.html>`__.
1496+
If the JSON string already contains a hierarchy, this parameter will override it.
14901497
metadata
14911498
Optionally, a dictionary with metadata to be added to the TimeSeries.
1499+
If the JSON string already contains metadata, this parameter will override it.
14921500
14931501
Returns
14941502
-------
@@ -1501,16 +1509,33 @@ def from_json(
15011509
>>> json_str = (
15021510
>>> '{"columns":["vals"],"index":["2020-01-01","2020-01-02","2020-01-03"],"data":[[0.0],[1.0],[2.0]]}'
15031511
>>> )
1504-
>>> series = TimeSeries.from_json("data.csv")
1512+
>>> series = TimeSeries.from_json(json_str)
15051513
>>> series.shape
15061514
(3, 1, 1)
15071515
"""
1516+
parsed = json.loads(json_str)
1517+
1518+
static_covariates_ = parsed.pop("static_covariates", None)
1519+
if static_covariates_ is not None and static_covariates is None:
1520+
static_covariates = pd.read_json(
1521+
StringIO(json.dumps(static_covariates_)), orient="split"
1522+
)
1523+
1524+
hierarchy_ = parsed.pop("hierarchy", None)
1525+
if hierarchy is None:
1526+
hierarchy = hierarchy_
1527+
1528+
metadata_ = parsed.pop("metadata", None)
1529+
if metadata is None:
1530+
metadata = metadata_
1531+
1532+
df = pd.read_json(StringIO(json.dumps(parsed)), orient="split")
15081533
return cls.from_dataframe(
1509-
df=pd.read_json(StringIO(json_str), orient="split"),
1534+
df=df,
15101535
static_covariates=static_covariates,
15111536
hierarchy=hierarchy,
15121537
metadata=metadata,
1513-
copy=False, # JSON is immutable, so no need to copy
1538+
copy=False,
15141539
)
15151540

15161541
@classmethod
@@ -4271,17 +4296,31 @@ def to_json(self) -> str:
42714296
42724297
At the moment this function works only on deterministic time series (i.e., made of 1 sample).
42734298
4274-
Notes
4275-
-----
4276-
Static covariates are not returned in the JSON string. When using `TimeSeries.from_json()`, the static
4277-
covariates can be added with input argument `static_covariates`.
4299+
The JSON string includes the series values, time index, component names, as well as static covariates,
4300+
hierarchy, and metadata (if any).
42784301
42794302
Returns
42804303
-------
42814304
str
42824305
A JSON String representing the series
4306+
4307+
See Also
4308+
--------
4309+
TimeSeries.from_json : Create a TimeSeries from a JSON string.
42834310
"""
4284-
return self.to_dataframe().to_json(orient="split", date_format="iso")
4311+
result = json.loads(
4312+
self.to_dataframe().to_json(orient="split", date_format="iso")
4313+
)
4314+
if self.static_covariates is not None:
4315+
result["static_covariates"] = json.loads(
4316+
self.static_covariates.to_json(orient="split")
4317+
)
4318+
if self.hierarchy is not None:
4319+
result["hierarchy"] = self.hierarchy
4320+
if self.metadata is not None:
4321+
result["metadata"] = self.metadata
4322+
4323+
return json.dumps(result)
42854324

42864325
def to_csv(self, *args, **kwargs):
42874326
"""Write the deterministic series to a CSV file.

0 commit comments

Comments
 (0)