Skip to content

Commit 456c851

Browse files
committed
support multivaraite
1 parent 1fa624d commit 456c851

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

orion/primitives/timesfm.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,30 @@ def predict(self, X):
5454
Args:
5555
X (ndarray):
5656
input timeseries.
57-
target (int):
57+
target (int):
5858
index of target column in multivariate case. Default to 0.
5959
Return:
6060
ndarray:
6161
forecasted timeseries.
6262
"""
6363
frequency_input = [self.freq for _ in range(len(X))]
64-
d = X.shape[-1] #number of variables
65-
if d == 1: #univariate
64+
d = X.shape[-1]
65+
66+
67+
#univariate
68+
if d == 1:
6669
y_hat, _ = self.model.forecast(X[:, :, 0], freq=frequency_input)
6770
return y_hat[:, 0]
68-
69-
else: #multivariate
70-
#Extend the x_reg to future values
71+
72+
73+
#multivariate
74+
else:
7175
X_reg = X[:, :, 1:d]
7276
m, n, k = X_reg.shape
7377
X_reg_new = np.zeros((m, n+1, k))
7478
X_reg_new[:, :-1, :] = X_reg
7579
X_reg_new[:-1, -1, :] = X_reg[1:, 0, :]
7680

77-
7881
x_reg = {str(i): X_reg_new[:, :, i] for i in range(k)}
7982
y_hat, _ = self.model.forecast_with_covariates(
8083
inputs=X[:, :, 0],

0 commit comments

Comments
 (0)