Skip to content

Commit c2a8ca5

Browse files
FEAT add max-candidates in make_narx (#181)
* FEAT add max-candidates in make_narx
1 parent 8bf95b4 commit c2a8ca5

File tree

7 files changed

+137
-23
lines changed

7 files changed

+137
-23
lines changed

.github/workflows/emscripten.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
steps:
1010
- uses: actions/checkout@v5
1111
- name: Build WASM wheel
12-
uses: pypa/[email protected].0
12+
uses: pypa/[email protected].1
1313
env:
1414
CIBW_PLATFORM: pyodide
1515
- name: Upload package

.github/workflows/wheel.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ jobs:
3333
steps:
3434
- uses: actions/checkout@v5
3535
- name: Build wheels
36-
uses: pypa/[email protected].0
36+
uses: pypa/[email protected].1
3737
env:
3838
CIBW_SKIP: "*_i686 *_ppc64le *_s390x *_universal2 *-musllinux_* cp314t*"
3939
CIBW_PROJECT_REQUIRES_PYTHON: ">=3.10"

fastcan/narx/_feature.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# SPDX-License-Identifier: MIT
77

88
import math
9+
import warnings
910
from itertools import combinations_with_replacement
1011
from numbers import Integral
1112

@@ -198,12 +199,16 @@ def make_poly_features(X, ids):
198199
None,
199200
Interval(Integral, 1, None, closed="left"),
200201
],
202+
"max_poly": [None, Interval(Integral, 1, None, closed="left")],
203+
"random_state": ["random_state"],
201204
},
202205
prefer_skip_nested_validation=True,
203206
)
204207
def make_poly_ids(
205208
n_features=1,
206209
degree=1,
210+
max_poly=None,
211+
random_state=None,
207212
):
208213
"""Generate ids for polynomial features.
209214
(variable_index, variable_index, ...)
@@ -217,6 +222,15 @@ def make_poly_ids(
217222
degree : int, default=1
218223
The maximum degree of polynomial features.
219224
225+
max_poly : int, default=None
226+
Maximum number of ids of polynomial features to generate.
227+
Randomly selected by reservoir sampling.
228+
If None, all possible ids are returned.
229+
230+
random_state : int or RandomState instance, default=None
231+
Used when `max_poly` is not None to subsample ids of polynomial features.
232+
See :term:`Glossary <random_state>` for details.
233+
220234
Returns
221235
-------
222236
ids : array-like of shape (n_outputs, degree)
@@ -236,29 +250,45 @@ def make_poly_ids(
236250
[1, 2, 2],
237251
[2, 2, 2]])
238252
"""
239-
n_outputs = math.comb(n_features + degree, degree) - 1
240-
if n_outputs > np.iinfo(np.intp).max:
253+
n_total = math.comb(n_features + degree, degree) - 1
254+
if n_total > np.iinfo(np.intp).max:
241255
msg = (
242-
"The output that would result from the current configuration would"
243-
f" have {n_outputs} features which is too large to be"
244-
f" indexed by {np.intp().dtype.name}."
256+
"The current configuration would "
257+
f"result in {n_total} features which is too large to be "
258+
f"indexed by {np.intp().dtype.name}."
245259
)
246260
raise ValueError(msg)
247-
248-
ids = np.array(
249-
list(
250-
combinations_with_replacement(
251-
range(n_features + 1),
252-
degree,
253-
)
261+
if n_total > 10_000_000:
262+
warnings.warn(
263+
"Total number of polynomial features is larger than 10,000,000! "
264+
f"The current configuration would result in {n_total} features. "
265+
"This may take a while.",
266+
UserWarning,
267+
)
268+
if max_poly is not None and max_poly < n_total:
269+
# reservoir sampling
270+
rng = np.random.default_rng(random_state)
271+
reservoir = []
272+
for i, comb in enumerate(
273+
combinations_with_replacement(range(n_features + 1), degree)
274+
):
275+
if i < max_poly:
276+
reservoir.append(comb)
277+
else:
278+
j = rng.integers(0, i + 1)
279+
if j < max_poly:
280+
reservoir[j] = comb
281+
ids = np.array(reservoir)
282+
else:
283+
ids = np.array(
284+
list(combinations_with_replacement(range(n_features + 1), degree))
254285
)
255-
)
256286

257287
const_id = np.where((ids == 0).all(axis=1))
258288
return np.delete(ids, const_id, 0) # remove the constant feature
259289

260290

261-
def _valiate_time_shift_poly_ids(
291+
def _validate_time_shift_poly_ids(
262292
time_shift_ids, poly_ids, n_samples=None, n_features=None, n_outputs=None
263293
):
264294
if n_samples is None:
@@ -496,7 +526,7 @@ def tp2fd(time_shift_ids, poly_ids):
496526
[[-1 1]
497527
[ 2 3]]
498528
"""
499-
_time_shift_ids, _poly_ids = _valiate_time_shift_poly_ids(
529+
_time_shift_ids, _poly_ids = _validate_time_shift_poly_ids(
500530
time_shift_ids,
501531
poly_ids,
502532
)

fastcan/narx/_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,11 @@
77

88
import numpy as np
99
from scipy.stats import rankdata
10-
from sklearn.utils import check_array, check_consistent_length, column_or_1d
10+
from sklearn.utils import (
11+
check_array,
12+
check_consistent_length,
13+
column_or_1d,
14+
)
1115
from sklearn.utils._param_validation import Interval, StrOptions, validate_params
1216
from sklearn.utils.validation import check_is_fitted
1317

@@ -132,6 +136,8 @@ def _get_term_str(term_feat_ids, term_delay_ids):
132136
Interval(Integral, 1, None, closed="left"),
133137
],
134138
"fit_intercept": ["boolean"],
139+
"max_candidates": [None, Interval(Integral, 1, None, closed="left")],
140+
"random_state": ["random_state"],
135141
"include_zero_delay": [None, "array-like"],
136142
"static_indices": [None, "array-like"],
137143
"refine_verbose": ["verbose"],
@@ -155,6 +161,8 @@ def make_narx(
155161
poly_degree=1,
156162
*,
157163
fit_intercept=True,
164+
max_candidates=None,
165+
random_state=None,
158166
include_zero_delay=None,
159167
static_indices=None,
160168
refine_verbose=1,
@@ -186,6 +194,15 @@ def make_narx(
186194
fit_intercept : bool, default=True
187195
Whether to fit the intercept. If set to False, intercept will be zeros.
188196
197+
max_candidates : int, default=None
198+
Maximum number of candidate polynomial terms retained before selection.
199+
Randomly selected by reservoir sampling.
200+
If None, all candidates are considered.
201+
202+
random_state : int or RandomState instance, default=None
203+
Used when `max_candidates` is not None to subsample candidate terms.
204+
See :term:`Glossary <random_state>` for details.
205+
189206
include_zero_delay : {None, array-like} of shape (n_features,) default=None
190207
Whether to include the original (zero-delay) features.
191208
@@ -306,6 +323,8 @@ def make_narx(
306323
poly_ids_all = make_poly_ids(
307324
time_shift_ids_all.shape[0],
308325
poly_degree,
326+
max_poly=max_candidates,
327+
random_state=random_state,
309328
)
310329
poly_terms = make_poly_features(time_shift_vars, poly_ids_all)
311330

fastcan/narx/tests/test_narx.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,18 @@ def test_narx_is_sklearn_estimator():
2929
check_estimator(NARX(), expected_failed_checks=expected_failures)
3030

3131

32-
def test_poly_ids():
33-
with pytest.raises(ValueError, match=r"The output that would result from the .*"):
32+
def test_poly_ids(monkeypatch):
33+
with pytest.raises(ValueError, match=r"The current configuration would .*"):
3434
make_poly_ids(10, 1000)
3535

36+
# Mock combinations_with_replacement to avoid heavy computation
37+
monkeypatch.setattr(
38+
"fastcan.narx._feature.combinations_with_replacement",
39+
lambda *args, **kwargs: iter([[0, 0]]),
40+
)
41+
with pytest.warns(UserWarning, match=r"Total number of polynomial features .*"):
42+
make_poly_ids(18, 10)
43+
3644

3745
def test_time_ids():
3846
with pytest.raises(ValueError, match=r"The length of `include_zero_delay`.*"):
@@ -553,6 +561,57 @@ def test_make_narx_refine_print(capsys):
553561
assert "No. of iterations: " in captured.out
554562

555563

564+
def test_make_narx_max_candidates():
565+
"""Test max_candidates and random_state in make_narx."""
566+
rng = np.random.default_rng(12345)
567+
X = rng.random((100, 2))
568+
y = rng.random((100, 1))
569+
max_delay = 3
570+
poly_degree = 10
571+
n_terms_to_select = 5
572+
max_candidates = 20
573+
574+
# With the same random_state, the results should be identical
575+
narx1 = make_narx(
576+
X,
577+
y,
578+
n_terms_to_select=n_terms_to_select,
579+
max_delay=max_delay,
580+
poly_degree=poly_degree,
581+
max_candidates=max_candidates,
582+
random_state=123,
583+
verbose=0,
584+
)
585+
narx2 = make_narx(
586+
X,
587+
y,
588+
n_terms_to_select=n_terms_to_select,
589+
max_delay=max_delay,
590+
poly_degree=poly_degree,
591+
max_candidates=max_candidates,
592+
random_state=123,
593+
verbose=0,
594+
)
595+
assert_array_equal(narx1.feat_ids, narx2.feat_ids)
596+
assert_array_equal(narx1.delay_ids, narx2.delay_ids)
597+
598+
# With different random_state, the results should be different
599+
narx3 = make_narx(
600+
X,
601+
y,
602+
n_terms_to_select=n_terms_to_select,
603+
max_delay=max_delay,
604+
poly_degree=poly_degree,
605+
max_candidates=max_candidates,
606+
random_state=456,
607+
verbose=0,
608+
)
609+
assert not np.array_equal(narx1.feat_ids, narx3.feat_ids)
610+
611+
# Check if number of selected terms is correct
612+
assert narx1.feat_ids.shape[0] == n_terms_to_select
613+
614+
556615
@pytest.mark.parametrize("max_delay", [1, 3, 7, 10])
557616
def test_nan_split(max_delay):
558617
n_sessions = 10

pixi.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ asv-publish = { cmd = "python -m asv publish", cwd = "asv_benchmarks" }
143143
asv-preview = { cmd = "python -m asv preview", cwd = "asv_benchmarks", depends-on = ["asv-publish"] }
144144

145145
[tool.pixi.feature.test.tasks]
146-
test = "pytest ./tests ./fastcan/narx/tests"
147-
test-coverage = { cmd = "rm -rf .coverage && pytest --cov-report {{ FMT }} --cov={{ PACKAGE }} .", args = [{ arg = "PACKAGE", default = "fastcan" }, { arg = "FMT", default = "html" }] }
146+
test = "pytest"
147+
test-coverage = { cmd = "rm -rf .coverage && pytest --cov-report {{ FMT }} --cov={{ PACKAGE }}", args = [{ arg = "FMT", default = "html" }, { arg = "PACKAGE", default = "fastcan" }] }
148148

149149
[tool.pixi.feature.build.tasks]
150150
build-wheel = "rm -rf dist && python -m build -wnx -Cinstall-args=--tags=runtime,python-runtime,devel"
@@ -192,6 +192,12 @@ static = { features = ["static"], no-default-feature = true }
192192
nogil = { features = ["nogil"], no-default-feature = true }
193193
wasm = { features = ["wasm"], no-default-feature = true }
194194

195+
[tool.pytest.ini_options]
196+
testpaths = [
197+
"./tests",
198+
"./fastcan/narx/tests",
199+
]
200+
195201
[tool.coverage.run]
196202
omit = ["**/tests/*"]
197203

0 commit comments

Comments
 (0)