diff --git a/examples/plot_narx_msa.py b/examples/plot_narx_msa.py index 737de11..de1c852 100644 --- a/examples/plot_narx_msa.py +++ b/examples/plot_narx_msa.py @@ -158,8 +158,8 @@ def plot_prediction(ax, t, y_true, y_pred, title): # and test data are from different measurement sessions. The plot shows that the # prediction performance of the NARX on test data has been largely improved. -u_all = np.r_[u_train, [[np.nan]]*max_delay, u_test] -y_all = np.r_[y_train, [np.nan]*max_delay, y_test] +u_all = np.r_[u_train, [[np.nan]] * max_delay, u_test] +y_all = np.r_[y_train, [np.nan] * max_delay, y_test] narx_model = make_narx( X=u_all, y=y_all, diff --git a/tests/test_narx.py b/tests/test_narx.py index 26e827f..38bca0c 100644 --- a/tests/test_narx.py +++ b/tests/test_narx.py @@ -10,7 +10,9 @@ NARX, fd2tp, make_narx, + make_poly_features, make_poly_ids, + make_time_shift_features, make_time_shift_ids, print_narx, tp2fd, @@ -534,3 +536,43 @@ def test_make_narx_refine_print(capsys): ) captured = capsys.readouterr() assert "No. of iterations: " in captured.out + + +@pytest.mark.parametrize("max_delay", [1, 3, 7, 10]) +def test_nan_split(max_delay): + n_sessions = 10 + n_samples_per_session = 100 + X = np.random.rand(n_samples_per_session, 2) + y = np.random.rand(n_samples_per_session, 2) + for _ in range(n_sessions - 1): + X = np.r_[ + X, + [[np.nan, np.nan]] * max_delay, + np.random.rand(n_samples_per_session, 2), + ] + y = np.r_[ + y, + [[np.nan, np.nan]] * max_delay, + np.random.rand(n_samples_per_session, 2), + ] + narx = make_narx( + X, + y, + n_terms_to_select=10, + max_delay=max_delay, + poly_degree=3, + verbose=0, + ).fit( + X, + y, + ) + + xy_hstack = np.c_[X, y] + time_shift_ids, poly_ids = fd2tp(narx.feat_ids_, narx.delay_ids_) + time_shift_vars = make_time_shift_features(xy_hstack, time_shift_ids) + poly_terms = make_poly_features(time_shift_vars, poly_ids) + poly_terms_masked, y_masked = mask_missing_values(poly_terms, y) + assert poly_terms_masked.shape[0] == y_masked.shape[0] + assert poly_terms_masked.shape[0] == n_sessions * ( + n_samples_per_session - narx.max_delay_ + )