1212from functools import partial
1313
1414import numpy as np
15- import pytest
1615import sklearn
1716from scipy import sparse
1817from sklearn .base import clone , is_classifier , is_regressor
3029 SkipTest ,
3130 assert_allclose ,
3231 assert_array_equal ,
33- assert_raises_regex ,
3432 raises ,
3533 set_random_state ,
3634)
@@ -62,11 +60,6 @@ def sample_dataset_generator():
6260 return X , y
6361
6462
65- @pytest .fixture (name = "sample_dataset_generator" )
66- def sample_dataset_generator_fixture ():
67- return sample_dataset_generator ()
68-
69-
7063def _set_checking_parameters (estimator ):
7164 params = estimator .get_params ()
7265 name = estimator .__class__ .__name__
@@ -166,6 +159,7 @@ def parametrize_with_checks(estimators):
166159 ... def test_sklearn_compatible_estimator(estimator, check):
167160 ... check(estimator)
168161 """
162+ import pytest
169163
170164 def checks_generator ():
171165 for estimator in estimators :
@@ -185,24 +179,14 @@ def check_target_type(name, estimator_orig):
185179 X = np .random .random ((20 , 2 ))
186180 y = np .linspace (0 , 1 , 20 )
187181 msg = "Unknown label type:"
188- assert_raises_regex (
189- ValueError ,
190- msg ,
191- estimator .fit_resample ,
192- X ,
193- y ,
194- )
182+ with raises (ValueError , err_msg = msg ):
183+ estimator .fit_resample (X , y )
195184 # if the target is multilabel then we should raise an error
196185 rng = np .random .RandomState (42 )
197186 y = rng .randint (2 , size = (20 , 3 ))
198187 msg = "Multilabel and multioutput targets are not supported."
199- assert_raises_regex (
200- ValueError ,
201- msg ,
202- estimator .fit_resample ,
203- X ,
204- y ,
205- )
188+ with raises (ValueError , err_msg = msg ):
189+ estimator .fit_resample (X , y )
206190
207191
208192def check_samplers_one_label (name , sampler_orig ):
@@ -303,7 +287,12 @@ def check_samplers_sparse(name, sampler_orig):
303287
304288
305289def check_samplers_pandas_sparse (name , sampler_orig ):
306- pd = pytest .importorskip ("pandas" )
290+ try :
291+ import pandas as pd
292+ except ImportError :
293+ raise SkipTest (
294+ "pandas is not installed: not checking column name consistency for pandas"
295+ )
307296 sampler = clone (sampler_orig )
308297 # Check that the samplers handle pandas dataframe and pandas series
309298 X , y = sample_dataset_generator ()
@@ -331,7 +320,12 @@ def check_samplers_pandas_sparse(name, sampler_orig):
331320
332321
333322def check_samplers_pandas (name , sampler_orig ):
334- pd = pytest .importorskip ("pandas" )
323+ try :
324+ import pandas as pd
325+ except ImportError :
326+ raise SkipTest (
327+ "pandas is not installed: not checking column name consistency for pandas"
328+ )
335329 sampler = clone (sampler_orig )
336330 # Check that the samplers handle pandas dataframe and pandas series
337331 X , y = sample_dataset_generator ()
@@ -451,14 +445,19 @@ def check_classifier_on_multilabel_or_multioutput_targets(name, estimator_orig):
451445 estimator = clone (estimator_orig )
452446 X , y = make_multilabel_classification (n_samples = 30 )
453447 msg = "Multilabel and multioutput targets are not supported."
454- with pytest . raises (ValueError , match = msg ):
448+ with raises (ValueError , match = msg ):
455449 estimator .fit (X , y )
456450
457451
458452def check_classifiers_with_encoded_labels (name , classifier_orig ):
459453 # Non-regression test for #709
460454 # https://github.com/scikit-learn-contrib/imbalanced-learn/issues/709
461- pd = pytest .importorskip ("pandas" )
455+ try :
456+ import pandas as pd
457+ except ImportError :
458+ raise SkipTest (
459+ "pandas is not installed: not checking column name consistency for pandas"
460+ )
462461 classifier = clone (classifier_orig )
463462 iris = load_iris (as_frame = True )
464463 df , y = iris .data , iris .target
0 commit comments