@@ -364,14 +364,56 @@ def _sample(self, X, y):
364
364
365
365
prev_len = y_ .shape [0 ]
366
366
if self .return_indices :
367
- X_ , y_ , idx_ = self .enn_ .fit_sample (X_ , y_ )
368
- idx_under = idx_under [idx_ ]
367
+ X_enn , y_enn , idx_enn = self .enn_ .fit_sample (X_ , y_ )
369
368
else :
370
- X_ , y_ = self .enn_ .fit_sample (X_ , y_ )
371
-
372
- if prev_len == y_ .shape [0 ]:
369
+ X_enn , y_enn = self .enn_ .fit_sample (X_ , y_ )
370
+
371
+ # Check the stopping criterion
372
+ # 1. If there is no changes for the vector y
373
+ # 2. If the number of samples in the other class become inferior to
374
+ # the number of samples in the majority class
375
+ # 3. If one of the class is disappearing
376
+
377
+ # Case 1
378
+ b_conv = (prev_len == y_enn .shape [0 ])
379
+
380
+ # Case 2
381
+ stats_enn = Counter (y_enn )
382
+ self .logger .debug ('Current ENN stats: %s' , stats_enn )
383
+ # Get the number of samples in the non-minority classes
384
+ count_non_min = np .array ([val for val , key
385
+ in zip (stats_enn .itervalues (),
386
+ stats_enn .iterkeys ())
387
+ if key != self .min_c_ ])
388
+ self .logger .debug ('Number of samples in the non-majority'
389
+ ' classes: %s' , count_non_min )
390
+ # Check the minority stop to be the minority
391
+ b_min_bec_maj = np .any (count_non_min < self .stats_c_ [self .min_c_ ])
392
+
393
+ # Case 3
394
+ b_remove_maj_class = (len (stats_enn ) < len (self .stats_c_ ))
395
+
396
+ if b_conv or b_min_bec_maj or b_remove_maj_class :
397
+ # If this is a normal convergence, get the last data
398
+ if b_conv :
399
+ if self .return_indices :
400
+ X_ , y_ , = X_enn , y_enn
401
+ idx_under = idx_under [idx_enn ]
402
+ else :
403
+ X_ , y_ , = X_enn , y_enn
404
+ # Log the variables to explain the stop of the algorithm
405
+ self .logger .debug ('RENN converged: %s' , b_conv )
406
+ self .logger .debug ('RENN minority become majority: %s' ,
407
+ b_min_bec_maj )
408
+ self .logger .debug ('RENN remove one class: %s' ,
409
+ b_remove_maj_class )
373
410
break
374
411
412
+ # Update the data for the next iteration
413
+ X_ , y_ , = X_enn , y_enn
414
+ if self .return_indices :
415
+ idx_under = idx_under [idx_enn ]
416
+
375
417
self .logger .info ('Under-sampling performed: %s' , Counter (y_ ))
376
418
377
419
X_resampled , y_resampled = X_ , y_
0 commit comments