Skip to content

Commit 4a7c4bb

Browse files
MNT padding nan in time shift (#102)
* MNT padding nan in time shift * TST test mask_missing_values when len(arrays)==1
1 parent 5584e0c commit 4a7c4bb

File tree

6 files changed

+30
-17
lines changed

6 files changed

+30
-17
lines changed

doc/narx.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ No matter how the discontinuity is caused, :class:`NARX` treats the discontinuou
5454
>>> import numpy as np
5555
>>> x0 = np.zeros((3, 2)) # First measurement session has 3 samples with 2 features
5656
>>> x1 = np.ones((5, 2)) # Second measurement session has 5 samples with 2 features
57-
>>> u = np.r_[x0, [[np.nan, np.nan]], x1] # Insert np.nan to break the two measurement sessions
57+
>>> max_delay = 2 # Assume the maximum delay for NARX model is 2
58+
>>> u = np.r_[x0, [[np.nan, np.nan]]*max_delay, x1] # Insert (at least max_delay number of) np.nan to break the two measurement sessions
5859

5960
It is important to break the different measurement sessions by np.nan, because otherwise,
6061
the model will assume the time interval between the two measurement sessions is the same as the time interval within a session.

examples/plot_narx.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,14 @@
108108
# whose :math:`X` is the nonlinear terms and :math:`y` is the output signal.
109109

110110
from fastcan import FastCan
111+
from fastcan.utils import mask_missing_values
112+
113+
# Mask out missing values caused by time-shifting
114+
poly_terms_masked, y_masked = mask_missing_values(poly_terms, y)
111115

112116
selector = FastCan(
113117
n_features_to_select=4, # 4 terms should be selected
114-
).fit(poly_terms, y)
118+
).fit(poly_terms_masked, y_masked)
115119

116120
support = selector.get_support()
117121
selected_poly_ids = poly_ids[support]

examples/plot_narx_msa.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def auto_duffing_equation(y, t):
8383
n_samples = 1000
8484

8585
rng = np.random.default_rng(12345)
86-
e_train = rng.normal(0, 0.0002, n_samples)
87-
e_test = rng.normal(0, 0.0002, n_samples)
86+
e_train = rng.normal(0, 0.0004, n_samples)
87+
e_test = rng.normal(0, 0.0004, n_samples)
8888
t = np.linspace(0, dur, n_samples)
8989

9090
sol = odeint(duffing_equation, [0.6, 0.8], t)
@@ -153,12 +153,13 @@ def plot_prediction(ax, t, y_true, y_pred, title):
153153
# The plot above shows that the NARX model cannot capture the dynamics at
154154
# the left equilibrium shown in the phase portraits. To improve the performance, let us
155155
# combine the training and test data for model training to include the dynamics of both
156-
# equilibria. Here, we need to insert `np.nan` to indicate the model that training data
156+
# equilibria. Here, we need to insert (at least max_delay number of) `np.nan` to
157+
# indicate the model that training data
157158
# and test data are from different measurement sessions. The plot shows that the
158159
# prediction performance of the NARX on test data has been largely improved.
159160

160-
u_all = np.r_[u_train, [[np.nan]], u_test]
161-
y_all = np.r_[y_train, [np.nan], y_test]
161+
u_all = np.r_[u_train, [[np.nan]]*max_delay, u_test]
162+
y_all = np.r_[y_train, [np.nan]*max_delay, y_test]
162163
narx_model = make_narx(
163164
X=u_all,
164165
y=y_all,

fastcan/narx.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def make_time_shift_features(X, ids):
6262
>>> X = [[1, 2], [3, 4], [5, 6], [7, 8]]
6363
>>> ids = [[0, 0], [0, 1], [1, 1]]
6464
>>> make_time_shift_features(X, ids)
65-
array([[1., 1., 2.],
66-
[3., 1., 2.],
67-
[5., 3., 4.],
68-
[7., 5., 6.]])
65+
array([[ 1., nan, nan],
66+
[ 3., 1., 2.],
67+
[ 5., 3., 4.],
68+
[ 7., 5., 6.]])
6969
"""
7070
X = check_array(X, ensure_2d=True, dtype=float, ensure_all_finite="allow-nan")
7171
ids = check_array(ids, ensure_2d=True, dtype=int)
@@ -74,7 +74,7 @@ def make_time_shift_features(X, ids):
7474
out = np.zeros([n_samples, n_outputs])
7575
for i, id_temp in enumerate(ids):
7676
out[:, i] = np.r_[
77-
np.full(id_temp[1], X[0, id_temp[0]]),
77+
np.full(id_temp[1], np.nan),
7878
X[: -id_temp[1] or None, id_temp[0]],
7979
]
8080

@@ -1413,11 +1413,11 @@ def make_narx(
14131413
>>> print_narx(narx)
14141414
| yid | Term | Coef |
14151415
=======================================
1416-
| 0 | Intercept | 1.054 |
1417-
| 0 | y_hat[k-1,0] | 0.483 |
1418-
| 0 | X[k,0]*X[k,0] | 0.307 |
1419-
| 0 | X[k-1,0]*X[k-3,0] | 1.999 |
1420-
| 0 | X[k-2,0]*X[k,1] | 1.527 |
1416+
| 0 | Intercept | 1.050 |
1417+
| 0 | y_hat[k-1,0] | 0.484 |
1418+
| 0 | X[k,0]*X[k,0] | 0.306 |
1419+
| 0 | X[k-1,0]*X[k-3,0] | 2.000 |
1420+
| 0 | X[k-2,0]*X[k,1] | 1.528 |
14211421
"""
14221422
X = check_array(X, dtype=float, ensure_2d=True, ensure_all_finite="allow-nan")
14231423
y = check_array(y, dtype=float, ensure_2d=False, ensure_all_finite="allow-nan")

fastcan/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,4 +168,7 @@ def mask_missing_values(*arrays, return_mask=False):
168168
mask_valid = np.all(np.isfinite(np.c_[arrays]), axis=1)
169169
if return_mask:
170170
return mask_valid
171+
masked_arrays = [_safe_indexing(x, mask_valid) for x in arrays]
172+
if len(masked_arrays) == 1:
173+
return masked_arrays[0]
171174
return [_safe_indexing(x, mask_valid) for x in arrays]

tests/test_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ def test_mask_missing():
6666
b = rng.random((100, 2))
6767
c = rng.random(100)
6868
a[10, 0] = np.nan
69+
a_masked = mask_missing_values(a)
70+
mask_valid = mask_missing_values(a, return_mask=True)
71+
assert a_masked.shape == (99, 10)
72+
assert_array_equal(actual=a_masked, desired=a[mask_valid])
6973
b[20, 1] = np.nan
7074
c[30] = np.nan
7175
a_masked, b_masked, c_mask = mask_missing_values(a, b, c)

0 commit comments

Comments
 (0)