Skip to content

Commit 762fa48

Browse files
committed
iter
1 parent c457b4a commit 762fa48

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

imblearn/tests/test_common.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99

1010
import numpy as np
1111
import pytest
12+
import sklearn
1213
from sklearn.base import clone
1314
from sklearn.exceptions import ConvergenceWarning
15+
from sklearn.utils.fixes import parse_version
1416
from sklearn.utils._testing import ignore_warnings
1517
from sklearn.utils.estimator_checks import (
1618
parametrize_with_checks as parametrize_with_checks_sklearn,
@@ -27,9 +29,18 @@
2729
from imblearn.utils.testing import all_estimators
2830
from imblearn.utils._test_common.instance_generator import (
2931
_get_check_estimator_ids,
32+
_get_expected_failed_checks,
3033
_tested_estimators,
3134
)
3235

36+
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)
37+
if sklearn_version >= parse_version("1.6"):
38+
kwargs_parametrize_with_checks = {
39+
"expected_failed_checks": _get_expected_failed_checks
40+
}
41+
else:
42+
kwargs_parametrize_with_checks = {}
43+
3344

3445
@pytest.mark.parametrize("name, Estimator", all_estimators())
3546
def test_all_estimator_no_base_class(name, Estimator):
@@ -38,13 +49,17 @@ def test_all_estimator_no_base_class(name, Estimator):
3849
assert not name.lower().startswith("base"), msg
3950

4051

41-
@parametrize_with_checks_sklearn(list(_tested_estimators()))
52+
@parametrize_with_checks_sklearn(
53+
list(_tested_estimators()), **kwargs_parametrize_with_checks
54+
)
4255
def test_estimators_compatibility_sklearn(estimator, check, request):
4356
_set_checking_parameters(estimator)
4457
check(estimator)
4558

4659

47-
@parametrize_with_checks(list(_tested_estimators()))
60+
@parametrize_with_checks(
61+
list(_tested_estimators()), expected_failed_checks=_get_expected_failed_checks
62+
)
4863
def test_estimators_imblearn(estimator, check, request):
4964
# Common tests for estimator instances
5065
with ignore_warnings(

imblearn/utils/_test_common/instance_generator.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
BalancedBaggingClassifier,
2020
BalancedRandomForestClassifier,
2121
EasyEnsembleClassifier,
22+
RUSBoostClassifier,
2223
)
2324
from imblearn.over_sampling import (
2425
ADASYN,
@@ -83,7 +84,13 @@
8384
# same check with multiple instances of the same estimator with different parameters.
8485
# The special key "*" allows to apply the parameters to all checks.
8586
# TODO(devtools): allow third-party developers to pass test specific params to checks
86-
PER_ESTIMATOR_CHECK_PARAMS: dict = {}
87+
PER_ESTIMATOR_CHECK_PARAMS: dict = {
88+
Pipeline: {
89+
"check_classifiers_with_encoded_labels": dict(
90+
sampler__sampling_strategy={"setosa": 20, "virginica": 20}
91+
)
92+
}
93+
}
8794

8895
SKIPPED_ESTIMATORS = [SMOTENC]
8996

@@ -187,3 +194,31 @@ def _yield_instances_for_check(check, estimator_orig):
187194
estimator = clone(estimator_orig)
188195
estimator.set_params(**params)
189196
yield estimator
197+
198+
199+
PER_ESTIMATOR_XFAIL_CHECKS = {
200+
BalancedRandomForestClassifier: {
201+
"check_sample_weight_equivalence": "FIXME",
202+
},
203+
NearMiss: {
204+
"check_samplers_fit_resample": "FIXME",
205+
},
206+
Pipeline: {
207+
"check_dont_overwrite_parameters": (
208+
"Pipeline changes the `steps` parameter, which it shouldn't."
209+
"Therefore this test is x-fail until we fix this."
210+
),
211+
"check_estimators_overwrite_params": (
212+
"Pipeline changes the `steps` parameter, which it shouldn't."
213+
"Therefore this test is x-fail until we fix this."
214+
),
215+
},
216+
RUSBoostClassifier: {
217+
"check_sample_weight_equivalence": "FIXME",
218+
},
219+
}
220+
221+
def _get_expected_failed_checks(estimator):
222+
"""Get the expected failed checks for all estimators in scikit-learn."""
223+
failed_checks = PER_ESTIMATOR_XFAIL_CHECKS.get(type(estimator), {})
224+
return failed_checks

imblearn/utils/estimator_checks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,6 @@ def check_classifiers_with_encoded_labels(name, classifier_orig):
668668
"virginica": 50,
669669
},
670670
)
671-
classifier.set_params(sampling_strategy={"setosa": 20, "virginica": 20})
672671
classifier.fit(df, y)
673672
assert set(classifier.classes_) == set(y.cat.categories.tolist())
674673
y_pred = classifier.predict(df)

0 commit comments

Comments
 (0)