1313from sklearn .linear_model import LinearRegression
1414from sklearn .utils import check_array , check_consistent_length , column_or_1d
1515from sklearn .utils ._param_validation import Interval , StrOptions , validate_params
16- from sklearn .utils .validation import check_is_fitted
16+ from sklearn .utils .validation import (
17+ _check_sample_weight ,
18+ check_is_fitted ,
19+ validate_data ,
20+ )
1721
1822from ._fastcan import FastCan
1923from ._refine import refine
@@ -52,7 +56,7 @@ def make_time_shift_features(X, ids):
5256 [5., 3., 4.],
5357 [7., 5., 6.]])
5458 """
55- X = check_array (X , ensure_2d = True , dtype = float , force_all_finite = "allow-nan" )
59+ X = check_array (X , ensure_2d = True , dtype = float , ensure_all_finite = "allow-nan" )
5660 ids = check_array (ids , ensure_2d = True , dtype = int )
5761 n_samples = X .shape [0 ]
5862 n_outputs = ids .shape [0 ]
@@ -177,7 +181,7 @@ def make_poly_features(X, ids):
177181 [ 1., 5., 25., 6.],
178182 [ 1., 7., 49., 8.]])
179183 """
180- X = check_array (X , ensure_2d = True , dtype = float , force_all_finite = "allow-nan" )
184+ X = check_array (X , ensure_2d = True , dtype = float , ensure_all_finite = "allow-nan" )
181185 ids = check_array (ids , ensure_2d = True , dtype = int )
182186 n_samples = X .shape [0 ]
183187 n_outputs , degree = ids .shape
@@ -269,7 +273,7 @@ def _mask_missing_value(*arr):
269273 return tuple ([x [mask_nomissing ] for x in arr ])
270274
271275
272- class Narx (BaseEstimator , RegressorMixin ):
276+ class Narx (RegressorMixin , BaseEstimator ):
273277 """Nonlinear Autoregressive eXogenous model.
274278 For example, a (polynomial) Narx model is like
275279 y(t) = y(t-1)*u(t-1) + u(t-1)^2 + u(t-2) + 1.5
@@ -374,10 +378,11 @@ def __init__(
374378 @validate_params (
375379 {
376380 "coef_init" : [None , StrOptions ({"one_step_ahead" }), "array-like" ],
381+ "sample_weight" : ["array-like" , None ],
377382 },
378383 prefer_skip_nested_validation = True ,
379384 )
380- def fit (self , X , y , coef_init = None , ** params ):
385+ def fit (self , X , y , sample_weight = None , coef_init = None , ** params ):
381386 """
382387 Fit narx model.
383388
@@ -389,6 +394,10 @@ def fit(self, X, y, coef_init=None, **params):
389394 y : array-like of shape (n_samples,)
390395 Target values. Will be cast to X's dtype if necessary.
391396
397+ sample_weight : array-like of shape (n_samples,), default=None
398+ Individual weights for each sample, which are used for a One-Step-Ahead
399+ Narx.
400+
392401 coef_init : array-like of shape (n_terms,), default=None
393402 The initial values of coefficients and intercept for optimization.
394403 When `coef_init` is None, the model will be a One-Step-Ahead Narx.
@@ -410,9 +419,15 @@ def fit(self, X, y, coef_init=None, **params):
410419 self : object
411420 Fitted Estimator.
412421 """
413- X = self ._validate_data (X , dtype = float , force_all_finite = "allow-nan" )
422+ X = validate_data (
423+ self ,
424+ X ,
425+ dtype = float ,
426+ ensure_all_finite = "allow-nan" ,
427+ )
414428 y = column_or_1d (y , dtype = float , warn = True )
415429 check_consistent_length (X , y )
430+ sample_weight = _check_sample_weight (sample_weight , X )
416431
417432 if self .time_shift_ids is None :
418433 self .time_shift_ids_ = make_time_shift_ids (
@@ -467,9 +482,11 @@ def fit(self, X, y, coef_init=None, **params):
467482 time_shift_vars = make_time_shift_features (xy_hstack , self .time_shift_ids_ )
468483 poly_terms = make_poly_features (time_shift_vars , self .poly_ids_ )
469484 # Remove missing values
470- poly_terms_masked , y_masked = _mask_missing_value (poly_terms , y )
485+ poly_terms_masked , y_masked , sample_weight_masked = _mask_missing_value (
486+ poly_terms , y , sample_weight
487+ )
471488
472- osa_narx .fit (poly_terms_masked , y_masked )
489+ osa_narx .fit (poly_terms_masked , y_masked , sample_weight_masked )
473490 if coef_init is None :
474491 self .coef_ = osa_narx .coef_
475492 self .intercept_ = osa_narx .intercept_
@@ -545,7 +562,7 @@ def _predict(expression, X, y_init, coef, intercept, max_delay):
545562 else :
546563 y_hat [k ] = expression (X , y_hat , coef , intercept , k )
547564 if np .any (y_hat [k ] > 1e20 ):
548- y_hat [k :] = np . inf
565+ y_hat [k :] = 1e20
549566 return y_hat
550567 return y_hat
551568
@@ -564,7 +581,7 @@ def _residual(
564581
565582 y_masked , y_hat_masked = _mask_missing_value (y , y_hat )
566583
567- return ( y_masked - y_hat_masked ). flatten ()
584+ return y_masked - y_hat_masked
568585
569586 @validate_params (
570587 {
@@ -591,7 +608,7 @@ def predict(self, X, y_init=None):
591608 """
592609 check_is_fitted (self )
593610
594- X = self . _validate_data ( X , reset = False , force_all_finite = "allow-nan" )
611+ X = validate_data ( self , X , reset = False , ensure_all_finite = "allow-nan" )
595612 if y_init is None :
596613 y_init = np .zeros (self .max_delay_ )
597614 else :
@@ -613,8 +630,10 @@ def predict(self, X, y_init=None):
613630 self .max_delay_ ,
614631 )
615632
616- def _more_tags (self ):
617- return {"allow_nan" : True }
633+ def __sklearn_tags__ (self ):
634+ tags = super ().__sklearn_tags__ ()
635+ tags .input_tags .allow_nan = True
636+ return tags
618637
619638
620639@validate_params (
@@ -718,14 +737,13 @@ def _get_term_str(term_id):
718737 ],
719738 "include_zero_delay" : [None , "array-like" ],
720739 "static_indices" : [None , "array-like" ],
721- "eta" : ["boolean" ],
722- "verbose" : ["verbose" ],
723- "drop" : [
740+ "refine_verbose" : ["verbose" ],
741+ "refine_drop" : [
724742 None ,
725743 Interval (Integral , 1 , None , closed = "left" ),
726744 StrOptions ({"all" }),
727745 ],
728- "max_iter " : [
746+ "refine_max_iter " : [
729747 None ,
730748 Interval (Integral , 1 , None , closed = "left" ),
731749 ],
@@ -741,10 +759,10 @@ def make_narx(
741759 * ,
742760 include_zero_delay = None ,
743761 static_indices = None ,
744- eta = False ,
745- verbose = 1 ,
746- drop = None ,
747- max_iter = None ,
762+ refine_verbose = 1 ,
763+ refine_drop = None ,
764+ refine_max_iter = None ,
765+ ** params ,
748766):
749767 """Find `time_shift_ids` and `poly_ids` for a Narx model.
750768
@@ -775,19 +793,20 @@ def make_narx(
775793 If the corresponding include_zero_delay of the static features is False, the
776794 static feature will be excluded from candidate features.
777795
778- eta : bool , default=False
779- Whether to use eta-cosine method .
796+ refine_verbose : int , default=1
797+ The verbosity level of refine .
780798
781- verbose : int, default=1
782- The verbosity level.
783-
784- drop : int or "all", default=None
799+ refine_drop : int or "all", default=None
785800 The number of the selected features dropped for the consequencing
786801 reselection. If `drop` is None, no refining will be performed.
787802
788- max_iter : int, default=None
803+ refine_max_iter : int, default=None
789804 The maximum number of valid iterations in the refining process.
790805
806+ **params : dict
807+ Keyword arguments passed to
808+ `fastcan.FastCan`.
809+
791810 Returns
792811 -------
793812 narx : Narx
@@ -818,7 +837,8 @@ def make_narx(
818837 ... static_indices=[1],
819838 ... eta=True,
820839 ... verbose=0,
821- ... drop=1)
840+ ... refine_verbose=0,
841+ ... refine_drop=1)
822842 >>> print(f"{mean_squared_error(y, narx.fit(X, y).predict(X)):.4f}")
823843 0.0289
824844 >>> print_narx(narx)
@@ -830,20 +850,22 @@ def make_narx(
830850 | X[k-1,0]*X[k-3,0] | 1.999 |
831851 | X[k-2,0]*X[k-0,1] | 1.527 |
832852 """
833- X = check_array (X , dtype = float , ensure_2d = True , force_all_finite = "allow-nan" )
853+ X = check_array (X , dtype = float , ensure_2d = True , ensure_all_finite = "allow-nan" )
834854 y = column_or_1d (y , dtype = float )
835855 check_consistent_length (X , y )
836856
837857 xy_hstack = np .c_ [X , y ]
838858 n_features = X .shape [1 ]
839859
840860 if include_zero_delay is None :
841- include_zero_delay = [True ] * n_features + [False ]
861+ _include_zero_delay = [True ] * n_features + [False ]
862+ else :
863+ _include_zero_delay = include_zero_delay + [False ]
842864
843865 time_shift_ids_all = make_time_shift_ids (
844866 n_features = xy_hstack .shape [1 ],
845867 max_delay = max_delay ,
846- include_zero_delay = include_zero_delay ,
868+ include_zero_delay = _include_zero_delay ,
847869 )
848870
849871 time_shift_ids_all = np .delete (
@@ -867,11 +889,12 @@ def make_narx(
867889
868890 csf = FastCan (
869891 n_features_to_select ,
870- eta = eta ,
871- verbose = 0 ,
892+ ** params ,
872893 ).fit (poly_terms_masked , y_masked )
873- if drop is not None :
874- indices , _ = refine (csf , drop = drop , max_iter = max_iter , verbose = verbose )
894+ if refine_drop is not None :
895+ indices , _ = refine (
896+ csf , drop = refine_drop , max_iter = refine_max_iter , verbose = refine_verbose
897+ )
875898 support = np .zeros (shape = csf .n_features_in_ , dtype = bool )
876899 support [indices ] = True
877900 else :
0 commit comments