Skip to content

Commit 941a77f

Browse files
cajchristianPicoCentauri
authored andcommitted
Adding validate_data calls and tags for CUR, FPS, PCovCUR, and PCovFPS
1 parent a3ab23b commit 941a77f

File tree

1 file changed

+47
-15
lines changed

1 file changed

+47
-15
lines changed

src/skmatter/_selection.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,13 @@
8383
from scipy.sparse.linalg import eigsh
8484
from sklearn.base import BaseEstimator, MetaEstimatorMixin
8585
from sklearn.feature_selection._base import SelectorMixin
86-
from sklearn.utils import check_array, check_random_state, check_X_y, safe_mask
87-
from sklearn.utils.validation import FLOAT_DTYPES, as_float_array, check_is_fitted
86+
from sklearn.utils import check_random_state, safe_mask
87+
from sklearn.utils.validation import (
88+
FLOAT_DTYPES,
89+
as_float_array,
90+
check_is_fitted,
91+
validate_data,
92+
)
8893

8994
from .utils import (
9095
X_orthogonalizer,
@@ -157,11 +162,6 @@ def __init__(
157162
self.n_to_select = n_to_select
158163
self.score_threshold = score_threshold
159164
self.score_threshold_type = score_threshold_type
160-
if self.score_threshold_type not in ["relative", "absolute"]:
161-
raise ValueError(
162-
"invalid score_threshold_type, expected one of 'relative' or 'absolute'"
163-
)
164-
165165
self.full = full
166166
self.progress_bar = progress_bar
167167
self.random_state = random_state
@@ -184,6 +184,11 @@ def fit(self, X, y=None, warm_start=False):
184184
-------
185185
self : object
186186
"""
187+
if self.score_threshold_type not in ["relative", "absolute"]:
188+
raise ValueError(
189+
"invalid score_threshold_type, expected one of 'relative' or 'absolute'"
190+
)
191+
187192
if self.selection_type == "feature":
188193
self._axis = 1
189194
elif self.selection_type == "sample":
@@ -205,7 +210,7 @@ def fit(self, X, y=None, warm_start=False):
205210

206211
if hasattr(self, "mixing") or y is not None:
207212
X, y = self._validate_data(X, y, **params)
208-
X, y = check_X_y(X, y, multi_output=True)
213+
X, y = validate_data(self, X, y, multi_output=True)
209214

210215
if len(y.shape) == 1:
211216
# force y to have multi_output 2D format even when it's 1D, since
@@ -214,7 +219,7 @@ def fit(self, X, y=None, warm_start=False):
214219
y = y.reshape((len(y), 1))
215220

216221
else:
217-
X = check_array(X, **params)
222+
X = validate_data(self, X, **params)
218223

219224
if self.full and self.score_threshold is not None:
220225
raise ValueError(
@@ -308,7 +313,7 @@ def transform(self, X, y=None):
308313

309314
mask = self.get_support()
310315

311-
X = check_array(X)
316+
X = validate_data(self, X, reset=False)
312317

313318
if len(X.shape) == 1:
314319
if self._axis == 0:
@@ -486,6 +491,11 @@ def _more_tags(self):
486491
"requires_y": False,
487492
}
488493

494+
def __sklearn_tags__(self):
495+
tags = super().__sklearn_tags__()
496+
tags.target_tags.required = False
497+
return tags
498+
489499

490500
class _CUR(GreedySelector):
491501
"""Transformer that performs Greedy Selection by choosing features
@@ -560,6 +570,9 @@ def score(self, X, y=None):
560570
score : numpy.ndarray of (n_to_select_from_)
561571
:math:`\pi` importance for the given samples or features
562572
"""
573+
574+
X, y = validate_data(self, X, y, reset=False)
575+
563576
return self.pi_
564577

565578
def _init_greedy_search(self, X, y, n_to_select):
@@ -734,6 +747,9 @@ def score(self, X, y=None):
734747
score : numpy.ndarray of (n_to_select_from_)
735748
:math:`\pi` importance for the given samples or features
736749
"""
750+
751+
X, y = validate_data(self, X, y, reset=False)
752+
737753
return self.pi_
738754

739755
def _init_greedy_search(self, X, y, n_to_select):
@@ -927,6 +943,9 @@ def score(self, X, y=None):
927943
-------
928944
hausdorff : Hausdorff distances
929945
"""
946+
947+
X, y = validate_data(self, X, y, reset=False)
948+
930949
return self.hausdorff_
931950

932951
def get_distance(self):
@@ -1048,11 +1067,6 @@ def __init__(
10481067
full=False,
10491068
random_state=0,
10501069
):
1051-
if mixing == 1.0:
1052-
raise ValueError(
1053-
"Mixing = 1.0 corresponds to traditional FPS."
1054-
"Please use the FPS class."
1055-
)
10561070

10571071
self.mixing = mixing
10581072
self.initialize = initialize
@@ -1067,6 +1081,16 @@ def __init__(
10671081
random_state=random_state,
10681082
)
10691083

1084+
def fit(self, X, y=None, warm_start=False):
1085+
1086+
if self.mixing == 1.0:
1087+
raise ValueError(
1088+
"Mixing = 1.0 corresponds to traditional FPS."
1089+
"Please use the FPS class."
1090+
)
1091+
1092+
return super().fit(X, y)
1093+
10701094
def score(self, X, y=None):
10711095
"""Returns the Hausdorff distances of all samples to previous selections.
10721096
@@ -1083,6 +1107,9 @@ def score(self, X, y=None):
10831107
-------
10841108
hausdorff : Hausdorff distances
10851109
"""
1110+
1111+
X, y = validate_data(self, X, y, reset=False)
1112+
10861113
return self.hausdorff_
10871114

10881115
def get_distance(self):
@@ -1159,3 +1186,8 @@ def _more_tags(self):
11591186
return {
11601187
"requires_y": True,
11611188
}
1189+
1190+
def __sklearn_tags__(self):
1191+
tags = super().__sklearn_tags__()
1192+
tags.target_tags.required = True
1193+
return tags

0 commit comments

Comments
 (0)