Skip to content

Commit 7a5afeb

Browse files
glemaitrechkoar
authored andcommitted
ENH: Pass a classifier object instead of string (#186)
1 parent a1af197 commit 7a5afeb

File tree

4 files changed

+524
-211
lines changed

4 files changed

+524
-211
lines changed

imblearn/ensemble/balance_cascade.py

Lines changed: 138 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
"""Class to perform under-sampling using balace cascade."""
22
from __future__ import print_function
33

4+
import warnings
5+
46
import numpy as np
7+
8+
from sklearn.base import ClassifierMixin
9+
from sklearn.neighbors import KNeighborsClassifier
510
from sklearn.utils import check_random_state
11+
from sklearn.utils.validation import has_fit_parameter
612

7-
from ..base import BaseBinarySampler
13+
from six import string_types
814

9-
ESTIMATOR_KIND = ('knn', 'decision-tree', 'random-forest', 'adaboost',
10-
'gradient-boosting', 'linear-svm')
15+
from ..base import BaseBinarySampler
1116

1217

1318
class BalanceCascade(BaseBinarySampler):
@@ -40,18 +45,29 @@ class BalanceCascade(BaseBinarySampler):
4045
the training will be selected that could lead to a large number of
4146
subsets. We can probably deduce this number empirically.
4247
43-
classifier : str, optional (default='knn')
48+
classifier : str, optional (default=None)
4449
The classifier that will be selected to confront the prediction
4550
with the real labels. The choices are the following: 'knn',
4651
'decision-tree', 'random-forest', 'adaboost', 'gradient-boosting'
4752
and 'linear-svm'.
4853
54+
NOTE: `classifier` is deprecated from 0.2 and will be replaced in 0.4.
55+
Use `estimator` instead.
56+
57+
estimator : object, optional (default=KNeighborsClassifier())
58+
An estimator inherited from `sklearn.base.ClassifierMixin` and having
59+
an attribute `predict_proba`.
60+
4961
bootstrap : bool, optional (default=True)
5062
Whether to bootstrap the data before each iteration.
5163
5264
**kwargs : keywords
5365
The parameters associated with the classifier provided.
5466
67+
NOTE: `**kwargs` has been deprecated from 0.2 and will be replaced in
68+
0.4. Use `estimator` object instead to pass parameters associated
69+
to an estimator.
70+
5571
Attributes
5672
----------
5773
min_c_ : str or int
@@ -100,16 +116,97 @@ class BalanceCascade(BaseBinarySampler):
100116
"""
101117

102118
def __init__(self, ratio='auto', return_indices=False, random_state=None,
103-
n_max_subset=None, classifier='knn', bootstrap=True,
104-
**kwargs):
119+
n_max_subset=None, classifier=None, estimator=None,
120+
bootstrap=True, **kwargs):
105121
super(BalanceCascade, self).__init__(ratio=ratio,
106122
random_state=random_state)
107123
self.return_indices = return_indices
108124
self.classifier = classifier
125+
self.estimator = estimator
109126
self.n_max_subset = n_max_subset
110127
self.bootstrap = bootstrap
111128
self.kwargs = kwargs
112129

130+
def _validate_estimator(self):
131+
"""Private function to create the classifier"""
132+
133+
if self.classifier is not None:
134+
warnings.warn('`classifier` will be replaced in version'
135+
' 0.4. Use a `estimator` instead.',
136+
DeprecationWarning)
137+
self.estimator = self.classifier
138+
139+
if (self.estimator is not None and
140+
isinstance(self.estimator, ClassifierMixin) and
141+
hasattr(self.estimator, 'predict')):
142+
self.estimator_ = self.estimator
143+
elif self.estimator is None:
144+
self.estimator_ = KNeighborsClassifier()
145+
# To be removed in 0.4
146+
elif (self.estimator is not None and
147+
isinstance(self.estimator, string_types)):
148+
warnings.warn('`estimator` will be replaced in version'
149+
' 0.4. Use a classifier object instead of a string.',
150+
DeprecationWarning)
151+
# Define the classifier to use
152+
if self.estimator == 'knn':
153+
self.estimator_ = KNeighborsClassifier(
154+
**self.kwargs)
155+
elif self.estimator == 'decision-tree':
156+
from sklearn.tree import DecisionTreeClassifier
157+
self.estimator_ = DecisionTreeClassifier(
158+
random_state=self.random_state,
159+
**self.kwargs)
160+
elif self.estimator == 'random-forest':
161+
from sklearn.ensemble import RandomForestClassifier
162+
self.estimator_ = RandomForestClassifier(
163+
random_state=self.random_state,
164+
**self.kwargs)
165+
elif self.estimator == 'adaboost':
166+
from sklearn.ensemble import AdaBoostClassifier
167+
self.estimator_ = AdaBoostClassifier(
168+
random_state=self.random_state,
169+
**self.kwargs)
170+
elif self.estimator == 'gradient-boosting':
171+
from sklearn.ensemble import GradientBoostingClassifier
172+
self.estimator_ = GradientBoostingClassifier(
173+
random_state=self.random_state,
174+
**self.kwargs)
175+
elif self.estimator == 'linear-svm':
176+
from sklearn.svm import LinearSVC
177+
self.estimator_ = LinearSVC(random_state=self.random_state,
178+
**self.kwargs)
179+
else:
180+
raise NotImplementedError
181+
else:
182+
raise ValueError('Invalid parameter `estimator`')
183+
184+
self.logger.debug(self.estimator_)
185+
186+
def fit(self, X, y):
187+
"""Find the classes statistics before to perform sampling.
188+
189+
Parameters
190+
----------
191+
X : ndarray, shape (n_samples, n_features)
192+
Matrix containing the data which have to be sampled.
193+
194+
y : ndarray, shape (n_samples, )
195+
Corresponding label for each sample in X.
196+
197+
Returns
198+
-------
199+
self : object,
200+
Return self.
201+
202+
"""
203+
204+
super(BalanceCascade, self).fit(X, y)
205+
206+
self._validate_estimator()
207+
208+
return self
209+
113210
def _sample(self, X, y):
114211
"""Resample the dataset.
115212
@@ -135,42 +232,9 @@ def _sample(self, X, y):
135232
136233
"""
137234

138-
if self.classifier not in ESTIMATOR_KIND:
139-
raise NotImplementedError
140-
141235
random_state = check_random_state(self.random_state)
142-
143-
# Define the classifier to use
144-
if self.classifier == 'knn':
145-
from sklearn.neighbors import KNeighborsClassifier
146-
classifier = KNeighborsClassifier(
147-
**self.kwargs)
148-
elif self.classifier == 'decision-tree':
149-
from sklearn.tree import DecisionTreeClassifier
150-
classifier = DecisionTreeClassifier(
151-
random_state=random_state,
152-
**self.kwargs)
153-
elif self.classifier == 'random-forest':
154-
from sklearn.ensemble import RandomForestClassifier
155-
classifier = RandomForestClassifier(
156-
random_state=random_state,
157-
**self.kwargs)
158-
elif self.classifier == 'adaboost':
159-
from sklearn.ensemble import AdaBoostClassifier
160-
classifier = AdaBoostClassifier(
161-
random_state=random_state,
162-
**self.kwargs)
163-
elif self.classifier == 'gradient-boosting':
164-
from sklearn.ensemble import GradientBoostingClassifier
165-
classifier = GradientBoostingClassifier(
166-
random_state=random_state,
167-
**self.kwargs)
168-
elif self.classifier == 'linear-svm':
169-
from sklearn.svm import LinearSVC
170-
classifier = LinearSVC(random_state=random_state,
171-
**self.kwargs)
172-
else:
173-
raise NotImplementedError
236+
support_sample_weight = has_fit_parameter(self.estimator_,
237+
"sample_weight")
174238

175239
X_resampled = []
176240
y_resampled = []
@@ -185,6 +249,7 @@ def _sample(self, X, y):
185249
# return them later
186250
if self.return_indices:
187251
idx_min = np.flatnonzero(y == self.min_c_)
252+
idx_maj = np.flatnonzero(y == self.maj_c_)
188253

189254
# Condition to initiliase before the search
190255
b_subset_search = True
@@ -227,27 +292,42 @@ def _sample(self, X, y):
227292
X_resampled.append(x_data)
228293
y_resampled.append(y_data)
229294
if self.return_indices:
230-
idx_under.append(np.concatenate((idx_min, idx_sel_from_maj),
295+
idx_under.append(np.concatenate((idx_min,
296+
idx_maj[idx_sel_from_maj]),
231297
axis=0))
232298

233-
if (not (self.classifier == 'knn' or
234-
self.classifier == 'linear-svm') and
235-
self.bootstrap):
236-
# Apply a bootstrap on x_data
237-
curr_sample_weight = np.ones((y_data.size,), dtype=np.float64)
299+
# Get the indices of interest
300+
if self.bootstrap:
238301
indices = random_state.randint(0, y_data.size, y_data.size)
239-
sample_counts = np.bincount(indices, minlength=y_data.size)
240-
curr_sample_weight *= sample_counts
302+
else:
303+
indices = np.arange(y_data.size)
241304

242-
# Train the classifier using the current data
243-
classifier.fit(x_data, y_data, curr_sample_weight)
305+
# Draw samples, using sample weights, and then fit
306+
if support_sample_weight:
307+
self.logger.debug('Sample-weight is supported')
308+
curr_sample_weight = np.ones((y_data.size,), dtype=np.float64)
244309

310+
if self.bootstrap:
311+
self.logger.debug('Go for a bootstrap')
312+
sample_counts = np.bincount(indices, minlength=y_data.size)
313+
curr_sample_weight *= sample_counts
314+
else:
315+
self.logger.debug('No bootstrap')
316+
mask = np.zeros(y_data.size, dtype=np.bool)
317+
mask[indices] = True
318+
not_indices_mask = ~mask
319+
curr_sample_weight[not_indices_mask] = 0
320+
321+
self.estimator_.fit(x_data, y_data,
322+
sample_weight=curr_sample_weight)
323+
324+
# Draw samples, using a mask, and then fit
245325
else:
246-
# Train the classifier using the current data
247-
classifier.fit(x_data, y_data)
326+
self.logger.debug('Sample-weight is not supported')
327+
self.estimator_.fit(x_data[indices], y_data[indices])
248328

249329
# Predict using only the majority class
250-
pred_label = classifier.predict(N_x[idx_sel_from_maj, :])
330+
pred_label = self.estimator_.predict(N_x[idx_sel_from_maj, :])
251331

252332
# Basically let's find which sample have to be retained for the
253333
# next round
@@ -288,9 +368,8 @@ def _sample(self, X, y):
288368
X_resampled.append(x_data)
289369
y_resampled.append(y_data)
290370
if self.return_indices:
291-
idx_under.append(np.concatenate((idx_min,
292-
idx_sel_from_maj),
293-
axis=0))
371+
idx_under.append(np.concatenate(
372+
(idx_min, idx_maj[idx_sel_from_maj]), axis=0))
294373

295374
self.logger.debug('Creation of the subset #%s', n_subsets)
296375

@@ -321,9 +400,8 @@ def _sample(self, X, y):
321400
X_resampled.append(x_data)
322401
y_resampled.append(y_data)
323402
if self.return_indices:
324-
idx_under.append(np.concatenate((idx_min,
325-
idx_sel_from_maj),
326-
axis=0))
403+
idx_under.append(np.concatenate(
404+
(idx_min, idx_maj[idx_sel_from_maj]), axis=0))
327405
self.logger.debug('Creation of the subset #%s', n_subsets)
328406

329407
# We found a new subset, increase the counter

0 commit comments

Comments
 (0)