Skip to content

Commit 1b8cd47

Browse files
authored
ENH accept string labels in classifier (#718)
1 parent f21bdea commit 1b8cd47

File tree

6 files changed

+82
-10
lines changed

6 files changed

+82
-10
lines changed

doc/whats_new/v0.7.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,14 @@ Bug fixes
2626
- Change the default value `min_samples_leaf` to be consistent with
2727
scikit-learn.
2828
:pr:`711` by :user:`zerolfx <zerolfx>`.
29+
30+
Enhancements
31+
............
32+
33+
- The classifier implemented in imbalanced-learn,
34+
:class:`imblearn.ensemble.BalancedBaggingClassifier`,
35+
:class:`imblearn.ensemble.BalancedRandomForestClassifier`,
36+
:class:`imblearn.ensemble.EasyEnsembleClassifier`, and
37+
:class:`imblearn.ensemble.RUSBoostClassifier`, accept `sampling_strategy`
38+
with the same key than in `y` without the need of encoding `y` in advance.
39+
:pr:`718` by :user:`Guillaume Lemaitre <glemaitre>`.

imblearn/ensemble/_bagging.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ..pipeline import Pipeline
1616
from ..under_sampling import RandomUnderSampler
1717
from ..under_sampling.base import BaseUnderSampler
18-
from ..utils import Substitution, check_target_type
18+
from ..utils import Substitution, check_target_type, check_sampling_strategy
1919
from ..utils._docstring import _n_jobs_docstring
2020
from ..utils._docstring import _random_state_docstring
2121

@@ -208,6 +208,19 @@ def __init__(
208208
self.sampling_strategy = sampling_strategy
209209
self.replacement = replacement
210210

211+
def _validate_y(self, y):
212+
y_encoded = super()._validate_y(y)
213+
if isinstance(self.sampling_strategy, dict):
214+
self._sampling_strategy = {
215+
np.where(self.classes_ == key)[0][0]: value
216+
for key, value in check_sampling_strategy(
217+
self.sampling_strategy, y, 'under-sampling',
218+
).items()
219+
}
220+
else:
221+
self._sampling_strategy = self.sampling_strategy
222+
return y_encoded
223+
211224
def _validate_estimator(self, default=DecisionTreeClassifier()):
212225
"""Check the estimator and the n_estimator attribute, set the
213226
`base_estimator_` attribute."""
@@ -233,7 +246,7 @@ def _validate_estimator(self, default=DecisionTreeClassifier()):
233246
(
234247
"sampler",
235248
RandomUnderSampler(
236-
sampling_strategy=self.sampling_strategy,
249+
sampling_strategy=self._sampling_strategy,
237250
replacement=self.replacement,
238251
),
239252
),

imblearn/ensemble/_easy_ensemble.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from ..under_sampling import RandomUnderSampler
1616
from ..under_sampling.base import BaseUnderSampler
17-
from ..utils import Substitution, check_target_type
17+
from ..utils import Substitution, check_target_type, check_sampling_strategy
1818
from ..utils._docstring import _n_jobs_docstring
1919
from ..utils._docstring import _random_state_docstring
2020
from ..pipeline import Pipeline
@@ -152,6 +152,19 @@ def __init__(
152152
self.sampling_strategy = sampling_strategy
153153
self.replacement = replacement
154154

155+
def _validate_y(self, y):
156+
y_encoded = super()._validate_y(y)
157+
if isinstance(self.sampling_strategy, dict):
158+
self._sampling_strategy = {
159+
np.where(self.classes_ == key)[0][0]: value
160+
for key, value in check_sampling_strategy(
161+
self.sampling_strategy, y, 'under-sampling',
162+
).items()
163+
}
164+
else:
165+
self._sampling_strategy = self.sampling_strategy
166+
return y_encoded
167+
155168
def _validate_estimator(self, default=AdaBoostClassifier()):
156169
"""Check the estimator and the n_estimator attribute, set the
157170
`base_estimator_` attribute."""
@@ -177,7 +190,7 @@ def _validate_estimator(self, default=AdaBoostClassifier()):
177190
(
178191
"sampler",
179192
RandomUnderSampler(
180-
sampling_strategy=self.sampling_strategy,
193+
sampling_strategy=self._sampling_strategy,
181194
replacement=self.replacement,
182195
),
183196
),

imblearn/ensemble/_forest.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from ..utils import Substitution
3434
from ..utils._docstring import _n_jobs_docstring
3535
from ..utils._docstring import _random_state_docstring
36+
from ..utils._validation import check_sampling_strategy
3637

3738
MAX_INT = np.iinfo(np.int32).max
3839

@@ -364,7 +365,7 @@ def _validate_estimator(self, default=DecisionTreeClassifier()):
364365
self.base_estimator_ = clone(default)
365366

366367
self.base_sampler_ = RandomUnderSampler(
367-
sampling_strategy=self.sampling_strategy,
368+
sampling_strategy=self._sampling_strategy,
368369
replacement=self.replacement,
369370
)
370371

@@ -447,10 +448,20 @@ def fit(self, X, y, sample_weight=None):
447448

448449
self.n_outputs_ = y.shape[1]
449450

450-
y, expanded_class_weight = self._validate_y_class_weight(y)
451+
y_encoded, expanded_class_weight = self._validate_y_class_weight(y)
451452

452453
if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
453-
y = np.ascontiguousarray(y, dtype=DOUBLE)
454+
y_encoded = np.ascontiguousarray(y_encoded, dtype=DOUBLE)
455+
456+
if isinstance(self.sampling_strategy, dict):
457+
self._sampling_strategy = {
458+
np.where(self.classes_[0] == key)[0][0]: value
459+
for key, value in check_sampling_strategy(
460+
self.sampling_strategy, y, 'under-sampling',
461+
).items()
462+
}
463+
else:
464+
self._sampling_strategy = self.sampling_strategy
454465

455466
if expanded_class_weight is not None:
456467
if sample_weight is not None:
@@ -523,7 +534,7 @@ def fit(self, X, y, sample_weight=None):
523534
t,
524535
self,
525536
X,
526-
y,
537+
y_encoded,
527538
sample_weight,
528539
i,
529540
len(trees),
@@ -548,7 +559,7 @@ def fit(self, X, y, sample_weight=None):
548559
)
549560

550561
if self.oob_score:
551-
self._set_oob_score(X, y)
562+
self._set_oob_score(X, y_encoded)
552563

553564
# Decapsulate classes_ attributes
554565
if hasattr(self, "classes_") and self.n_outputs_ == 1:

imblearn/utils/_validation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929

3030
class ArraysTransformer:
31-
"""A class to convert sampler ouput arrays to their orinal types."""
31+
"""A class to convert sampler output arrays to their original types."""
3232

3333
def __init__(self, X, y):
3434
self.x_props = self._gets_props(X)

imblearn/utils/estimator_checks.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
from sklearn.base import clone
2020
from sklearn.datasets import (
21+
fetch_openml,
2122
make_classification,
2223
make_multilabel_classification,
2324
) # noqa
@@ -30,6 +31,7 @@
3031
from sklearn.utils._testing import assert_raises_regex
3132
from sklearn.utils.multiclass import type_of_target
3233

34+
from imblearn.datasets import make_imbalance
3335
from imblearn.over_sampling.base import BaseOverSampler
3436
from imblearn.under_sampling.base import BaseCleaningSampler, BaseUnderSampler
3537

@@ -65,6 +67,7 @@ def _yield_sampler_checks(sampler):
6567

6668
def _yield_classifier_checks(classifier):
6769
yield check_classifier_on_multilabel_or_multioutput_targets
70+
yield check_classifiers_with_encoded_labels
6871

6972

7073
def _yield_all_checks(estimator):
@@ -376,3 +379,24 @@ def check_classifier_on_multilabel_or_multioutput_targets(name, estimator):
376379
msg = "Multilabel and multioutput targets are not supported."
377380
with pytest.raises(ValueError, match=msg):
378381
estimator.fit(X, y)
382+
383+
384+
def check_classifiers_with_encoded_labels(name, classifier):
385+
# Non-regression test for #709
386+
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/709
387+
pytest.importorskip("pandas")
388+
df, y = fetch_openml("iris", version=1, as_frame=True, return_X_y=True)
389+
df, y = make_imbalance(
390+
df, y, sampling_strategy={
391+
"Iris-setosa": 30, "Iris-versicolor": 20, "Iris-virginica": 50,
392+
}
393+
)
394+
classifier.set_params(
395+
sampling_strategy={
396+
"Iris-setosa": 20, "Iris-virginica": 20,
397+
}
398+
)
399+
classifier.fit(df, y)
400+
assert set(classifier.classes_) == set(y.cat.categories.tolist())
401+
y_pred = classifier.predict(df)
402+
assert set(y_pred) == set(y.cat.categories.tolist())

0 commit comments

Comments
 (0)