1010
1111import  pytest 
1212from  sklearn .datasets  import  make_classification 
13- from  sklearn .linear_model  import  LogisticRegression 
1413from  sklearn .utils ._testing  import  (
1514    _get_func_name ,
1615    check_docstring_parameters ,
2423
2524import  imblearn 
2625from  imblearn .base  import  is_sampler 
27- from  imblearn .utils ._sklearn_compat  import  _construct_instances 
26+ from  imblearn .under_sampling  import  NearMiss 
27+ from  imblearn .utils ._test_common .instance_generator  import  _tested_estimators 
2828from  imblearn .utils .estimator_checks  import  _set_checking_parameters 
29- from  imblearn .utils .testing  import  all_estimators 
3029
3130# walk_packages() ignores DeprecationWarnings, now we need to ignore 
3231# FutureWarnings 
4342    )
4443
4544# functions to ignore args / docstring of 
46- _DOCSTRING_IGNORES  =  [
47-      "RUSBoostClassifier" ,   # TODO remove after releasing scikit-learn 1.0.1 
48-     "ValueDifferenceMetric" ,
49- ] 
45+ _DOCSTRING_IGNORES  =  ["ValueDifferenceMetric" ] 
46+ _IGNORE_ATTRIBUTES   =  { 
47+     NearMiss : [ "nn_ver3_" ] ,
48+ } 
5049
5150# Methods where y param should be ignored if y=None by default 
5251_METHODS_IGNORE_NONE_Y  =  [
@@ -159,28 +158,19 @@ def test_tabs():
159158        )
160159
161160
162- def  _construct_compose_pipeline_instance (Estimator ):
163-     # Minimal / degenerate instances: only useful to test the docstrings. 
164-     if  Estimator .__name__  ==  "Pipeline" :
165-         return  Estimator (steps = [("clf" , LogisticRegression ())])
166- 
167- 
168- @pytest .mark .parametrize ("name, Estimator" , all_estimators ()) 
169- def  test_fit_docstring_attributes (name , Estimator ):
161+ @pytest .mark .parametrize ("estimator" , list (_tested_estimators ())) 
162+ def  test_fit_docstring_attributes (estimator ):
170163    pytest .importorskip ("numpydoc" )
171164    from  numpydoc  import  docscrape 
172165
166+     Estimator  =  estimator .__class__ 
173167    if  Estimator .__name__  in  _DOCSTRING_IGNORES :
174168        return 
175169
176170    doc  =  docscrape .ClassDoc (Estimator )
177171    attributes  =  doc ["Attributes" ]
178172
179-     if  Estimator .__name__  ==  "Pipeline" :
180-         est  =  _construct_compose_pipeline_instance (Estimator )
181-     else :
182-         est  =  next (_construct_instances (Estimator ))
183-     _set_checking_parameters (est )
173+     _set_checking_parameters (estimator )
184174
185175    X , y  =  make_classification (
186176        n_samples = 20 ,
@@ -190,16 +180,16 @@ def test_fit_docstring_attributes(name, Estimator):
190180        random_state = 2 ,
191181    )
192182
193-     y  =  _enforce_estimator_tags_y (est , y )
194-     X  =  _enforce_estimator_tags_X (est , X )
183+     y  =  _enforce_estimator_tags_y (estimator , y )
184+     X  =  _enforce_estimator_tags_X (estimator , X )
195185
196-     if  "oob_score"  in  est .get_params ():
197-         est .set_params (bootstrap = True , oob_score = True )
186+     if  "oob_score"  in  estimator .get_params ():
187+         estimator .set_params (bootstrap = True , oob_score = True )
198188
199-     if  is_sampler (est ):
200-         est .fit_resample (X , y )
189+     if  is_sampler (estimator ):
190+         estimator .fit_resample (X , y )
201191    else :
202-         est .fit (X , y )
192+         estimator .fit (X , y )
203193
204194    skipped_attributes  =  set (
205195        [
@@ -218,9 +208,11 @@ def test_fit_docstring_attributes(name, Estimator):
218208            continue 
219209        # ignore deprecation warnings 
220210        with  ignore_warnings (category = FutureWarning ):
221-             assert  hasattr (est , attr .name )
211+             if  attr .name  in  _IGNORE_ATTRIBUTES .get (Estimator , []):
212+                 continue 
213+             assert  hasattr (estimator , attr .name )
222214
223-     fit_attr  =  _get_all_fitted_attributes (est )
215+     fit_attr  =  _get_all_fitted_attributes (estimator )
224216    fit_attr_names  =  [attr .name  for  attr  in  attributes ]
225217    undocumented_attrs  =  set (fit_attr ).difference (fit_attr_names )
226218    undocumented_attrs  =  set (undocumented_attrs ).difference (skipped_attributes )
0 commit comments