Skip to content

Commit e42f6a5

Browse files
committed
iter
1 parent 09570a3 commit e42f6a5

File tree

4 files changed

+35
-136
lines changed

4 files changed

+35
-136
lines changed

imblearn/utils/estimator_checks.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from functools import partial
1313

1414
import numpy as np
15-
import pytest
1615
import sklearn
1716
from scipy import sparse
1817
from sklearn.base import clone, is_classifier, is_regressor
@@ -30,7 +29,6 @@
3029
SkipTest,
3130
assert_allclose,
3231
assert_array_equal,
33-
assert_raises_regex,
3432
raises,
3533
set_random_state,
3634
)
@@ -62,11 +60,6 @@ def sample_dataset_generator():
6260
return X, y
6361

6462

65-
@pytest.fixture(name="sample_dataset_generator")
66-
def sample_dataset_generator_fixture():
67-
return sample_dataset_generator()
68-
69-
7063
def _set_checking_parameters(estimator):
7164
params = estimator.get_params()
7265
name = estimator.__class__.__name__
@@ -166,6 +159,7 @@ def parametrize_with_checks(estimators):
166159
... def test_sklearn_compatible_estimator(estimator, check):
167160
... check(estimator)
168161
"""
162+
import pytest
169163

170164
def checks_generator():
171165
for estimator in estimators:
@@ -185,24 +179,14 @@ def check_target_type(name, estimator_orig):
185179
X = np.random.random((20, 2))
186180
y = np.linspace(0, 1, 20)
187181
msg = "Unknown label type:"
188-
assert_raises_regex(
189-
ValueError,
190-
msg,
191-
estimator.fit_resample,
192-
X,
193-
y,
194-
)
182+
with raises(ValueError, err_msg=msg):
183+
estimator.fit_resample(X, y)
195184
# if the target is multilabel then we should raise an error
196185
rng = np.random.RandomState(42)
197186
y = rng.randint(2, size=(20, 3))
198187
msg = "Multilabel and multioutput targets are not supported."
199-
assert_raises_regex(
200-
ValueError,
201-
msg,
202-
estimator.fit_resample,
203-
X,
204-
y,
205-
)
188+
with raises(ValueError, err_msg=msg):
189+
estimator.fit_resample(X, y)
206190

207191

208192
def check_samplers_one_label(name, sampler_orig):
@@ -303,7 +287,12 @@ def check_samplers_sparse(name, sampler_orig):
303287

304288

305289
def check_samplers_pandas_sparse(name, sampler_orig):
306-
pd = pytest.importorskip("pandas")
290+
try:
291+
import pandas as pd
292+
except ImportError:
293+
raise SkipTest(
294+
"pandas is not installed: not checking column name consistency for pandas"
295+
)
307296
sampler = clone(sampler_orig)
308297
# Check that the samplers handle pandas dataframe and pandas series
309298
X, y = sample_dataset_generator()
@@ -331,7 +320,12 @@ def check_samplers_pandas_sparse(name, sampler_orig):
331320

332321

333322
def check_samplers_pandas(name, sampler_orig):
334-
pd = pytest.importorskip("pandas")
323+
try:
324+
import pandas as pd
325+
except ImportError:
326+
raise SkipTest(
327+
"pandas is not installed: not checking column name consistency for pandas"
328+
)
335329
sampler = clone(sampler_orig)
336330
# Check that the samplers handle pandas dataframe and pandas series
337331
X, y = sample_dataset_generator()
@@ -451,14 +445,19 @@ def check_classifier_on_multilabel_or_multioutput_targets(name, estimator_orig):
451445
estimator = clone(estimator_orig)
452446
X, y = make_multilabel_classification(n_samples=30)
453447
msg = "Multilabel and multioutput targets are not supported."
454-
with pytest.raises(ValueError, match=msg):
448+
with raises(ValueError, match=msg):
455449
estimator.fit(X, y)
456450

457451

458452
def check_classifiers_with_encoded_labels(name, classifier_orig):
459453
# Non-regression test for #709
460454
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/709
461-
pd = pytest.importorskip("pandas")
455+
try:
456+
import pandas as pd
457+
except ImportError:
458+
raise SkipTest(
459+
"pandas is not installed: not checking column name consistency for pandas"
460+
)
462461
classifier = clone(classifier_orig)
463462
iris = load_iris(as_frame=True)
464463
df, y = iris.data, iris.target

imblearn/utils/tests/test_estimator_checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def test_check_samplers_nan():
9898

9999

100100
mapping_estimator_error = {
101-
"BaseBadSampler": (AssertionError, "ValueError not raised by fit"),
101+
"BaseBadSampler": (AssertionError, None),
102102
"SamplerSingleClass": (AssertionError, "Sampler can't balance when only"),
103103
"NotFittedSampler": (AssertionError, "No fitted attribute"),
104104
"NoAcceptingSparseSampler": (TypeError, "dense data is required"),

0 commit comments

Comments
 (0)