Skip to content

Commit 6916fe9

Browse files
authored
EHN: random sampler can sample from heterogeneous data (#451)
1 parent 41cd9a6 commit 6916fe9

File tree

14 files changed

+156
-39
lines changed

14 files changed

+156
-39
lines changed

doc/over_sampling.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,22 @@ As a result, the majority class does not take over the other classes during the
5252
training process. Consequently, all classes are represented by the decision
5353
function.
5454

55+
In addition, :class:`RandomOverSampler` allows to sample heterogeneous data
56+
(e.g. containing some strings)::
57+
58+
>>> import numpy as np
59+
>>> X_hetero = np.array([['xxx', 1, 1.0], ['yyy', 2, 2.0], ['zzz', 3, 3.0]],
60+
... dtype=np.object)
61+
>>> y_hetero = np.array([0, 0, 1])
62+
>>> X_resampled, y_resampled = ros.fit_sample(X_hetero, y_hetero)
63+
>>> print(X_resampled)
64+
[['xxx' 1 1.0]
65+
['yyy' 2 2.0]
66+
['zzz' 3 3.0]
67+
['zzz' 3 3.0]]
68+
>>> print(y_resampled)
69+
[0 0 1 1]
70+
5571
See :ref:`sphx_glr_auto_examples_over-sampling_plot_random_over_sampling.py`
5672
for usage example.
5773

doc/under_sampling.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,19 @@ by considering independently each targeted class::
103103
>>> print(np.vstack({tuple(row) for row in X_resampled}).shape)
104104
(181, 2)
105105

106+
In addition, :class:`RandomUnderSampler` allows to sample heterogeneous data
107+
(e.g. containing some strings)::
108+
109+
>>> X_hetero = np.array([['xxx', 1, 1.0], ['yyy', 2, 2.0], ['zzz', 3, 3.0]],
110+
... dtype=np.object)
111+
>>> y_hetero = np.array([0, 0, 1])
112+
>>> X_resampled, y_resampled = rus.fit_sample(X_hetero, y_hetero)
113+
>>> print(X_resampled)
114+
[['xxx' 1 1.0]
115+
['zzz' 3 3.0]]
116+
>>> print(y_resampled)
117+
[0 1]
118+
106119
See :ref:`sphx_glr_auto_examples_plot_sampling_strategy_usage.py`.,
107120
:ref:`sphx_glr_auto_examples_under-sampling_plot_comparison_under_sampling.py`,
108121
and :ref:`sphx_glr_auto_examples_under-sampling_plot_random_under_sampler.py`.

doc/whats_new/v0.0.4.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,11 @@ Enhancement
4545
:issue:`439` by :user:`Hugo Gascon<hgascon>` and
4646
:user:`Guillaume Lemaitre <glemaitre>`.
4747

48+
- Allow :class:`imblearn.under_sampling.RandomUnderSampler` and
49+
:class:`imblearn.over_sampling.RandomOverSampler` to sample object array
50+
containing strings.
51+
:issue:`448` by :user:`Guillaume Lemaitre <glemaitre>`.
52+
4853
Bug fixes
4954
.........
5055

imblearn/base.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,6 @@ class SamplerMixin(six.with_metaclass(ABCMeta, BaseEstimator)):
3131

3232
_estimator_type = 'sampler'
3333

34-
def _check_X_y(self, X, y):
35-
"""Private function to check that the X and y in fitting are the same
36-
than in sampling."""
37-
X_hash, y_hash = hash_X_y(X, y)
38-
if self.X_hash_ != X_hash or self.y_hash_ != y_hash:
39-
raise RuntimeError("X and y need to be same array earlier fitted.")
40-
4134
def sample(self, X, y):
4235
"""Resample the dataset.
4336
@@ -60,11 +53,10 @@ def sample(self, X, y):
6053
6154
"""
6255
# Check the consistency of X and y
63-
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
64-
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
56+
X, y, binarize_y = self._check_X_y(X, y)
6557

6658
check_is_fitted(self, 'sampling_strategy_')
67-
self._check_X_y(X, y)
59+
self._check_X_y_hash(X, y)
6860

6961
output = self._sample(X, y)
7062

@@ -151,6 +143,19 @@ def __init__(self, sampling_strategy='auto', ratio=None):
151143
self.ratio = ratio
152144
self.logger = logging.getLogger(self.__module__)
153145

146+
@staticmethod
147+
def _check_X_y(X, y):
148+
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
149+
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
150+
return X, y, binarize_y
151+
152+
def _check_X_y_hash(self, X, y):
153+
"""Private function to check that the X and y in fitting are the same
154+
than in sampling."""
155+
X_hash, y_hash = hash_X_y(X, y)
156+
if self.X_hash_ != X_hash or self.y_hash_ != y_hash:
157+
raise RuntimeError("X and y need to be same array earlier fitted.")
158+
154159
@property
155160
def ratio_(self):
156161
# FIXME: remove in 0.6
@@ -183,9 +188,9 @@ def fit(self, X, y):
183188
184189
"""
185190
self._deprecate_ratio()
186-
y = check_target_type(y)
187-
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
191+
X, y, _ = self._check_X_y(X, y)
188192
self.X_hash_, self.y_hash_ = hash_X_y(X, y)
193+
# _sampling_type is defined in the children base class
189194
self.sampling_strategy_ = check_sampling_strategy(
190195
self.sampling_strategy, y, self._sampling_type)
191196

imblearn/combine/smote_enn.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from sklearn.base import clone
1313
from sklearn.utils import check_X_y
1414

15-
from ..base import SamplerMixin
15+
from ..base import BaseSampler
1616
from ..over_sampling import SMOTE
1717
from ..over_sampling.base import BaseOverSampler
1818
from ..under_sampling import EditedNearestNeighbours
@@ -24,7 +24,7 @@
2424
@Substitution(
2525
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
2626
random_state=_random_state_docstring)
27-
class SMOTEENN(SamplerMixin):
27+
class SMOTEENN(BaseSampler):
2828
"""Class to perform over-sampling using SMOTE and cleaning using ENN.
2929
3030
Combine over- and under-sampling using SMOTE and Edited Nearest Neighbours.
@@ -125,14 +125,6 @@ def _validate_estimator(self):
125125
else:
126126
self.enn_ = EditedNearestNeighbours(sampling_strategy='all')
127127

128-
@property
129-
def ratio_(self):
130-
# FIXME: remove in 0.6
131-
warnings.warn("'ratio' and 'ratio_' are deprecated. Use "
132-
"'sampling_strategy' and 'sampling_strategy_' instead.",
133-
DeprecationWarning)
134-
return self.sampling_strategy_
135-
136128
def fit(self, X, y):
137129
"""Find the classes statistics before to perform sampling.
138130

imblearn/combine/smote_tomek.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sklearn.base import clone
1414
from sklearn.utils import check_X_y
1515

16-
from ..base import SamplerMixin
16+
from ..base import BaseSampler
1717
from ..over_sampling import SMOTE
1818
from ..over_sampling.base import BaseOverSampler
1919
from ..under_sampling import TomekLinks
@@ -25,7 +25,7 @@
2525
@Substitution(
2626
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
2727
random_state=_random_state_docstring)
28-
class SMOTETomek(SamplerMixin):
28+
class SMOTETomek(BaseSampler):
2929
"""Class to perform over-sampling using SMOTE and cleaning using
3030
Tomek links.
3131
@@ -133,14 +133,6 @@ def _validate_estimator(self):
133133
else:
134134
self.tomek_ = TomekLinks(sampling_strategy='all')
135135

136-
@property
137-
def ratio_(self):
138-
# FIXME: remove in 0.6
139-
warnings.warn("'ratio' and 'ratio_' are deprecated. Use "
140-
"'sampling_strategy' and 'sampling_strategy_' instead.",
141-
DeprecationWarning)
142-
return self.sampling_strategy_
143-
144136
def fit(self, X, y):
145137
"""Find the classes statistics before to perform sampling.
146138

imblearn/ensemble/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def sample(self, X, y):
6060
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'])
6161

6262
check_is_fitted(self, 'sampling_strategy_')
63-
self._check_X_y(X, y)
63+
self._check_X_y_hash(X, y)
6464

6565
output = self._sample(X, y)
6666

imblearn/over_sampling/random_over_sampler.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
from collections import Counter
99

1010
import numpy as np
11-
from sklearn.utils import check_random_state, safe_indexing
11+
from sklearn.utils import check_X_y, check_random_state, safe_indexing
1212

1313
from .base import BaseOverSampler
14+
from ..utils import check_target_type
1415
from ..utils import Substitution
1516
from ..utils._docstring import _random_state_docstring
1617

@@ -44,6 +45,8 @@ class RandomOverSampler(BaseOverSampler):
4445
Notes
4546
-----
4647
Supports multi-class resampling by sampling each class independently.
48+
Supports heterogeneous data as object array containing string and numeric
49+
data.
4750
4851
See
4952
:ref:`sphx_glr_auto_examples_over-sampling_plot_comparison_over_sampling.py`,
@@ -79,6 +82,12 @@ def __init__(self, sampling_strategy='auto',
7982
self.return_indices = return_indices
8083
self.random_state = random_state
8184

85+
@staticmethod
86+
def _check_X_y(X, y):
87+
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
88+
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'], dtype=None)
89+
return X, y, binarize_y
90+
8291
def _sample(self, X, y):
8392
"""Resample the dataset.
8493

imblearn/over_sampling/tests/test_random_over_sampler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,16 @@ def test_multiclass_fit_sample():
8888
assert count_y_res[0] == 5
8989
assert count_y_res[1] == 5
9090
assert count_y_res[2] == 5
91+
92+
93+
def test_random_over_sampling_heterogeneous_data():
94+
X_hetero = np.array([['xxx', 1, 1.0], ['yyy', 2, 2.0], ['zzz', 3, 3.0]],
95+
dtype=np.object)
96+
y = np.array([0, 0, 1])
97+
ros = RandomOverSampler(random_state=RND_SEED)
98+
X_res, y_res = ros.fit_sample(X_hetero, y)
99+
100+
assert X_res.shape[0] == 4
101+
assert y_res.shape[0] == 4
102+
assert X_res.dtype == object
103+
assert X_res[-1, 0] in X_hetero[:, 0]

imblearn/under_sampling/prototype_selection/random_under_sampler.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
from __future__ import division
88

99
import numpy as np
10-
from sklearn.utils import check_random_state, safe_indexing
10+
11+
from sklearn.utils import check_X_y, check_random_state, safe_indexing
1112

1213
from ..base import BaseUnderSampler
14+
from ...utils import check_target_type
1315
from ...utils import Substitution
1416
from ...utils._docstring import _random_state_docstring
1517

@@ -46,6 +48,8 @@ class RandomUnderSampler(BaseUnderSampler):
4648
Notes
4749
-----
4850
Supports multi-class resampling by sampling each class independently.
51+
Supports heterogeneous data as object array containing string and numeric
52+
data.
4953
5054
See
5155
:ref:`sphx_glr_auto_examples_plot_sampling_strategy_usage.py` and
@@ -82,6 +86,12 @@ def __init__(self,
8286
self.return_indices = return_indices
8387
self.replacement = replacement
8488

89+
@staticmethod
90+
def _check_X_y(X, y):
91+
y, binarize_y = check_target_type(y, indicate_one_vs_all=True)
92+
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc'], dtype=None)
93+
return X, y, binarize_y
94+
8595
def _sample(self, X, y):
8696
"""Resample the dataset.
8797

0 commit comments

Comments
 (0)