Skip to content

Commit 355c944

Browse files
committed
DOC add note to create balanced RF (#373)
1 parent 03148fe commit 355c944

File tree

3 files changed

+48
-6
lines changed

3 files changed

+48
-6
lines changed

doc/ensemble.rst

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,29 @@ takes the same parameters than the scikit-learn
100100
... ratio='auto',
101101
... replacement=False,
102102
... random_state=0)
103-
>>> bbc.fit(X, y) # doctest: +ELLIPSIS
103+
>>> bbc.fit(X_train, y_train) # doctest: +ELLIPSIS
104104
BalancedBaggingClassifier(...)
105105
>>> y_pred = bbc.predict(X_test)
106106
>>> confusion_matrix(y_test, y_pred)
107-
array([[ 12, 0, 0],
108-
[ 1, 54, 4],
109-
[ 49, 53, 1077]])
107+
array([[ 9, 1, 2],
108+
[ 0, 55, 4],
109+
[ 42, 46, 1091]])
110+
111+
It also possible to turn a balanced bagging classifier into a balanced random
112+
forest using a decision tree classifier and setting the parameter
113+
``max_features='auto'`. It allows to randomly select a subset of features for
114+
each tree::
115+
116+
>>> brf = BalancedBaggingClassifier(
117+
... base_estimator=DecisionTreeClassifier(max_features='auto'),
118+
... random_state=0)
119+
>>> brf.fit(X_train, y_train) # doctest: +ELLIPSIS
120+
BalancedBaggingClassifier(...)
121+
>>> y_pred = brf.predict(X_test)
122+
>>> confusion_matrix(y_test, y_pred)
123+
array([[ 9, 1, 2],
124+
[ 0, 54, 5],
125+
[ 31, 34, 1114]])
110126
111127
See
112128
:ref:`sphx_glr_auto_examples_ensemble_plot_comparison_bagging_classifier.py`.

examples/ensemble/plot_comparison_bagging_classifier.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from sklearn.model_selection import train_test_split
2828
from sklearn.ensemble import BaggingClassifier
29+
from sklearn.tree import DecisionTreeClassifier
2930
from sklearn.metrics import confusion_matrix
3031

3132
from imblearn.datasets import fetch_datasets
@@ -99,4 +100,25 @@ def plot_confusion_matrix(cm, classes,
99100
plot_confusion_matrix(cm_balanced_bagging, classes=np.unique(ozone.target),
100101
title='Confusion matrix using BalancedBaggingClassifier')
101102

103+
###############################################################################
104+
# Turning the balanced bagging classifier into a balanced random forest
105+
###############################################################################
106+
# It is possible to turn the ``BalancedBaggingClassifier`` into a balanced
107+
# random forest by using a ``DecisionTreeClassifier`` with
108+
# ``max_features='auto'``. We illustrate such changes below.
109+
110+
balanced_random_forest = BalancedBaggingClassifier(
111+
base_estimator=DecisionTreeClassifier(max_features='auto'),
112+
random_state=0)
113+
114+
balanced_random_forest.fit(X_train, y_train)
115+
print('Classification results using a balanced random forest classifier on'
116+
' imbalanced data')
117+
y_pred_balanced_rf = balanced_random_forest.predict(X_test)
118+
print(classification_report_imbalanced(y_test, y_pred_balanced_rf))
119+
cm_bagging = confusion_matrix(y_test, y_pred_balanced_rf)
120+
plt.figure()
121+
plot_confusion_matrix(cm_bagging, classes=np.unique(ozone.target),
122+
title='Confusion matrix using balanced random forest')
123+
102124
plt.show()

imblearn/ensemble/classifier.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,10 @@
88

99
import numpy as np
1010

11-
import sklearn
1211
from sklearn.base import clone
1312
from sklearn.ensemble import BaggingClassifier
1413
from sklearn.tree import DecisionTreeClassifier
1514
from sklearn.ensemble.bagging import _generate_bagging_indices
16-
from sklearn.utils import indices_to_mask
1715

1816
from ..pipeline import Pipeline
1917
from ..under_sampling import RandomUnderSampler
@@ -136,6 +134,9 @@ class BalancedBaggingClassifier(BaggingClassifier):
136134
137135
Notes
138136
-----
137+
This is possible to turn this classifier into a balanced random forest [5]_
138+
by passing a :class:`sklearn.tree.DecisionTreeClassifier` with
139+
`max_features='auto'` as a base estimator.
139140
140141
See
141142
:ref:`sphx_glr_auto_examples_ensemble_plot_comparison_bagging_classifier.py`.
@@ -155,6 +156,9 @@ class BalancedBaggingClassifier(BaggingClassifier):
155156
1998.
156157
.. [4] G. Louppe and P. Geurts, "Ensembles on Random Patches", Machine
157158
Learning and Knowledge Discovery in Databases, 346-361, 2012.
159+
.. [5] Chen, Chao, Andy Liaw, and Leo Breiman. "Using random forest to
160+
learn imbalanced data." University of California, Berkeley 110,
161+
2004.
158162
159163
Examples
160164
--------

0 commit comments

Comments
 (0)