Skip to content

Commit 8a0f010

Browse files
author
Guillaume Lemaitre
committed
Change RENN for scikit-learn compatibility
1 parent 28e1116 commit 8a0f010

File tree

2 files changed

+36
-21
lines changed

2 files changed

+36
-21
lines changed

unbalanced_dataset/under_sampling/edited_nearest_neighbours.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -376,10 +376,10 @@ def __init__(self, return_indices=False, random_state=None, verbose=True,
376376
self.max_iter = max_iter
377377

378378
self.enn_ = EditedNearestNeighbours(
379-
return_indices=return_indices,
380-
random_state=random_state, verbose=False,
381-
size_ngh=size_ngh, kind_sel=kind_sel,
382-
n_jobs=n_jobs)
379+
return_indices=self.return_indices,
380+
random_state=self.random_state, verbose=False,
381+
size_ngh=self.size_ngh, kind_sel=self.kind_sel,
382+
n_jobs=self.n_jobs)
383383

384384
def fit(self, X, y):
385385
"""Find the classes statistics before to perform sampling.
@@ -406,7 +406,7 @@ def fit(self, X, y):
406406

407407
return self
408408

409-
def transform(self, X, y):
409+
def sample(self, X, y):
410410
"""Resample the dataset.
411411
412412
Parameters
@@ -442,10 +442,10 @@ def transform(self, X, y):
442442
for n_iter in range(self.max_iter):
443443
prev_len = y_.shape[0]
444444
if self.return_indices:
445-
X_, y_, idx_ = self.enn_.transform(X_, y_)
445+
X_, y_, idx_ = self.enn_.sample(X_, y_)
446446
idx_under = idx_under[idx_]
447447
else:
448-
X_, y_ = self.enn_.transform(X_, y_)
448+
X_, y_ = self.enn_.sample(X_, y_)
449449

450450
if prev_len == y_.shape[0]:
451451
break

unbalanced_dataset/under_sampling/tests/test_repeated_edited_nearest_neighbours.py

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
n_samples=5000, random_state=RND_SEED)
2323

2424

25+
def test_enn_sk_estimator():
26+
"""Test the sklearn estimator compatibility"""
27+
check_estimator(RepeatedEditedNearestNeighbours)
28+
29+
2530
def test_renn_init():
2631
"""Test the initialisation of the object"""
2732

@@ -33,13 +38,23 @@ def test_renn_init():
3338
assert_equal(renn.size_ngh, 3)
3439
assert_equal(renn.kind_sel, 'all')
3540
assert_equal(renn.n_jobs, -1)
36-
assert_equal(renn.rs_, RND_SEED)
41+
assert_equal(renn.random_state, RND_SEED)
3742
assert_equal(renn.verbose, verbose)
3843
assert_equal(renn.min_c_, None)
3944
assert_equal(renn.maj_c_, None)
4045
assert_equal(renn.stats_c_, {})
4146

4247

48+
def test_renn_iter_wrong():
49+
"""Test either if an error is raised when the numbr of iteration
50+
is wrong"""
51+
52+
# Create the object
53+
max_iter = -1
54+
assert_raises(ValueError, RepeatedEditedNearestNeighbours,
55+
max_iter=max_iter, random_state=RND_SEED)
56+
57+
4358
def test_renn_fit_single_class():
4459
"""Test either if an error when there is a single class"""
4560

@@ -48,7 +63,7 @@ def test_renn_fit_single_class():
4863
# Resample the data
4964
# Create a wrong y
5065
y_single_class = np.zeros((X.shape[0], ))
51-
assert_raises(RuntimeError, renn.fit, X, y_single_class)
66+
assert_warns(RuntimeWarning, renn.fit, X, y_single_class)
5267

5368

5469
def test_renn_fit():
@@ -66,21 +81,21 @@ def test_renn_fit():
6681
assert_equal(renn.stats_c_[1], 4500)
6782

6883

69-
def test_renn_transform_wt_fit():
70-
"""Test either if an error is raised when transform is called before
84+
def test_renn_sample_wt_fit():
85+
"""Test either if an error is raised when sample is called before
7186
fitting"""
7287

7388
# Create the object
7489
renn = RepeatedEditedNearestNeighbours(random_state=RND_SEED)
75-
assert_raises(RuntimeError, renn.transform, X, Y)
90+
assert_raises(RuntimeError, renn.sample, X, Y)
7691

7792

78-
def test_renn_fit_transform():
79-
"""Test the fit transform routine"""
93+
def test_renn_fit_sample():
94+
"""Test the fit sample routine"""
8095

8196
# Resample the data
8297
renn = RepeatedEditedNearestNeighbours(random_state=RND_SEED)
83-
X_resampled, y_resampled = renn.fit_transform(X, Y)
98+
X_resampled, y_resampled = renn.fit_sample(X, Y)
8499

85100
currdir = os.path.dirname(os.path.abspath(__file__))
86101
X_gt = np.load(os.path.join(currdir, 'data', 'renn_x.npy'))
@@ -89,13 +104,13 @@ def test_renn_fit_transform():
89104
assert_array_equal(y_resampled, y_gt)
90105

91106

92-
def test_renn_fit_transform_with_indices():
93-
"""Test the fit transform routine with indices support"""
107+
def test_renn_fit_sample_with_indices():
108+
"""Test the fit sample routine with indices support"""
94109

95110
# Resample the data
96111
renn = RepeatedEditedNearestNeighbours(return_indices=True,
97112
random_state=RND_SEED)
98-
X_resampled, y_resampled, idx_under = renn.fit_transform(X, Y)
113+
X_resampled, y_resampled, idx_under = renn.fit_sample(X, Y)
99114

100115
currdir = os.path.dirname(os.path.abspath(__file__))
101116
X_gt = np.load(os.path.join(currdir, 'data', 'renn_x.npy'))
@@ -106,13 +121,13 @@ def test_renn_fit_transform_with_indices():
106121
assert_array_equal(idx_under, idx_gt)
107122

108123

109-
def test_renn_fit_transform_mode():
110-
"""Test the fit transform routine using the mode as selection"""
124+
def test_renn_fit_sample_mode():
125+
"""Test the fit sample routine using the mode as selection"""
111126

112127
# Resample the data
113128
renn = RepeatedEditedNearestNeighbours(random_state=RND_SEED,
114129
kind_sel='mode')
115-
X_resampled, y_resampled = renn.fit_transform(X, Y)
130+
X_resampled, y_resampled = renn.fit_sample(X, Y)
116131

117132
currdir = os.path.dirname(os.path.abspath(__file__))
118133
X_gt = np.load(os.path.join(currdir, 'data', 'renn_x_mode.npy'))

0 commit comments

Comments
 (0)