8383from scipy .sparse .linalg import eigsh
8484from sklearn .base import BaseEstimator , MetaEstimatorMixin
8585from sklearn .feature_selection ._base import SelectorMixin
86- from sklearn .utils import check_array , check_random_state , check_X_y , safe_mask
87- from sklearn .utils .validation import FLOAT_DTYPES , as_float_array , check_is_fitted
86+ from sklearn .utils import check_random_state , safe_mask
87+ from sklearn .utils .validation import (
88+ FLOAT_DTYPES ,
89+ as_float_array ,
90+ check_is_fitted ,
91+ validate_data ,
92+ )
8893
8994from .utils import (
9095 X_orthogonalizer ,
@@ -157,11 +162,6 @@ def __init__(
157162 self .n_to_select = n_to_select
158163 self .score_threshold = score_threshold
159164 self .score_threshold_type = score_threshold_type
160- if self .score_threshold_type not in ["relative" , "absolute" ]:
161- raise ValueError (
162- "invalid score_threshold_type, expected one of 'relative' or 'absolute'"
163- )
164-
165165 self .full = full
166166 self .progress_bar = progress_bar
167167 self .random_state = random_state
@@ -184,6 +184,11 @@ def fit(self, X, y=None, warm_start=False):
184184 -------
185185 self : object
186186 """
187+ if self .score_threshold_type not in ["relative" , "absolute" ]:
188+ raise ValueError (
189+ "invalid score_threshold_type, expected one of 'relative' or 'absolute'"
190+ )
191+
187192 if self .selection_type == "feature" :
188193 self ._axis = 1
189194 elif self .selection_type == "sample" :
@@ -205,7 +210,7 @@ def fit(self, X, y=None, warm_start=False):
205210
206211 if hasattr (self , "mixing" ) or y is not None :
207212 X , y = self ._validate_data (X , y , ** params )
208- X , y = check_X_y ( X , y , multi_output = True )
213+ X , y = validate_data ( self , X , y , multi_output = True )
209214
210215 if len (y .shape ) == 1 :
211216 # force y to have multi_output 2D format even when it's 1D, since
@@ -214,7 +219,7 @@ def fit(self, X, y=None, warm_start=False):
214219 y = y .reshape ((len (y ), 1 ))
215220
216221 else :
217- X = check_array ( X , ** params )
222+ X = validate_data ( self , X , ** params )
218223
219224 if self .full and self .score_threshold is not None :
220225 raise ValueError (
@@ -308,7 +313,7 @@ def transform(self, X, y=None):
308313
309314 mask = self .get_support ()
310315
311- X = check_array ( X )
316+ X = validate_data ( self , X , reset = False )
312317
313318 if len (X .shape ) == 1 :
314319 if self ._axis == 0 :
@@ -486,6 +491,11 @@ def _more_tags(self):
486491 "requires_y" : False ,
487492 }
488493
494+ def __sklearn_tags__ (self ):
495+ tags = super ().__sklearn_tags__ ()
496+ tags .target_tags .required = False
497+ return tags
498+
489499
490500class _CUR (GreedySelector ):
491501 """Transformer that performs Greedy Selection by choosing features
@@ -560,6 +570,9 @@ def score(self, X, y=None):
560570 score : numpy.ndarray of (n_to_select_from_)
561571 :math:`\pi` importance for the given samples or features
562572 """
573+
574+ X , y = validate_data (self , X , y , reset = False )
575+
563576 return self .pi_
564577
565578 def _init_greedy_search (self , X , y , n_to_select ):
@@ -734,6 +747,9 @@ def score(self, X, y=None):
734747 score : numpy.ndarray of (n_to_select_from_)
735748 :math:`\pi` importance for the given samples or features
736749 """
750+
751+ X , y = validate_data (self , X , y , reset = False )
752+
737753 return self .pi_
738754
739755 def _init_greedy_search (self , X , y , n_to_select ):
@@ -927,6 +943,9 @@ def score(self, X, y=None):
927943 -------
928944 hausdorff : Hausdorff distances
929945 """
946+
947+ X , y = validate_data (self , X , y , reset = False )
948+
930949 return self .hausdorff_
931950
932951 def get_distance (self ):
@@ -1048,11 +1067,6 @@ def __init__(
10481067 full = False ,
10491068 random_state = 0 ,
10501069 ):
1051- if mixing == 1.0 :
1052- raise ValueError (
1053- "Mixing = 1.0 corresponds to traditional FPS."
1054- "Please use the FPS class."
1055- )
10561070
10571071 self .mixing = mixing
10581072 self .initialize = initialize
@@ -1067,6 +1081,16 @@ def __init__(
10671081 random_state = random_state ,
10681082 )
10691083
1084+ def fit (self , X , y = None , warm_start = False ):
1085+
1086+ if self .mixing == 1.0 :
1087+ raise ValueError (
1088+ "Mixing = 1.0 corresponds to traditional FPS."
1089+ "Please use the FPS class."
1090+ )
1091+
1092+ return super ().fit (X , y )
1093+
10701094 def score (self , X , y = None ):
10711095 """Returns the Hausdorff distances of all samples to previous selections.
10721096
@@ -1083,6 +1107,9 @@ def score(self, X, y=None):
10831107 -------
10841108 hausdorff : Hausdorff distances
10851109 """
1110+
1111+ X , y = validate_data (self , X , y , reset = False )
1112+
10861113 return self .hausdorff_
10871114
10881115 def get_distance (self ):
@@ -1159,3 +1186,8 @@ def _more_tags(self):
11591186 return {
11601187 "requires_y" : True ,
11611188 }
1189+
1190+ def __sklearn_tags__ (self ):
1191+ tags = super ().__sklearn_tags__ ()
1192+ tags .target_tags .required = True
1193+ return tags
0 commit comments