Skip to content

Commit 9d8a6d6

Browse files
committed
Implement DeepAR V2 with improved scaling and distribution handling
1 parent ea75590 commit 9d8a6d6

File tree

7 files changed

+684
-5
lines changed

7 files changed

+684
-5
lines changed

docs/source/tutorials/ptf_V2_example.ipynb

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1763,11 +1763,13 @@
17631763
"provenance": []
17641764
},
17651765
"kernelspec": {
1766-
"display_name": "Python 3",
1766+
"display_name": ".venv",
1767+
"language": "python",
17671768
"name": "python3"
17681769
},
17691770
"language_info": {
1770-
"name": "python"
1771+
"name": "python",
1772+
"version": "3.12.3"
17711773
}
17721774
},
17731775
"nbformat": 4,

docs/source/tutorials/stallion.ipynb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
"source": [
4242
"import warnings\n",
4343
"\n",
44-
"\n",
4544
"warnings.filterwarnings(\"ignore\") # avoid printing out absolute paths"
4645
]
4746
},
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""
2+
Packages container for DeepAR model.
3+
"""
4+
5+
from pytorch_forecasting.base._base_pkg import Base_pkg
6+
7+
8+
class DeepAR_pkg_v2(Base_pkg):
9+
"""DeepAR package container."""
10+
11+
_tags = {
12+
"info:name": "DeepAR",
13+
"info:compute": 3,
14+
"authors": ["jdb78"],
15+
"capability:exogenous": True,
16+
"capability:multivariate": True,
17+
"capability:pred_int": True,
18+
"capability:flexible_history_length": True,
19+
"capability:cold_start": False,
20+
}
21+
22+
@classmethod
23+
def get_cls(cls):
24+
"""Get model class."""
25+
from pytorch_forecasting.models.deepar._deepar_v2 import DeepAR
26+
27+
return DeepAR
28+
29+
@classmethod
30+
def get_datamodule_cls(cls):
31+
"""Get the underlying DataModule class."""
32+
from pytorch_forecasting.data.data_module import (
33+
EncoderDecoderTimeSeriesDataModule,
34+
)
35+
36+
return EncoderDecoderTimeSeriesDataModule
37+
38+
@classmethod
39+
def get_base_test_params(cls):
40+
"""Return testing parameter settings for the trainer."""
41+
return [
42+
{},
43+
dict(
44+
cell_type="GRU",
45+
hidden_size=16,
46+
rnn_layers=2,
47+
),
48+
]
49+
50+
@classmethod
51+
def get_test_train_params(cls):
52+
"""Return testing parameter settings for the trainer.
53+
54+
Returns
55+
-------
56+
params : dict or list of dict, default = {}
57+
Parameters to create testing instances of the class
58+
Each dict are parameters to construct an "interesting" test instance, i.e.,
59+
`MyClass(**params)` or `MyClass(**params[i])` creates a valid test instance.
60+
`create_test_instance` uses the first (or only) dictionary in `params`
61+
"""
62+
from pytorch_forecasting.metrics import NormalDistributionLoss
63+
64+
params = [
65+
dict(
66+
loss=NormalDistributionLoss(),
67+
),
68+
dict(
69+
loss=NormalDistributionLoss(),
70+
cell_type="GRU",
71+
hidden_size=16,
72+
rnn_layers=2,
73+
),
74+
dict(
75+
loss=NormalDistributionLoss(),
76+
hidden_size=32,
77+
rnn_layers=3,
78+
dropout=0.2,
79+
),
80+
dict(
81+
loss=NormalDistributionLoss(),
82+
hidden_size=20,
83+
datamodule_cfg=dict(
84+
max_encoder_length=7,
85+
max_prediction_length=5,
86+
),
87+
),
88+
dict(
89+
loss=NormalDistributionLoss(),
90+
hidden_size=16,
91+
n_validation_samples=50,
92+
n_plotting_samples=25,
93+
),
94+
dict(
95+
loss=NormalDistributionLoss(),
96+
hidden_size=10,
97+
rnn_layers=1,
98+
dropout=0.0,
99+
datamodule_cfg=dict(
100+
max_encoder_length=3,
101+
max_prediction_length=2,
102+
),
103+
),
104+
]
105+
106+
default_dm_cfg = {
107+
"max_encoder_length": 4,
108+
"max_prediction_length": 3,
109+
}
110+
111+
for param in params:
112+
current_dm_cfg = param.get("datamodule_cfg", {})
113+
merged_dm_cfg = default_dm_cfg.copy()
114+
merged_dm_cfg.update(current_dm_cfg)
115+
param["datamodule_cfg"] = merged_dm_cfg
116+
117+
return params
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""DeepAR: Probabilistic forecasting with autoregressive recurrent networks."""
22

3+
from pytorch_forecasting.models.deepar.__deepar_pkg_v2 import DeepAR_pkg_v2
34
from pytorch_forecasting.models.deepar._deepar import DeepAR
45
from pytorch_forecasting.models.deepar._deepar_pkg import DeepAR_pkg
6+
from pytorch_forecasting.models.deepar._deepar_v2 import DeepAR as DeepAR_v2
57

6-
__all__ = ["DeepAR", "DeepAR_pkg"]
8+
__all__ = ["DeepAR", "DeepAR_v2", "DeepAR_pkg", "DeepAR_pkg_v2"]

0 commit comments

Comments
 (0)