2525 shift_timestamp ,
2626 target_transformation_length ,
2727)
28+ from gluonts .transform .sampler import InstanceSampler
2829
2930
3031class BroadcastTo (MapTransformation ):
@@ -54,7 +55,7 @@ class TFTInstanceSplitter(InstanceSplitter):
5455 @validated ()
5556 def __init__ (
5657 self ,
57- instance_sampler ,
58+ instance_sampler : InstanceSampler ,
5859 past_length : int ,
5960 future_length : int ,
6061 target_field : str = FieldName .TARGET ,
@@ -64,29 +65,30 @@ def __init__(
6465 observed_value_field : str = FieldName .OBSERVED_VALUES ,
6566 lead_time : int = 0 ,
6667 output_NTC : bool = True ,
67- time_series_fields : Optional [ List [str ]] = None ,
68- past_time_series_fields : Optional [ List [str ]] = None ,
68+ time_series_fields : List [str ] = [] ,
69+ past_time_series_fields : List [str ] = [] ,
6970 dummy_value : float = 0.0 ,
7071 ) -> None :
7172
73+ super ().__init__ (
74+ target_field = target_field ,
75+ is_pad_field = is_pad_field ,
76+ start_field = start_field ,
77+ forecast_start_field = forecast_start_field ,
78+ instance_sampler = instance_sampler ,
79+ past_length = past_length ,
80+ future_length = future_length ,
81+ lead_time = lead_time ,
82+ output_NTC = output_NTC ,
83+ time_series_fields = time_series_fields ,
84+ dummy_value = dummy_value ,
85+ )
86+
7287 assert past_length > 0 , "The value of `past_length` should be > 0"
7388 assert future_length > 0 , "The value of `future_length` should be > 0"
7489
75- self .instance_sampler = instance_sampler
76- self .past_length = past_length
77- self .future_length = future_length
78- self .lead_time = lead_time
79- self .output_NTC = output_NTC
80- self .dummy_value = dummy_value
81-
82- self .target_field = target_field
83- self .is_pad_field = is_pad_field
84- self .start_field = start_field
85- self .forecast_start_field = forecast_start_field
8690 self .observed_value_field = observed_value_field
87-
88- self .ts_fields = time_series_fields or []
89- self .past_ts_fields = past_time_series_fields or []
91+ self .past_ts_fields = past_time_series_fields
9092
9193 def flatmap_transform (self , data : DataEntry , is_train : bool ) -> Iterator [DataEntry ]:
9294 pl = self .future_length
0 commit comments