Skip to content

Commit 409f057

Browse files
TST add test to check max_delay nan can split two time series (#110)
* TST add test to check max_delay nan can split two time series
1 parent afeabc9 commit 409f057

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

examples/plot_narx_msa.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ def plot_prediction(ax, t, y_true, y_pred, title):
158158
# and test data are from different measurement sessions. The plot shows that the
159159
# prediction performance of the NARX on test data has been largely improved.
160160

161-
u_all = np.r_[u_train, [[np.nan]]*max_delay, u_test]
162-
y_all = np.r_[y_train, [np.nan]*max_delay, y_test]
161+
u_all = np.r_[u_train, [[np.nan]] * max_delay, u_test]
162+
y_all = np.r_[y_train, [np.nan] * max_delay, y_test]
163163
narx_model = make_narx(
164164
X=u_all,
165165
y=y_all,

tests/test_narx.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
NARX,
1111
fd2tp,
1212
make_narx,
13+
make_poly_features,
1314
make_poly_ids,
15+
make_time_shift_features,
1416
make_time_shift_ids,
1517
print_narx,
1618
tp2fd,
@@ -534,3 +536,43 @@ def test_make_narx_refine_print(capsys):
534536
)
535537
captured = capsys.readouterr()
536538
assert "No. of iterations: " in captured.out
539+
540+
541+
@pytest.mark.parametrize("max_delay", [1, 3, 7, 10])
542+
def test_nan_split(max_delay):
543+
n_sessions = 10
544+
n_samples_per_session = 100
545+
X = np.random.rand(n_samples_per_session, 2)
546+
y = np.random.rand(n_samples_per_session, 2)
547+
for _ in range(n_sessions - 1):
548+
X = np.r_[
549+
X,
550+
[[np.nan, np.nan]] * max_delay,
551+
np.random.rand(n_samples_per_session, 2),
552+
]
553+
y = np.r_[
554+
y,
555+
[[np.nan, np.nan]] * max_delay,
556+
np.random.rand(n_samples_per_session, 2),
557+
]
558+
narx = make_narx(
559+
X,
560+
y,
561+
n_terms_to_select=10,
562+
max_delay=max_delay,
563+
poly_degree=3,
564+
verbose=0,
565+
).fit(
566+
X,
567+
y,
568+
)
569+
570+
xy_hstack = np.c_[X, y]
571+
time_shift_ids, poly_ids = fd2tp(narx.feat_ids_, narx.delay_ids_)
572+
time_shift_vars = make_time_shift_features(xy_hstack, time_shift_ids)
573+
poly_terms = make_poly_features(time_shift_vars, poly_ids)
574+
poly_terms_masked, y_masked = mask_missing_values(poly_terms, y)
575+
assert poly_terms_masked.shape[0] == y_masked.shape[0]
576+
assert poly_terms_masked.shape[0] == n_sessions * (
577+
n_samples_per_session - narx.max_delay_
578+
)

0 commit comments

Comments
 (0)