From 727a676cb8bccc0679b7516d88c73d01fdc346ec Mon Sep 17 00:00:00 2001 From: SIKAI ZHANG <34108862+MatthewSZhang@users.noreply.github.com> Date: Thu, 10 Apr 2025 16:41:12 +0800 Subject: [PATCH] MNT output proper constant when some outputs only have intercepts --- fastcan/narx.py | 8 +- pixi.lock | 234 ++++++++++++++++----------------------------- tests/test_narx.py | 44 +++++++-- 3 files changed, 124 insertions(+), 162 deletions(-) diff --git a/fastcan/narx.py b/fastcan/narx.py index f5299f7..e65ebd9 100644 --- a/fastcan/narx.py +++ b/fastcan/narx.py @@ -760,7 +760,8 @@ def fit(self, X, y, sample_weight=None, coef_init=None, **params): warnings.warn( f"output_ids got {self.output_ids_}, which does not " f"contain all values from 0 to {self.n_outputs_ - 1}." - "The predicted outputs for the missing values will be 0.", + "The prediction for the missing outputs will be a constant" + "(i.e., intercept).", UserWarning, ) @@ -783,6 +784,7 @@ def fit(self, X, y, sample_weight=None, coef_init=None, **params): for i in range(self.n_outputs_): output_i_mask = self.output_ids_ == i if np.sum(output_i_mask) == 0: + intercept[i] = np.mean(y_masked[:, i]) continue osa_narx.fit( poly_terms_masked[:, output_i_mask], @@ -974,8 +976,8 @@ def _update_dydx( dydx[k, y_ids, x_ids] = terms # Update dynamic terms of Jacobian - cfd = np.zeros((n_y, n_y, max_delay), dtype=float) - if max_delay > 0: + if max_delay > 0 and grad_yyd_ids.size > 0: + cfd = np.zeros((n_y, n_y, max_delay), dtype=float) _update_cfd( X, y_hat, diff --git a/pixi.lock b/pixi.lock index 87d8e2c..761afa4 100644 --- a/pixi.lock +++ b/pixi.lock @@ -329,7 +329,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/jupyter_events-0.12.0-pyh29332c3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyter_server-2.15.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyter_server_terminals-0.5.3-pyhd8ed1ab_1.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab-4.3.6-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab-4.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab_pygments-0.3.0-pyhd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab_server-2.27.3-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/keyutils-1.6.1-h166bdaf_0.tar.bz2 @@ -393,15 +393,15 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/meson-python-0.17.1-pyh70fd9c4_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/mistune-3.1.3-pyh29332c3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/munkres-1.1.4-pyh9f0ad1d_0.tar.bz2 - - conda: https://conda.anaconda.org/conda-forge/linux-64/mysql-common-9.0.1-h266115a_6.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-9.0.1-he0572af_6.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/mysql-common-9.2.0-h266115a_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-9.2.0-he0572af_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/nbclient-0.10.2-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/nbconvert-core-7.16.6-pyh29332c3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/nbformat-5.10.4-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h2d0b736_3.conda - conda: https://conda.anaconda.org/conda-forge/noarch/nest-asyncio-1.6.0-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/ninja-1.12.1-h297d8ca_0.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/notebook-7.3.3-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/notebook-7.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/notebook-shim-0.2.4-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/numpy-2.2.4-py313h17eae1a_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/openjpeg-2.5.3-h5fbd93e_0.conda @@ -446,7 +446,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/linux-64/pyyaml-6.0.2-py313h8060acc_2.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/pyzmq-26.4.0-py313h8e95178_0.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/qhull-2020.2-h434a139_5.conda - - conda: https://conda.anaconda.org/conda-forge/linux-64/qt6-main-6.9.0-h6441bc3_0.conda + - conda: https://conda.anaconda.org/conda-forge/linux-64/qt6-main-6.9.0-h6441bc3_1.conda - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8c095d6_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/referencing-0.36.2-pyh29332c3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/requests-2.32.3-pyhd8ed1ab_1.conda @@ -611,7 +611,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/jupyter_events-0.12.0-pyh29332c3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyter_server-2.15.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyter_server_terminals-0.5.3-pyhd8ed1ab_1.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab-4.3.6-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab-4.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab_pygments-0.3.0-pyhd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab_server-2.27.3-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/kiwisolver-1.4.7-py313h0c4e38b_0.conda @@ -669,7 +669,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-64/ncurses-6.5-h0622a9a_3.conda - conda: https://conda.anaconda.org/conda-forge/noarch/nest-asyncio-1.6.0-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/ninja-1.12.1-h3c5361c_0.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/notebook-7.3.3-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/notebook-7.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/notebook-shim-0.2.4-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/numpy-2.2.4-py313hc518a0f_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/openjpeg-2.5.3-h7fd6d84_0.conda @@ -857,7 +857,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/jupyter_events-0.12.0-pyh29332c3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyter_server-2.15.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyter_server_terminals-0.5.3-pyhd8ed1ab_1.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab-4.3.6-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab-4.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab_pygments-0.3.0-pyhd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab_server-2.27.3-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/kiwisolver-1.4.7-py313hf9c7212_0.conda @@ -915,7 +915,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ncurses-6.5-h5e97a16_3.conda - conda: https://conda.anaconda.org/conda-forge/noarch/nest-asyncio-1.6.0-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ninja-1.12.1-h420ef59_0.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/notebook-7.3.3-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/notebook-7.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/notebook-shim-0.2.4-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/numpy-2.2.4-py313h41a2e72_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-arm64/openjpeg-2.5.3-h8a3d83b_0.conda @@ -1097,7 +1097,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/jupyter_events-0.12.0-pyh29332c3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyter_server-2.15.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyter_server_terminals-0.5.3-pyhd8ed1ab_1.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab-4.3.6-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab-4.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab_pygments-0.3.0-pyhd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab_server-2.27.3-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/win-64/kiwisolver-1.4.7-py313h1ec8472_0.conda @@ -1147,7 +1147,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/nbformat-5.10.4-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/nest-asyncio-1.6.0-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/win-64/ninja-1.12.1-hc790b64_0.conda - - conda: https://conda.anaconda.org/conda-forge/noarch/notebook-7.3.3-pyhd8ed1ab_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/notebook-7.4.0-pyhd8ed1ab_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/notebook-shim-0.2.4-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/win-64/numpy-2.2.4-py313hefb8edb_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/openjpeg-2.5.3-h4d64b90_0.conda @@ -1191,7 +1191,7 @@ environments: - conda: https://conda.anaconda.org/conda-forge/win-64/pyyaml-6.0.2-py313hb4c8b1a_2.conda - conda: https://conda.anaconda.org/conda-forge/win-64/pyzmq-26.4.0-py313h2100fd5_0.conda - conda: https://conda.anaconda.org/conda-forge/win-64/qhull-2020.2-hc790b64_5.conda - - conda: https://conda.anaconda.org/conda-forge/win-64/qt6-main-6.9.0-h83cda92_0.conda + - conda: https://conda.anaconda.org/conda-forge/win-64/qt6-main-6.9.0-h83cda92_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/referencing-0.36.2-pyh29332c3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/requests-2.32.3-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/rfc3339-validator-0.1.4-pyhd8ed1ab_1.conda @@ -1299,31 +1299,32 @@ environments: - conda: https://conda.anaconda.org/conda-forge/noarch/typing_extensions-4.13.1-pyh29332c3_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tzdata-2025b-h78e105d_0.conda osx-64: - - conda: https://conda.anaconda.org/conda-forge/osx-64/black-25.1.0-py312hb401068_0.conda + - conda: https://conda.anaconda.org/conda-forge/noarch/black-25.1.0-pyh866005b_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/bzip2-1.0.8-hfdf4475_7.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/ca-certificates-2025.1.31-h8857fd0_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/click-8.1.8-pyh707e725_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-64/cython-3.0.12-py312hdfbeeba_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/cython-3.0.12-py313h9efc8c2_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/cython-lint-0.16.6-pyhff2d567_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libcxx-20.1.2-hf95d169_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libexpat-2.7.0-h240833e_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libffi-3.4.6-h281671d_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/liblzma-5.8.1-hd471939_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/libmpdec-4.0.0-hfdf4475_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libsqlite-3.49.1-hdb6dae5_2.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/libzlib-1.3.1-hd23fc13_2.conda - - conda: https://conda.anaconda.org/conda-forge/osx-64/mypy-1.15.0-py312h01d7ebd_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/mypy-1.15.0-py313h63b0ddb_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/mypy_extensions-1.0.0-pyha770c72_1.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/ncurses-6.5-h0622a9a_3.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/openssl-3.5.0-hc426f3f_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/packaging-24.2-pyhd8ed1ab_2.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pathspec-0.12.1-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/platformdirs-4.3.7-pyh29332c3_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-64/psutil-7.0.0-py312h01d7ebd_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/psutil-7.0.0-py313h63b0ddb_0.conda - conda: https://conda.anaconda.org/conda-forge/noarch/pycodestyle-2.13.0-pyhd8ed1ab_0.conda - - conda: https://conda.anaconda.org/conda-forge/osx-64/python-3.12.9-h9ccd52b_1_cpython.conda - - conda: https://conda.anaconda.org/conda-forge/osx-64/python_abi-3.12-6_cp312.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/python-3.13.2-h534c281_101_cp313.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/python_abi-3.13-6_cp313.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/readline-8.2-h7cca4af_2.conda - - conda: https://conda.anaconda.org/conda-forge/osx-64/ruff-0.11.4-py312h60e8e2e_0.conda + - conda: https://conda.anaconda.org/conda-forge/osx-64/ruff-0.11.4-py313h837c616_0.conda - conda: https://conda.anaconda.org/conda-forge/osx-64/tk-8.6.13-h1abcd95_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tokenize-rt-6.1.0-pyhd8ed1ab_1.conda - conda: https://conda.anaconda.org/conda-forge/noarch/tomli-2.2.1-pyhd8ed1ab_1.conda @@ -1787,21 +1788,6 @@ packages: license_family: MIT size: 172678 timestamp: 1742502887437 -- conda: https://conda.anaconda.org/conda-forge/osx-64/black-25.1.0-py312hb401068_0.conda - sha256: e937f18e36e23ecf0ec9ab89fc3ef5263308e88b645c4278fe8807fd95bef4c1 - md5: d37d5213fcf23a33d946e40937578a02 - depends: - - click >=8.0.0 - - mypy_extensions >=0.4.3 - - packaging >=22.0 - - pathspec >=0.9 - - platformdirs >=2 - - python >=3.12,<3.13.0a0 - - python_abi 3.12.* *_cp312 - license: MIT - license_family: MIT - size: 393484 - timestamp: 1738616259890 - conda: https://conda.anaconda.org/conda-forge/noarch/bleach-6.2.0-pyh29332c3_4.conda sha256: a05971bb80cca50ce9977aad3f7fc053e54ea7d5321523efc7b9a6e12901d3cd md5: f0b4c8e370446ef89797608d60a564b3 @@ -2807,18 +2793,6 @@ packages: - pkg:pypi/cython?source=hash-mapping size: 3766349 timestamp: 1739228643862 -- conda: https://conda.anaconda.org/conda-forge/osx-64/cython-3.0.12-py312hdfbeeba_0.conda - sha256: a186d286aedb2230dcdcaf2a8602c098112eaacdf9d8af39da2a474950bf1b98 - md5: 5801a15eece1bd00c7f6dc0c68640a9f - depends: - - __osx >=10.13 - - libcxx >=18 - - python >=3.12,<3.13.0a0 - - python_abi 3.12.* *_cp312 - license: Apache-2.0 - license_family: APACHE - size: 3455381 - timestamp: 1739228540351 - conda: https://conda.anaconda.org/conda-forge/osx-64/cython-3.0.12-py313h9efc8c2_0.conda sha256: 132d6e81a95c042210f33c3d24f03d52632738434b3ea48cfb184a26684d365e md5: ddace7cae5c3073c031ad08ef01881da @@ -4081,9 +4055,9 @@ packages: - pkg:pypi/jupyter-server-terminals?source=hash-mapping size: 19711 timestamp: 1733428049134 -- conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab-4.3.6-pyhd8ed1ab_0.conda - sha256: cf10c9b4158c4ef2796fde546f2bbe45f43c1402a0c2a175939ebbb308846ada - md5: 8b91a10c966aa65b9ad1a2702e6ef121 +- conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab-4.4.0-pyhd8ed1ab_0.conda + sha256: 4d225d094d1e5a8e95c2bde0f9c9bbc5aac52d9abf7fd597dd7af0f467b44347 + md5: 2da6a5e2c788a1b1998b24c50a18572a depends: - async-lru >=1.0.0 - httpx >=0.25.0 @@ -4097,7 +4071,7 @@ packages: - notebook-shim >=0.2 - packaging - python >=3.9 - - setuptools >=40.8.0 + - setuptools >=41.1.0 - tomli >=1.2.2 - tornado >=6.2.0 - traitlets @@ -4105,8 +4079,8 @@ packages: license_family: BSD purls: - pkg:pypi/jupyterlab?source=compressed-mapping - size: 7641308 - timestamp: 1741964212957 + size: 8402007 + timestamp: 1743711353203 - conda: https://conda.anaconda.org/conda-forge/noarch/jupyterlab_pygments-0.3.0-pyhd8ed1ab_2.conda sha256: dc24b900742fdaf1e077d9a3458fd865711de80bca95fe3c6d46610c532c6ef0 md5: fd312693df06da3578383232528c468d @@ -6551,20 +6525,20 @@ packages: license_family: MIT size: 17058016 timestamp: 1738767732637 -- conda: https://conda.anaconda.org/conda-forge/osx-64/mypy-1.15.0-py312h01d7ebd_0.conda - sha256: 38132c4b5de6686965f21b51a1656438e83b2a53d6f50e9589e73fb57a43dd49 - md5: 0251bb4d6702b729b06fd5c7918e9242 +- conda: https://conda.anaconda.org/conda-forge/osx-64/mypy-1.15.0-py313h63b0ddb_0.conda + sha256: ec50dc7be70eff5008d73b4bd29fba72e02e499e9b60060a49ece4c1e12a9d55 + md5: e9dc60a2c2c62f4d2e24f61603f00bdc depends: - __osx >=10.13 - mypy_extensions >=1.0.0 - psutil >=4.0 - - python >=3.12,<3.13.0a0 - - python_abi 3.12.* *_cp312 + - python >=3.13,<3.14.0a0 + - python_abi 3.13.* *_cp313 - typing_extensions >=4.1.0 license: MIT license_family: MIT - size: 12384787 - timestamp: 1738768017667 + size: 11022410 + timestamp: 1738768159908 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/mypy-1.15.0-py313h90d716c_0.conda sha256: 4dc7a5a30017c742c204311afd078c639ca434b7f44835dfba789a5fb972ea6c md5: d01a9742c8e3c425d3c3d5e412a43872 @@ -6605,9 +6579,9 @@ packages: license_family: MIT size: 10854 timestamp: 1733230986902 -- conda: https://conda.anaconda.org/conda-forge/linux-64/mysql-common-9.0.1-h266115a_6.conda - sha256: 9c2e3f9e9883e4b8d7e9e6abf7b235dc00bdcd5ef66640a360464a9f5756294d - md5: 94116b69829e90b72d566e64421e1bff +- conda: https://conda.anaconda.org/conda-forge/linux-64/mysql-common-9.2.0-h266115a_0.conda + sha256: 571b6a2bffaf186ab92cdb06852fc5b6b5b7c6605de2b397fb13cfb0bb05c375 + md5: db22a0962c953e81a2a679ecb1fc6027 depends: - __glibc >=2.17,<3.0.a0 - libgcc >=13 @@ -6616,24 +6590,24 @@ packages: license: GPL-2.0-or-later license_family: GPL purls: [] - size: 616215 - timestamp: 1744124836761 -- conda: https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-9.0.1-he0572af_6.conda - sha256: 274467a602944d12722f757f660ad034de6f5f5d7d2ea1b913ef6fd836c1b8ce - md5: 9802ae6d20982f42c0f5d69008988763 + size: 653477 + timestamp: 1743939199519 +- conda: https://conda.anaconda.org/conda-forge/linux-64/mysql-libs-9.2.0-he0572af_0.conda + sha256: 41cd870c04961591eabe7a43283d2bbc80a382e007f766edb8396ffd2bdfa418 + md5: 93340b072c393d23c4700a1d40565dca depends: - __glibc >=2.17,<3.0.a0 - libgcc >=13 - libstdcxx >=13 - libzlib >=1.3.1,<2.0a0 - - mysql-common 9.0.1 h266115a_6 + - mysql-common 9.2.0 h266115a_0 - openssl >=3.4.1,<4.0a0 - zstd >=1.5.7,<1.6.0a0 license: GPL-2.0-or-later license_family: GPL purls: [] - size: 1369369 - timestamp: 1744124916632 + size: 1371585 + timestamp: 1743939293417 - conda: https://conda.anaconda.org/conda-forge/noarch/nbclient-0.10.2-pyhd8ed1ab_0.conda sha256: a20cff739d66c2f89f413e4ba4c6f6b59c50d5c30b5f0d840c13e8c9c2df9135 md5: 6bb0d77277061742744176ab555b723c @@ -6778,13 +6752,12 @@ packages: purls: [] size: 285150 timestamp: 1715441052517 -- conda: https://conda.anaconda.org/conda-forge/noarch/notebook-7.3.3-pyhd8ed1ab_0.conda - sha256: 5086c70ff352a72b9d47fcf73d37a1be583cf5b416c9729295a9b3710330d781 - md5: 3b04a08fc654590f45e0a713982f898b +- conda: https://conda.anaconda.org/conda-forge/noarch/notebook-7.4.0-pyhd8ed1ab_0.conda + sha256: d3f70987bc1e1a20b122726a49a24e5e6f09d00c9862bb399cd1682cd59a1e1e + md5: 7e82caa4495c513bcfb33a159e1222d4 depends: - - importlib_resources >=5.0 - jupyter_server >=2.4.0,<3 - - jupyterlab >=4.3.6,<4.4 + - jupyterlab >=4.4.0,<4.5 - jupyterlab_server >=2.27.1,<3 - notebook-shim >=0.2,<0.3 - python >=3.9 @@ -6793,8 +6766,8 @@ packages: license_family: BSD purls: - pkg:pypi/notebook?source=hash-mapping - size: 9705127 - timestamp: 1741968301453 + size: 10473675 + timestamp: 1744236307330 - conda: https://conda.anaconda.org/conda-forge/noarch/notebook-shim-0.2.4-pyhd8ed1ab_1.conda sha256: 7b920e46b9f7a2d2aa6434222e5c8d739021dbc5cc75f32d124a8191d86f9056 md5: e7f89ea5f7ea9401642758ff50a2d9c1 @@ -7403,17 +7376,6 @@ packages: - pkg:pypi/psutil?source=hash-mapping size: 475101 timestamp: 1740663284505 -- conda: https://conda.anaconda.org/conda-forge/osx-64/psutil-7.0.0-py312h01d7ebd_0.conda - sha256: bdfa40a1ef3a80c3bec425a5ed507ebda2bdebce2a19bccb000db9d5c931750c - md5: fcad6b89f4f7faa999fa4d887eab14ba - depends: - - __osx >=10.13 - - python >=3.12,<3.13.0a0 - - python_abi 3.12.* *_cp312 - license: BSD-3-Clause - license_family: BSD - size: 473946 - timestamp: 1740663466925 - conda: https://conda.anaconda.org/conda-forge/osx-64/psutil-7.0.0-py313h63b0ddb_0.conda sha256: b117f61eaf3d5fb640d773c3021f222c833a69c2ac123d7f4b028b3e5d638dd4 md5: 2c8969aaee2cf24bc8931f5fc36cccfd @@ -7807,28 +7769,6 @@ packages: size: 33233150 timestamp: 1739803603242 python_site_packages_path: lib/python3.13/site-packages -- conda: https://conda.anaconda.org/conda-forge/osx-64/python-3.12.9-h9ccd52b_1_cpython.conda - build_number: 1 - sha256: c394f7068a714cad7853992f18292bb34c6d99fe7c21025664b05069c86b9450 - md5: b878567b6b749f993dbdbc2834115bc3 - depends: - - __osx >=10.13 - - bzip2 >=1.0.8,<2.0a0 - - libexpat >=2.6.4,<3.0a0 - - libffi >=3.4,<4.0a0 - - liblzma >=5.6.4,<6.0a0 - - libsqlite >=3.49.1,<4.0a0 - - libzlib >=1.3.1,<2.0a0 - - ncurses >=6.5,<7.0a0 - - openssl >=3.4.1,<4.0a0 - - readline >=8.2,<9.0a0 - - tk >=8.6.13,<8.7.0a0 - - tzdata - constrains: - - python_abi 3.12.* *_cp312 - license: Python-2.0 - size: 13833024 - timestamp: 1741129416409 - conda: https://conda.anaconda.org/conda-forge/osx-64/python-3.13.2-h2267d90_1_cp313t.conda build_number: 1 sha256: 95abaeed4b827aa209f89a2fa18e219c89b913a82510f0fbe729ef2e04a68b7d @@ -8072,16 +8012,6 @@ packages: purls: [] size: 6858 timestamp: 1743483201023 -- conda: https://conda.anaconda.org/conda-forge/osx-64/python_abi-3.12-6_cp312.conda - build_number: 6 - sha256: abbe800dc60cfe459be68a4f3fd946b09ae573c586efb3396ee48634ee3723ad - md5: e6096b1328952bbe07342f8927940ea9 - constrains: - - python 3.12.* *_cpython - license: BSD-3-Clause - license_family: BSD - size: 6929 - timestamp: 1743483235505 - conda: https://conda.anaconda.org/conda-forge/osx-64/python_abi-3.13-5_cp313t.conda build_number: 5 sha256: a96553de64be6441400e88c2c6ad7123d91cbcea4898b5966a653163f30d9f55 @@ -8359,9 +8289,9 @@ packages: purls: [] size: 1377020 timestamp: 1720814433486 -- conda: https://conda.anaconda.org/conda-forge/linux-64/qt6-main-6.9.0-h6441bc3_0.conda - sha256: 5f04a118f6f4124bf0f3a7b7ca2510954860c764db8c25a62bcfa91f51693073 - md5: d3df16592e15a3f833cfc4d19ae58677 +- conda: https://conda.anaconda.org/conda-forge/linux-64/qt6-main-6.9.0-h6441bc3_1.conda + sha256: 0485df0e29daf02023b98b0d7f5f4f97bd23650582d8c64d80f22cdb1ad01676 + md5: 4029a8dcb1d97ea241dbe5abfda1fad6 depends: - __glibc >=2.17,<3.0.a0 - alsa-lib >=1.2.13,<1.3.0a0 @@ -8370,19 +8300,19 @@ packages: - fontconfig >=2.15.0,<3.0a0 - fonts-conda-ecosystem - freetype >=2.13.3,<3.0a0 - - harfbuzz >=11.0.0,<12.0a0 + - harfbuzz >=11.0.1 - icu >=75.1,<76.0a0 - krb5 >=1.21.3,<1.22.0a0 - - libclang-cpp20.1 >=20.1.1,<20.2.0a0 - - libclang13 >=20.1.1 + - libclang-cpp20.1 >=20.1.2,<20.2.0a0 + - libclang13 >=20.1.2 - libcups >=2.3.3,<2.4.0a0 - libdrm >=2.4.124,<2.5.0a0 - libegl >=1.7.0,<2.0a0 - libgcc >=13 - libgl >=1.7.0,<2.0a0 - - libglib >=2.84.0,<3.0a0 + - libglib >=2.84.1,<3.0a0 - libjpeg-turbo >=3.0.0,<4.0a0 - - libllvm20 >=20.1.1,<20.2.0a0 + - libllvm20 >=20.1.2,<20.2.0a0 - libpng >=1.6.47,<1.7.0a0 - libpq >=17.4,<18.0a0 - libsqlite >=3.49.1,<4.0a0 @@ -8393,8 +8323,8 @@ packages: - libxkbcommon >=1.8.1,<2.0a0 - libxml2 >=2.13.7,<2.14.0a0 - libzlib >=1.3.1,<2.0a0 - - mysql-libs >=9.0.1,<9.1.0a0 - - openssl >=3.4.1,<4.0a0 + - mysql-libs >=9.2.0,<9.3.0a0 + - openssl >=3.5.0,<4.0a0 - pcre2 >=10.44,<10.45.0a0 - wayland >=1.23.1,<2.0a0 - xcb-util >=0.4.1,<0.5.0a0 @@ -8419,25 +8349,25 @@ packages: license: LGPL-3.0-only license_family: LGPL purls: [] - size: 51884819 - timestamp: 1743632133306 -- conda: https://conda.anaconda.org/conda-forge/win-64/qt6-main-6.9.0-h83cda92_0.conda - sha256: 84ff37de3c72a612dfbf9b317d5501231a87582f2edf4959ff706b84b4aa9246 - md5: d92e5a0de3263315551d54d5574f5193 + size: 51522155 + timestamp: 1744201848686 +- conda: https://conda.anaconda.org/conda-forge/win-64/qt6-main-6.9.0-h83cda92_1.conda + sha256: 40bb84f5898e60dd7ee27a504c0403ea5dae514ce0638b763bb00ff4d73ab611 + md5: 412f970fc305449b6ee626fe9c6638a8 depends: - double-conversion >=3.3.1,<3.4.0a0 - - harfbuzz >=11.0.0,<12.0a0 + - harfbuzz >=11.0.1 - icu >=75.1,<76.0a0 - krb5 >=1.21.3,<1.22.0a0 - - libclang13 >=20.1.1 - - libglib >=2.84.0,<3.0a0 + - libclang13 >=20.1.2 + - libglib >=2.84.1,<3.0a0 - libjpeg-turbo >=3.0.0,<4.0a0 - libpng >=1.6.47,<1.7.0a0 - libsqlite >=3.49.1,<4.0a0 - libtiff >=4.7.0,<4.8.0a0 - libwebp-base >=1.5.0,<2.0a0 - libzlib >=1.3.1,<2.0a0 - - openssl >=3.4.1,<4.0a0 + - openssl >=3.5.0,<4.0a0 - pcre2 >=10.44,<10.45.0a0 - ucrt >=10.0.20348.0 - vc >=14.3,<15 @@ -8448,8 +8378,8 @@ packages: license: LGPL-3.0-only license_family: LGPL purls: [] - size: 94992566 - timestamp: 1743635306726 + size: 94780444 + timestamp: 1744204470975 - conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8c095d6_2.conda sha256: 2d6d0c026902561ed77cd646b5021aef2d4db22e57a5b0178dfc669231e06d2c md5: 283b96675859b20a825f8fa30f311446 @@ -8626,20 +8556,20 @@ packages: license_family: MIT size: 9001038 timestamp: 1743819292575 -- conda: https://conda.anaconda.org/conda-forge/osx-64/ruff-0.11.4-py312h60e8e2e_0.conda - sha256: 6f2a267f074d65a8b198971c0f55648be9d5a36e9b77cc7f201b7011125d20a1 - md5: 5ac88d878ce5ec735befa59dce14d0bf +- conda: https://conda.anaconda.org/conda-forge/osx-64/ruff-0.11.4-py313h837c616_0.conda + sha256: 6eeb0acc7868ac46186e8317154df5403d19687665114fced5aae45c522e8d00 + md5: 164585e6a4b12fd29517c940602c265d depends: - __osx >=10.13 - libcxx >=18 - - python >=3.12,<3.13.0a0 - - python_abi 3.12.* *_cp312 + - python >=3.13,<3.14.0a0 + - python_abi 3.13.* *_cp313 constrains: - __osx >=10.13 license: MIT license_family: MIT - size: 8403708 - timestamp: 1743820004397 + size: 8406832 + timestamp: 1743819610560 - conda: https://conda.anaconda.org/conda-forge/osx-arm64/ruff-0.11.4-py313hd3a9b03_0.conda sha256: 7b7fe29220adb5b40210956dd25e05b77ce3e1a3eb89e5c867724d74eb4e8d01 md5: 4543a37cc5e13195d08611d1a154c38f @@ -8722,7 +8652,7 @@ packages: - ruff>=0.11.0 ; extra == 'tests' - black>=24.3.0 ; extra == 'tests' - mypy>=1.15 ; extra == 'tests' - - pyamg>=5.0.0 ; extra == 'tests' + - pyamg>=4.2.1 ; extra == 'tests' - polars>=0.20.30 ; extra == 'tests' - pyarrow>=12.0.0 ; extra == 'tests' - numpydoc>=1.2.0 ; extra == 'tests' @@ -8783,7 +8713,7 @@ packages: - ruff>=0.11.0 ; extra == 'tests' - black>=24.3.0 ; extra == 'tests' - mypy>=1.15 ; extra == 'tests' - - pyamg>=5.0.0 ; extra == 'tests' + - pyamg>=4.2.1 ; extra == 'tests' - polars>=0.20.30 ; extra == 'tests' - pyarrow>=12.0.0 ; extra == 'tests' - numpydoc>=1.2.0 ; extra == 'tests' @@ -8844,7 +8774,7 @@ packages: - ruff>=0.11.0 ; extra == 'tests' - black>=24.3.0 ; extra == 'tests' - mypy>=1.15 ; extra == 'tests' - - pyamg>=5.0.0 ; extra == 'tests' + - pyamg>=4.2.1 ; extra == 'tests' - polars>=0.20.30 ; extra == 'tests' - pyarrow>=12.0.0 ; extra == 'tests' - numpydoc>=1.2.0 ; extra == 'tests' @@ -8905,7 +8835,7 @@ packages: - ruff>=0.11.0 ; extra == 'tests' - black>=24.3.0 ; extra == 'tests' - mypy>=1.15 ; extra == 'tests' - - pyamg>=5.0.0 ; extra == 'tests' + - pyamg>=4.2.1 ; extra == 'tests' - polars>=0.20.30 ; extra == 'tests' - pyarrow>=12.0.0 ; extra == 'tests' - numpydoc>=1.2.0 ; extra == 'tests' diff --git a/tests/test_narx.py b/tests/test_narx.py index 685bcda..c8139cf 100644 --- a/tests/test_narx.py +++ b/tests/test_narx.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from numpy.testing import assert_array_equal +from numpy.testing import assert_almost_equal, assert_array_equal from sklearn.metrics import r2_score from sklearn.utils.estimator_checks import check_estimator @@ -243,17 +243,48 @@ def test_narx(nan, multi_output): ).fit(X, y) -def test_mulit_output_warn_error(): +def test_mulit_output_warn(): + X = np.random.rand(10, 2) + y = np.random.rand(10, 2) + for i in range(2): + if i == 0: + # X only, grad does not have dynamic part + time_shift_ids = np.array([[0, 1], [1, 1]]) + poly_ids = np.array([[1, 1], [2, 2]]) + else: + time_shift_ids = np.array([[0, 0], [1, 1], [2, 1]]) + poly_ids = np.array([[1, 1], [2, 2], [0, 3]]) + feat_ids, delay_ids = tp2fd(time_shift_ids, poly_ids) + + with pytest.warns(UserWarning, match="output_ids got"): + narx = NARX(feat_ids=feat_ids, delay_ids=delay_ids) + narx.fit(X, y) + y_pred = narx.predict(X) + assert_almost_equal(np.std(y_pred[narx.max_delay_:, 1] - np.mean(y[:, 1])), 0.0) + + X_nan = np.copy(X) + y_nan = np.copy(y) + X_nan[4, 0] = np.nan + y_nan[4, 1] = np.nan + for coef_init in [None, "one_step_ahead"]: + with pytest.warns(UserWarning, match="output_ids got"): + y_pred = narx.fit(X_nan, y_nan, coef_init=coef_init).predict(X_nan) + y_nan_masked, y_pred_masked = _mask_missing_value(y_nan, y_pred) + assert_almost_equal( + np.std( + y_pred_masked[y_pred_masked[:, 0]!=0, 1] -\ + np.mean(y_nan_masked[:, 1]) + ), + 0.0, + ) + +def test_mulit_output_error(): X = np.random.rand(10, 2) y = np.random.rand(10, 2) time_shift_ids = np.array([[0, 1], [1, 1]]) poly_ids = np.array([[1, 1], [2, 2]]) feat_ids, delay_ids = tp2fd(time_shift_ids, poly_ids) - with pytest.warns(UserWarning, match="output_ids got"): - narx = NARX(feat_ids=feat_ids, delay_ids=delay_ids) - narx.fit(X, y) - with pytest.raises(ValueError, match="The length of output_ids should"): narx = NARX( feat_ids=feat_ids, @@ -281,7 +312,6 @@ def test_mulit_output_warn_error(): narx.predict(X, y_init=[1, 1, 1]) - def test_sample_weight(): rng = np.random.default_rng(12345) n_samples = 100