Skip to content

Commit c70ae30

Browse files
committed
RepeatedEditedNearestNeighbors pep8
1 parent 6f3c6fa commit c70ae30

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

unbalanced_dataset/under_sampling/edited_nearest_neighbours.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -376,12 +376,16 @@ def __init__(self, return_indices=False, random_state=None, verbose=True,
376376
self.kind_sel = kind_sel
377377
self.n_jobs = n_jobs
378378

379-
self.max_iter = max_iter
379+
if max_iter < 2:
380+
raise ValueError('max_iter must be greater than 1.')
381+
else:
382+
self.max_iter = max_iter
383+
380384
self.enn_ = EditedNearestNeighbours(
381-
return_indices=return_indices,
382-
random_state=random_state, verbose=False,
383-
size_ngh=size_ngh, kind_sel=kind_sel,
384-
n_jobs=n_jobs)
385+
return_indices=return_indices,
386+
random_state=random_state, verbose=False,
387+
size_ngh=size_ngh, kind_sel=kind_sel,
388+
n_jobs=n_jobs)
385389

386390
def fit(self, X, y):
387391
"""Find the classes statistics before to perform sampling.
@@ -405,7 +409,7 @@ def fit(self, X, y):
405409

406410
super(RepeatedEditedNearestNeighbours, self).fit(X, y)
407411
self.enn_.fit(X, y)
408-
412+
409413
return self
410414

411415
def transform(self, X, y):
@@ -434,11 +438,10 @@ def transform(self, X, y):
434438
"""
435439
# Check the consistency of X and y
436440
X, y = check_X_y(X, y)
437-
438-
X_, y_ = X, y
441+
X_, y_ = X.copy(), y.copy()
439442

440443
if self.return_indices:
441-
idx_under = np.arange(len(X.shape[0]), dtype=int)
444+
idx_under = np.arange(X.shape[0], dtype=int)
442445

443446
prev_len = y.shape[0]
444447

@@ -456,7 +459,6 @@ def transform(self, X, y):
456459
if self.verbose:
457460
print("Under-sampling performed: {}".format(Counter(y_)))
458461

459-
#X_resampled, y_resampled = X_.copy(), y_.copy()
460462
X_resampled, y_resampled = X_, y_
461463

462464
# Check if the indices of the samples selected should be returned too
@@ -465,4 +467,3 @@ def transform(self, X, y):
465467
return X_resampled, y_resampled, idx_under
466468
else:
467469
return X_resampled, y_resampled
468-

0 commit comments

Comments
 (0)