Skip to content

Commit 1fa624d

Browse files
committed
support multivaraite
1 parent a7afdc5 commit 1fa624d

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

orion/primitives/timesfm.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
"""
1010

1111
import timesfm as tf
12+
import numpy as np
1213

1314

1415
class TimesFM:
@@ -53,10 +54,32 @@ def predict(self, X):
5354
Args:
5455
X (ndarray):
5556
input timeseries.
57+
target (int):
58+
index of target column in multivariate case. Default to 0.
5659
Return:
5760
ndarray:
5861
forecasted timeseries.
5962
"""
6063
frequency_input = [self.freq for _ in range(len(X))]
61-
y_hat, _ = self.model.forecast(X[:, :, 0], freq=frequency_input)
62-
return y_hat[:, 0]
64+
d = X.shape[-1] #number of variables
65+
if d == 1: #univariate
66+
y_hat, _ = self.model.forecast(X[:, :, 0], freq=frequency_input)
67+
return y_hat[:, 0]
68+
69+
else: #multivariate
70+
#Extend the x_reg to future values
71+
X_reg = X[:, :, 1:d]
72+
m, n, k = X_reg.shape
73+
X_reg_new = np.zeros((m, n+1, k))
74+
X_reg_new[:, :-1, :] = X_reg
75+
X_reg_new[:-1, -1, :] = X_reg[1:, 0, :]
76+
77+
78+
x_reg = {str(i): X_reg_new[:, :, i] for i in range(k)}
79+
y_hat, _ = self.model.forecast_with_covariates(
80+
inputs=X[:, :, 0],
81+
dynamic_numerical_covariates=x_reg,
82+
freq=frequency_input,
83+
)
84+
return np.concatenate(y_hat)
85+

0 commit comments

Comments
 (0)