5
5
# License: MIT
6
6
7
7
import copy
8
- import inspect
9
8
import numbers
10
9
import warnings
11
10
15
14
from sklearn .ensemble import BaggingClassifier
16
15
from sklearn .ensemble ._bagging import _parallel_decision_function
17
16
from sklearn .ensemble ._base import _partition_estimators
17
+ from sklearn .exceptions import NotFittedError
18
18
from sklearn .tree import DecisionTreeClassifier
19
19
from sklearn .utils import parse_version
20
20
from sklearn .utils .validation import check_is_fitted
@@ -121,30 +121,13 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
121
121
122
122
.. versionadded:: 0.8
123
123
124
- base_estimator : estimator object, default=None
125
- The base estimator to fit on random subsets of the dataset.
126
- If None, then the base estimator is a decision tree.
127
-
128
- .. deprecated:: 0.10
129
- `base_estimator` was renamed to `estimator` in version 0.10 and
130
- will be removed in 0.12.
131
-
132
124
Attributes
133
125
----------
134
126
estimator_ : estimator
135
127
The base estimator from which the ensemble is grown.
136
128
137
129
.. versionadded:: 0.10
138
130
139
- base_estimator_ : estimator
140
- The base estimator from which the ensemble is grown.
141
-
142
- .. deprecated:: 1.2
143
- `base_estimator_` is deprecated in `scikit-learn` 1.2 and will be
144
- removed in 1.4. Use `estimator_` instead. When the minimum version
145
- of `scikit-learn` supported by `imbalanced-learn` will reach 1.4,
146
- this attribute will be removed.
147
-
148
131
n_features_ : int
149
132
The number of features when `fit` is performed.
150
133
@@ -266,7 +249,7 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
266
249
"""
267
250
268
251
# make a deepcopy to not modify the original dictionary
269
- if sklearn_version >= parse_version ("1.3 " ):
252
+ if sklearn_version >= parse_version ("1.4 " ):
270
253
_parameter_constraints = copy .deepcopy (BaggingClassifier ._parameter_constraints )
271
254
else :
272
255
_parameter_constraints = copy .deepcopy (_bagging_parameter_constraints )
@@ -283,6 +266,9 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
283
266
"sampler" : [HasMethods (["fit_resample" ]), None ],
284
267
}
285
268
)
269
+ # TODO: remove when minimum supported version of scikit-learn is 1.4
270
+ if "base_estimator" in _parameter_constraints :
271
+ del _parameter_constraints ["base_estimator" ]
286
272
287
273
def __init__ (
288
274
self ,
@@ -301,18 +287,8 @@ def __init__(
301
287
random_state = None ,
302
288
verbose = 0 ,
303
289
sampler = None ,
304
- base_estimator = "deprecated" ,
305
290
):
306
- # TODO: remove when supporting scikit-learn>=1.2
307
- bagging_classifier_signature = inspect .signature (super ().__init__ )
308
- estimator_params = {"base_estimator" : base_estimator }
309
- if "estimator" in bagging_classifier_signature .parameters :
310
- estimator_params ["estimator" ] = estimator
311
- else :
312
- self .estimator = estimator
313
-
314
291
super ().__init__ (
315
- ** estimator_params ,
316
292
n_estimators = n_estimators ,
317
293
max_samples = max_samples ,
318
294
max_features = max_features ,
@@ -324,6 +300,7 @@ def __init__(
324
300
random_state = random_state ,
325
301
verbose = verbose ,
326
302
)
303
+ self .estimator = estimator
327
304
self .sampling_strategy = sampling_strategy
328
305
self .replacement = replacement
329
306
self .sampler = sampler
@@ -349,42 +326,17 @@ def _validate_y(self, y):
349
326
def _validate_estimator (self , default = DecisionTreeClassifier ()):
350
327
"""Check the estimator and the n_estimator attribute, set the
351
328
`estimator_` attribute."""
352
- if self .estimator is not None and (
353
- self .base_estimator not in [None , "deprecated" ]
354
- ):
355
- raise ValueError (
356
- "Both `estimator` and `base_estimator` were set. Only set `estimator`."
357
- )
358
-
359
329
if self .estimator is not None :
360
- base_estimator = clone (self .estimator )
361
- elif self .base_estimator not in [None , "deprecated" ]:
362
- warnings .warn (
363
- "`base_estimator` was renamed to `estimator` in version 0.10 and "
364
- "will be removed in 0.12." ,
365
- FutureWarning ,
366
- )
367
- base_estimator = clone (self .base_estimator )
330
+ estimator = clone (self .estimator )
368
331
else :
369
- base_estimator = clone (default )
332
+ estimator = clone (default )
370
333
371
334
if self .sampler_ ._sampling_type != "bypass" :
372
335
self .sampler_ .set_params (sampling_strategy = self ._sampling_strategy )
373
336
374
- self ._estimator = Pipeline (
375
- [("sampler" , self .sampler_ ), ("classifier" , base_estimator )]
337
+ self .estimator_ = Pipeline (
338
+ [("sampler" , self .sampler_ ), ("classifier" , estimator )]
376
339
)
377
- try :
378
- # scikit-learn < 1.2
379
- self .base_estimator_ = self ._estimator
380
- except AttributeError :
381
- pass
382
-
383
- # TODO: remove when supporting scikit-learn>=1.4
384
- @property
385
- def estimator_ (self ):
386
- """Estimator used to grow the ensemble."""
387
- return self ._estimator
388
340
389
341
# TODO: remove when supporting scikit-learn>=1.2
390
342
@property
@@ -483,6 +435,22 @@ def decision_function(self, X):
483
435
484
436
return decisions
485
437
438
+ @property
439
+ def base_estimator_ (self ):
440
+ """Attribute for older sklearn version compatibility."""
441
+ error = AttributeError (
442
+ f"{ self .__class__ .__name__ } object has no attribute 'base_estimator_'."
443
+ )
444
+ if sklearn_version < parse_version ("1.2" ):
445
+ # The base class require to have the attribute defined. For scikit-learn
446
+ # > 1.2, we are going to raise an error.
447
+ try :
448
+ check_is_fitted (self )
449
+ return self .estimator_
450
+ except NotFittedError :
451
+ raise error
452
+ raise error
453
+
486
454
def _more_tags (self ):
487
455
tags = super ()._more_tags ()
488
456
tags_key = "_xfail_checks"
0 commit comments