Skip to content

Commit 29f708c

Browse files
committed
iter
1 parent e384fa2 commit 29f708c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+1933
-3625
lines changed

imblearn/base.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,30 +12,11 @@
1212
from sklearn.utils.multiclass import check_classification_targets
1313

1414
from .utils import check_sampling_strategy, check_target_type
15-
from .utils._param_validation import validate_parameter_constraints
15+
from .utils._sklearn_compat import _fit_context, validate_data
1616
from .utils._validation import ArraysTransformer
1717

1818

19-
class _ParamsValidationMixin:
20-
"""Mixin class to validate parameters."""
21-
22-
def _validate_params(self):
23-
"""Validate types and values of constructor parameters.
24-
25-
The expected type and values must be defined in the `_parameter_constraints`
26-
class attribute, which is a dictionary `param_name: list of constraints`. See
27-
the docstring of `validate_parameter_constraints` for a description of the
28-
accepted constraints.
29-
"""
30-
if hasattr(self, "_parameter_constraints"):
31-
validate_parameter_constraints(
32-
self._parameter_constraints,
33-
self.get_params(deep=False),
34-
caller_name=self.__class__.__name__,
35-
)
36-
37-
38-
class SamplerMixin(_ParamsValidationMixin, BaseEstimator, metaclass=ABCMeta):
19+
class SamplerMixin(metaclass=ABCMeta):
3920
"""Mixin class for samplers with abstract method.
4021
4122
Warning: This class should not be used directly. Use the derive classes
@@ -44,6 +25,7 @@ class SamplerMixin(_ParamsValidationMixin, BaseEstimator, metaclass=ABCMeta):
4425

4526
_estimator_type = "sampler"
4627

28+
@_fit_context(prefer_skip_nested_validation=True)
4729
def fit(self, X, y):
4830
"""Check inputs and statistics of the sampler.
4931
@@ -133,7 +115,7 @@ def _fit_resample(self, X, y):
133115
pass
134116

135117

136-
class BaseSampler(SamplerMixin, OneToOneFeatureMixin):
118+
class BaseSampler(SamplerMixin, OneToOneFeatureMixin, BaseEstimator):
137119
"""Base class for sampling algorithms.
138120
139121
Warning: This class should not be used directly. Use the derive classes
@@ -147,7 +129,7 @@ def _check_X_y(self, X, y, accept_sparse=None):
147129
if accept_sparse is None:
148130
accept_sparse = ["csr", "csc"]
149131
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
150-
X, y = self._validate_data(X, y, reset=True, accept_sparse=accept_sparse)
132+
X, y = validate_data(self, X=X, y=y, reset=True, accept_sparse=accept_sparse)
151133
return X, y, binarize_y
152134

153135
def fit(self, X, y):
@@ -199,6 +181,24 @@ def fit_resample(self, X, y):
199181
def _more_tags(self):
200182
return {"X_types": ["2darray", "sparse", "dataframe"]}
201183

184+
def __sklearn_tags__(self):
185+
from .utils._sklearn_compat import TargetTags
186+
from .utils._tags import Tags, SamplerTags, InputTags
187+
188+
tags = Tags(
189+
estimator_type="sampler",
190+
target_tags=TargetTags(required=True),
191+
transformer_tags=None,
192+
regressor_tags=None,
193+
classifier_tags=None,
194+
sampler_tags=SamplerTags(),
195+
)
196+
tags.input_tags = InputTags()
197+
tags.input_tags.two_d_array = True
198+
tags.input_tags.sparse = True
199+
tags.input_tags.dataframe = True
200+
return tags
201+
202202

203203
def _identity(X, y):
204204
return X, y

imblearn/datasets/_imbalance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from ..under_sampling import RandomUnderSampler
1212
from ..utils import check_sampling_strategy
13-
from ..utils._param_validation import validate_params
13+
from ..utils._sklearn_compat import validate_params
1414

1515

1616
@validate_params(

imblearn/datasets/_zenodo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from sklearn.datasets import get_data_home
5656
from sklearn.utils import Bunch, check_random_state
5757

58-
from ..utils._param_validation import validate_params
58+
from ..utils._sklearn_compat import validate_params
5959

6060
URL = "https://zenodo.org/record/61452/files/benchmark-imbalanced-learn.tar.gz"
6161
PRE_FILENAME = "x"

imblearn/ensemble/_bagging.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,14 @@
1818
from sklearn.utils.metaestimators import available_if
1919
from sklearn.utils.parallel import Parallel, delayed
2020
from sklearn.utils.validation import check_is_fitted
21+
from sklearn.utils._param_validation import HasMethods, Interval, StrOptions
2122

22-
from ..base import _ParamsValidationMixin
2323
from ..pipeline import Pipeline
2424
from ..under_sampling import RandomUnderSampler
2525
from ..under_sampling.base import BaseUnderSampler
2626
from ..utils import Substitution, check_sampling_strategy, check_target_type
2727
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
28-
from ..utils._param_validation import HasMethods, Interval, StrOptions
29-
from ..utils.fixes import _fit_context
28+
from ..utils._sklearn_compat import _fit_context, validate_data
3029
from ._common import _bagging_parameter_constraints, _estimator_has
3130

3231
sklearn_version = parse_version(sklearn.__version__)
@@ -37,7 +36,7 @@
3736
n_jobs=_n_jobs_docstring,
3837
random_state=_random_state_docstring,
3938
)
40-
class BalancedBaggingClassifier(_ParamsValidationMixin, BaggingClassifier):
39+
class BalancedBaggingClassifier(BaggingClassifier):
4140
"""A Bagging classifier with additional balancing.
4241
4342
This implementation of Bagging is similar to the scikit-learn
@@ -382,11 +381,12 @@ def decision_function(self, X):
382381
check_is_fitted(self)
383382

384383
# Check data
385-
X = self._validate_data(
386-
X,
384+
X = validate_data(
385+
self,
386+
X=X,
387387
accept_sparse=["csr", "csc"],
388388
dtype=None,
389-
force_all_finite=False,
389+
ensure_all_finite=False,
390390
reset=False,
391391
)
392392

@@ -425,3 +425,7 @@ def _more_tags(self):
425425
else:
426426
tags[tags_key] = {failing_test: reason}
427427
return tags
428+
429+
def __sklearn_tags__(self):
430+
tags = super().__sklearn_tags__()
431+
return tags

imblearn/ensemble/_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from sklearn.tree._criterion import Criterion
44

5-
from ..utils._param_validation import (
5+
from sklearn.utils._param_validation import (
66
HasMethods,
77
Hidden,
88
Interval,

imblearn/ensemble/_easy_ensemble.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@
1919
from sklearn.utils.metaestimators import available_if
2020
from sklearn.utils.parallel import Parallel, delayed
2121
from sklearn.utils.validation import check_is_fitted
22+
from sklearn.utils._param_validation import Interval, StrOptions
2223

23-
from ..base import _ParamsValidationMixin
2424
from ..pipeline import Pipeline
2525
from ..under_sampling import RandomUnderSampler
2626
from ..under_sampling.base import BaseUnderSampler
2727
from ..utils import Substitution, check_sampling_strategy, check_target_type
2828
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
29-
from ..utils._param_validation import Interval, StrOptions
30-
from ..utils.fixes import _fit_context
29+
from ..utils._sklearn_compat import _fit_context, get_tags, validate_data
3130
from ._common import _bagging_parameter_constraints, _estimator_has
3231

3332
MAX_INT = np.iinfo(np.int32).max
@@ -39,7 +38,7 @@
3938
n_jobs=_n_jobs_docstring,
4039
random_state=_random_state_docstring,
4140
)
42-
class EasyEnsembleClassifier(_ParamsValidationMixin, BaggingClassifier):
41+
class EasyEnsembleClassifier(BaggingClassifier):
4342
"""Bag of balanced boosted learners also known as EasyEnsemble.
4443
4544
This algorithm is known as EasyEnsemble [1]_. The classifier is an
@@ -311,11 +310,12 @@ def decision_function(self, X):
311310
check_is_fitted(self)
312311

313312
# Check data
314-
X = self._validate_data(
315-
X,
313+
X = validate_data(
314+
self,
315+
X=X,
316316
accept_sparse=["csr", "csc"],
317317
dtype=None,
318-
force_all_finite=False,
318+
ensure_all_finite=False,
319319
reset=False,
320320
)
321321

@@ -346,9 +346,17 @@ def base_estimator_(self):
346346

347347
def _get_estimator(self):
348348
if self.estimator is None:
349-
return AdaBoostClassifier(algorithm="SAMME")
349+
if parse_version("1.4") <= sklearn_version < parse_version("1.6"):
350+
return AdaBoostClassifier(algorithm="SAMME")
351+
else:
352+
return AdaBoostClassifier()
350353
return self.estimator
351354

352355
# TODO: remove when minimum supported version of scikit-learn is 1.5
353356
def _more_tags(self):
354357
return {"allow_nan": _safe_tags(self._get_estimator(), "allow_nan")}
358+
359+
def __sklearn_tags__(self):
360+
tags = super().__sklearn_tags__()
361+
tags.input_tags.allow_nan = get_tags(self._get_estimator()).input_tags.allow_nan
362+
return tags

imblearn/ensemble/_forest.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,18 @@
2727
from sklearn.utils.multiclass import type_of_target
2828
from sklearn.utils.parallel import Parallel, delayed
2929
from sklearn.utils.validation import _check_sample_weight
30+
from sklearn.utils._param_validation import Hidden, Interval, StrOptions
3031

31-
from ..base import _ParamsValidationMixin
3232
from ..pipeline import make_pipeline
3333
from ..under_sampling import RandomUnderSampler
3434
from ..utils import Substitution
3535
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
36-
from ..utils._param_validation import Hidden, Interval, StrOptions
36+
from ..utils._sklearn_compat import _fit_context, validate_data
3737
from ..utils._validation import check_sampling_strategy
38-
from ..utils.fixes import _fit_context
3938
from ._common import _random_forest_classifier_parameter_constraints
4039

4140
MAX_INT = np.iinfo(np.int32).max
42-
sklearn_version = parse_version(sklearn.__version__)
41+
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)
4342

4443

4544
def _local_parallel_build_trees(
@@ -77,7 +76,7 @@ def _local_parallel_build_trees(
7776
"bootstrap": bootstrap,
7877
}
7978

80-
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
79+
if sklearn_version >= parse_version("1.4"):
8180
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
8281
# support for missing values
8382
params_parallel_build_trees["missing_values_in_feature_mask"] = (
@@ -93,7 +92,7 @@ def _local_parallel_build_trees(
9392
n_jobs=_n_jobs_docstring,
9493
random_state=_random_state_docstring,
9594
)
96-
class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassifier):
95+
class BalancedRandomForestClassifier(RandomForestClassifier):
9796
"""A balanced random forest classifier.
9897
9998
A balanced random forest differs from a classical random forest by the
@@ -474,7 +473,7 @@ def __init__(
474473
"max_samples": max_samples,
475474
}
476475
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
477-
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
476+
if sklearn_version >= parse_version("1.4"):
478477
# use scikit-learn support for monotonic constraints
479478
params_random_forest["monotonic_cst"] = monotonic_cst
480479
else:
@@ -596,22 +595,23 @@ def fit(self, X, y, sample_weight=None):
596595

597596
# TODO: remove when the minimum supported version of scipy will be 1.4
598597
# Support for missing values
599-
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
600-
force_all_finite = False
598+
if sklearn_version >= parse_version("1.4"):
599+
ensure_all_finite = False
601600
else:
602-
force_all_finite = True
601+
ensure_all_finite = True
603602

604-
X, y = self._validate_data(
605-
X,
606-
y,
603+
X, y = validate_data(
604+
self,
605+
X=X,
606+
y=y,
607607
multi_output=True,
608608
accept_sparse="csc",
609609
dtype=DTYPE,
610-
force_all_finite=force_all_finite,
610+
ensure_all_finite=ensure_all_finite,
611611
)
612612

613613
# TODO: remove when the minimum supported version of scikit-learn will be 1.4
614-
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
614+
if sklearn_version >= parse_version("1.4"):
615615
# _compute_missing_values_in_feature_mask checks if X has missing values and
616616
# will raise an error if the underlying tree base estimator can't handle
617617
# missing values. Only the criterion is required to determine if the tree
@@ -882,3 +882,10 @@ def _compute_oob_predictions(self, X, y):
882882

883883
def _more_tags(self):
884884
return {"multioutput": False, "multilabel": False}
885+
886+
def __sklearn_tags__(self):
887+
tags = super().__sklearn_tags__()
888+
tags.target_tags.multi_output = False
889+
tags.classifier_tags.multi_label = False
890+
tags.input_tags.allow_nan = sklearn_version >= parse_version("1.4")
891+
return tags

imblearn/ensemble/_weight_boosting.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import copy
22
import numbers
3+
import warnings
34
from copy import deepcopy
45

56
import numpy as np
@@ -11,15 +12,14 @@
1112
from sklearn.utils import _safe_indexing
1213
from sklearn.utils.fixes import parse_version
1314
from sklearn.utils.validation import has_fit_parameter
15+
from sklearn.utils._param_validation import Hidden, Interval, StrOptions
1416

15-
from ..base import _ParamsValidationMixin
1617
from ..pipeline import make_pipeline
1718
from ..under_sampling import RandomUnderSampler
1819
from ..under_sampling.base import BaseUnderSampler
1920
from ..utils import Substitution, check_target_type
2021
from ..utils._docstring import _random_state_docstring
21-
from ..utils._param_validation import Interval, StrOptions
22-
from ..utils.fixes import _fit_context
22+
from ..utils._sklearn_compat import _fit_context
2323
from ._common import _adaboost_classifier_parameter_constraints
2424

2525
sklearn_version = parse_version(sklearn.__version__)
@@ -29,7 +29,7 @@
2929
sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
3030
random_state=_random_state_docstring,
3131
)
32-
class RUSBoostClassifier(_ParamsValidationMixin, AdaBoostClassifier):
32+
class RUSBoostClassifier(AdaBoostClassifier):
3333
"""Random under-sampling integrated in the learning of AdaBoost.
3434
3535
During learning, the problem of class balancing is alleviated by random
@@ -167,6 +167,10 @@ class RUSBoostClassifier(_ParamsValidationMixin, AdaBoostClassifier):
167167

168168
_parameter_constraints.update(
169169
{
170+
"algorithm": [
171+
StrOptions({"SAMME", "SAMME.R"}),
172+
Hidden(StrOptions({"deprecated"})),
173+
],
170174
"sampling_strategy": [
171175
Interval(numbers.Real, 0, 1, closed="right"),
172176
StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
@@ -186,17 +190,17 @@ def __init__(
186190
*,
187191
n_estimators=50,
188192
learning_rate=1.0,
189-
algorithm="SAMME.R",
193+
algorithm="deprecated",
190194
sampling_strategy="auto",
191195
replacement=False,
192196
random_state=None,
193197
):
194198
super().__init__(
195199
n_estimators=n_estimators,
196200
learning_rate=learning_rate,
197-
algorithm=algorithm,
198201
random_state=random_state,
199202
)
203+
self.algorithm = algorithm
200204
self.estimator = estimator
201205
self.sampling_strategy = sampling_strategy
202206
self.replacement = replacement
@@ -394,3 +398,16 @@ def _boost_discrete(self, iboost, X, y, sample_weight, random_state):
394398
sample_weight *= np.exp(estimator_weight * incorrect * (sample_weight > 0))
395399

396400
return sample_weight, estimator_weight, estimator_error
401+
402+
def _boost(self, iboost, X, y, sample_weight, random_state):
403+
if self.algorithm != "deprecated":
404+
warnings.warn(
405+
"`algorithm` parameter is deprecated in 0.12 and will be removed in "
406+
"0.14. In the future, the SAMME algorithm will always be used.",
407+
FutureWarning,
408+
)
409+
if self.algorithm == "SAMME.R":
410+
return self._boost_real(iboost, X, y, sample_weight, random_state)
411+
412+
else: # elif self.algorithm == "SAMME":
413+
return self._boost_discrete(iboost, X, y, sample_weight, random_state)

0 commit comments

Comments
 (0)