Skip to content

Commit cc7dea9

Browse files
committed
Fixed TFT transform
1 parent 1a94965 commit cc7dea9

File tree

3 files changed

+28
-19
lines changed

3 files changed

+28
-19
lines changed

pts/model/estimator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from torch.utils import data
99
from torch.utils.data import DataLoader
1010

11+
from gluonts.env import env
1112
from gluonts.core.component import validated
1213
from gluonts.dataset.common import Dataset
1314
from gluonts.model.estimator import Estimator
1415
from gluonts.torch.model.predictor import PyTorchPredictor
1516
from gluonts.transform import SelectFields, Transformation
17+
from gluonts.support.util import maybe_len
1618

1719
from pts import Trainer
1820
from pts.model import get_module_forward_input_names
@@ -101,7 +103,9 @@ def train_model(
101103
trained_net = self.create_training_network(self.trainer.device)
102104

103105
input_names = get_module_forward_input_names(trained_net)
104-
training_instance_splitter = self.create_instance_splitter("training")
106+
107+
with env._let(max_idle_transforms=maybe_len(training_data) or 0):
108+
training_instance_splitter = self.create_instance_splitter("training")
105109
training_iter_dataset = TransformedIterableDataset(
106110
dataset=training_data,
107111
transform=transformation
@@ -124,7 +128,8 @@ def train_model(
124128

125129
validation_data_loader = None
126130
if validation_data is not None:
127-
validation_instance_splitter = self.create_instance_splitter("validation")
131+
with env._let(max_idle_transforms=maybe_len(validation_data) or 0):
132+
validation_instance_splitter = self.create_instance_splitter("validation")
128133
validation_iter_dataset = TransformedIterableDataset(
129134
dataset=validation_data,
130135
transform=transformation

pts/model/tft/tft_estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import torch
6+
67
from gluonts.core.component import validated
78
from gluonts.dataset.field_names import FieldName
89
from gluonts.model.forecast_generator import QuantileForecastGenerator
@@ -30,6 +31,7 @@
3031
from pts import Trainer
3132
from pts.model import PyTorchEstimator
3233
from pts.model.utils import get_module_forward_input_names
34+
3335
from .tft_network import (
3436
TemporalFusionTransformerPredictionNetwork,
3537
TemporalFusionTransformerTrainingNetwork,

pts/model/tft/tft_transform.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
shift_timestamp,
2626
target_transformation_length,
2727
)
28+
from gluonts.transform.sampler import InstanceSampler
2829

2930

3031
class BroadcastTo(MapTransformation):
@@ -54,7 +55,7 @@ class TFTInstanceSplitter(InstanceSplitter):
5455
@validated()
5556
def __init__(
5657
self,
57-
instance_sampler,
58+
instance_sampler: InstanceSampler,
5859
past_length: int,
5960
future_length: int,
6061
target_field: str = FieldName.TARGET,
@@ -64,29 +65,30 @@ def __init__(
6465
observed_value_field: str = FieldName.OBSERVED_VALUES,
6566
lead_time: int = 0,
6667
output_NTC: bool = True,
67-
time_series_fields: Optional[List[str]] = None,
68-
past_time_series_fields: Optional[List[str]] = None,
68+
time_series_fields: List[str] = [],
69+
past_time_series_fields: List[str] = [],
6970
dummy_value: float = 0.0,
7071
) -> None:
7172

73+
super().__init__(
74+
target_field=target_field,
75+
is_pad_field=is_pad_field,
76+
start_field=start_field,
77+
forecast_start_field=forecast_start_field,
78+
instance_sampler=instance_sampler,
79+
past_length=past_length,
80+
future_length=future_length,
81+
lead_time=lead_time,
82+
output_NTC=output_NTC,
83+
time_series_fields=time_series_fields,
84+
dummy_value=dummy_value,
85+
)
86+
7287
assert past_length > 0, "The value of `past_length` should be > 0"
7388
assert future_length > 0, "The value of `future_length` should be > 0"
7489

75-
self.instance_sampler = instance_sampler
76-
self.past_length = past_length
77-
self.future_length = future_length
78-
self.lead_time = lead_time
79-
self.output_NTC = output_NTC
80-
self.dummy_value = dummy_value
81-
82-
self.target_field = target_field
83-
self.is_pad_field = is_pad_field
84-
self.start_field = start_field
85-
self.forecast_start_field = forecast_start_field
8690
self.observed_value_field = observed_value_field
87-
88-
self.ts_fields = time_series_fields or []
89-
self.past_ts_fields = past_time_series_fields or []
91+
self.past_ts_fields = past_time_series_fields
9092

9193
def flatmap_transform(self, data: DataEntry, is_train: bool) -> Iterator[DataEntry]:
9294
pl = self.future_length

0 commit comments

Comments
 (0)