|
10 | 10 | NARX, |
11 | 11 | fd2tp, |
12 | 12 | make_narx, |
| 13 | + make_poly_features, |
13 | 14 | make_poly_ids, |
| 15 | + make_time_shift_features, |
14 | 16 | make_time_shift_ids, |
15 | 17 | print_narx, |
16 | 18 | tp2fd, |
@@ -534,3 +536,43 @@ def test_make_narx_refine_print(capsys): |
534 | 536 | ) |
535 | 537 | captured = capsys.readouterr() |
536 | 538 | assert "No. of iterations: " in captured.out |
| 539 | + |
| 540 | + |
| 541 | +@pytest.mark.parametrize("max_delay", [1, 3, 7, 10]) |
| 542 | +def test_nan_split(max_delay): |
| 543 | + n_sessions = 10 |
| 544 | + n_samples_per_session = 100 |
| 545 | + X = np.random.rand(n_samples_per_session, 2) |
| 546 | + y = np.random.rand(n_samples_per_session, 2) |
| 547 | + for _ in range(n_sessions - 1): |
| 548 | + X = np.r_[ |
| 549 | + X, |
| 550 | + [[np.nan, np.nan]] * max_delay, |
| 551 | + np.random.rand(n_samples_per_session, 2), |
| 552 | + ] |
| 553 | + y = np.r_[ |
| 554 | + y, |
| 555 | + [[np.nan, np.nan]] * max_delay, |
| 556 | + np.random.rand(n_samples_per_session, 2), |
| 557 | + ] |
| 558 | + narx = make_narx( |
| 559 | + X, |
| 560 | + y, |
| 561 | + n_terms_to_select=10, |
| 562 | + max_delay=max_delay, |
| 563 | + poly_degree=3, |
| 564 | + verbose=0, |
| 565 | + ).fit( |
| 566 | + X, |
| 567 | + y, |
| 568 | + ) |
| 569 | + |
| 570 | + xy_hstack = np.c_[X, y] |
| 571 | + time_shift_ids, poly_ids = fd2tp(narx.feat_ids_, narx.delay_ids_) |
| 572 | + time_shift_vars = make_time_shift_features(xy_hstack, time_shift_ids) |
| 573 | + poly_terms = make_poly_features(time_shift_vars, poly_ids) |
| 574 | + poly_terms_masked, y_masked = mask_missing_values(poly_terms, y) |
| 575 | + assert poly_terms_masked.shape[0] == y_masked.shape[0] |
| 576 | + assert poly_terms_masked.shape[0] == n_sessions * ( |
| 577 | + n_samples_per_session - narx.max_delay_ |
| 578 | + ) |
0 commit comments