Skip to content

Commit 729c4a1

Browse files
committed
fix the scaler
1 parent 1a0872d commit 729c4a1

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

pts/model/causal_deepar/causal_deepar_network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,8 @@ def unroll_encoder(
177177
# scale is computed on the context length last units of the past target
178178
# scale shape is (batch_size, 1, *target_shape)
179179
_, scale = self.scaler(
180-
past_target[:, self.context_length :, ...],
181-
past_observed_values[:, self.context_length :, ...],
180+
past_target[:, -self.context_length :, ...],
181+
past_observed_values[:, -self.context_length :, ...],
182182
)
183183

184184
_, control_scale = self.control_scaler(

pts/model/deepar/deepar_network.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def unroll_encoder(
160160
# scale is computed on the context length last units of the past target
161161
# scale shape is (batch_size, 1, *target_shape)
162162
_, scale = self.scaler(
163-
past_target[:, self.context_length :, ...],
164-
past_observed_values[:, self.context_length :, ...],
163+
past_target[:, -self.context_length :, ...],
164+
past_observed_values[:, -self.context_length :, ...],
165165
)
166166

167167
# (batch_size, num_features)

0 commit comments

Comments
 (0)