Skip to content

Commit 2b126e3

Browse files
authored
API change default of replacement and sampling_strategy in BRF (#1006)
1 parent 92b5305 commit 2b126e3

File tree

7 files changed

+160
-26
lines changed

7 files changed

+160
-26
lines changed

doc/ensemble.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,9 @@ each tree of the forest will be provided a balanced bootstrap sample
7777
:class:`~sklearn.ensemble.RandomForestClassifier`::
7878

7979
>>> from imblearn.ensemble import BalancedRandomForestClassifier
80-
>>> brf = BalancedRandomForestClassifier(n_estimators=100, random_state=0)
80+
>>> brf = BalancedRandomForestClassifier(
81+
... n_estimators=100, random_state=0, sampling_strategy="all", replacement=True
82+
... )
8183
>>> brf.fit(X_train, y_train)
8284
BalancedRandomForestClassifier(...)
8385
>>> y_pred = brf.predict(X_test)

doc/whats_new/v0.11.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ Deprecation
3030
and will be removed in version 0.13. Use `categorical_encoder_` instead.
3131
:pr:`1000` by :user:`Guillaume Lemaitre <glemaitre>`.
3232

33+
- The default of the parameters `sampling_strategy` and `replacement` will change in
34+
:class:`~imblearn.ensemble.BalancedRandomForestClassifier` to follow the
35+
implementation of the original paper. This changes will take effect in version 0.13.
36+
:pr:`1006` by :user:`Guillaume Lemaitre <glemaitre>`.
37+
3338
Enhancements
3439
............
3540

examples/applications/plot_impact_imbalanced_classes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,9 @@
319319

320320
rf_clf = make_pipeline(
321321
preprocessor_tree,
322-
BalancedRandomForestClassifier(random_state=42, n_jobs=2),
322+
BalancedRandomForestClassifier(
323+
sampling_strategy="all", replacement=True, random_state=42, n_jobs=2
324+
),
323325
)
324326

325327
# %%

examples/ensemble/plot_comparison_ensemble_classifier.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@
143143
from imblearn.ensemble import BalancedRandomForestClassifier
144144

145145
rf = RandomForestClassifier(n_estimators=50, random_state=0)
146-
brf = BalancedRandomForestClassifier(n_estimators=50, random_state=0)
146+
brf = BalancedRandomForestClassifier(
147+
n_estimators=50, sampling_strategy="all", replacement=True, random_state=0
148+
)
147149

148150
rf.fit(X_train, y_train)
149151
brf.fit(X_train, y_train)

imblearn/ensemble/_forest.py

Lines changed: 81 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,9 @@
3636
from ..base import _ParamsValidationMixin
3737
from ..pipeline import make_pipeline
3838
from ..under_sampling import RandomUnderSampler
39-
from ..under_sampling.base import BaseUnderSampler
4039
from ..utils import Substitution
4140
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
42-
from ..utils._param_validation import Interval, StrOptions
41+
from ..utils._param_validation import Hidden, Interval, StrOptions
4342
from ..utils._validation import check_sampling_strategy
4443
from ..utils.fixes import _fit_context
4544
from ._common import _random_forest_classifier_parameter_constraints
@@ -100,7 +99,6 @@ def _local_parallel_build_trees(
10099

101100

102101
@Substitution(
103-
sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,
104102
n_jobs=_n_jobs_docstring,
105103
random_state=_random_state_docstring,
106104
)
@@ -193,11 +191,56 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassif
193191
Whether to use out-of-bag samples to estimate
194192
the generalization accuracy.
195193
196-
{sampling_strategy}
194+
sampling_strategy : float, str, dict, callable, default="auto"
195+
Sampling information to sample the data set.
196+
197+
- When ``float``, it corresponds to the desired ratio of the number of
198+
samples in the minority class over the number of samples in the
199+
majority class after resampling. Therefore, the ratio is expressed as
200+
:math:`\\alpha_{{us}} = N_{{m}} / N_{{rM}}` where :math:`N_{{m}}` is the
201+
number of samples in the minority class and
202+
:math:`N_{{rM}}` is the number of samples in the majority class
203+
after resampling.
204+
205+
.. warning::
206+
``float`` is only available for **binary** classification. An
207+
error is raised for multi-class classification.
208+
209+
- When ``str``, specify the class targeted by the resampling. The
210+
number of samples in the different classes will be equalized.
211+
Possible choices are:
212+
213+
``'majority'``: resample only the majority class;
214+
215+
``'not minority'``: resample all classes but the minority class;
216+
217+
``'not majority'``: resample all classes but the majority class;
218+
219+
``'all'``: resample all classes;
220+
221+
``'auto'``: equivalent to ``'not minority'``.
222+
223+
- When ``dict``, the keys correspond to the targeted classes. The
224+
values correspond to the desired number of samples for each targeted
225+
class.
226+
227+
- When callable, function taking ``y`` and returns a ``dict``. The keys
228+
correspond to the targeted classes. The values correspond to the
229+
desired number of samples for each class.
230+
231+
.. versionchanged:: 0.11
232+
The default of `sampling_strategy` will change from `"auto"` to
233+
`"all"` in version 0.13. This forces to use a bootstrap of the
234+
minority class as proposed in [1]_.
197235
198236
replacement : bool, default=False
199237
Whether or not to sample randomly with replacement or not.
200238
239+
.. versionchanged:: 0.11
240+
The default of `replacement` will change from `False` to `True` in
241+
version 0.13. This forces to use a bootstrap of the
242+
minority class and draw with replacement as proposed in [1]_.
243+
201244
{n_jobs}
202245
203246
{random_state}
@@ -351,7 +394,8 @@ class labels (multi-output problem).
351394
>>> X, y = make_classification(n_samples=1000, n_classes=3,
352395
... n_informative=4, weights=[0.2, 0.3, 0.5],
353396
... random_state=0)
354-
>>> clf = BalancedRandomForestClassifier(max_depth=2, random_state=0)
397+
>>> clf = BalancedRandomForestClassifier(
398+
... sampling_strategy="all", replacement=True, max_depth=2, random_state=0)
355399
>>> clf.fit(X, y)
356400
BalancedRandomForestClassifier(...)
357401
>>> print(clf.feature_importances_)
@@ -376,8 +420,9 @@ class labels (multi-output problem).
376420
StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
377421
dict,
378422
callable,
423+
Hidden(StrOptions({"warn"})),
379424
],
380-
"replacement": ["boolean"],
425+
"replacement": ["boolean", Hidden(StrOptions({"warn"}))],
381426
}
382427
)
383428

@@ -395,8 +440,8 @@ def __init__(
395440
min_impurity_decrease=0.0,
396441
bootstrap=True,
397442
oob_score=False,
398-
sampling_strategy="auto",
399-
replacement=False,
443+
sampling_strategy="warn",
444+
replacement="warn",
400445
n_jobs=None,
401446
random_state=None,
402447
verbose=0,
@@ -450,7 +495,7 @@ def _validate_estimator(self, default=DecisionTreeClassifier()):
450495

451496
self.base_sampler_ = RandomUnderSampler(
452497
sampling_strategy=self._sampling_strategy,
453-
replacement=self.replacement,
498+
replacement=self._replacement,
454499
)
455500

456501
def _make_sampler_estimator(self, random_state=None):
@@ -496,6 +541,31 @@ def fit(self, X, y, sample_weight=None):
496541
The fitted instance.
497542
"""
498543
self._validate_params()
544+
# TODO: remove in 0.13
545+
if self.sampling_strategy == "warn":
546+
warn(
547+
"The default of `sampling_strategy` will change from `'auto'` to "
548+
"`'all'` in version 0.13. This change will follow the implementation "
549+
"proposed in the original paper. Set to `'all'` to silence this "
550+
"warning and adopt the future behaviour.",
551+
FutureWarning,
552+
)
553+
self._sampling_strategy = "auto"
554+
else:
555+
self._sampling_strategy = self.sampling_strategy
556+
557+
if self.replacement == "warn":
558+
warn(
559+
"The default of `replacement` will change from `False` to "
560+
"`True` in version 0.13. This change will follow the implementation "
561+
"proposed in the original paper. Set to `True` to silence this "
562+
"warning and adopt the future behaviour.",
563+
FutureWarning,
564+
)
565+
self._replacement = False
566+
else:
567+
self._replacement = self.replacement
568+
499569
# Validate or convert input data
500570
if issparse(y):
501571
raise ValueError("sparse multilabel-indicator for y is not supported.")
@@ -533,7 +603,7 @@ def fit(self, X, y, sample_weight=None):
533603
if getattr(y, "dtype", None) != DOUBLE or not y.flags.contiguous:
534604
y_encoded = np.ascontiguousarray(y_encoded, dtype=DOUBLE)
535605

536-
if isinstance(self.sampling_strategy, dict):
606+
if isinstance(self._sampling_strategy, dict):
537607
self._sampling_strategy = {
538608
np.where(self.classes_[0] == key)[0][0]: value
539609
for key, value in check_sampling_strategy(
@@ -543,7 +613,7 @@ def fit(self, X, y, sample_weight=None):
543613
).items()
544614
}
545615
else:
546-
self._sampling_strategy = self.sampling_strategy
616+
self._sampling_strategy = self._sampling_strategy
547617

548618
if expanded_class_weight is not None:
549619
if sample_weight is not None:

imblearn/ensemble/tests/test_forest.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def imbalanced_dataset():
2828

2929

3030
def test_balanced_random_forest_error_warning_warm_start(imbalanced_dataset):
31-
brf = BalancedRandomForestClassifier(n_estimators=5)
31+
brf = BalancedRandomForestClassifier(
32+
n_estimators=5, sampling_strategy="all", replacement=True
33+
)
3234
brf.fit(*imbalanced_dataset)
3335

3436
with pytest.raises(ValueError, match="must be larger or equal to"):
@@ -44,7 +46,12 @@ def test_balanced_random_forest_error_warning_warm_start(imbalanced_dataset):
4446

4547
def test_balanced_random_forest(imbalanced_dataset):
4648
n_estimators = 10
47-
brf = BalancedRandomForestClassifier(n_estimators=n_estimators, random_state=0)
49+
brf = BalancedRandomForestClassifier(
50+
n_estimators=n_estimators,
51+
random_state=0,
52+
sampling_strategy="all",
53+
replacement=True,
54+
)
4855
brf.fit(*imbalanced_dataset)
4956

5057
assert len(brf.samplers_) == n_estimators
@@ -56,7 +63,12 @@ def test_balanced_random_forest(imbalanced_dataset):
5663
def test_balanced_random_forest_attributes(imbalanced_dataset):
5764
X, y = imbalanced_dataset
5865
n_estimators = 10
59-
brf = BalancedRandomForestClassifier(n_estimators=n_estimators, random_state=0)
66+
brf = BalancedRandomForestClassifier(
67+
n_estimators=n_estimators,
68+
random_state=0,
69+
sampling_strategy="all",
70+
replacement=True,
71+
)
6072
brf.fit(X, y)
6173

6274
for idx in range(n_estimators):
@@ -80,7 +92,9 @@ def test_balanced_random_forest_sample_weight(imbalanced_dataset):
8092
rng = np.random.RandomState(42)
8193
X, y = imbalanced_dataset
8294
sample_weight = rng.rand(y.shape[0])
83-
brf = BalancedRandomForestClassifier(n_estimators=5, random_state=0)
95+
brf = BalancedRandomForestClassifier(
96+
n_estimators=5, random_state=0, sampling_strategy="all", replacement=True
97+
)
8498
brf.fit(X, y, sample_weight)
8599

86100

@@ -95,6 +109,8 @@ def test_balanced_random_forest_oob(imbalanced_dataset):
95109
random_state=0,
96110
n_estimators=1000,
97111
min_samples_leaf=2,
112+
sampling_strategy="all",
113+
replacement=True,
98114
)
99115

100116
est.fit(X_train, y_train)
@@ -104,14 +120,19 @@ def test_balanced_random_forest_oob(imbalanced_dataset):
104120

105121
# Check warning if not enough estimators
106122
est = BalancedRandomForestClassifier(
107-
oob_score=True, random_state=0, n_estimators=1, bootstrap=True
123+
oob_score=True,
124+
random_state=0,
125+
n_estimators=1,
126+
bootstrap=True,
127+
sampling_strategy="all",
128+
replacement=True,
108129
)
109130
with pytest.warns(UserWarning) and np.errstate(divide="ignore", invalid="ignore"):
110131
est.fit(X, y)
111132

112133

113134
def test_balanced_random_forest_grid_search(imbalanced_dataset):
114-
brf = BalancedRandomForestClassifier()
135+
brf = BalancedRandomForestClassifier(sampling_strategy="all", replacement=True)
115136
grid = GridSearchCV(brf, {"n_estimators": (1, 2), "max_depth": (1, 2)}, cv=3)
116137
grid.fit(*imbalanced_dataset)
117138

@@ -127,13 +148,17 @@ def test_little_tree_with_small_max_samples():
127148
n_estimators=1,
128149
random_state=rng,
129150
max_samples=None,
151+
sampling_strategy="all",
152+
replacement=True,
130153
)
131154

132155
# Second fit with max samples restricted to just 2
133156
est2 = BalancedRandomForestClassifier(
134157
n_estimators=1,
135158
random_state=rng,
136159
max_samples=2,
160+
sampling_strategy="all",
161+
replacement=True,
137162
)
138163

139164
est1.fit(X, y)
@@ -147,11 +172,13 @@ def test_little_tree_with_small_max_samples():
147172

148173

149174
def test_balanced_random_forest_pruning(imbalanced_dataset):
150-
brf = BalancedRandomForestClassifier()
175+
brf = BalancedRandomForestClassifier(sampling_strategy="all", replacement=True)
151176
brf.fit(*imbalanced_dataset)
152177
n_nodes_no_pruning = brf.estimators_[0].tree_.node_count
153178

154-
brf_pruned = BalancedRandomForestClassifier(ccp_alpha=0.015)
179+
brf_pruned = BalancedRandomForestClassifier(
180+
ccp_alpha=0.015, sampling_strategy="all", replacement=True
181+
)
155182
brf_pruned.fit(*imbalanced_dataset)
156183
n_nodes_pruning = brf_pruned.estimators_[0].tree_.node_count
157184

@@ -168,25 +195,47 @@ def test_balanced_random_forest_oob_binomial(ratio):
168195
X = np.arange(n_samples).reshape(-1, 1)
169196
y = rng.binomial(1, ratio, size=n_samples)
170197

171-
erf = BalancedRandomForestClassifier(oob_score=True, random_state=42)
198+
erf = BalancedRandomForestClassifier(
199+
oob_score=True,
200+
random_state=42,
201+
sampling_strategy="not minority",
202+
replacement=False,
203+
)
172204
erf.fit(X, y)
173205
assert np.abs(erf.oob_score_ - 0.5) < 0.1
174206

175207

176208
def test_balanced_bagging_classifier_n_features():
177209
"""Check that we raise a FutureWarning when accessing `n_features_`."""
178210
X, y = load_iris(return_X_y=True)
179-
estimator = BalancedRandomForestClassifier().fit(X, y)
211+
estimator = BalancedRandomForestClassifier(
212+
sampling_strategy="all", replacement=True
213+
).fit(X, y)
180214
with pytest.warns(FutureWarning, match="`n_features_` was deprecated"):
181215
estimator.n_features_
182216

183217

184218
@pytest.mark.skipif(
185219
sklearn_version < parse_version("1.2"), reason="requires scikit-learn>=1.2"
186220
)
187-
def test_balanced_bagging_classifier_base_estimator():
221+
def test_balanced_random_forest_classifier_base_estimator():
188222
"""Check that we raise a FutureWarning when accessing `base_estimator_`."""
189223
X, y = load_iris(return_X_y=True)
190-
estimator = BalancedRandomForestClassifier().fit(X, y)
224+
estimator = BalancedRandomForestClassifier(
225+
sampling_strategy="all", replacement=True
226+
).fit(X, y)
191227
with pytest.warns(FutureWarning, match="`base_estimator_` was deprecated"):
192228
estimator.base_estimator_
229+
230+
231+
# TODO: remove in 0.13
232+
def test_balanced_random_forest_change_behaviour(imbalanced_dataset):
233+
"""Check that we raise a change of behaviour for the parameters `sampling_strategy`
234+
and `replacement`.
235+
"""
236+
estimator = BalancedRandomForestClassifier(sampling_strategy="all")
237+
with pytest.warns(FutureWarning, match="The default of `replacement`"):
238+
estimator.fit(*imbalanced_dataset)
239+
estimator = BalancedRandomForestClassifier(replacement=True)
240+
with pytest.warns(FutureWarning, match="The default of `sampling_strategy`"):
241+
estimator.fit(*imbalanced_dataset)

imblearn/utils/estimator_checks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ def _set_checking_parameters(estimator):
7575
)
7676
if name == "KMeansSMOTE":
7777
estimator.set_params(kmeans_estimator=12)
78+
if name == "BalancedRandomForestClassifier":
79+
# TODO: remove in 0.13
80+
# future default in 0.13
81+
estimator.set_params(replacement=True, sampling_strategy="all")
7882

7983

8084
def _yield_sampler_checks(sampler):

0 commit comments

Comments
 (0)