Skip to content

Commit c7a1838

Browse files
authored
MAINT compatibility sklearn 1.4 (#1058)
* MAINT compatibility sklearn 1.4 * iter * fix * doc * compat numpydoc * update changelog * fix
1 parent 0a659af commit c7a1838

17 files changed

+166
-352
lines changed

azure-pipelines.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,10 @@ jobs:
115115
ne(variables['Build.Reason'], 'Schedule')
116116
)
117117
matrix:
118-
py38_conda_forge_openblas_ubuntu_1804:
118+
py39_conda_forge_openblas_ubuntu_1804:
119119
DISTRIB: 'conda'
120120
CONDA_CHANNEL: 'conda-forge'
121-
PYTHON_VERSION: '3.8'
121+
PYTHON_VERSION: '3.9'
122122
BLAS: 'openblas'
123123
COVERAGE: 'false'
124124

@@ -188,7 +188,7 @@ jobs:
188188
pylatest_conda_tensorflow:
189189
DISTRIB: 'conda-latest-tensorflow'
190190
CONDA_CHANNEL: 'conda-forge'
191-
PYTHON_VERSION: '3.8'
191+
PYTHON_VERSION: '3.9'
192192
TEST_DOCS: 'true'
193193
TEST_DOCSTRINGS: 'true'
194194
CHECK_WARNINGS: 'true'
@@ -214,7 +214,7 @@ jobs:
214214
pylatest_conda_keras:
215215
DISTRIB: 'conda-latest-keras'
216216
CONDA_CHANNEL: 'conda-forge'
217-
PYTHON_VERSION: '3.8'
217+
PYTHON_VERSION: '3.9'
218218
TEST_DOCS: 'true'
219219
TEST_DOCSTRINGS: 'true'
220220
CHECK_WARNINGS: 'true'
@@ -301,7 +301,7 @@ jobs:
301301
py38_conda_forge_mkl:
302302
DISTRIB: 'conda'
303303
CONDA_CHANNEL: 'conda-forge'
304-
PYTHON_VERSION: '3.8'
304+
PYTHON_VERSION: '3.10'
305305
CHECK_WARNINGS: 'true'
306306
PYTHON_ARCH: '64'
307307
PYTEST_VERSION: '*'

doc/ensemble.rst

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ data set, this classifier will favor the majority classes::
3333
>>> from sklearn.ensemble import BaggingClassifier
3434
>>> from sklearn.tree import DecisionTreeClassifier
3535
>>> X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
36-
>>> bc = BaggingClassifier(base_estimator=DecisionTreeClassifier(),
37-
... random_state=0)
36+
>>> bc = BaggingClassifier(DecisionTreeClassifier(), random_state=0)
3837
>>> bc.fit(X_train, y_train) #doctest:
3938
BaggingClassifier(...)
4039
>>> y_pred = bc.predict(X_test)
@@ -50,7 +49,7 @@ sampling is controlled by the parameter `sampler` or the two parameters
5049
:class:`~imblearn.under_sampling.RandomUnderSampler`::
5150

5251
>>> from imblearn.ensemble import BalancedBaggingClassifier
53-
>>> bbc = BalancedBaggingClassifier(base_estimator=DecisionTreeClassifier(),
52+
>>> bbc = BalancedBaggingClassifier(DecisionTreeClassifier(),
5453
... sampling_strategy='auto',
5554
... replacement=False,
5655
... random_state=0)

doc/whats_new/v0.12.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,13 @@ Compatibility
2323

2424
- :class:`~imblearn.ensemble.BalancedRandomForestClassifier` now support missing values
2525
and monotonic constraints if scikit-learn >= 1.4 is installed.
26+
2627
- :class:`~imblearn.pipeline.Pipeline` support metadata routing if scikit-learn >= 1.4
2728
is installed.
2829

30+
- Compatibility with scikit-learn 1.4.
31+
:pr:`1058` by :user:`Guillaume Lemaitre <glemaitre>`.
32+
2933
Deprecations
3034
............
3135

imblearn/ensemble/_bagging.py

Lines changed: 26 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# License: MIT
66

77
import copy
8-
import inspect
98
import numbers
109
import warnings
1110

@@ -15,6 +14,7 @@
1514
from sklearn.ensemble import BaggingClassifier
1615
from sklearn.ensemble._bagging import _parallel_decision_function
1716
from sklearn.ensemble._base import _partition_estimators
17+
from sklearn.exceptions import NotFittedError
1818
from sklearn.tree import DecisionTreeClassifier
1919
from sklearn.utils import parse_version
2020
from sklearn.utils.validation import check_is_fitted
@@ -121,30 +121,13 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
121121
122122
.. versionadded:: 0.8
123123
124-
base_estimator : estimator object, default=None
125-
The base estimator to fit on random subsets of the dataset.
126-
If None, then the base estimator is a decision tree.
127-
128-
.. deprecated:: 0.10
129-
`base_estimator` was renamed to `estimator` in version 0.10 and
130-
will be removed in 0.12.
131-
132124
Attributes
133125
----------
134126
estimator_ : estimator
135127
The base estimator from which the ensemble is grown.
136128
137129
.. versionadded:: 0.10
138130
139-
base_estimator_ : estimator
140-
The base estimator from which the ensemble is grown.
141-
142-
.. deprecated:: 1.2
143-
`base_estimator_` is deprecated in `scikit-learn` 1.2 and will be
144-
removed in 1.4. Use `estimator_` instead. When the minimum version
145-
of `scikit-learn` supported by `imbalanced-learn` will reach 1.4,
146-
this attribute will be removed.
147-
148131
n_features_ : int
149132
The number of features when `fit` is performed.
150133
@@ -266,7 +249,7 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
266249
"""
267250

268251
# make a deepcopy to not modify the original dictionary
269-
if sklearn_version >= parse_version("1.3"):
252+
if sklearn_version >= parse_version("1.4"):
270253
_parameter_constraints = copy.deepcopy(BaggingClassifier._parameter_constraints)
271254
else:
272255
_parameter_constraints = copy.deepcopy(_bagging_parameter_constraints)
@@ -283,6 +266,9 @@ class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
283266
"sampler": [HasMethods(["fit_resample"]), None],
284267
}
285268
)
269+
# TODO: remove when minimum supported version of scikit-learn is 1.4
270+
if "base_estimator" in _parameter_constraints:
271+
del _parameter_constraints["base_estimator"]
286272

287273
def __init__(
288274
self,
@@ -301,18 +287,8 @@ def __init__(
301287
random_state=None,
302288
verbose=0,
303289
sampler=None,
304-
base_estimator="deprecated",
305290
):
306-
# TODO: remove when supporting scikit-learn>=1.2
307-
bagging_classifier_signature = inspect.signature(super().__init__)
308-
estimator_params = {"base_estimator": base_estimator}
309-
if "estimator" in bagging_classifier_signature.parameters:
310-
estimator_params["estimator"] = estimator
311-
else:
312-
self.estimator = estimator
313-
314291
super().__init__(
315-
**estimator_params,
316292
n_estimators=n_estimators,
317293
max_samples=max_samples,
318294
max_features=max_features,
@@ -324,6 +300,7 @@ def __init__(
324300
random_state=random_state,
325301
verbose=verbose,
326302
)
303+
self.estimator = estimator
327304
self.sampling_strategy = sampling_strategy
328305
self.replacement = replacement
329306
self.sampler = sampler
@@ -349,42 +326,17 @@ def _validate_y(self, y):
349326
def _validate_estimator(self, default=DecisionTreeClassifier()):
350327
"""Check the estimator and the n_estimator attribute, set the
351328
`estimator_` attribute."""
352-
if self.estimator is not None and (
353-
self.base_estimator not in [None, "deprecated"]
354-
):
355-
raise ValueError(
356-
"Both `estimator` and `base_estimator` were set. Only set `estimator`."
357-
)
358-
359329
if self.estimator is not None:
360-
base_estimator = clone(self.estimator)
361-
elif self.base_estimator not in [None, "deprecated"]:
362-
warnings.warn(
363-
"`base_estimator` was renamed to `estimator` in version 0.10 and "
364-
"will be removed in 0.12.",
365-
FutureWarning,
366-
)
367-
base_estimator = clone(self.base_estimator)
330+
estimator = clone(self.estimator)
368331
else:
369-
base_estimator = clone(default)
332+
estimator = clone(default)
370333

371334
if self.sampler_._sampling_type != "bypass":
372335
self.sampler_.set_params(sampling_strategy=self._sampling_strategy)
373336

374-
self._estimator = Pipeline(
375-
[("sampler", self.sampler_), ("classifier", base_estimator)]
337+
self.estimator_ = Pipeline(
338+
[("sampler", self.sampler_), ("classifier", estimator)]
376339
)
377-
try:
378-
# scikit-learn < 1.2
379-
self.base_estimator_ = self._estimator
380-
except AttributeError:
381-
pass
382-
383-
# TODO: remove when supporting scikit-learn>=1.4
384-
@property
385-
def estimator_(self):
386-
"""Estimator used to grow the ensemble."""
387-
return self._estimator
388340

389341
# TODO: remove when supporting scikit-learn>=1.2
390342
@property
@@ -483,6 +435,22 @@ def decision_function(self, X):
483435

484436
return decisions
485437

438+
@property
439+
def base_estimator_(self):
440+
"""Attribute for older sklearn version compatibility."""
441+
error = AttributeError(
442+
f"{self.__class__.__name__} object has no attribute 'base_estimator_'."
443+
)
444+
if sklearn_version < parse_version("1.2"):
445+
# The base class require to have the attribute defined. For scikit-learn
446+
# > 1.2, we are going to raise an error.
447+
try:
448+
check_is_fitted(self)
449+
return self.estimator_
450+
except NotFittedError:
451+
raise error
452+
raise error
453+
486454
def _more_tags(self):
487455
tags = super()._more_tags()
488456
tags_key = "_xfail_checks"

0 commit comments

Comments
 (0)