Skip to content

Commit c630df3

Browse files
authored
FIX properly set the default n_neighbours in SMOTE svm and borderline (#578)
1 parent b2a9941 commit c630df3

File tree

3 files changed

+20
-2
lines changed

3 files changed

+20
-2
lines changed

doc/whats_new/v0.5.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ Bug
6464
estimator.
6565
:pr:`554` by :user:`Oliver Rausch <orausch>`.
6666

67+
- Fix bug in :class:`imblearn.over_sampling.SVMSMOTE` and
68+
:class:`imblearn.over_sampling.BorderlineSMOTE` where the default parameter
69+
of ``n_neighbors`` was not set properly.
70+
:pr:`578` by :user:`Guillaume Lemaitre <glemaitre>`.
71+
6772
- Fix bug by changing the default depth in
6873
:class:`imblearn.ensemble.RUSBoostClassifier` to get a decision stump as a
6974
weak learner as in the original paper.

imblearn/over_sampling/_smote.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ def __init__(self,
323323
def _validate_estimator(self):
324324
super()._validate_estimator()
325325
self.nn_m_ = check_neighbors_object(
326-
'k_neighbors', self.k_neighbors, additional_neighbor=1)
326+
'm_neighbors', self.m_neighbors, additional_neighbor=1)
327327
self.nn_m_.set_params(**{'n_jobs': self.n_jobs})
328328
if self.kind not in ('borderline-1', 'borderline-2'):
329329
raise ValueError('The possible "kind" of algorithm are '
@@ -506,7 +506,7 @@ def __init__(self,
506506
def _validate_estimator(self):
507507
super()._validate_estimator()
508508
self.nn_m_ = check_neighbors_object(
509-
'k_neighbors', self.k_neighbors, additional_neighbor=1)
509+
'm_neighbors', self.m_neighbors, additional_neighbor=1)
510510
self.nn_m_.set_params(**{'n_jobs': self.n_jobs})
511511

512512
if self.svm_estimator is None:

imblearn/over_sampling/tests/test_smote.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from sklearn.svm import SVC
1212

1313
from imblearn.over_sampling import SMOTE
14+
from imblearn.over_sampling import SVMSMOTE
15+
from imblearn.over_sampling import BorderlineSMOTE
1416

1517

1618
RND_SEED = 0
@@ -286,3 +288,14 @@ def test_sample_with_nn_svm():
286288
1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0])
287289
assert_allclose(X_resampled, X_gt, rtol=R_TOL)
288290
assert_array_equal(y_resampled, y_gt)
291+
292+
293+
@pytest.mark.parametrize(
294+
"smote", [BorderlineSMOTE(), SVMSMOTE()], ids=['borderline', 'svm']
295+
)
296+
def test_smote_m_neighbors(smote):
297+
# check that m_neighbors is properly set. Regression test for:
298+
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/568
299+
_ = smote.fit_resample(X, Y)
300+
assert smote.nn_k_.n_neighbors == 6
301+
assert smote.nn_m_.n_neighbors == 11

0 commit comments

Comments
 (0)