Skip to content

Commit e6e9a28

Browse files
author
Guillaume Lemaitre
committed
Improve testing of instance hardness threshold
1 parent ef11cc8 commit e6e9a28

File tree

4 files changed

+115
-25
lines changed

4 files changed

+115
-25
lines changed

unbalanced_dataset/ensemble/balance_cascade.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,12 @@ def __init__(self, ratio='auto', return_indices=False, random_state=None,
135135
verbose=verbose,
136136
random_state=random_state)
137137
# Define the classifier to use
138-
self.classifier = classifier
138+
list_classifier = ('knn', 'decision-tree', 'random-forest', 'adaboost',
139+
'gradient-boosting', 'linear-svm')
140+
if classifier in list_classifier:
141+
self.classifier = classifier
142+
else:
143+
raise NotImplementedError
139144
self.n_max_subset = n_max_subset
140145
self.bootstrap = bootstrap
141146
self.kwargs = kwargs
@@ -223,8 +228,7 @@ def sample(self, X, y):
223228
classifier = LinearSVC(random_state=self.random_state,
224229
**self.kwargs)
225230
else:
226-
raise RuntimeError('UnbalancedData.BalanceCascade: classifier '
227-
'not yet supported.')
231+
raise NotImplementedError
228232

229233
X_resampled = []
230234
y_resampled = []

unbalanced_dataset/ensemble/tests/test_balance_cascade.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -302,11 +302,9 @@ def test_init_wrong_classifier():
302302
ratio = 'auto'
303303
classifier = 'rnd'
304304

305-
bc = BalanceCascade(ratio=ratio, random_state=RND_SEED,
306-
return_indices=True, classifier=classifier)
307-
308-
# Create the sampling object
309-
assert_raises(RuntimeError, bc.fit_sample, X, Y)
305+
assert_raises(NotImplementedError, BalanceCascade, ratio=ratio,
306+
random_state=RND_SEED, return_indices=True,
307+
classifier=classifier)
310308

311309

312310
def test_fit_sample_auto_early_stop():

unbalanced_dataset/under_sampling/instance_hardness_threshold.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,13 @@ def __init__(self, estimator='linear-svm', ratio='auto',
132132
random_state=random_state,
133133
verbose=verbose)
134134

135-
# if not hasattr(estimator, 'predict_proba'):
136-
# raise ValueError('Estimator does not have predict_proba method.')
137-
# else:
138-
# self.estimator = estimator
139-
140135
# Define the estimator to use
141-
self.estimator = estimator
136+
list_estimator = ('knn', 'decision-tree', 'random-forest', 'adaboost',
137+
'gradient-boosting', 'linear-svm')
138+
if estimator in list_estimator:
139+
self.estimator = estimator
140+
else:
141+
raise NotImplementedError
142142
self.kwargs = kwargs
143143
self.cv = cv
144144
self.n_jobs = n_jobs
@@ -200,7 +200,6 @@ def sample(self, X, y):
200200
if self.estimator == 'knn':
201201
from sklearn.neighbors import KNeighborsClassifier
202202
estimator = KNeighborsClassifier(
203-
random_state=self.random_state,
204203
**self.kwargs)
205204
elif self.estimator == 'decision-tree':
206205
from sklearn.tree import DecisionTreeClassifier
@@ -227,8 +226,7 @@ def sample(self, X, y):
227226
estimator = SVC(probability=True,
228227
random_state=self.random_state, **self.kwargs)
229228
else:
230-
raise ValueError('UnbalancedData.BalanceCascade: classifier '
231-
'not yet supported.')
229+
raise NotImplementedError
232230

233231
# Create the different folds
234232
skf = StratifiedKFold(y, n_folds=self.cv, shuffle=False,

unbalanced_dataset/under_sampling/tests/test_instance_hardness_threshold.py

Lines changed: 98 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,14 @@ def test_iht_bad_ratio():
5454
ratio=ratio)
5555

5656

57-
# def test_iht_estimator_no_proba():
58-
# """Test either if an error is raised when the estimator does not have
59-
# predict_proba function"""
57+
def test_iht_wrong_estimator():
58+
"""Test either if an error is raised when the estimator is unknown"""
6059

61-
# # Resample the data
62-
# ratio = 0.5
63-
# est = 'linear-svm'
64-
# assert_raises(ValueError, InstanceHardnessThreshold, est, ratio=ratio,
65-
# random_state=RND_SEED)
60+
# Resample the data
61+
ratio = 0.5
62+
est = 'rnd'
63+
assert_raises(NotImplementedError, InstanceHardnessThreshold, est,
64+
ratio=ratio, random_state=RND_SEED)
6665

6766
def test_iht_init():
6867
"""Test the initialisation of the object"""
@@ -174,3 +173,94 @@ def test_iht_fit_sample_half():
174173
y_gt = np.load(os.path.join(currdir, 'data', 'iht_y_05.npy'))
175174
assert_array_equal(X_resampled, X_gt)
176175
assert_array_equal(y_resampled, y_gt)
176+
177+
178+
def test_iht_fit_sample_knn():
179+
"""Test the fit sample routine with knn"""
180+
181+
# Resample the data
182+
est = 'knn'
183+
iht = InstanceHardnessThreshold(est, random_state=RND_SEED)
184+
X_resampled, y_resampled = iht.fit_sample(X, Y)
185+
186+
currdir = os.path.dirname(os.path.abspath(__file__))
187+
X_gt = np.load(os.path.join(currdir, 'data', 'iht_x_knn.npy'))
188+
y_gt = np.load(os.path.join(currdir, 'data', 'iht_y_knn.npy'))
189+
assert_array_equal(X_resampled, X_gt)
190+
assert_array_equal(y_resampled, y_gt)
191+
192+
193+
def test_iht_fit_sample_decision_tree():
194+
"""Test the fit sample routine with decision-tree"""
195+
196+
# Resample the data
197+
est = 'decision-tree'
198+
iht = InstanceHardnessThreshold(est, random_state=RND_SEED)
199+
X_resampled, y_resampled = iht.fit_sample(X, Y)
200+
201+
currdir = os.path.dirname(os.path.abspath(__file__))
202+
X_gt = np.load(os.path.join(currdir, 'data', 'iht_x_dt.npy'))
203+
y_gt = np.load(os.path.join(currdir, 'data', 'iht_y_dt.npy'))
204+
assert_array_equal(X_resampled, X_gt)
205+
assert_array_equal(y_resampled, y_gt)
206+
207+
208+
def test_iht_fit_sample_random_forest():
209+
"""Test the fit sample routine with random forest"""
210+
211+
# Resample the data
212+
est = 'random-forest'
213+
iht = InstanceHardnessThreshold(est, random_state=RND_SEED)
214+
X_resampled, y_resampled = iht.fit_sample(X, Y)
215+
216+
currdir = os.path.dirname(os.path.abspath(__file__))
217+
X_gt = np.load(os.path.join(currdir, 'data', 'iht_x_rf.npy'))
218+
y_gt = np.load(os.path.join(currdir, 'data', 'iht_y_rf.npy'))
219+
assert_array_equal(X_resampled, X_gt)
220+
assert_array_equal(y_resampled, y_gt)
221+
222+
223+
def test_iht_fit_sample_adaboost():
224+
"""Test the fit sample routine with adaboost"""
225+
226+
# Resample the data
227+
est = 'adaboost'
228+
iht = InstanceHardnessThreshold(est, random_state=RND_SEED)
229+
X_resampled, y_resampled = iht.fit_sample(X, Y)
230+
231+
currdir = os.path.dirname(os.path.abspath(__file__))
232+
X_gt = np.load(os.path.join(currdir, 'data', 'iht_x_adb.npy'))
233+
y_gt = np.load(os.path.join(currdir, 'data', 'iht_y_adb.npy'))
234+
assert_array_equal(X_resampled, X_gt)
235+
assert_array_equal(y_resampled, y_gt)
236+
237+
238+
239+
def test_iht_fit_sample_gradient_boosting():
240+
"""Test the fit sample routine with gradient boosting"""
241+
242+
# Resample the data
243+
est = 'gradient-boosting'
244+
iht = InstanceHardnessThreshold(est, random_state=RND_SEED)
245+
X_resampled, y_resampled = iht.fit_sample(X, Y)
246+
247+
currdir = os.path.dirname(os.path.abspath(__file__))
248+
X_gt = np.load(os.path.join(currdir, 'data', 'iht_x_gb.npy'))
249+
y_gt = np.load(os.path.join(currdir, 'data', 'iht_y_gb.npy'))
250+
assert_array_equal(X_resampled, X_gt)
251+
assert_array_equal(y_resampled, y_gt)
252+
253+
254+
def test_iht_fit_sample_linear_svm():
255+
"""Test the fit sample routine with linear SVM"""
256+
257+
# Resample the data
258+
est = 'linear-svm'
259+
iht = InstanceHardnessThreshold(est, random_state=RND_SEED)
260+
X_resampled, y_resampled = iht.fit_sample(X, Y)
261+
262+
currdir = os.path.dirname(os.path.abspath(__file__))
263+
X_gt = np.load(os.path.join(currdir, 'data', 'iht_x_svm.npy'))
264+
y_gt = np.load(os.path.join(currdir, 'data', 'iht_y_svm.npy'))
265+
assert_array_equal(X_resampled, X_gt)
266+
assert_array_equal(y_resampled, y_gt)

0 commit comments

Comments
 (0)