Skip to content

Commit e9b5a81

Browse files
authored
Merge pull request #73 from dvro/renn
RENN - Repeated Edited Nearest Neighbors undersampling method
2 parents e2115d6 + c70ae30 commit e9b5a81

File tree

3 files changed

+292
-0
lines changed

3 files changed

+292
-0
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
"""
2+
=========================
3+
Repeated Edited nearest-neighbours
4+
=========================
5+
6+
An illustration of the repeated edited nearest-neighbours method.
7+
8+
"""
9+
10+
print(__doc__)
11+
12+
import matplotlib.pyplot as plt
13+
import seaborn as sns
14+
sns.set()
15+
16+
# Define some color for the plotting
17+
almost_black = '#262626'
18+
palette = sns.color_palette()
19+
20+
from sklearn.datasets import make_classification
21+
from sklearn.decomposition import PCA
22+
23+
from unbalanced_dataset.under_sampling import EditedNearestNeighbours, \
24+
RepeatedEditedNearestNeighbours
25+
26+
# Generate the dataset
27+
X, y = make_classification(n_classes=2, class_sep=1.25, weights=[0.3, 0.7],
28+
n_informative=3, n_redundant=1, flip_y=0,
29+
n_features=5, n_clusters_per_class=1,
30+
n_samples=5000, random_state=10)
31+
32+
# Instanciate a PCA object for the sake of easy visualisation
33+
pca = PCA(n_components=2)
34+
# Fit and transform x to visualise inside a 2D feature space
35+
X_vis = pca.fit_transform(X)
36+
37+
# Three subplots, unpack the axes array immediately
38+
f, (ax1, ax2, ax3) = plt.subplots(1, 3)
39+
40+
ax1.scatter(X_vis[y == 0, 0], X_vis[y == 0, 1], label="Class #0", alpha=.5,
41+
edgecolor=almost_black, facecolor=palette[0], linewidth=0.15)
42+
ax1.scatter(X_vis[y == 1, 0], X_vis[y == 1, 1], label="Class #1", alpha=.5,
43+
edgecolor=almost_black, facecolor=palette[2], linewidth=0.15)
44+
ax1.set_title('Original set')
45+
46+
# Apply the ENN
47+
print('ENN')
48+
enn = EditedNearestNeighbours()
49+
X_resampled, y_resampled = enn.fit_transform(X, y)
50+
X_res_vis = pca.transform(X_resampled)
51+
52+
ax2.scatter(X_res_vis[y_resampled == 0, 0], X_res_vis[y_resampled == 0, 1],
53+
label="Class #0", alpha=.5, edgecolor=almost_black,
54+
facecolor=palette[0], linewidth=0.15)
55+
ax2.scatter(X_res_vis[y_resampled == 1, 0], X_res_vis[y_resampled == 1, 1],
56+
label="Class #1", alpha=.5, edgecolor=almost_black,
57+
facecolor=palette[2], linewidth=0.15)
58+
ax2.set_title('Edited nearest neighbours')
59+
60+
# Apply the RENN
61+
print('RENN')
62+
renn = RepeatedEditedNearestNeighbours()
63+
X_resampled, y_resampled = renn.fit_transform(X, y)
64+
X_res_vis = pca.transform(X_resampled)
65+
66+
ax3.scatter(X_res_vis[y_resampled == 0, 0], X_res_vis[y_resampled == 0, 1],
67+
label="Class #0", alpha=.5, edgecolor=almost_black,
68+
facecolor=palette[0], linewidth=0.15)
69+
ax3.scatter(X_res_vis[y_resampled == 1, 0], X_res_vis[y_resampled == 1, 1],
70+
label="Class #1", alpha=.5, edgecolor=almost_black,
71+
facecolor=palette[2], linewidth=0.15)
72+
ax3.set_title('Repeated Edited nearest neighbours')
73+
74+
plt.show()

unbalanced_dataset/under_sampling/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .one_sided_selection import OneSidedSelection
1313
from .neighbourhood_cleaning_rule import NeighbourhoodCleaningRule
1414
from .edited_nearest_neighbours import EditedNearestNeighbours
15+
from .edited_nearest_neighbours import RepeatedEditedNearestNeighbours
1516
from .instance_hardness_threshold import InstanceHardnessThreshold
1617

1718
__all__ = ['UnderSampler',
@@ -23,4 +24,5 @@
2324
'OneSidedSelection',
2425
'NeighbourhoodCleaningRule',
2526
'EditedNearestNeighbours',
27+
'RepeatedEditedNearestNeighbours',
2628
'InstanceHardnessThreshold']

unbalanced_dataset/under_sampling/edited_nearest_neighbours.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,219 @@ def transform(self, X, y):
251251
return X_resampled, y_resampled, idx_under
252252
else:
253253
return X_resampled, y_resampled
254+
255+
256+
class RepeatedEditedNearestNeighbours(UnderSampler):
257+
"""Class to perform under-sampling based on the repeated edited nearest
258+
neighbour method.
259+
260+
Parameters
261+
----------
262+
return_indices : bool, optional (default=False)
263+
Either to return or not the indices which will be selected from
264+
the majority class.
265+
266+
random_state : int or None, optional (default=None)
267+
Seed for random number generation.
268+
269+
verbose : bool, optional (default=True)
270+
Boolean to either or not print information about the processing
271+
272+
size_ngh : int, optional (default=3)
273+
Size of the neighbourhood to consider to compute the average
274+
distance to the minority point samples.
275+
276+
kind_sel : str, optional (default='all')
277+
Strategy to use in order to exclude samples.
278+
279+
- If 'all', all neighbours will have to agree with the samples of
280+
interest to not be excluded.
281+
- If 'mode', the majority vote of the neighbours will be used in
282+
order to exclude a sample.
283+
284+
n_jobs : int, optional (default=-1)
285+
The number of thread to open when it is possible.
286+
287+
Attributes
288+
----------
289+
ratio_ : str or float, optional (default='auto')
290+
If 'auto', the ratio will be defined automatically to balanced
291+
the dataset. Otherwise, the ratio will corresponds to the number
292+
of samples in the minority class over the the number of samples
293+
in the majority class.
294+
295+
rs_ : int or None, optional (default=None)
296+
Seed for random number generation.
297+
298+
min_c_ : str or int
299+
The identifier of the minority class.
300+
301+
max_c_ : str or int
302+
The identifier of the majority class.
303+
304+
stats_c_ : dict of str/int : int
305+
A dictionary in which the number of occurences of each class is
306+
reported.
307+
308+
max_iter : int, optional (default=100)
309+
Maximum number of iterations of the edited nearest neighbours
310+
algorithm for a single run.
311+
312+
Notes
313+
-----
314+
The method is based on [1]_.
315+
316+
This class supports multi-class.
317+
318+
References
319+
----------
320+
.. [1] I. Tomek, “An Experiment with the Edited Nearest-Neighbor
321+
Rule,” IEEE Trans. Systems, Man, and Cybernetics, vol. 6, no. 6,
322+
pp. 448-452, June 1976.
323+
324+
"""
325+
326+
def __init__(self, return_indices=False, random_state=None, verbose=True,
327+
size_ngh=3, max_iter=100, kind_sel='all', n_jobs=-1):
328+
"""Initialisation of RENN object.
329+
330+
Parameters
331+
----------
332+
return_indices : bool, optional (default=False)
333+
Either to return or not the indices which will be selected from
334+
the majority class.
335+
336+
random_state : int or None, optional (default=None)
337+
Seed for random number generation.
338+
339+
verbose : bool, optional (default=True)
340+
Boolean to either or not print information about the processing
341+
342+
size_ngh : int, optional (default=3)
343+
Size of the neighbourhood to consider to compute the average
344+
distance to the minority point samples.
345+
346+
max_iter : int, optional (default=100)
347+
Maximum number of iterations of the edited nearest neighbours
348+
algorithm for a single run.
349+
350+
kind_sel : str, optional (default='all')
351+
Strategy to use in order to exclude samples.
352+
353+
- If 'all', all neighbours will have to agree with the samples of
354+
interest to not be excluded.
355+
- If 'mode', the majority vote of the neighbours will be used in
356+
order to exclude a sample.
357+
358+
n_jobs : int, optional (default=-1)
359+
The number of thread to open when it is possible.
360+
361+
Returns
362+
-------
363+
None
364+
365+
"""
366+
super(RepeatedEditedNearestNeighbours, self).__init__(
367+
return_indices=return_indices,
368+
random_state=random_state,
369+
verbose=verbose)
370+
371+
self.size_ngh = size_ngh
372+
possible_kind_sel = ('all', 'mode')
373+
if kind_sel not in possible_kind_sel:
374+
raise NotImplementedError
375+
else:
376+
self.kind_sel = kind_sel
377+
self.n_jobs = n_jobs
378+
379+
if max_iter < 2:
380+
raise ValueError('max_iter must be greater than 1.')
381+
else:
382+
self.max_iter = max_iter
383+
384+
self.enn_ = EditedNearestNeighbours(
385+
return_indices=return_indices,
386+
random_state=random_state, verbose=False,
387+
size_ngh=size_ngh, kind_sel=kind_sel,
388+
n_jobs=n_jobs)
389+
390+
def fit(self, X, y):
391+
"""Find the classes statistics before to perform sampling.
392+
393+
Parameters
394+
----------
395+
X : ndarray, shape (n_samples, n_features)
396+
Matrix containing the data which have to be sampled.
397+
398+
y : ndarray, shape (n_samples, )
399+
Corresponding label for each sample in X.
400+
401+
Returns
402+
-------
403+
self : object,
404+
Return self.
405+
406+
"""
407+
# Check the consistency of X and y
408+
X, y = check_X_y(X, y)
409+
410+
super(RepeatedEditedNearestNeighbours, self).fit(X, y)
411+
self.enn_.fit(X, y)
412+
413+
return self
414+
415+
def transform(self, X, y):
416+
"""Resample the dataset.
417+
418+
Parameters
419+
----------
420+
X : ndarray, shape (n_samples, n_features)
421+
Matrix containing the data which have to be sampled.
422+
423+
y : ndarray, shape (n_samples, )
424+
Corresponding label for each sample in X.
425+
426+
Returns
427+
-------
428+
X_resampled : ndarray, shape (n_samples_new, n_features)
429+
The array containing the resampled data.
430+
431+
y_resampled : ndarray, shape (n_samples_new)
432+
The corresponding label of `X_resampled`
433+
434+
idx_under : ndarray, shape (n_samples, )
435+
If `return_indices` is `True`, a boolean array will be returned
436+
containing the which samples have been selected.
437+
438+
"""
439+
# Check the consistency of X and y
440+
X, y = check_X_y(X, y)
441+
X_, y_ = X.copy(), y.copy()
442+
443+
if self.return_indices:
444+
idx_under = np.arange(X.shape[0], dtype=int)
445+
446+
prev_len = y.shape[0]
447+
448+
for n_iter in range(self.max_iter):
449+
prev_len = y_.shape[0]
450+
if self.return_indices:
451+
X_, y_, idx_ = self.enn_.transform(X_, y_)
452+
idx_under = idx_under[idx_]
453+
else:
454+
X_, y_ = self.enn_.transform(X_, y_)
455+
456+
if prev_len == y_.shape[0]:
457+
break
458+
459+
if self.verbose:
460+
print("Under-sampling performed: {}".format(Counter(y_)))
461+
462+
X_resampled, y_resampled = X_, y_
463+
464+
# Check if the indices of the samples selected should be returned too
465+
if self.return_indices:
466+
# Return the indices of interest
467+
return X_resampled, y_resampled, idx_under
468+
else:
469+
return X_resampled, y_resampled

0 commit comments

Comments
 (0)