Skip to content

Commit 1ae7d9e

Browse files
Merge pull request #26 from MatthewSZhang/sklearn1.6
2 parents 0e6bbbc + 56f15bd commit 1ae7d9e

File tree

6 files changed

+1882
-2917
lines changed

6 files changed

+1882
-2917
lines changed

fastcan/_fastcan.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sklearn.utils import check_array, check_consistent_length
1313
from sklearn.utils._openmp_helpers import _openmp_effective_n_threads
1414
from sklearn.utils._param_validation import Interval
15-
from sklearn.utils.validation import check_is_fitted
15+
from sklearn.utils.validation import check_is_fitted, validate_data
1616

1717
from ._cancorr_fast import _forward_search # type: ignore
1818

@@ -162,7 +162,8 @@ def fit(self, X, y):
162162
"dtype": float,
163163
"force_writeable": True,
164164
}
165-
X, y = self._validate_data(
165+
X, y = validate_data(
166+
self,
166167
X=X,
167168
y=y,
168169
multi_output=True,

fastcan/_narx.py

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
from sklearn.linear_model import LinearRegression
1414
from sklearn.utils import check_array, check_consistent_length, column_or_1d
1515
from sklearn.utils._param_validation import Interval, StrOptions, validate_params
16-
from sklearn.utils.validation import check_is_fitted
16+
from sklearn.utils.validation import (
17+
_check_sample_weight,
18+
check_is_fitted,
19+
validate_data,
20+
)
1721

1822
from ._fastcan import FastCan
1923
from ._refine import refine
@@ -52,7 +56,7 @@ def make_time_shift_features(X, ids):
5256
[5., 3., 4.],
5357
[7., 5., 6.]])
5458
"""
55-
X = check_array(X, ensure_2d=True, dtype=float, force_all_finite="allow-nan")
59+
X = check_array(X, ensure_2d=True, dtype=float, ensure_all_finite="allow-nan")
5660
ids = check_array(ids, ensure_2d=True, dtype=int)
5761
n_samples = X.shape[0]
5862
n_outputs = ids.shape[0]
@@ -177,7 +181,7 @@ def make_poly_features(X, ids):
177181
[ 1., 5., 25., 6.],
178182
[ 1., 7., 49., 8.]])
179183
"""
180-
X = check_array(X, ensure_2d=True, dtype=float, force_all_finite="allow-nan")
184+
X = check_array(X, ensure_2d=True, dtype=float, ensure_all_finite="allow-nan")
181185
ids = check_array(ids, ensure_2d=True, dtype=int)
182186
n_samples = X.shape[0]
183187
n_outputs, degree = ids.shape
@@ -269,7 +273,7 @@ def _mask_missing_value(*arr):
269273
return tuple([x[mask_nomissing] for x in arr])
270274

271275

272-
class Narx(BaseEstimator, RegressorMixin):
276+
class Narx(RegressorMixin, BaseEstimator):
273277
"""Nonlinear Autoregressive eXogenous model.
274278
For example, a (polynomial) Narx model is like
275279
y(t) = y(t-1)*u(t-1) + u(t-1)^2 + u(t-2) + 1.5
@@ -374,10 +378,11 @@ def __init__(
374378
@validate_params(
375379
{
376380
"coef_init": [None, StrOptions({"one_step_ahead"}), "array-like"],
381+
"sample_weight": ["array-like", None],
377382
},
378383
prefer_skip_nested_validation=True,
379384
)
380-
def fit(self, X, y, coef_init=None, **params):
385+
def fit(self, X, y, sample_weight=None, coef_init=None, **params):
381386
"""
382387
Fit narx model.
383388
@@ -389,6 +394,10 @@ def fit(self, X, y, coef_init=None, **params):
389394
y : array-like of shape (n_samples,)
390395
Target values. Will be cast to X's dtype if necessary.
391396
397+
sample_weight : array-like of shape (n_samples,), default=None
398+
Individual weights for each sample, which are used for a One-Step-Ahead
399+
Narx.
400+
392401
coef_init : array-like of shape (n_terms,), default=None
393402
The initial values of coefficients and intercept for optimization.
394403
When `coef_init` is None, the model will be a One-Step-Ahead Narx.
@@ -410,9 +419,15 @@ def fit(self, X, y, coef_init=None, **params):
410419
self : object
411420
Fitted Estimator.
412421
"""
413-
X = self._validate_data(X, dtype=float, force_all_finite="allow-nan")
422+
X = validate_data(
423+
self,
424+
X,
425+
dtype=float,
426+
ensure_all_finite="allow-nan",
427+
)
414428
y = column_or_1d(y, dtype=float, warn=True)
415429
check_consistent_length(X, y)
430+
sample_weight = _check_sample_weight(sample_weight, X)
416431

417432
if self.time_shift_ids is None:
418433
self.time_shift_ids_ = make_time_shift_ids(
@@ -467,9 +482,11 @@ def fit(self, X, y, coef_init=None, **params):
467482
time_shift_vars = make_time_shift_features(xy_hstack, self.time_shift_ids_)
468483
poly_terms = make_poly_features(time_shift_vars, self.poly_ids_)
469484
# Remove missing values
470-
poly_terms_masked, y_masked = _mask_missing_value(poly_terms, y)
485+
poly_terms_masked, y_masked, sample_weight_masked = _mask_missing_value(
486+
poly_terms, y, sample_weight
487+
)
471488

472-
osa_narx.fit(poly_terms_masked, y_masked)
489+
osa_narx.fit(poly_terms_masked, y_masked, sample_weight_masked)
473490
if coef_init is None:
474491
self.coef_ = osa_narx.coef_
475492
self.intercept_ = osa_narx.intercept_
@@ -545,7 +562,7 @@ def _predict(expression, X, y_init, coef, intercept, max_delay):
545562
else:
546563
y_hat[k] = expression(X, y_hat, coef, intercept, k)
547564
if np.any(y_hat[k] > 1e20):
548-
y_hat[k:] = np.inf
565+
y_hat[k:] = 1e20
549566
return y_hat
550567
return y_hat
551568

@@ -564,7 +581,7 @@ def _residual(
564581

565582
y_masked, y_hat_masked = _mask_missing_value(y, y_hat)
566583

567-
return (y_masked - y_hat_masked).flatten()
584+
return y_masked - y_hat_masked
568585

569586
@validate_params(
570587
{
@@ -591,7 +608,7 @@ def predict(self, X, y_init=None):
591608
"""
592609
check_is_fitted(self)
593610

594-
X = self._validate_data(X, reset=False, force_all_finite="allow-nan")
611+
X = validate_data(self, X, reset=False, ensure_all_finite="allow-nan")
595612
if y_init is None:
596613
y_init = np.zeros(self.max_delay_)
597614
else:
@@ -613,8 +630,10 @@ def predict(self, X, y_init=None):
613630
self.max_delay_,
614631
)
615632

616-
def _more_tags(self):
617-
return {"allow_nan": True}
633+
def __sklearn_tags__(self):
634+
tags = super().__sklearn_tags__()
635+
tags.input_tags.allow_nan = True
636+
return tags
618637

619638

620639
@validate_params(
@@ -718,14 +737,13 @@ def _get_term_str(term_id):
718737
],
719738
"include_zero_delay": [None, "array-like"],
720739
"static_indices": [None, "array-like"],
721-
"eta": ["boolean"],
722-
"verbose": ["verbose"],
723-
"drop": [
740+
"refine_verbose": ["verbose"],
741+
"refine_drop": [
724742
None,
725743
Interval(Integral, 1, None, closed="left"),
726744
StrOptions({"all"}),
727745
],
728-
"max_iter": [
746+
"refine_max_iter": [
729747
None,
730748
Interval(Integral, 1, None, closed="left"),
731749
],
@@ -741,10 +759,10 @@ def make_narx(
741759
*,
742760
include_zero_delay=None,
743761
static_indices=None,
744-
eta=False,
745-
verbose=1,
746-
drop=None,
747-
max_iter=None,
762+
refine_verbose=1,
763+
refine_drop=None,
764+
refine_max_iter=None,
765+
**params,
748766
):
749767
"""Find `time_shift_ids` and `poly_ids` for a Narx model.
750768
@@ -775,19 +793,20 @@ def make_narx(
775793
If the corresponding include_zero_delay of the static features is False, the
776794
static feature will be excluded from candidate features.
777795
778-
eta : bool, default=False
779-
Whether to use eta-cosine method.
796+
refine_verbose : int, default=1
797+
The verbosity level of refine.
780798
781-
verbose : int, default=1
782-
The verbosity level.
783-
784-
drop : int or "all", default=None
799+
refine_drop : int or "all", default=None
785800
The number of the selected features dropped for the consequencing
786801
reselection. If `drop` is None, no refining will be performed.
787802
788-
max_iter : int, default=None
803+
refine_max_iter : int, default=None
789804
The maximum number of valid iterations in the refining process.
790805
806+
**params : dict
807+
Keyword arguments passed to
808+
`fastcan.FastCan`.
809+
791810
Returns
792811
-------
793812
narx : Narx
@@ -818,7 +837,8 @@ def make_narx(
818837
... static_indices=[1],
819838
... eta=True,
820839
... verbose=0,
821-
... drop=1)
840+
... refine_verbose=0,
841+
... refine_drop=1)
822842
>>> print(f"{mean_squared_error(y, narx.fit(X, y).predict(X)):.4f}")
823843
0.0289
824844
>>> print_narx(narx)
@@ -830,20 +850,22 @@ def make_narx(
830850
| X[k-1,0]*X[k-3,0] | 1.999 |
831851
| X[k-2,0]*X[k-0,1] | 1.527 |
832852
"""
833-
X = check_array(X, dtype=float, ensure_2d=True, force_all_finite="allow-nan")
853+
X = check_array(X, dtype=float, ensure_2d=True, ensure_all_finite="allow-nan")
834854
y = column_or_1d(y, dtype=float)
835855
check_consistent_length(X, y)
836856

837857
xy_hstack = np.c_[X, y]
838858
n_features = X.shape[1]
839859

840860
if include_zero_delay is None:
841-
include_zero_delay = [True] * n_features + [False]
861+
_include_zero_delay = [True] * n_features + [False]
862+
else:
863+
_include_zero_delay = include_zero_delay + [False]
842864

843865
time_shift_ids_all = make_time_shift_ids(
844866
n_features=xy_hstack.shape[1],
845867
max_delay=max_delay,
846-
include_zero_delay=include_zero_delay,
868+
include_zero_delay=_include_zero_delay,
847869
)
848870

849871
time_shift_ids_all = np.delete(
@@ -867,11 +889,12 @@ def make_narx(
867889

868890
csf = FastCan(
869891
n_features_to_select,
870-
eta=eta,
871-
verbose=0,
892+
**params,
872893
).fit(poly_terms_masked, y_masked)
873-
if drop is not None:
874-
indices, _ = refine(csf, drop=drop, max_iter=max_iter, verbose=verbose)
894+
if refine_drop is not None:
895+
indices, _ = refine(
896+
csf, drop=refine_drop, max_iter=refine_max_iter, verbose=refine_verbose
897+
)
875898
support = np.zeros(shape=csf.n_features_in_, dtype=bool)
876899
support[indices] = True
877900
else:

meson.build

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
project(
22
'fastcan',
33
'c', 'cython',
4-
version: '0.2.7',
4+
version: '0.3.0',
55
license: 'MIT',
66
meson_version: '>= 1.1.0',
77
default_options: [

0 commit comments

Comments
 (0)