Skip to content

Commit e3df215

Browse files
authored
FEA implements SMOTEN to handle nominal categorical features (#802)
1 parent b6621f9 commit e3df215

File tree

9 files changed

+281
-6
lines changed

9 files changed

+281
-6
lines changed

README.rst

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -168,11 +168,12 @@ Below is a list of the methods currently implemented in this module.
168168
1. Random minority over-sampling with replacement
169169
2. SMOTE - Synthetic Minority Over-sampling Technique [8]_
170170
3. SMOTENC - SMOTE for Nominal Continuous [8]_
171-
4. bSMOTE(1 & 2) - Borderline SMOTE of types 1 and 2 [9]_
172-
5. SVM SMOTE - Support Vectors SMOTE [10]_
173-
6. ADASYN - Adaptive synthetic sampling approach for imbalanced learning [15]_
174-
7. KMeans-SMOTE [17]_
175-
8. ROSE - Random OverSampling Examples [19]_
171+
4. SMOTEN - SMMOTE for Nominal only [8]_
172+
5. bSMOTE(1 & 2) - Borderline SMOTE of types 1 and 2 [9]_
173+
6. SVM SMOTE - Support Vectors SMOTE [10]_
174+
7. ADASYN - Adaptive synthetic sampling approach for imbalanced learning [15]_
175+
8. KMeans-SMOTE [17]_
176+
9. ROSE - Random OverSampling Examples [19]_
176177

177178
* Over-sampling followed by under-sampling
178179
1. SMOTE + Tomek links [12]_

doc/over_sampling.rst

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,44 @@ Therefore, it can be seen that the samples generated in the first and last
211211
columns are belonging to the same categories originally presented without any
212212
other extra interpolation.
213213

214+
However, :class:`SMOTENC` is working with data composed of categorical data
215+
only. WHen data are made of only nominal categorical data, one can use the
216+
:class:`SMOTEN` variant :cite:`chawla2002smote`. The algorithm changes in
217+
two ways:
218+
219+
* the nearest neighbors search does not rely on the Euclidean distance. Indeed,
220+
the value difference metric (VDM) also implemented in the class
221+
:class:`~imblearn.metrics.ValueDifferenceMetric` is used.
222+
* the new sample generation is based on majority vote per feature to generate
223+
the most common category seen in the neighbors samples.
224+
225+
Let's take the following example::
226+
227+
>>> import numpy as np
228+
>>> X = np.array(["green"] * 5 + ["red"] * 10 + ["blue"] * 7,
229+
... dtype=object).reshape(-1, 1)
230+
>>> y = np.array(["apple"] * 5 + ["not apple"] * 3 + ["apple"] * 7 +
231+
... ["not apple"] * 5 + ["apple"] * 2, dtype=object)
232+
233+
We generate a dataset associating a color to being an apple or not an apple.
234+
We strongly associated "green" and "red" to being an apple. The minority class
235+
being "not apple", we expect new data generated belonging to the category
236+
"blue"::
237+
238+
>>> from imblearn.over_sampling import SMOTEN
239+
>>> sampler = SMOTEN(random_state=0)
240+
>>> X_res, y_res = sampler.fit_resample(X, y)
241+
>>> X_res[y.size:]
242+
array([['blue'],
243+
['blue'],
244+
['blue'],
245+
['blue'],
246+
['blue'],
247+
['blue']], dtype=object)
248+
>>> y_res[y.size:]
249+
array(['not apple', 'not apple', 'not apple', 'not apple', 'not apple',
250+
'not apple'], dtype=object)
251+
214252
Mathematical formulation
215253
========================
216254

doc/references/over_sampling.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ SMOTE algorithms
2727

2828
SMOTE
2929
SMOTENC
30+
SMOTEN
3031
ADASYN
3132
BorderlineSMOTE
3233
KMeansSMOTE

doc/whats_new/v0.8.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ New features
1919
compute pairwise distances between samples containing only nominal values.
2020
:pr:`796` by :user:`Guillaume Lemaitre <glemaitre>`.
2121

22+
- Add the class :class:`imblearn.over_sampling.SMOTEN` to over-sample data
23+
only containing nominal categorical features.
24+
:pr:`802` by :user:`Guillaume Lemaitre <glemaitre>`.
25+
2226
Enhancements
2327
............
2428

imblearn/over_sampling/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from ._smote import KMeansSMOTE
1111
from ._smote import SVMSMOTE
1212
from ._smote import SMOTENC
13+
from ._smote import SMOTEN
1314

1415
__all__ = [
1516
"ADASYN",
@@ -19,4 +20,5 @@
1920
"BorderlineSMOTE",
2021
"SVMSMOTE",
2122
"SMOTENC",
23+
"SMOTEN",
2224
]

imblearn/over_sampling/_adasyn.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ class ADASYN(BaseOverSampler):
5050
--------
5151
SMOTE : Over-sample using SMOTE.
5252
53+
SMOTENC : Over-sample using SMOTE for continuous and categorical features.
54+
55+
SMOTEN : Over-sample using the SMOTE variable specifically for categorical
56+
features only.
57+
58+
SVMSMOTE : Over-sample using SVM-SMOTE variant.
59+
60+
BorderlineSMOTE : Over-sample using Borderline-SMOTE variant.
61+
5362
Notes
5463
-----
5564
The implementation is based on [1]_.
@@ -169,3 +178,8 @@ def _fit_resample(self, X, y):
169178
y_resampled = np.hstack(y_resampled)
170179

171180
return X_resampled, y_resampled
181+
182+
def _more_tags(self):
183+
return {
184+
"X_types": ["2darray"],
185+
}

imblearn/over_sampling/_random_over_sampler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ class RandomOverSampler(BaseOverSampler):
7676
7777
SMOTENC : Over-sample using SMOTE for continuous and categorical features.
7878
79+
SMOTEN : Over-sample using the SMOTE variable specifically for categorical
80+
features only.
81+
7982
SVMSMOTE : Over-sample using SVM-SMOTE variant.
8083
8184
ADASYN : Over-sample using ADASYN.

imblearn/over_sampling/_smote.py

Lines changed: 159 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111

1212
import numpy as np
1313
from scipy import sparse
14+
from scipy import stats
1415

1516
from sklearn.base import clone
1617
from sklearn.cluster import MiniBatchKMeans
1718
from sklearn.metrics import pairwise_distances
18-
from sklearn.preprocessing import OneHotEncoder
19+
from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder
1920
from sklearn.svm import SVC
2021
from sklearn.utils import check_random_state
2122
from sklearn.utils import _safe_indexing
@@ -25,6 +26,7 @@
2526

2627
from .base import BaseOverSampler
2728
from ..exceptions import raise_isinstance_error
29+
from ..metrics.pairwise import ValueDifferenceMetric
2830
from ..utils import check_neighbors_object
2931
from ..utils import check_target_type
3032
from ..utils import Substitution
@@ -448,6 +450,9 @@ class SVMSMOTE(BaseSMOTE):
448450
449451
SMOTENC : Over-sample using SMOTE for continuous and categorical features.
450452
453+
SMOTEN : Over-sample using the SMOTE variable specifically for categorical
454+
features only.
455+
451456
BorderlineSMOTE : Over-sample using Borderline-SMOTE.
452457
453458
ADASYN : Over-sample using ADASYN.
@@ -643,6 +648,9 @@ class SMOTE(BaseSMOTE):
643648
--------
644649
SMOTENC : Over-sample using SMOTE for continuous and categorical features.
645650
651+
SMOTEN : Over-sample using the SMOTE variable specifically for categorical
652+
features only.
653+
646654
BorderlineSMOTE : Over-sample using the borderline-SMOTE variant.
647655
648656
SVMSMOTE : Over-sample using the SVM-SMOTE variant.
@@ -766,6 +774,9 @@ class SMOTENC(SMOTE):
766774
--------
767775
SMOTE : Over-sample using SMOTE.
768776
777+
SMOTEN : Over-sample using the SMOTE variable specifically for categorical
778+
features only.
779+
769780
SVMSMOTE : Over-sample using SVM-SMOTE variant.
770781
771782
BorderlineSMOTE : Over-sample using Borderline-SMOTE variant.
@@ -1055,6 +1066,11 @@ class KMeansSMOTE(BaseSMOTE):
10551066
--------
10561067
SMOTE : Over-sample using SMOTE.
10571068
1069+
SMOTENC : Over-sample using SMOTE for continuous and categorical features.
1070+
1071+
SMOTEN : Over-sample using the SMOTE variable specifically for categorical
1072+
features only.
1073+
10581074
SVMSMOTE : Over-sample using SVM-SMOTE variant.
10591075
10601076
BorderlineSMOTE : Over-sample using Borderline-SMOTE variant.
@@ -1248,3 +1264,145 @@ def _fit_resample(self, X, y):
12481264
y_resampled = np.hstack((y_resampled, y_new))
12491265

12501266
return X_resampled, y_resampled
1267+
1268+
1269+
@Substitution(
1270+
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
1271+
n_jobs=_n_jobs_docstring,
1272+
random_state=_random_state_docstring,
1273+
)
1274+
class SMOTEN(SMOTE):
1275+
"""Perform SMOTE over-sampling for nominal categorical features only.
1276+
1277+
This method is refered as SMOTEN in [1]_.
1278+
1279+
Read more in the :ref:`User Guide <smote_adasyn>`.
1280+
1281+
Parameters
1282+
----------
1283+
{sampling_strategy}
1284+
1285+
{random_state}
1286+
1287+
k_neighbors : int or object, default=5
1288+
If ``int``, number of nearest neighbours to used to construct synthetic
1289+
samples. If object, an estimator that inherits from
1290+
:class:`~sklearn.neighbors.base.KNeighborsMixin` that will be used to
1291+
find the k_neighbors.
1292+
1293+
{n_jobs}
1294+
1295+
See Also
1296+
--------
1297+
SMOTE : Over-sample using SMOTE.
1298+
1299+
SMOTENC : Over-sample using SMOTE for continuous and categorical features.
1300+
1301+
BorderlineSMOTE : Over-sample using the borderline-SMOTE variant.
1302+
1303+
SVMSMOTE : Over-sample using the SVM-SMOTE variant.
1304+
1305+
ADASYN : Over-sample using ADASYN.
1306+
1307+
KMeansSMOTE : Over-sample applying a clustering before to oversample using
1308+
SMOTE.
1309+
1310+
Notes
1311+
-----
1312+
See the original papers: [1]_ for more details.
1313+
1314+
Supports multi-class resampling. A one-vs.-rest scheme is used as
1315+
originally proposed in [1]_.
1316+
1317+
References
1318+
----------
1319+
.. [1] N. V. Chawla, K. W. Bowyer, L. O.Hall, W. P. Kegelmeyer, "SMOTE:
1320+
synthetic minority over-sampling technique," Journal of artificial
1321+
intelligence research, 321-357, 2002.
1322+
1323+
Examples
1324+
--------
1325+
>>> import numpy as np
1326+
>>> X = np.array(["A"] * 10 + ["B"] * 20 + ["C"] * 30, dtype=object).reshape(-1, 1)
1327+
>>> y = np.array([0] * 20 + [1] * 40, dtype=np.int32)
1328+
>>> from collections import Counter
1329+
>>> print(f"Original class counts: {{Counter(y)}}")
1330+
Original class counts: Counter({{1: 40, 0: 20}})
1331+
>>> from imblearn.over_sampling import SMOTEN
1332+
>>> sampler = SMOTEN(random_state=0)
1333+
>>> X_res, y_res = sampler.fit_resample(X, y)
1334+
>>> print(f"Class counts after resampling {{Counter(y_res)}}")
1335+
Class counts after resampling Counter({{0: 40, 1: 40}})
1336+
"""
1337+
1338+
def _check_X_y(self, X, y):
1339+
"""Check should accept strings and not sparse matrices."""
1340+
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
1341+
X, y = self._validate_data(
1342+
X,
1343+
y,
1344+
reset=True,
1345+
dtype=None,
1346+
accept_sparse=False,
1347+
)
1348+
return X, y, binarize_y
1349+
1350+
def _validate_estimator(self):
1351+
"""Force to use precomputed distance matrix."""
1352+
super()._validate_estimator()
1353+
self.nn_k_.set_params(metric="precomputed")
1354+
1355+
def _make_samples(self, X_class, klass, y_dtype, nn_indices, n_samples):
1356+
random_state = check_random_state(self.random_state)
1357+
# generate sample indices that will be used to generate new samples
1358+
samples_indices = random_state.choice(
1359+
np.arange(X_class.shape[0]), size=n_samples, replace=True
1360+
)
1361+
# for each drawn samples, select its k-neighbors and generate a sample
1362+
# where for each feature individually, each category generated is the
1363+
# most common category
1364+
X_new = np.squeeze(
1365+
stats.mode(X_class[nn_indices[samples_indices]], axis=1).mode, axis=1
1366+
)
1367+
y_new = np.full(n_samples, fill_value=klass, dtype=y_dtype)
1368+
return X_new, y_new
1369+
1370+
def _fit_resample(self, X, y):
1371+
self._validate_estimator()
1372+
1373+
X_resampled = [X.copy()]
1374+
y_resampled = [y.copy()]
1375+
1376+
encoder = OrdinalEncoder(dtype=np.int32)
1377+
X_encoded = encoder.fit_transform(X)
1378+
1379+
vdm = ValueDifferenceMetric(
1380+
n_categories=[len(cat) for cat in encoder.categories_]
1381+
).fit(X_encoded, y)
1382+
1383+
for class_sample, n_samples in self.sampling_strategy_.items():
1384+
if n_samples == 0:
1385+
continue
1386+
target_class_indices = np.flatnonzero(y == class_sample)
1387+
X_class = _safe_indexing(X_encoded, target_class_indices)
1388+
1389+
X_class_dist = vdm.pairwise(X_class)
1390+
self.nn_k_.fit(X_class_dist)
1391+
# the kneigbors search will include the sample itself which is
1392+
# expected from the original algorithm
1393+
nn_indices = self.nn_k_.kneighbors(X_class_dist, return_distance=False)
1394+
X_new, y_new = self._make_samples(
1395+
X_class, class_sample, y.dtype, nn_indices, n_samples
1396+
)
1397+
1398+
X_new = encoder.inverse_transform(X_new)
1399+
X_resampled.append(X_new)
1400+
y_resampled.append(y_new)
1401+
1402+
X_resampled = np.vstack(X_resampled)
1403+
y_resampled = np.hstack(y_resampled)
1404+
1405+
return X_resampled, y_resampled
1406+
1407+
def _more_tags(self):
1408+
return {"X_types": ["2darray", "dataframe", "string"]}
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import numpy as np
2+
import pytest
3+
4+
from imblearn.over_sampling import SMOTEN
5+
6+
7+
@pytest.fixture
8+
def data():
9+
rng = np.random.RandomState(0)
10+
11+
feature_1 = ["A"] * 10 + ["B"] * 20 + ["C"] * 30
12+
feature_2 = ["A"] * 40 + ["B"] * 20
13+
feature_3 = ["A"] * 20 + ["B"] * 20 + ["C"] * 10 + ["D"] * 10
14+
X = np.array([feature_1, feature_2, feature_3], dtype=object).T
15+
rng.shuffle(X)
16+
y = np.array([0] * 20 + [1] * 40, dtype=np.int32)
17+
y_labels = np.array(["not apple", "apple"], dtype=object)
18+
y = y_labels[y]
19+
return X, y
20+
21+
22+
def test_smoten(data):
23+
# overall check for SMOTEN
24+
X, y = data
25+
sampler = SMOTEN(random_state=0)
26+
X_res, y_res = sampler.fit_resample(X, y)
27+
28+
assert X_res.shape == (80, 3)
29+
assert y_res.shape == (80,)
30+
31+
32+
def test_smoten_resampling():
33+
# check if the SMOTEN resample data as expected
34+
# we generate data such that "not apple" will be the minority class and
35+
# samples from this class will be generated. We will force the "blue"
36+
# category to be associated with this class. Therefore, the new generated
37+
# samples should as well be from the "blue" category.
38+
X = np.array(["green"] * 5 + ["red"] * 10 + ["blue"] * 7, dtype=object).reshape(
39+
-1, 1
40+
)
41+
y = np.array(
42+
["apple"] * 5
43+
+ ["not apple"] * 3
44+
+ ["apple"] * 7
45+
+ ["not apple"] * 5
46+
+ ["apple"] * 2,
47+
dtype=object,
48+
)
49+
sampler = SMOTEN(random_state=0)
50+
X_res, y_res = sampler.fit_resample(X, y)
51+
52+
X_generated, y_generated = X_res[X.shape[0] :], y_res[X.shape[0] :]
53+
np.testing.assert_array_equal(X_generated, "blue")
54+
np.testing.assert_array_equal(y_generated, "not apple")

0 commit comments

Comments
 (0)