Skip to content

Commit 1629b06

Browse files
committed
iter
1 parent ef735f4 commit 1629b06

File tree

4 files changed

+61
-22
lines changed

4 files changed

+61
-22
lines changed

imblearn/ensemble/_easy_ensemble.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,12 +346,19 @@ def base_estimator_(self):
346346

347347
def _get_estimator(self):
348348
if self.estimator is None:
349-
return AdaBoostClassifier(algorithm="SAMME")
349+
if parse_version("1.4") <= sklearn_version < parse_version("1.6"):
350+
return AdaBoostClassifier(algorithm="SAMME")
351+
else:
352+
return AdaBoostClassifier()
350353
return self.estimator
351354

352355
# TODO: remove when minimum supported version of scikit-learn is 1.5
353356
@available_if(check_version_package("sklearn", "<", "1.6"))
354357
def _more_tags(self):
355-
# This code should not be called for scikit-learn >= 1.6
356-
# Therefore, get_tags corresponds to _safe_tags that returns a dict
357-
return {"allow_nan": get_tags(self._get_estimator(), "allow_nan")}
358+
return {"allow_nan": get_tags(self._get_estimator())["allow_nan"]}
359+
360+
@available_if(check_version_package("sklearn", ">=", "1.6"))
361+
def __sklearn_tags__(self):
362+
tags = super().__sklearn_tags__()
363+
tags.input_tags.allow_nan = get_tags(self._get_estimator()).input_tags.allow_nan
364+
return tags

imblearn/ensemble/_forest.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numbers
77
from copy import deepcopy
8+
from dataclasses import is_dataclass
89
from warnings import warn
910

1011
import numpy as np
@@ -36,7 +37,7 @@
3637
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
3738
from ..utils._param_validation import Hidden, Interval, StrOptions
3839
from ..utils._validation import check_sampling_strategy
39-
from ..utils.fixes import _fit_context, check_version_package, validate_data
40+
from ..utils.fixes import _fit_context, check_version_package, get_tags, validate_data
4041
from ._common import _random_forest_classifier_parameter_constraints
4142

4243
MAX_INT = np.iinfo(np.int32).max
@@ -78,7 +79,7 @@ def _local_parallel_build_trees(
7879
"bootstrap": bootstrap,
7980
}
8081

81-
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
82+
if sklearn_version >= parse_version("1.4"):
8283
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
8384
# support for missing values
8485
params_parallel_build_trees["missing_values_in_feature_mask"] = (
@@ -475,7 +476,7 @@ def __init__(
475476
"max_samples": max_samples,
476477
}
477478
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
478-
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
479+
if sklearn_version >= parse_version("1.4"):
479480
# use scikit-learn support for monotonic constraints
480481
params_random_forest["monotonic_cst"] = monotonic_cst
481482
else:
@@ -595,12 +596,12 @@ def fit(self, X, y, sample_weight=None):
595596
if issparse(y):
596597
raise ValueError("sparse multilabel-indicator for y is not supported.")
597598

598-
# TODO: remove when the minimum supported version of scipy will be 1.4
599-
# Support for missing values
600-
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
601-
ensure_all_finite = False
599+
# TODO (1.6): simplify because we will only have dataclass tags
600+
tags = get_tags(self)
601+
if is_dataclass(tags):
602+
ensure_all_finite = not tags.input_tags.allow_nan
602603
else:
603-
ensure_all_finite = False
604+
ensure_all_finite = not tags.get("allow_nan", False)
604605

605606
X, y = validate_data(
606607
self,
@@ -884,4 +885,13 @@ def _compute_oob_predictions(self, X, y):
884885

885886
@available_if(check_version_package("sklearn", "<", "1.6"))
886887
def _more_tags(self):
887-
return {"multioutput": False, "multilabel": False}
888+
allow_nan = sklearn_version >= parse_version("1.4")
889+
return {"multioutput": False, "multilabel": False, "allow_nan": allow_nan}
890+
891+
@available_if(check_version_package("sklearn", ">=", "1.6"))
892+
def __sklearn_tags__(self):
893+
tags = super().__sklearn_tags__()
894+
tags.target_tags.multi_output = False
895+
tags.classifier_tags.multi_label = False
896+
tags.input_tags.allow_nan = sklearn_version >= parse_version("1.4")
897+
return tags

imblearn/ensemble/_weight_boosting.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from sklearn.tree import DecisionTreeClassifier
1111
from sklearn.utils import _safe_indexing
1212
from sklearn.utils.fixes import parse_version
13+
from sklearn.utils.metaestimators import available_if
1314
from sklearn.utils.validation import has_fit_parameter
1415

1516
from ..base import _ParamsValidationMixin
@@ -18,8 +19,8 @@
1819
from ..under_sampling.base import BaseUnderSampler
1920
from ..utils import Substitution, check_target_type
2021
from ..utils._docstring import _random_state_docstring
21-
from ..utils._param_validation import Interval, StrOptions
22-
from ..utils.fixes import _fit_context
22+
from ..utils._param_validation import Hidden, Interval, StrOptions
23+
from ..utils.fixes import _fit_context, check_version_package
2324
from ._common import _adaboost_classifier_parameter_constraints
2425

2526
sklearn_version = parse_version(sklearn.__version__)
@@ -58,16 +59,18 @@ class RUSBoostClassifier(_ParamsValidationMixin, AdaBoostClassifier):
5859
``learning_rate``. There is a trade-off between ``learning_rate`` and
5960
``n_estimators``.
6061
61-
algorithm : {{'SAMME', 'SAMME.R'}}, default='SAMME.R'
62+
algorithm : {{'SAMME', 'SAMME.R'}}, default='deprecated'
6263
If 'SAMME.R' then use the SAMME.R real boosting algorithm.
6364
``base_estimator`` must support calculation of class probabilities.
6465
If 'SAMME' then use the SAMME discrete boosting algorithm.
6566
The SAMME.R algorithm typically converges faster than SAMME,
6667
achieving a lower test error with fewer boosting iterations.
6768
6869
.. deprecated:: 0.12
69-
`"SAMME.R"` is deprecated and will be removed in version 0.14.
70-
'"SAMME"' will become the default.
70+
`algorithm` is deprecated in 0.12 and will be removed 0.14.
71+
Depending on the `scikit-learn` version, the "SAMME.R" algorithm might not
72+
be available. Refer to the documentation of
73+
:class:`~sklearn.ensemble.AdaBoostClassifier` for more information.
7174
7275
{sampling_strategy}
7376
@@ -109,7 +112,7 @@ class RUSBoostClassifier(_ParamsValidationMixin, AdaBoostClassifier):
109112
ensemble.
110113
111114
feature_importances_ : ndarray of shape (n_features,)
112-
The feature importances if supported by the ``base_estimator``.
115+
The feature importances if supported by the ``estimator``.
113116
114117
n_features_in_ : int
115118
Number of features in the input dataset.
@@ -167,6 +170,10 @@ class RUSBoostClassifier(_ParamsValidationMixin, AdaBoostClassifier):
167170

168171
_parameter_constraints.update(
169172
{
173+
"algorithm": [
174+
StrOptions({"SAMME", "SAMME.R"}),
175+
Hidden(StrOptions({"deprecated"})),
176+
],
170177
"sampling_strategy": [
171178
Interval(numbers.Real, 0, 1, closed="right"),
172179
StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
@@ -186,17 +193,17 @@ def __init__(
186193
*,
187194
n_estimators=50,
188195
learning_rate=1.0,
189-
algorithm="SAMME.R",
196+
algorithm="deprecated",
190197
sampling_strategy="auto",
191198
replacement=False,
192199
random_state=None,
193200
):
194201
super().__init__(
195202
n_estimators=n_estimators,
196203
learning_rate=learning_rate,
197-
algorithm=algorithm,
198204
random_state=random_state,
199205
)
206+
self.algorithm = algorithm
200207
self.estimator = estimator
201208
self.sampling_strategy = sampling_strategy
202209
self.replacement = replacement
@@ -394,3 +401,7 @@ def _boost_discrete(self, iboost, X, y, sample_weight, random_state):
394401
sample_weight *= np.exp(estimator_weight * incorrect * (sample_weight > 0))
395402

396403
return sample_weight, estimator_weight, estimator_error
404+
405+
@available_if(check_version_package("sklearn", ">=", "1.6"))
406+
def _boost(self, iboost, X, y, sample_weight, random_state):
407+
return self._boost_discrete(iboost, X, y, sample_weight, random_state)

imblearn/utils/_test_common/instance_generator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,16 @@
1010

1111
from sklearn import clone, config_context
1212
from sklearn.linear_model import LogisticRegression
13+
from sklearn.tree import DecisionTreeClassifier
1314
from sklearn.exceptions import SkipTestWarning
1415
from sklearn.utils._testing import SkipTest
1516

1617
from imblearn.combine import SMOTEENN, SMOTETomek
17-
from imblearn.ensemble import BalancedBaggingClassifier, BalancedRandomForestClassifier
18+
from imblearn.ensemble import (
19+
BalancedBaggingClassifier,
20+
BalancedRandomForestClassifier,
21+
EasyEnsembleClassifier,
22+
)
1823
from imblearn.over_sampling import (
1924
ADASYN,
2025
BorderlineSMOTE,
@@ -42,6 +47,12 @@
4247
# estimator
4348
BalancedBaggingClassifier: dict(random_state=42),
4449
BalancedRandomForestClassifier: dict(random_state=42),
50+
EasyEnsembleClassifier: [
51+
# AdaBoostClassifier does not allow nan values
52+
dict(random_state=42),
53+
# DecisionTreeClassifier allows nan values
54+
dict(estimator=DecisionTreeClassifier(random_state=42), random_state=42),
55+
],
4556
Pipeline: dict(
4657
steps=[("sampler", RandomUnderSampler()), ("logistic", LogisticRegression())]
4758
),

0 commit comments

Comments
 (0)