Skip to content

Commit 4889c2e

Browse files
glemaitrechkoar
authored andcommitted
ENH: Pass a nearest neighbor estimator in *_neighbors parameter PR#182
1 parent 70f2e75 commit 4889c2e

15 files changed

+890
-169
lines changed

doc/whats_new.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ API changes summary
5656
- Provide estimators instead of parameters in :class:`combine.SMOTEENN` and :class:`combine.SMOTETomek`. Therefore, the list of parameters have been deprecated. By `Guillaume Lemaitre`_ and `Christos Aridas`_.
5757
- `k` has been deprecated in :class:`over_sampling.ADASYN`. Use `n_neighbors` instead. By `Guillaume Lemaitre`_.
5858
- `k` and `m` have been deprecated in :class:`over_sampling.SMOTE`. Use `k_neighbors` and `m_neighbors` instead. By `Guillaume Lemaitre`_.
59-
59+
- `n_neighbors` accept `KNeighborsMixin` based object for :class:`under_sampling.EditedNearestNeighbors`, :class:`under_sampling.CondensedNeareastNeigbour`, :class:`under_sampling.NeighbourhoodCleaningRule`, :class:`under_sampling.RepeatedEditedNearestNeighbours`, and :class:`under_sampling.AllKNN`. By `Guillaume Lemaitre`_.
6060

6161
Documentation changes
6262
~~~~~~~~~~~~~~~~~~~~~

imblearn/over_sampling/adasyn.py

Lines changed: 48 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
from sklearn.neighbors import NearestNeighbors
8+
from sklearn.neighbors.base import KNeighborsMixin
89
from sklearn.utils import check_random_state
910

1011
from ..base import BaseBinarySampler
@@ -37,8 +38,12 @@ class ADASYN(BaseBinarySampler):
3738
NOTE: `k` is deprecated from 0.2 and will be replaced in 0.4
3839
Use ``n_neighbors`` instead.
3940
40-
n_neighbours : int, optional (default=5)
41-
Number of nearest neighbours to used to construct synthetic samples.
41+
n_neighbours : int int or object, optional (default=5)
42+
If int, number of nearest neighbours to used to construct
43+
synthetic samples.
44+
If object, an estimator that inherits from
45+
`sklearn.neighbors.base.KNeighborsMixin` that will be used to find
46+
the k_neighbors.
4247
4348
n_jobs : int, optional (default=1)
4449
Number of threads to run the algorithm when it is possible.
@@ -96,9 +101,42 @@ def __init__(self, ratio='auto', random_state=None, k=None, n_neighbors=5,
96101
self.k = k
97102
self.n_neighbors = n_neighbors
98103
self.n_jobs = n_jobs
99-
self.nearest_neighbour = NearestNeighbors(
100-
n_neighbors=self.n_neighbors + 1,
101-
n_jobs=self.n_jobs)
104+
105+
def _validate_estimator(self):
106+
"""Private function to create the NN estimator"""
107+
108+
if isinstance(self.n_neighbors, int):
109+
self.nn_ = NearestNeighbors(n_neighbors=self.n_neighbors + 1,
110+
n_jobs=self.n_jobs)
111+
elif isinstance(self.n_neighbors, KNeighborsMixin):
112+
self.nn_ = self.n_neighbors
113+
else:
114+
raise ValueError('`n_neighbors` has to be be either int or a'
115+
' subclass of KNeighborsMixin.')
116+
117+
def fit(self, X, y):
118+
"""Find the classes statistics before to perform sampling.
119+
120+
Parameters
121+
----------
122+
X : ndarray, shape (n_samples, n_features)
123+
Matrix containing the data which have to be sampled.
124+
125+
y : ndarray, shape (n_samples, )
126+
Corresponding label for each sample in X.
127+
128+
Returns
129+
-------
130+
self : object,
131+
Return self.
132+
133+
"""
134+
135+
super(ADASYN, self).fit(X, y)
136+
137+
self._validate_estimator()
138+
139+
return self
102140

103141
def _sample(self, X, y):
104142
"""Resample the dataset.
@@ -140,18 +178,18 @@ def _sample(self, X, y):
140178

141179
# Print if verbose is true
142180
self.logger.debug('Finding the %s nearest neighbours ...',
143-
self.n_neighbors)
181+
self.nn_.n_neighbors - 1)
144182

145183
# Look for k-th nearest neighbours, excluding, of course, the
146184
# point itself.
147-
self.nearest_neighbour.fit(X)
185+
self.nn_.fit(X)
148186

149187
# Get the distance to the NN
150-
_, ind_nn = self.nearest_neighbour.kneighbors(X_min)
188+
_, ind_nn = self.nn_.kneighbors(X_min)
151189

152190
# Compute the ratio of majority samples next to minority samples
153191
ratio_nn = (np.sum(y[ind_nn[:, 1:]] == self.maj_c_, axis=1) /
154-
self.n_neighbors)
192+
(self.nn_.n_neighbors - 1))
155193
# Check that we found at least some neighbours belonging to the
156194
# majority class
157195
if not np.sum(ratio_nn):
@@ -169,7 +207,7 @@ def _sample(self, X, y):
169207
for x_i, x_i_nn, num_sample_i in zip(X_min, ind_nn, num_samples_nn):
170208

171209
# Pick-up the neighbors wanted
172-
nn_zs = random_state.randint(1, high=self.n_neighbors + 1,
210+
nn_zs = random_state.randint(1, high=self.nn_.n_neighbors,
173211
size=num_sample_i)
174212

175213
# Create a new sample

0 commit comments

Comments
 (0)