Skip to content

Commit 2136348

Browse files
authored
Merge pull request #3 from chkoar/deprecation_warning
Deprecation warning
2 parents 2b653d0 + 4225dff commit 2136348

10 files changed

+352
-72
lines changed

imblearn/over_sampling/smote.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@ class SMOTE(BaseBinarySampler):
6868
The type of SMOTE algorithm to use one of the following options:
6969
'regular', 'borderline1', 'borderline2', 'svm'.
7070
71+
svm_estimator : object, optional (default=SVC())
72+
If `kind='svm'`, a parametrized `sklearn.svm.SVC` classifier can
73+
be passed.
74+
7175
n_jobs : int, optional (default=1)
7276
The number of threads to open if possible.
7377
@@ -128,16 +132,16 @@ class SMOTE(BaseBinarySampler):
128132

129133
def __init__(self, ratio='auto', random_state=None, k=None, k_neighbors=5,
130134
m=None, m_neighbors=10, out_step=0.5, kind='regular',
131-
n_jobs=1, **kwargs):
135+
svm_estimator=None, n_jobs=1):
132136
super(SMOTE, self).__init__(ratio=ratio, random_state=random_state)
133137
self.kind = kind
134138
self.k = k
135139
self.k_neighbors = k_neighbors
136140
self.m = m
137141
self.m_neighbors = m_neighbors
138142
self.out_step = out_step
143+
self.svm_estimator = svm_estimator
139144
self.n_jobs = n_jobs
140-
self.kwargs = kwargs
141145

142146
def _in_danger_noise(self, samples, y, kind='danger'):
143147
"""Estimate if a set of sample are in danger or noise.
@@ -316,8 +320,13 @@ def _validate_estimator(self):
316320
# in danger (near the boundary). The level of extrapolation is
317321
# controled by the out_step.
318322
if self.kind == 'svm':
319-
# Store SVM object with any parameters
320-
self.svm = SVC(random_state=self.random_state, **self.kwargs)
323+
if self.svm_estimator is None:
324+
# Store SVM object with any parameters
325+
self.svm_estimator_ = SVC(random_state=self.random_state)
326+
elif isinstance(self.svm_estimator, SVC):
327+
self.svm_estimator_ = self.svm_estimator
328+
else:
329+
raise ValueError('`svm_estimator` has to be an SVC object')
321330

322331
def fit(self, X, y):
323332
"""Find the classes statistics before to perform sampling.
@@ -503,11 +512,11 @@ def _sample(self, X, y):
503512
# belonging to each class.
504513

505514
# Fit SVM to the full data#
506-
self.svm.fit(X, y)
515+
self.svm_estimator_.fit(X, y)
507516

508517
# Find the support vectors and their corresponding indexes
509-
support_index = self.svm.support_[y[self.svm.support_] ==
510-
self.min_c_]
518+
support_index = self.svm_estimator_.support_[
519+
y[self.svm_estimator_.support_] == self.min_c_]
511520
support_vector = X[support_index]
512521

513522
# First, find the nn of all the samples to identify samples

imblearn/over_sampling/tests/test_smote.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from sklearn.datasets import make_classification
1010
from sklearn.utils.estimator_checks import check_estimator
1111
from sklearn.neighbors import NearestNeighbors
12+
from sklearn.svm import SVC
1213

1314
from imblearn.over_sampling import SMOTE
1415

@@ -446,3 +447,57 @@ def test_wrong_nn():
446447
k_neighbors=nn_k)
447448

448449
assert_raises(ValueError, smote.fit_sample, X, Y)
450+
451+
452+
def test_sample_regular_with_nn_svm():
453+
"""Test sample function with regular SMOTE with a NN object."""
454+
455+
# Create the object
456+
kind = 'svm'
457+
nn_k = NearestNeighbors(n_neighbors=6)
458+
svm = SVC(random_state=RND_SEED)
459+
smote = SMOTE(random_state=RND_SEED, kind=kind,
460+
k_neighbors=nn_k, svm_estimator=svm)
461+
462+
X_resampled, y_resampled = smote.fit_sample(X, Y)
463+
464+
X_gt = np.array([[0.11622591, -0.0317206],
465+
[0.77481731, 0.60935141],
466+
[1.25192108, -0.22367336],
467+
[0.53366841, -0.30312976],
468+
[1.52091956, -0.49283504],
469+
[-0.28162401, -2.10400981],
470+
[0.83680821, 1.72827342],
471+
[0.3084254, 0.33299982],
472+
[0.70472253, -0.73309052],
473+
[0.28893132, -0.38761769],
474+
[1.15514042, 0.0129463],
475+
[0.88407872, 0.35454207],
476+
[1.31301027, -0.92648734],
477+
[-1.11515198, -0.93689695],
478+
[-0.18410027, -0.45194484],
479+
[0.9281014, 0.53085498],
480+
[-0.14374509, 0.27370049],
481+
[-0.41635887, -0.38299653],
482+
[0.08711622, 0.93259929],
483+
[1.70580611, -0.11219234],
484+
[0.47436888, -0.2645749],
485+
[1.07844561, -0.19435291],
486+
[1.44015515, -1.30621303]])
487+
y_gt = np.array([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1,
488+
0, 0, 0, 0])
489+
assert_array_almost_equal(X_resampled, X_gt)
490+
assert_array_equal(y_resampled, y_gt)
491+
492+
493+
def test_sample_regular_wrong_svm():
494+
"""Test sample function with regular SMOTE with a NN object."""
495+
496+
# Create the object
497+
kind = 'svm'
498+
nn_k = NearestNeighbors(n_neighbors=6)
499+
svm = 'rnd'
500+
smote = SMOTE(random_state=RND_SEED, kind=kind,
501+
k_neighbors=nn_k, svm_estimator=svm)
502+
503+
assert_raises(ValueError, smote.fit_sample, X, Y)

imblearn/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class Pipeline(pipeline.Pipeline):
101101
>>> X_train, X_test, y_train, y_test = tts(X, y, random_state=42)
102102
>>> pipeline.fit(X_train, y_train)
103103
Pipeline(steps=[('smt', SMOTE(k=None, k_neighbors=5, kind='regular', m=None, m_neighbors=10, n_jobs=1,
104-
out_step=0.5, random_state=42, ratio='auto')), ('pca', PCA(copy=True, n_components=None, whiten=False)), ('knn', KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
104+
out_step=0.5, random_state=42, ratio='auto', svm_estimator=None)), ('pca', PCA(copy=True, n_components=None, whiten=False)), ('knn', KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
105105
metric_params=None, n_jobs=1, n_neighbors=5, p=2,
106106
weights='uniform'))])
107107
>>> y_hat = pipeline.predict(X_test)

imblearn/under_sampling/cluster_centroids.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ class ClusterCentroids(BaseMulticlassSampler):
3535
If None, the random number generator is the RandomState instance used
3636
by np.random.
3737
38+
estimator : object, optional(default=KMeans())
39+
Pass a `sklearn.cluster.KMeans` estimator.
40+
3841
n_jobs : int, optional (default=1)
3942
The number of threads to open if possible.
4043
41-
**kwargs : keywords
42-
Parameter to use for the KMeans object.
43-
4444
Attributes
4545
----------
4646
min_c_ : str or int
@@ -79,11 +79,47 @@ class ClusterCentroids(BaseMulticlassSampler):
7979
8080
"""
8181

82-
def __init__(self, ratio='auto', random_state=None, n_jobs=1, **kwargs):
82+
def __init__(self, ratio='auto', random_state=None, estimator=None,
83+
n_jobs=1):
8384
super(ClusterCentroids, self).__init__(ratio=ratio,
8485
random_state=random_state)
86+
self.estimator = estimator
8587
self.n_jobs = n_jobs
86-
self.kwargs = kwargs
88+
89+
def _validate_estimator(self):
90+
"""Private function to create the NN estimator"""
91+
92+
if self.estimator is None:
93+
self.estimator_ = KMeans(random_state=self.random_state,
94+
n_jobs=self.n_jobs)
95+
elif isinstance(self.estimator, KMeans):
96+
self.estimator_ = self.estimator
97+
else:
98+
raise ValueError('`estimator` has to be a KMeans clustering.')
99+
100+
def fit(self, X, y):
101+
"""Find the classes statistics before to perform sampling.
102+
103+
Parameters
104+
----------
105+
X : ndarray, shape (n_samples, n_features)
106+
Matrix containing the data which have to be sampled.
107+
108+
y : ndarray, shape (n_samples, )
109+
Corresponding label for each sample in X.
110+
111+
Returns
112+
-------
113+
self : object,
114+
Return self.
115+
116+
"""
117+
118+
super(ClusterCentroids, self).fit(X, y)
119+
120+
self._validate_estimator()
121+
122+
return self
87123

88124
def _sample(self, X, y):
89125
"""Resample the dataset.
@@ -105,17 +141,15 @@ def _sample(self, X, y):
105141
The corresponding label of `X_resampled`
106142
107143
"""
108-
random_state = check_random_state(self.random_state)
109144

110145
# Compute the number of cluster needed
111146
if self.ratio == 'auto':
112147
num_samples = self.stats_c_[self.min_c_]
113148
else:
114149
num_samples = int(self.stats_c_[self.min_c_] / self.ratio)
115150

116-
# Create the clustering object
117-
kmeans = KMeans(n_clusters=num_samples, random_state=random_state)
118-
kmeans.set_params(**self.kwargs)
151+
# Set the number of sample for the estimator
152+
self.estimator_.set_params(**{'n_clusters': num_samples})
119153

120154
# Start with the minority class
121155
X_min = X[y == self.min_c_]
@@ -133,8 +167,8 @@ def _sample(self, X, y):
133167
continue
134168

135169
# Find the centroids via k-means
136-
kmeans.fit(X[y == key])
137-
centroids = kmeans.cluster_centers_
170+
self.estimator_.fit(X[y == key])
171+
centroids = self.estimator_.cluster_centers_
138172

139173
# Concatenate to the minority class
140174
X_resampled = np.concatenate((X_resampled, centroids), axis=0)

imblearn/under_sampling/condensed_nearest_neighbour.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,18 @@ class CondensedNearestNeighbour(BaseMulticlassSampler):
3636
NOTE: size_ngh is deprecated from 0.2 and will be replaced in 0.4
3737
Use ``n_neighbors`` instead.
3838
39-
n_neighbors : int, optional (default=1)
40-
Size of the neighbourhood to consider to compute the average
39+
n_neighbors : int or object, optional (default=KNeighborsClassifier(n_neighbors=1))
40+
If int, size of the neighbourhood to consider to compute the average
4141
distance to the minority point samples.
42+
If object, an object inherited from
43+
`sklearn.neigbors.KNeighborsClassifier` should be passed.
4244
4345
n_seeds_S : int, optional (default=1)
4446
Number of samples to extract in order to build the set S.
4547
4648
n_jobs : int, optional (default=1)
4749
The number of threads to open if possible.
4850
49-
**kwargs : keywords
50-
Parameter to use for the Neareast Neighbours object.
51-
52-
5351
Attributes
5452
----------
5553
min_c_ : str or int
@@ -95,16 +93,55 @@ class CondensedNearestNeighbour(BaseMulticlassSampler):
9593
"""
9694

9795
def __init__(self, return_indices=False, random_state=None,
98-
size_ngh=None, n_neighbors=1, n_seeds_S=1, n_jobs=1,
99-
**kwargs):
96+
size_ngh=None, n_neighbors=None, n_seeds_S=1, n_jobs=1):
10097
super(CondensedNearestNeighbour, self).__init__(
10198
random_state=random_state)
10299
self.return_indices = return_indices
103100
self.size_ngh = size_ngh
104101
self.n_neighbors = n_neighbors
105102
self.n_seeds_S = n_seeds_S
106103
self.n_jobs = n_jobs
107-
self.kwargs = kwargs
104+
105+
def _validate_estimator(self):
106+
"""Private function to create the NN estimator"""
107+
108+
if self.n_neighbors is None:
109+
self.estimator_ = KNeighborsClassifier(
110+
n_neighbors=1,
111+
n_jobs=self.n_jobs)
112+
elif isinstance(self.n_neighbors, int):
113+
self.estimator_ = KNeighborsClassifier(
114+
n_neighbors=self.n_neighbors,
115+
n_jobs=self.n_jobs)
116+
elif isinstance(self.n_neighbors, KNeighborsClassifier):
117+
self.estimator_ = self.n_neighbors
118+
else:
119+
raise ValueError('`n_neighbors` has to be a in or an object'
120+
' inhereited from KNeighborsClassifier.')
121+
122+
def fit(self, X, y):
123+
"""Find the classes statistics before to perform sampling.
124+
125+
Parameters
126+
----------
127+
X : ndarray, shape (n_samples, n_features)
128+
Matrix containing the data which have to be sampled.
129+
130+
y : ndarray, shape (n_samples, )
131+
Corresponding label for each sample in X.
132+
133+
Returns
134+
-------
135+
self : object,
136+
Return self.
137+
138+
"""
139+
140+
super(CondensedNearestNeighbour, self).fit(X, y)
141+
142+
self._validate_estimator()
143+
144+
return self
108145

109146
def _sample(self, X, y):
110147
"""Resample the dataset.
@@ -167,13 +204,8 @@ def _sample(self, X, y):
167204
S_x = X[y == key]
168205
S_y = y[y == key]
169206

170-
# Create a k-NN classifier
171-
knn = KNeighborsClassifier(n_neighbors=self.n_neighbors,
172-
n_jobs=self.n_jobs,
173-
**self.kwargs)
174-
175207
# Fit C into the knn
176-
knn.fit(C_x, C_y)
208+
self.estimator_.fit(C_x, C_y)
177209

178210
good_classif_label = idx_maj_sample.copy()
179211
# Check each sample in S if we keep it or drop it
@@ -184,7 +216,7 @@ def _sample(self, X, y):
184216
continue
185217

186218
# Classify on S
187-
pred_y = knn.predict(x_sam.reshape(1, -1))
219+
pred_y = self.estimator_.predict(x_sam.reshape(1, -1))
188220

189221
# If the prediction do not agree with the true label
190222
# append it in C_x
@@ -198,12 +230,12 @@ def _sample(self, X, y):
198230
idx_maj_sample.size))
199231

200232
# Fit C into the knn
201-
knn.fit(C_x, C_y)
233+
self.estimator_.fit(C_x, C_y)
202234

203235
# This experimental to speed up the search
204236
# Classify all the element in S and avoid to test the
205237
# well classified elements
206-
pred_S_y = knn.predict(S_x)
238+
pred_S_y = self.estimator_.predict(S_x)
207239
good_classif_label = np.unique(
208240
np.append(idx_maj_sample,
209241
np.flatnonzero(pred_S_y == S_y)))

0 commit comments

Comments
 (0)