Skip to content

Commit 3c6d232

Browse files
bgangliaglemaitre
andauthored
FIX Prevent incorrect class category resampling in SMOTENC when median_std_ is 0 (#675)
Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent 5a2d34f commit 3c6d232

File tree

3 files changed

+40
-1
lines changed

3 files changed

+40
-1
lines changed

doc/whats_new/v0.7.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ Bug fixes
4242
are given in :class:`imblearn.over_sampling.SMOTENC`.
4343
:pr:`720` by :user:`Guillaume Lemaitre <glemaitre>`.
4444

45+
- Fix a bug when the median of the standard deviation is null in
46+
:class:`imblearn.over_sampling.SMOTENC`.
47+
:pr:`675` by :user:`bganglia <bganglia>`.
48+
4549
Enhancements
4650
............
4751

imblearn/over_sampling/_smote.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def _validate_estimator(self):
5454
self.nn_k_ = check_neighbors_object(
5555
"k_neighbors", self.k_neighbors, additional_neighbor=1
5656
)
57-
self.nn_k_.set_params(**{"n_jobs": self.n_jobs})
5857

5958
def _make_samples(
6059
self, X, y_dtype, y_type, nn_data, nn_num, n_samples, step_size=1.0
@@ -956,6 +955,7 @@ def _fit_resample(self, X, y):
956955
self.ohe_ = OneHotEncoder(
957956
sparse=True, handle_unknown="ignore", dtype=dtype_ohe
958957
)
958+
959959
# the input of the OneHotEncoder needs to be dense
960960
X_ohe = self.ohe_.fit_transform(
961961
X_categorical.toarray()
@@ -967,6 +967,15 @@ def _fit_resample(self, X, y):
967967
# median of the standard deviation. It will ensure that whenever
968968
# distance is computed between 2 samples, the difference will be equal
969969
# to the median of the standard deviation as in the original paper.
970+
971+
# In the edge case where the median of the std is equal to 0, the 1s
972+
# entries will be also nullified. In this case, we store the original
973+
# categorical encoding which will be later used for inversing the OHE
974+
if math.isclose(self.median_std_, 0):
975+
self._X_categorical_minority_encoded = _safe_indexing(
976+
X_ohe.toarray(), np.flatnonzero(y == class_minority)
977+
)
978+
970979
X_ohe.data = (
971980
np.ones_like(X_ohe.data, dtype=X_ohe.dtype) * self.median_std_ / 2
972981
)
@@ -1027,6 +1036,14 @@ def _generate_samples(self, X, nn_data, nn_num, rows, cols, steps):
10271036

10281037
# convert to dense array since scipy.sparse doesn't handle 3D
10291038
nn_data = (nn_data.toarray() if sparse.issparse(nn_data) else nn_data)
1039+
1040+
# In the case that the median std was equal to zeros, we have to
1041+
# create non-null entry based on the encoded of OHE
1042+
if math.isclose(self.median_std_, 0):
1043+
nn_data[:, self.continuous_features_.size:] = (
1044+
self._X_categorical_minority_encoded
1045+
)
1046+
10301047
all_neighbors = nn_data[nn_num[rows]]
10311048

10321049
categories_size = [self.continuous_features_.size] + [

imblearn/over_sampling/tests/test_smote_nc.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,21 @@ def test_smotenc_raising_error_all_categorical(categorical_features):
218218
err_msg = "SMOTE-NC is not designed to work only with categorical features"
219219
with pytest.raises(ValueError, match=err_msg):
220220
smote.fit_resample(X, y)
221+
222+
223+
def test_smote_nc_with_null_median_std():
224+
# Non-regression test for #662
225+
# https://github.com/scikit-learn-contrib/imbalanced-learn/issues/662
226+
data = np.array([[1, 2, 1, 'A'],
227+
[2, 1, 2, 'A'],
228+
[1, 2, 3, 'B'],
229+
[1, 2, 4, 'C'],
230+
[1, 2, 5, 'C']], dtype="object")
231+
labels = np.array(
232+
['class_1', 'class_1', 'class_1', 'class_2', 'class_2'], dtype=object
233+
)
234+
smote = SMOTENC(categorical_features=[3], k_neighbors=1, random_state=0)
235+
X_res, y_res = smote.fit_resample(data, labels)
236+
# check that the categorical feature is not random but correspond to the
237+
# categories seen in the minority class samples
238+
assert X_res[-1, -1] == "C"

0 commit comments

Comments
 (0)