@@ -376,12 +376,16 @@ def __init__(self, return_indices=False, random_state=None, verbose=True,
376
376
self .kind_sel = kind_sel
377
377
self .n_jobs = n_jobs
378
378
379
- self .max_iter = max_iter
379
+ if max_iter < 2 :
380
+ raise ValueError ('max_iter must be greater than 1.' )
381
+ else :
382
+ self .max_iter = max_iter
383
+
380
384
self .enn_ = EditedNearestNeighbours (
381
- return_indices = return_indices ,
382
- random_state = random_state , verbose = False ,
383
- size_ngh = size_ngh , kind_sel = kind_sel ,
384
- n_jobs = n_jobs )
385
+ return_indices = return_indices ,
386
+ random_state = random_state , verbose = False ,
387
+ size_ngh = size_ngh , kind_sel = kind_sel ,
388
+ n_jobs = n_jobs )
385
389
386
390
def fit (self , X , y ):
387
391
"""Find the classes statistics before to perform sampling.
@@ -405,7 +409,7 @@ def fit(self, X, y):
405
409
406
410
super (RepeatedEditedNearestNeighbours , self ).fit (X , y )
407
411
self .enn_ .fit (X , y )
408
-
412
+
409
413
return self
410
414
411
415
def transform (self , X , y ):
@@ -434,11 +438,10 @@ def transform(self, X, y):
434
438
"""
435
439
# Check the consistency of X and y
436
440
X , y = check_X_y (X , y )
437
-
438
- X_ , y_ = X , y
441
+ X_ , y_ = X .copy (), y .copy ()
439
442
440
443
if self .return_indices :
441
- idx_under = np .arange (len ( X .shape [0 ]) , dtype = int )
444
+ idx_under = np .arange (X .shape [0 ], dtype = int )
442
445
443
446
prev_len = y .shape [0 ]
444
447
@@ -456,7 +459,6 @@ def transform(self, X, y):
456
459
if self .verbose :
457
460
print ("Under-sampling performed: {}" .format (Counter (y_ )))
458
461
459
- #X_resampled, y_resampled = X_.copy(), y_.copy()
460
462
X_resampled , y_resampled = X_ , y_
461
463
462
464
# Check if the indices of the samples selected should be returned too
@@ -465,4 +467,3 @@ def transform(self, X, y):
465
467
return X_resampled , y_resampled , idx_under
466
468
else :
467
469
return X_resampled , y_resampled
468
-
0 commit comments