1
1
"""Class to perform under-sampling using balace cascade."""
2
2
from __future__ import print_function
3
3
4
+ import warnings
5
+
4
6
import numpy as np
7
+
8
+ from sklearn .base import ClassifierMixin
9
+ from sklearn .neighbors import KNeighborsClassifier
5
10
from sklearn .utils import check_random_state
11
+ from sklearn .utils .validation import has_fit_parameter
6
12
7
- from .. base import BaseBinarySampler
13
+ from six import string_types
8
14
9
- ESTIMATOR_KIND = ('knn' , 'decision-tree' , 'random-forest' , 'adaboost' ,
10
- 'gradient-boosting' , 'linear-svm' )
15
+ from ..base import BaseBinarySampler
11
16
12
17
13
18
class BalanceCascade (BaseBinarySampler ):
@@ -40,18 +45,29 @@ class BalanceCascade(BaseBinarySampler):
40
45
the training will be selected that could lead to a large number of
41
46
subsets. We can probably deduce this number empirically.
42
47
43
- classifier : str, optional (default='knn' )
48
+ classifier : str, optional (default=None )
44
49
The classifier that will be selected to confront the prediction
45
50
with the real labels. The choices are the following: 'knn',
46
51
'decision-tree', 'random-forest', 'adaboost', 'gradient-boosting'
47
52
and 'linear-svm'.
48
53
54
+ NOTE: `classifier` is deprecated from 0.2 and will be replaced in 0.4.
55
+ Use `estimator` instead.
56
+
57
+ estimator : object, optional (default=KNeighborsClassifier())
58
+ An estimator inherited from `sklearn.base.ClassifierMixin` and having
59
+ an attribute `predict_proba`.
60
+
49
61
bootstrap : bool, optional (default=True)
50
62
Whether to bootstrap the data before each iteration.
51
63
52
64
**kwargs : keywords
53
65
The parameters associated with the classifier provided.
54
66
67
+ NOTE: `**kwargs` has been deprecated from 0.2 and will be replaced in
68
+ 0.4. Use `estimator` object instead to pass parameters associated
69
+ to an estimator.
70
+
55
71
Attributes
56
72
----------
57
73
min_c_ : str or int
@@ -100,16 +116,97 @@ class BalanceCascade(BaseBinarySampler):
100
116
"""
101
117
102
118
def __init__ (self , ratio = 'auto' , return_indices = False , random_state = None ,
103
- n_max_subset = None , classifier = 'knn' , bootstrap = True ,
104
- ** kwargs ):
119
+ n_max_subset = None , classifier = None , estimator = None ,
120
+ bootstrap = True , ** kwargs ):
105
121
super (BalanceCascade , self ).__init__ (ratio = ratio ,
106
122
random_state = random_state )
107
123
self .return_indices = return_indices
108
124
self .classifier = classifier
125
+ self .estimator = estimator
109
126
self .n_max_subset = n_max_subset
110
127
self .bootstrap = bootstrap
111
128
self .kwargs = kwargs
112
129
130
+ def _validate_estimator (self ):
131
+ """Private function to create the classifier"""
132
+
133
+ if self .classifier is not None :
134
+ warnings .warn ('`classifier` will be replaced in version'
135
+ ' 0.4. Use a `estimator` instead.' ,
136
+ DeprecationWarning )
137
+ self .estimator = self .classifier
138
+
139
+ if (self .estimator is not None and
140
+ isinstance (self .estimator , ClassifierMixin ) and
141
+ hasattr (self .estimator , 'predict' )):
142
+ self .estimator_ = self .estimator
143
+ elif self .estimator is None :
144
+ self .estimator_ = KNeighborsClassifier ()
145
+ # To be removed in 0.4
146
+ elif (self .estimator is not None and
147
+ isinstance (self .estimator , string_types )):
148
+ warnings .warn ('`estimator` will be replaced in version'
149
+ ' 0.4. Use a classifier object instead of a string.' ,
150
+ DeprecationWarning )
151
+ # Define the classifier to use
152
+ if self .estimator == 'knn' :
153
+ self .estimator_ = KNeighborsClassifier (
154
+ ** self .kwargs )
155
+ elif self .estimator == 'decision-tree' :
156
+ from sklearn .tree import DecisionTreeClassifier
157
+ self .estimator_ = DecisionTreeClassifier (
158
+ random_state = self .random_state ,
159
+ ** self .kwargs )
160
+ elif self .estimator == 'random-forest' :
161
+ from sklearn .ensemble import RandomForestClassifier
162
+ self .estimator_ = RandomForestClassifier (
163
+ random_state = self .random_state ,
164
+ ** self .kwargs )
165
+ elif self .estimator == 'adaboost' :
166
+ from sklearn .ensemble import AdaBoostClassifier
167
+ self .estimator_ = AdaBoostClassifier (
168
+ random_state = self .random_state ,
169
+ ** self .kwargs )
170
+ elif self .estimator == 'gradient-boosting' :
171
+ from sklearn .ensemble import GradientBoostingClassifier
172
+ self .estimator_ = GradientBoostingClassifier (
173
+ random_state = self .random_state ,
174
+ ** self .kwargs )
175
+ elif self .estimator == 'linear-svm' :
176
+ from sklearn .svm import LinearSVC
177
+ self .estimator_ = LinearSVC (random_state = self .random_state ,
178
+ ** self .kwargs )
179
+ else :
180
+ raise NotImplementedError
181
+ else :
182
+ raise ValueError ('Invalid parameter `estimator`' )
183
+
184
+ self .logger .debug (self .estimator_ )
185
+
186
+ def fit (self , X , y ):
187
+ """Find the classes statistics before to perform sampling.
188
+
189
+ Parameters
190
+ ----------
191
+ X : ndarray, shape (n_samples, n_features)
192
+ Matrix containing the data which have to be sampled.
193
+
194
+ y : ndarray, shape (n_samples, )
195
+ Corresponding label for each sample in X.
196
+
197
+ Returns
198
+ -------
199
+ self : object,
200
+ Return self.
201
+
202
+ """
203
+
204
+ super (BalanceCascade , self ).fit (X , y )
205
+
206
+ self ._validate_estimator ()
207
+
208
+ return self
209
+
113
210
def _sample (self , X , y ):
114
211
"""Resample the dataset.
115
212
@@ -135,42 +232,9 @@ def _sample(self, X, y):
135
232
136
233
"""
137
234
138
- if self .classifier not in ESTIMATOR_KIND :
139
- raise NotImplementedError
140
-
141
235
random_state = check_random_state (self .random_state )
142
-
143
- # Define the classifier to use
144
- if self .classifier == 'knn' :
145
- from sklearn .neighbors import KNeighborsClassifier
146
- classifier = KNeighborsClassifier (
147
- ** self .kwargs )
148
- elif self .classifier == 'decision-tree' :
149
- from sklearn .tree import DecisionTreeClassifier
150
- classifier = DecisionTreeClassifier (
151
- random_state = random_state ,
152
- ** self .kwargs )
153
- elif self .classifier == 'random-forest' :
154
- from sklearn .ensemble import RandomForestClassifier
155
- classifier = RandomForestClassifier (
156
- random_state = random_state ,
157
- ** self .kwargs )
158
- elif self .classifier == 'adaboost' :
159
- from sklearn .ensemble import AdaBoostClassifier
160
- classifier = AdaBoostClassifier (
161
- random_state = random_state ,
162
- ** self .kwargs )
163
- elif self .classifier == 'gradient-boosting' :
164
- from sklearn .ensemble import GradientBoostingClassifier
165
- classifier = GradientBoostingClassifier (
166
- random_state = random_state ,
167
- ** self .kwargs )
168
- elif self .classifier == 'linear-svm' :
169
- from sklearn .svm import LinearSVC
170
- classifier = LinearSVC (random_state = random_state ,
171
- ** self .kwargs )
172
- else :
173
- raise NotImplementedError
236
+ support_sample_weight = has_fit_parameter (self .estimator_ ,
237
+ "sample_weight" )
174
238
175
239
X_resampled = []
176
240
y_resampled = []
@@ -185,6 +249,7 @@ def _sample(self, X, y):
185
249
# return them later
186
250
if self .return_indices :
187
251
idx_min = np .flatnonzero (y == self .min_c_ )
252
+ idx_maj = np .flatnonzero (y == self .maj_c_ )
188
253
189
254
# Condition to initiliase before the search
190
255
b_subset_search = True
@@ -227,27 +292,42 @@ def _sample(self, X, y):
227
292
X_resampled .append (x_data )
228
293
y_resampled .append (y_data )
229
294
if self .return_indices :
230
- idx_under .append (np .concatenate ((idx_min , idx_sel_from_maj ),
295
+ idx_under .append (np .concatenate ((idx_min ,
296
+ idx_maj [idx_sel_from_maj ]),
231
297
axis = 0 ))
232
298
233
- if (not (self .classifier == 'knn' or
234
- self .classifier == 'linear-svm' ) and
235
- self .bootstrap ):
236
- # Apply a bootstrap on x_data
237
- curr_sample_weight = np .ones ((y_data .size ,), dtype = np .float64 )
299
+ # Get the indices of interest
300
+ if self .bootstrap :
238
301
indices = random_state .randint (0 , y_data .size , y_data .size )
239
- sample_counts = np . bincount ( indices , minlength = y_data . size )
240
- curr_sample_weight *= sample_counts
302
+ else :
303
+ indices = np . arange ( y_data . size )
241
304
242
- # Train the classifier using the current data
243
- classifier .fit (x_data , y_data , curr_sample_weight )
305
+ # Draw samples, using sample weights, and then fit
306
+ if support_sample_weight :
307
+ self .logger .debug ('Sample-weight is supported' )
308
+ curr_sample_weight = np .ones ((y_data .size ,), dtype = np .float64 )
244
309
310
+ if self .bootstrap :
311
+ self .logger .debug ('Go for a bootstrap' )
312
+ sample_counts = np .bincount (indices , minlength = y_data .size )
313
+ curr_sample_weight *= sample_counts
314
+ else :
315
+ self .logger .debug ('No bootstrap' )
316
+ mask = np .zeros (y_data .size , dtype = np .bool )
317
+ mask [indices ] = True
318
+ not_indices_mask = ~ mask
319
+ curr_sample_weight [not_indices_mask ] = 0
320
+
321
+ self .estimator_ .fit (x_data , y_data ,
322
+ sample_weight = curr_sample_weight )
323
+
324
+ # Draw samples, using a mask, and then fit
245
325
else :
246
- # Train the classifier using the current data
247
- classifier . fit (x_data , y_data )
326
+ self . logger . debug ( 'Sample-weight is not supported' )
327
+ self . estimator_ . fit (x_data [ indices ] , y_data [ indices ] )
248
328
249
329
# Predict using only the majority class
250
- pred_label = classifier .predict (N_x [idx_sel_from_maj , :])
330
+ pred_label = self . estimator_ .predict (N_x [idx_sel_from_maj , :])
251
331
252
332
# Basically let's find which sample have to be retained for the
253
333
# next round
@@ -288,9 +368,8 @@ def _sample(self, X, y):
288
368
X_resampled .append (x_data )
289
369
y_resampled .append (y_data )
290
370
if self .return_indices :
291
- idx_under .append (np .concatenate ((idx_min ,
292
- idx_sel_from_maj ),
293
- axis = 0 ))
371
+ idx_under .append (np .concatenate (
372
+ (idx_min , idx_maj [idx_sel_from_maj ]), axis = 0 ))
294
373
295
374
self .logger .debug ('Creation of the subset #%s' , n_subsets )
296
375
@@ -321,9 +400,8 @@ def _sample(self, X, y):
321
400
X_resampled .append (x_data )
322
401
y_resampled .append (y_data )
323
402
if self .return_indices :
324
- idx_under .append (np .concatenate ((idx_min ,
325
- idx_sel_from_maj ),
326
- axis = 0 ))
403
+ idx_under .append (np .concatenate (
404
+ (idx_min , idx_maj [idx_sel_from_maj ]), axis = 0 ))
327
405
self .logger .debug ('Creation of the subset #%s' , n_subsets )
328
406
329
407
# We found a new subset, increase the counter
0 commit comments