Skip to content

Commit 65d7f3f

Browse files
FEAT add auto-regressive model in NARX (#121)
1 parent ec7594f commit 65d7f3f

File tree

2 files changed

+56
-5
lines changed

2 files changed

+56
-5
lines changed

fastcan/narx.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,7 @@ def fit(self, X, y, sample_weight=None, coef_init=None, **params):
669669
670670
Parameters
671671
----------
672-
X : {array-like, sparse matrix} of shape (n_samples, `n_features_in_`)
672+
X : {array-like, sparse matrix} of shape (n_samples, `n_features_in_`) or None
673673
Training data.
674674
675675
y : array-like of shape (n_samples,) or (n_samples, `n_outputs_`)
@@ -700,7 +700,9 @@ def fit(self, X, y, sample_weight=None, coef_init=None, **params):
700700
self : object
701701
Fitted Estimator.
702702
"""
703-
check_X_params = dict(dtype=float, order="C", ensure_all_finite="allow-nan")
703+
check_X_params = dict(
704+
dtype=float, order="C", ensure_all_finite="allow-nan", ensure_min_features=0
705+
)
704706
check_y_params = dict(
705707
ensure_2d=False, dtype=float, order="C", ensure_all_finite="allow-nan"
706708
)
@@ -717,7 +719,10 @@ def fit(self, X, y, sample_weight=None, coef_init=None, **params):
717719
n_samples, n_features = X.shape
718720

719721
if self.feat_ids is None:
720-
feat_ids_ = make_poly_ids(n_features, 1) - 1
722+
if n_features == 0:
723+
feat_ids_ = make_poly_ids(self.n_outputs_, 1) - 1
724+
else:
725+
feat_ids_ = make_poly_ids(n_features, 1) - 1
721726
else:
722727
feat_ids_ = self.feat_ids
723728

@@ -1152,6 +1157,7 @@ def predict(self, X, y_init=None):
11521157
order="C",
11531158
reset=False,
11541159
ensure_all_finite="allow-nan",
1160+
ensure_min_features=0,
11551161
)
11561162
if y_init is None:
11571163
y_init = np.zeros((self.max_delay_, self.n_outputs_))
@@ -1419,7 +1425,13 @@ def make_narx(
14191425
| 0 | X[k-1,0]*X[k-3,0] | 2.000 |
14201426
| 0 | X[k-2,0]*X[k,1] | 1.528 |
14211427
"""
1422-
X = check_array(X, dtype=float, ensure_2d=True, ensure_all_finite="allow-nan")
1428+
X = check_array(
1429+
X,
1430+
dtype=float,
1431+
ensure_2d=True,
1432+
ensure_all_finite="allow-nan",
1433+
ensure_min_features=0,
1434+
)
14231435
y = check_array(y, dtype=float, ensure_2d=False, ensure_all_finite="allow-nan")
14241436
check_consistent_length(X, y)
14251437
if y.ndim == 1:

tests/test_narx.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,12 @@
2121

2222

2323
def test_narx_is_sklearn_estimator():
24+
# Skip 0 feature check for NARX, as AR models have no features
25+
expected_failures = {
26+
"check_estimators_empty_data_messages": ("NARX can handle 0 feature."),
27+
}
2428
with pytest.warns(UserWarning, match="output_ids got"):
25-
check_estimator(NARX())
29+
check_estimator(NARX(), expected_failed_checks=expected_failures)
2630

2731

2832
def test_poly_ids():
@@ -576,3 +580,38 @@ def test_nan_split(max_delay):
576580
assert poly_terms_masked.shape[0] == n_sessions * (
577581
n_samples_per_session - narx.max_delay_
578582
)
583+
584+
585+
def test_default_narx_handles_zero_features():
586+
"""Check that default NARX handles X with 0 features without error."""
587+
X = np.empty((10, 0))
588+
y = np.random.rand(10, 1)
589+
NARX().fit(X, y)
590+
591+
592+
def test_auto_reg():
593+
"""Test auto-regression with NARX"""
594+
rng = np.random.default_rng(12345)
595+
n_samples = 100
596+
max_delay = 2
597+
e0 = rng.normal(0, 0.01, n_samples)
598+
e1 = rng.normal(0, 0.01, n_samples)
599+
y0 = np.ones(n_samples + max_delay)
600+
y1 = np.ones(n_samples + max_delay)
601+
for i in range(max_delay, n_samples + max_delay):
602+
y0[i] = 0.5 * y0[i - 1] + 0.8 * y1[i - 1] + 1
603+
y1[i] = 0.6 * y1[i - 1] - 0.2 * y0[i - 1] * y1[i - 2] + 0.5
604+
y = np.c_[y0[max_delay:] + e0, y1[max_delay:] + e1]
605+
X = np.empty((n_samples, 0)) # No features, only auto-regression
606+
607+
model = make_narx(
608+
X,
609+
y,
610+
n_terms_to_select=2,
611+
max_delay=max_delay,
612+
poly_degree=2,
613+
verbose=0,
614+
)
615+
model.fit(X, y)
616+
y_pred = model.predict(X, y_init=y[: model.max_delay_])
617+
assert r2_score(y, model.predict(X, y_init=y)) > 0.5

0 commit comments

Comments
 (0)