@@ -51,16 +51,22 @@ def _set_checking_parameters(estimator):
51
51
52
52
53
53
def _yield_sampler_checks (sampler ):
54
+ tags = sampler ._get_tags ()
54
55
yield check_target_type
55
56
yield check_samplers_one_label
56
57
yield check_samplers_fit
57
58
yield check_samplers_fit_resample
58
59
yield check_samplers_sampling_strategy_fit_resample
59
- yield check_samplers_sparse
60
- yield check_samplers_pandas
60
+ if "sparse" in tags ["X_types" ]:
61
+ yield check_samplers_sparse
62
+ if "dataframe" in tags ["X_types" ]:
63
+ yield check_samplers_pandas
61
64
yield check_samplers_list
62
65
yield check_samplers_multiclass_ova
63
66
yield check_samplers_preserve_dtype
67
+ # we don't filter samplers based on their tag here because we want to make
68
+ # sure that the fitted attribute does not exist if the tag is not
69
+ # stipulated
64
70
yield check_samplers_sample_indices
65
71
yield check_samplers_2d_target
66
72
@@ -75,7 +81,8 @@ def _yield_all_checks(estimator):
75
81
tags = estimator ._get_tags ()
76
82
if tags ["_skip_test" ]:
77
83
warnings .warn (
78
- f"Explicit SKIP via _skip_test tag for estimator { name } ." , SkipTestWarning ,
84
+ f"Explicit SKIP via _skip_test tag for estimator { name } ." ,
85
+ SkipTestWarning ,
79
86
)
80
87
return
81
88
# trigger our checks if this is a SamplerMixin
@@ -116,6 +123,7 @@ def parametrize_with_checks(estimators):
116
123
... def test_sklearn_compatible_estimator(estimator, check):
117
124
... check(estimator)
118
125
"""
126
+
119
127
def checks_generator ():
120
128
for estimator in estimators :
121
129
name = type (estimator ).__name__
@@ -124,9 +132,7 @@ def checks_generator():
124
132
yield _maybe_mark_xfail (estimator , check , pytest )
125
133
126
134
return pytest .mark .parametrize (
127
- "estimator, check" ,
128
- checks_generator (),
129
- ids = _get_check_estimator_ids
135
+ "estimator, check" , checks_generator (), ids = _get_check_estimator_ids
130
136
)
131
137
132
138
@@ -137,14 +143,22 @@ def check_target_type(name, estimator_orig):
137
143
y = np .linspace (0 , 1 , 20 )
138
144
msg = "Unknown label type: 'continuous'"
139
145
assert_raises_regex (
140
- ValueError , msg , estimator .fit_resample , X , y ,
146
+ ValueError ,
147
+ msg ,
148
+ estimator .fit_resample ,
149
+ X ,
150
+ y ,
141
151
)
142
152
# if the target is multilabel then we should raise an error
143
153
rng = np .random .RandomState (42 )
144
154
y = rng .randint (2 , size = (20 , 3 ))
145
155
msg = "Multilabel and multioutput targets are not supported."
146
156
assert_raises_regex (
147
- ValueError , msg , estimator .fit_resample , X , y ,
157
+ ValueError ,
158
+ msg ,
159
+ estimator .fit_resample ,
160
+ X ,
161
+ y ,
148
162
)
149
163
150
164
@@ -385,9 +399,7 @@ def check_samplers_sample_indices(name, sampler_orig):
385
399
assert not hasattr (sampler , "sample_indices_" )
386
400
387
401
388
- def check_classifier_on_multilabel_or_multioutput_targets (
389
- name , estimator_orig
390
- ):
402
+ def check_classifier_on_multilabel_or_multioutput_targets (name , estimator_orig ):
391
403
estimator = clone (estimator_orig )
392
404
X , y = make_multilabel_classification (n_samples = 30 )
393
405
msg = "Multilabel and multioutput targets are not supported."
0 commit comments