Skip to content

Commit f7ff796

Browse files
committed
FEAT add jac for narx
1 parent 357b22f commit f7ff796

File tree

7 files changed

+881
-337
lines changed

7 files changed

+881
-337
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
steps:
2323
- uses: actions/checkout@v4
2424
- name: Build wheels
25-
uses: pypa/[email protected].1
25+
uses: pypa/[email protected].2
2626
env:
2727
CIBW_BUILD: cp3*-*
2828
CIBW_SKIP: pp* *i686* *musllinux* *-macosx_universal2 *-manylinux_ppc64le *-manylinux_s390x

examples/plot_narx.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,14 @@
125125
# In the printed NARX model, it is found that :class:`FastCan` selects the correct
126126
# terms and the coefficients are close to the true values.
127127

128-
from fastcan.narx import NARX, print_narx
128+
from fastcan.narx import NARX, _pt2fd, print_narx
129+
130+
# Convert poly_ids and time_shift_ids to feat_ids and delay_ids
131+
feat_ids, delay_ids = _pt2fd(selected_poly_ids, time_shift_ids)
129132

130133
narx_model = NARX(
131-
time_shift_ids=time_shift_ids,
132-
poly_ids=selected_poly_ids,
134+
feat_ids=feat_ids,
135+
delay_ids=delay_ids,
133136
)
134137

135138
narx_model.fit(X, y)

examples/plot_narx_msa.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# Nonlinear system
1616
# ----------------
1717
#
18-
# `Duffing equation <https://en.wikipedia.org/wiki/Duffing_equation>` is used to
18+
# `Duffing equation <https://en.wikipedia.org/wiki/Duffing_equation>`_ is used to
1919
# generate simulated data. The mathematical model is given by
2020
#
2121
# .. math::
@@ -130,7 +130,7 @@ def plot_prediction(ax, t, y_true, y_pred, title):
130130
y_train_osa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay])
131131
y_test_osa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay])
132132

133-
narx_model.fit(u_train, y_train, coef_init="one_step_ahead", method="Nelder-Mead")
133+
narx_model.fit(u_train, y_train, coef_init="one_step_ahead")
134134
y_train_msa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay])
135135
y_test_msa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay])
136136

@@ -169,7 +169,7 @@ def plot_prediction(ax, t, y_true, y_pred, title):
169169
y_train_osa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay])
170170
y_test_osa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay])
171171

172-
narx_model.fit(u_all, y_all, coef_init="one_step_ahead", method="Nelder-Mead")
172+
narx_model.fit(u_all, y_all, coef_init="one_step_ahead")
173173
y_train_msa_pred = narx_model.predict(u_train, y_init=y_train[:max_delay])
174174
y_test_msa_pred = narx_model.predict(u_test, y_init=y_test[:max_delay])
175175

0 commit comments

Comments
 (0)