Skip to content

Commit d8cf8d6

Browse files
authored
API change behaviour of bootstrap in BRF (#1010)
1 parent 124d108 commit d8cf8d6

File tree

8 files changed

+81
-22
lines changed

8 files changed

+81
-22
lines changed

doc/ensemble.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ each tree of the forest will be provided a balanced bootstrap sample
7878

7979
>>> from imblearn.ensemble import BalancedRandomForestClassifier
8080
>>> brf = BalancedRandomForestClassifier(
81-
... n_estimators=100, random_state=0, sampling_strategy="all", replacement=True
81+
... n_estimators=100, random_state=0, sampling_strategy="all", replacement=True,
82+
... bootstrap=False,
8283
... )
8384
>>> brf.fit(X_train, y_train)
8485
BalancedRandomForestClassifier(...)

doc/whats_new/v0.11.rst

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
.. _changes_0_11:
22

3+
Version 0.11.1
4+
==============
5+
6+
Changelog
7+
---------
8+
9+
310
Version 0.11.0
411
==============
512

@@ -40,9 +47,11 @@ Deprecation
4047
and will be removed in version 0.13. Use `categorical_encoder_` instead.
4148
:pr:`1000` by :user:`Guillaume Lemaitre <glemaitre>`.
4249

43-
- The default of the parameters `sampling_strategy` and `replacement` will change in
50+
- The default of the parameters `sampling_strategy`, `bootstrap` and
51+
`replacement` will change in
4452
:class:`~imblearn.ensemble.BalancedRandomForestClassifier` to follow the
45-
implementation of the original paper. This changes will take effect in version 0.13.
53+
implementation of the original paper. This changes will take effect in
54+
version 0.13.
4655
:pr:`1006` by :user:`Guillaume Lemaitre <glemaitre>`.
4756

4857
Enhancements

examples/applications/plot_impact_imbalanced_classes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,11 @@
320320
rf_clf = make_pipeline(
321321
preprocessor_tree,
322322
BalancedRandomForestClassifier(
323-
sampling_strategy="all", replacement=True, random_state=42, n_jobs=2
323+
sampling_strategy="all",
324+
replacement=True,
325+
bootstrap=False,
326+
random_state=42,
327+
n_jobs=2,
324328
),
325329
)
326330

examples/ensemble/plot_comparison_ensemble_classifier.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,11 @@
144144

145145
rf = RandomForestClassifier(n_estimators=50, random_state=0)
146146
brf = BalancedRandomForestClassifier(
147-
n_estimators=50, sampling_strategy="all", replacement=True, random_state=0
147+
n_estimators=50,
148+
sampling_strategy="all",
149+
replacement=True,
150+
bootstrap=False,
151+
random_state=0,
148152
)
149153

150154
rf.fit(X_train, y_train)

imblearn/ensemble/_forest.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,10 @@ def _local_parallel_build_trees(
105105
class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassifier):
106106
"""A balanced random forest classifier.
107107
108-
A balanced random forest randomly under-samples each bootstrap sample to
109-
balance it.
108+
A balanced random forest differs from a classical random forest by the
109+
fact that it will draw a bootstrap sample from the minority class and
110+
sample with replacement the same number of samples from the majority
111+
class.
110112
111113
Read more in the :ref:`User Guide <forest>`.
112114
@@ -187,6 +189,12 @@ class BalancedRandomForestClassifier(_ParamsValidationMixin, RandomForestClassif
187189
bootstrap : bool, default=True
188190
Whether bootstrap samples are used when building trees.
189191
192+
.. versionchanged:: 0.13
193+
The default of `bootstrap` will change from `True` to `False` in
194+
version 0.13. Bootstrapping is already taken care by the internal
195+
sampler using `replacement=True`. This implementation follows the
196+
algorithm proposed in [1]_.
197+
190198
oob_score : bool, default=False
191199
Whether to use out-of-bag samples to estimate
192200
the generalization accuracy.
@@ -395,7 +403,8 @@ class labels (multi-output problem).
395403
... n_informative=4, weights=[0.2, 0.3, 0.5],
396404
... random_state=0)
397405
>>> clf = BalancedRandomForestClassifier(
398-
... sampling_strategy="all", replacement=True, max_depth=2, random_state=0)
406+
... sampling_strategy="all", replacement=True, max_depth=2, random_state=0,
407+
... bootstrap=False)
399408
>>> clf.fit(X, y)
400409
BalancedRandomForestClassifier(...)
401410
>>> print(clf.feature_importances_)
@@ -415,6 +424,7 @@ class labels (multi-output problem).
415424

416425
_parameter_constraints.update(
417426
{
427+
"bootstrap": ["boolean", Hidden(StrOptions({"warn"}))],
418428
"sampling_strategy": [
419429
Interval(numbers.Real, 0, 1, closed="right"),
420430
StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
@@ -438,7 +448,7 @@ def __init__(
438448
max_features="sqrt",
439449
max_leaf_nodes=None,
440450
min_impurity_decrease=0.0,
441-
bootstrap=True,
451+
bootstrap="warn",
442452
oob_score=False,
443453
sampling_strategy="warn",
444454
replacement="warn",
@@ -566,6 +576,18 @@ def fit(self, X, y, sample_weight=None):
566576
else:
567577
self._replacement = self.replacement
568578

579+
if self.bootstrap == "warn":
580+
warn(
581+
"The default of `bootstrap` will change from `True` to "
582+
"`False` in version 0.13. This change will follow the implementation "
583+
"proposed in the original paper. Set to `False` to silence this "
584+
"warning and adopt the future behaviour.",
585+
FutureWarning,
586+
)
587+
self._bootstrap = True
588+
else:
589+
self._bootstrap = self.bootstrap
590+
569591
# Validate or convert input data
570592
if issparse(y):
571593
raise ValueError("sparse multilabel-indicator for y is not supported.")
@@ -629,7 +651,7 @@ def fit(self, X, y, sample_weight=None):
629651
# Check parameters
630652
self._validate_estimator()
631653

632-
if not self.bootstrap and self.oob_score:
654+
if not self._bootstrap and self.oob_score:
633655
raise ValueError("Out of bag estimation only available if bootstrap=True")
634656

635657
random_state = check_random_state(self.random_state)
@@ -681,7 +703,7 @@ def fit(self, X, y, sample_weight=None):
681703
delayed(_local_parallel_build_trees)(
682704
s,
683705
t,
684-
self.bootstrap,
706+
self._bootstrap,
685707
X,
686708
y_encoded,
687709
sample_weight,

imblearn/ensemble/tests/test_forest.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def imbalanced_dataset():
2929

3030
def test_balanced_random_forest_error_warning_warm_start(imbalanced_dataset):
3131
brf = BalancedRandomForestClassifier(
32-
n_estimators=5, sampling_strategy="all", replacement=True
32+
n_estimators=5, sampling_strategy="all", replacement=True, bootstrap=False
3333
)
3434
brf.fit(*imbalanced_dataset)
3535

@@ -51,6 +51,7 @@ def test_balanced_random_forest(imbalanced_dataset):
5151
random_state=0,
5252
sampling_strategy="all",
5353
replacement=True,
54+
bootstrap=False,
5455
)
5556
brf.fit(*imbalanced_dataset)
5657

@@ -68,6 +69,7 @@ def test_balanced_random_forest_attributes(imbalanced_dataset):
6869
random_state=0,
6970
sampling_strategy="all",
7071
replacement=True,
72+
bootstrap=False,
7173
)
7274
brf.fit(X, y)
7375

@@ -93,7 +95,11 @@ def test_balanced_random_forest_sample_weight(imbalanced_dataset):
9395
X, y = imbalanced_dataset
9496
sample_weight = rng.rand(y.shape[0])
9597
brf = BalancedRandomForestClassifier(
96-
n_estimators=5, random_state=0, sampling_strategy="all", replacement=True
98+
n_estimators=5,
99+
random_state=0,
100+
sampling_strategy="all",
101+
replacement=True,
102+
bootstrap=False,
97103
)
98104
brf.fit(X, y, sample_weight)
99105

@@ -111,6 +117,7 @@ def test_balanced_random_forest_oob(imbalanced_dataset):
111117
min_samples_leaf=2,
112118
sampling_strategy="all",
113119
replacement=True,
120+
bootstrap=True,
114121
)
115122

116123
est.fit(X_train, y_train)
@@ -132,7 +139,9 @@ def test_balanced_random_forest_oob(imbalanced_dataset):
132139

133140

134141
def test_balanced_random_forest_grid_search(imbalanced_dataset):
135-
brf = BalancedRandomForestClassifier(sampling_strategy="all", replacement=True)
142+
brf = BalancedRandomForestClassifier(
143+
sampling_strategy="all", replacement=True, bootstrap=False
144+
)
136145
grid = GridSearchCV(brf, {"n_estimators": (1, 2), "max_depth": (1, 2)}, cv=3)
137146
grid.fit(*imbalanced_dataset)
138147

@@ -150,6 +159,7 @@ def test_little_tree_with_small_max_samples():
150159
max_samples=None,
151160
sampling_strategy="all",
152161
replacement=True,
162+
bootstrap=True,
153163
)
154164

155165
# Second fit with max samples restricted to just 2
@@ -159,6 +169,7 @@ def test_little_tree_with_small_max_samples():
159169
max_samples=2,
160170
sampling_strategy="all",
161171
replacement=True,
172+
bootstrap=True,
162173
)
163174

164175
est1.fit(X, y)
@@ -172,12 +183,14 @@ def test_little_tree_with_small_max_samples():
172183

173184

174185
def test_balanced_random_forest_pruning(imbalanced_dataset):
175-
brf = BalancedRandomForestClassifier(sampling_strategy="all", replacement=True)
186+
brf = BalancedRandomForestClassifier(
187+
sampling_strategy="all", replacement=True, bootstrap=False
188+
)
176189
brf.fit(*imbalanced_dataset)
177190
n_nodes_no_pruning = brf.estimators_[0].tree_.node_count
178191

179192
brf_pruned = BalancedRandomForestClassifier(
180-
ccp_alpha=0.015, sampling_strategy="all", replacement=True
193+
ccp_alpha=0.015, sampling_strategy="all", replacement=True, bootstrap=False
181194
)
182195
brf_pruned.fit(*imbalanced_dataset)
183196
n_nodes_pruning = brf_pruned.estimators_[0].tree_.node_count
@@ -200,6 +213,7 @@ def test_balanced_random_forest_oob_binomial(ratio):
200213
random_state=42,
201214
sampling_strategy="not minority",
202215
replacement=False,
216+
bootstrap=True,
203217
)
204218
erf.fit(X, y)
205219
assert np.abs(erf.oob_score_ - 0.5) < 0.1
@@ -209,7 +223,7 @@ def test_balanced_bagging_classifier_n_features():
209223
"""Check that we raise a FutureWarning when accessing `n_features_`."""
210224
X, y = load_iris(return_X_y=True)
211225
estimator = BalancedRandomForestClassifier(
212-
sampling_strategy="all", replacement=True
226+
sampling_strategy="all", replacement=True, bootstrap=False
213227
).fit(X, y)
214228
with pytest.warns(FutureWarning, match="`n_features_` was deprecated"):
215229
estimator.n_features_
@@ -222,7 +236,7 @@ def test_balanced_random_forest_classifier_base_estimator():
222236
"""Check that we raise a FutureWarning when accessing `base_estimator_`."""
223237
X, y = load_iris(return_X_y=True)
224238
estimator = BalancedRandomForestClassifier(
225-
sampling_strategy="all", replacement=True
239+
sampling_strategy="all", replacement=True, bootstrap=False
226240
).fit(X, y)
227241
with pytest.warns(FutureWarning, match="`base_estimator_` was deprecated"):
228242
estimator.base_estimator_
@@ -233,9 +247,14 @@ def test_balanced_random_forest_change_behaviour(imbalanced_dataset):
233247
"""Check that we raise a change of behaviour for the parameters `sampling_strategy`
234248
and `replacement`.
235249
"""
236-
estimator = BalancedRandomForestClassifier(sampling_strategy="all")
250+
estimator = BalancedRandomForestClassifier(sampling_strategy="all", bootstrap=False)
237251
with pytest.warns(FutureWarning, match="The default of `replacement`"):
238252
estimator.fit(*imbalanced_dataset)
239-
estimator = BalancedRandomForestClassifier(replacement=True)
253+
estimator = BalancedRandomForestClassifier(replacement=True, bootstrap=False)
240254
with pytest.warns(FutureWarning, match="The default of `sampling_strategy`"):
241255
estimator.fit(*imbalanced_dataset)
256+
estimator = BalancedRandomForestClassifier(
257+
sampling_strategy="all", replacement=True
258+
)
259+
with pytest.warns(FutureWarning, match="The default of `bootstrap`"):
260+
estimator.fit(*imbalanced_dataset)

imblearn/tests/test_docstring_parameters.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def test_fit_docstring_attributes(name, Estimator):
205205
X = _enforce_estimator_tags_x(est, X)
206206

207207
if "oob_score" in est.get_params():
208-
est.set_params(oob_score=True)
208+
est.set_params(bootstrap=True, oob_score=True)
209209

210210
if is_sampler(est):
211211
est.fit_resample(X, y)

imblearn/utils/estimator_checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def _set_checking_parameters(estimator):
9494
if name == "BalancedRandomForestClassifier":
9595
# TODO: remove in 0.13
9696
# future default in 0.13
97-
estimator.set_params(replacement=True, sampling_strategy="all")
97+
estimator.set_params(replacement=True, sampling_strategy="all", bootstrap=False)
9898

9999

100100
def _yield_sampler_checks(sampler):

0 commit comments

Comments
 (0)