Skip to content

Commit 36b0195

Browse files
committed
MNT fix same y_init used for different sessions
1 parent 711673e commit 36b0195

File tree

4 files changed

+742
-1124
lines changed

4 files changed

+742
-1124
lines changed

fastcan/narx.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,9 @@ def _expression(self, X, y_hat, coef, intercept, k):
544544
return y_pred
545545

546546
@staticmethod
547-
def _predict(expression, X, y_init, coef, intercept, max_delay):
547+
def _predict(expression, X, y_ref, coef, intercept, max_delay):
548548
n_samples = X.shape[0]
549+
n_ref = len(y_ref)
549550
y_hat = np.zeros(n_samples)
550551
at_init = True
551552
init_k = 0
@@ -559,7 +560,7 @@ def _predict(expression, X, y_init, coef, intercept, max_delay):
559560
at_init = False
560561

561562
if at_init:
562-
y_hat[k] = y_init[k - init_k]
563+
y_hat[k] = y_ref[k % n_ref]
563564
else:
564565
y_hat[k] = expression(X, y_hat, coef, intercept, k)
565566
if np.any(y_hat[k] > 1e20):
@@ -578,7 +579,7 @@ def _residual(
578579
coef = coef_intercept[:-1]
579580
intercept = coef_intercept[-1]
580581

581-
y_hat = NARX._predict(expression, X, y[:max_delay], coef, intercept, max_delay)
582+
y_hat = NARX._predict(expression, X, y, coef, intercept, max_delay)
582583

583584
y_masked, y_hat_masked = _mask_missing_value(y, y_hat)
584585

@@ -599,8 +600,9 @@ def predict(self, X, y_init=None):
599600
X : array-like of shape (n_samples, `n_features_in_`)
600601
Samples.
601602
602-
y_init : array-like of shape (`max_delay`,), default=None
603+
y_init : array-like of shape (`n_init`,), default=None
603604
The initial values for the prediction of y.
605+
At least have one sample.
604606
605607
Returns
606608
-------
@@ -614,11 +616,9 @@ def predict(self, X, y_init=None):
614616
y_init = np.zeros(self.max_delay_)
615617
else:
616618
y_init = column_or_1d(y_init, dtype=float)
617-
if y_init.shape[0] != self.max_delay_:
619+
if y_init.shape[0] < 1:
618620
raise ValueError(
619-
"`y_init` should have the shape of "
620-
"(`max_delay`,), i.e., "
621-
f"({self.max_delay_},), "
621+
"`y_init` should at least have one sample "
622622
f"but got {y_init.shape}."
623623
)
624624

0 commit comments

Comments
 (0)