Skip to content

Commit 399f4a7

Browse files
author
Guillaume Lemaitre
committed
Advance the compatibility with scikit-learn
1 parent 7ac94b8 commit 399f4a7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+630
-432
lines changed

unbalanced_dataset/base_sampler.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,28 @@
33
from __future__ import division
44
from __future__ import print_function
55

6+
import warnings
7+
68
import numpy as np
79

810
from abc import ABCMeta, abstractmethod
911

1012
from collections import Counter
1113

14+
from sklearn.base import BaseEstimator, TransformerMixin
1215
from sklearn.utils import check_X_y
16+
from sklearn.externals import six
1317

1418
from six import string_types
1519

1620

17-
class BaseSampler(object):
21+
class BaseSampler(six.with_metaclass(ABCMeta, BaseEstimator)):
1822
"""Basic class with abstact method.
1923
2024
Warning: This class should not be used directly. Use the derive classes
2125
instead.
2226
"""
2327

24-
__metaclass__ = ABCMeta
25-
2628
@abstractmethod
2729
def __init__(self, ratio='auto', random_state=None, verbose=True):
2830
"""Initialize this object and its instance variables.
@@ -55,16 +57,16 @@ def __init__(self, ratio='auto', random_state=None, verbose=True):
5557
elif ratio <= 0:
5658
raise ValueError('Ratio cannot be negative.')
5759
else:
58-
self.ratio_ = ratio
60+
self.ratio = ratio
5961
elif isinstance(ratio, string_types):
6062
if ratio == 'auto':
61-
self.ratio_ = ratio
63+
self.ratio = ratio
6264
else:
6365
raise ValueError('Unknown string for the parameter ratio.')
6466
else:
6567
raise ValueError('Unknown parameter type for ratio.')
6668

67-
self.rs_ = random_state
69+
self.random_state = random_state
6870
self.verbose = verbose
6971

7072
# Create the member variables regarding the classes statistics
@@ -100,9 +102,13 @@ def fit(self, X, y):
100102
# Get all the unique elements in the target array
101103
uniques = np.unique(y)
102104

103-
# Raise an error if there is only one class
105+
# # Raise an error if there is only one class
106+
# if uniques.size == 1:
107+
# raise RuntimeError("Only one class detected, aborting...")
108+
# Raise a warning for the moment to be compatible with BaseEstimator
104109
if uniques.size == 1:
105-
raise RuntimeError("Only one class detected, aborting...")
110+
warnings.warn('Only one class detected, something will get wrong',
111+
RuntimeWarning)
106112

107113
# Create a dictionary containing the class statistics
108114
self.stats_c_ = Counter(y)
@@ -116,17 +122,17 @@ def fit(self, X, y):
116122
self.stats_c_))
117123

118124
# Check if the ratio provided at initialisation make sense
119-
if isinstance(self.ratio_, float):
120-
if self.ratio_ < (self.stats_c_[self.min_c_] /
121-
self.stats_c_[self.maj_c_]):
125+
if isinstance(self.ratio, float):
126+
if self.ratio < (self.stats_c_[self.min_c_] /
127+
self.stats_c_[self.maj_c_]):
122128
raise RuntimeError('The ratio requested at initialisation'
123129
' should be greater or equal than the'
124130
' balancing ratio of the current data.')
125131

126132
return self
127133

128134
@abstractmethod
129-
def transform(self, X, y):
135+
def sample(self, X, y):
130136
"""Resample the dataset.
131137
132138
Parameters
@@ -153,7 +159,7 @@ def transform(self, X, y):
153159

154160
return self
155161

156-
def fit_transform(self, X, y):
162+
def fit_sample(self, X, y):
157163
"""Fit the statistics and resample the data directly.
158164
159165
Parameters
@@ -174,4 +180,4 @@ def fit_transform(self, X, y):
174180
175181
"""
176182

177-
return self.fit(X, y).transform(X, y)
183+
return self.fit(X, y).sample(X, y)

unbalanced_dataset/combine/smote_enn.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -162,15 +162,28 @@ def __init__(self, ratio='auto', random_state=None, verbose=True,
162162
super(SMOTEENN, self).__init__(ratio=ratio, random_state=random_state,
163163
verbose=verbose)
164164

165-
self.sm = SMOTE(ratio=ratio, random_state=random_state,
166-
verbose=verbose, k=k, m=m, out_step=out_step,
167-
kind=kind_smote, nn_method=nn_method, n_jobs=n_jobs,
168-
**kwargs)
169-
170-
self.enn = EditedNearestNeighbours(random_state=random_state,
171-
verbose=verbose,
172-
size_ngh=size_ngh,
173-
kind_sel=kind_enn, n_jobs=n_jobs)
165+
self.k = k
166+
self.m = m
167+
self.out_step = out_step
168+
self.kind_smote = kind_smote
169+
self.nn_method = nn_method
170+
self.n_jobs = n_jobs
171+
self.kwargs = kwargs
172+
173+
self.sm = SMOTE(ratio=self.ratio, random_state=self.random_state,
174+
verbose=self.verbose, k=self.k, m=self.m,
175+
out_step=self.out_step, kind=self.kind_smote,
176+
nn_method=self.nn_method, n_jobs=self.n_jobs,
177+
**self.kwargs)
178+
179+
self.size_ngh = size_ngh
180+
self.kind_enn = kind_enn
181+
182+
self.enn = EditedNearestNeighbours(random_state=self.random_state,
183+
verbose=self.verbose,
184+
size_ngh=self.size_ngh,
185+
kind_sel=self.kind_enn,
186+
n_jobs=self.n_jobs)
174187

175188
def fit(self, X, y):
176189
"""Find the classes statistics before to perform sampling.
@@ -199,7 +212,7 @@ def fit(self, X, y):
199212

200213
return self
201214

202-
def transform(self, X, y):
215+
def sample(self, X, y):
203216
"""Resample the dataset.
204217
205218
Parameters
@@ -222,10 +235,10 @@ def transform(self, X, y):
222235
# Check the consistency of X and y
223236
X, y = check_X_y(X, y)
224237

225-
super(SMOTEENN, self).transform(X, y)
238+
super(SMOTEENN, self).sample(X, y)
226239

227240
# Transform using SMOTE
228-
X, y = self.sm.transform(X, y)
241+
X, y = self.sm.sample(X, y)
229242

230243
# Fit and transform using ENN
231-
return self.enn.fit_transform(X, y)
244+
return self.enn.fit_sample(X, y)

unbalanced_dataset/combine/smote_tomek.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,22 @@ def __init__(self, ratio='auto', random_state=None, verbose=True,
150150
random_state=random_state,
151151
verbose=verbose)
152152

153-
self.sm = SMOTE(ratio=ratio, random_state=random_state,
154-
verbose=verbose, k=k, m=m, out_step=out_step,
155-
kind=kind_smote, nn_method=nn_method, n_jobs=n_jobs,
156-
**kwargs)
157-
158-
self.tomek = TomekLinks(random_state=random_state,
159-
verbose=verbose)
153+
self.k = k
154+
self.m = m
155+
self.out_step = out_step
156+
self.kind_smote = kind_smote
157+
self.nn_method = nn_method
158+
self.n_jobs = n_jobs
159+
self.kwargs = kwargs
160+
161+
self.sm = SMOTE(ratio=self.ratio, random_state=self.random_state,
162+
verbose=self.verbose, k=self.k, m=self.m,
163+
out_step=self.out_step, kind=self.kind_smote,
164+
nn_method=self.nn_method, n_jobs=self.n_jobs,
165+
**self.kwargs)
166+
167+
self.tomek = TomekLinks(random_state=self.random_state,
168+
verbose=self.verbose)
160169

161170
def fit(self, X, y):
162171
"""Find the classes statistics before to perform sampling.
@@ -185,7 +194,7 @@ def fit(self, X, y):
185194

186195
return self
187196

188-
def transform(self, X, y):
197+
def sample(self, X, y):
189198
"""Resample the dataset.
190199
191200
Parameters
@@ -208,10 +217,10 @@ def transform(self, X, y):
208217
# Check the consistency of X and y
209218
X, y = check_X_y(X, y)
210219

211-
super(SMOTETomek, self).transform(X, y)
220+
super(SMOTETomek, self).sample(X, y)
212221

213222
# Transform using SMOTE
214-
X, y = self.sm.transform(X, y)
223+
X, y = self.sm.sample(X, y)
215224

216225
# Fit and transform using ENN
217-
return self.tomek.fit_transform(X, y)
226+
return self.tomek.fit_sample(X, y)

unbalanced_dataset/combine/tests/test_smote_enn.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from numpy.testing import assert_raises
88
from numpy.testing import assert_equal
99
from numpy.testing import assert_array_equal
10+
from numpy.testing import assert_warns
1011

1112
from sklearn.datasets import make_classification
13+
from sklearn.utils.estimator_checks import check_estimator
1214

1315
from unbalanced_dataset.combine import SMOTEENN
1416

@@ -20,6 +22,10 @@
2022
n_samples=5000, random_state=RND_SEED)
2123

2224

25+
def test_senn_sk_estimator():
26+
"""Test the sklearn estimator compatibility"""
27+
check_estimator(SMOTEENN)
28+
2329
def test_senn_bad_ratio():
2430
"""Test either if an error is raised with a wrong decimal value for
2531
the ratio"""
@@ -49,7 +55,7 @@ def test_smote_fit_single_class():
4955
# Resample the data
5056
# Create a wrong y
5157
y_single_class = np.zeros((X.shape[0], ))
52-
assert_raises(RuntimeError, smote.fit, X, y_single_class)
58+
assert_warns(RuntimeWarning, smote.fit, X, y_single_class)
5359

5460

5561
def test_smote_fit():
@@ -67,24 +73,24 @@ def test_smote_fit():
6773
assert_equal(smote.stats_c_[1], 4500)
6874

6975

70-
def test_smote_transform_wt_fit():
71-
"""Test either if an error is raised when transform is called before
76+
def test_smote_sample_wt_fit():
77+
"""Test either if an error is raised when sample is called before
7278
fitting"""
7379

7480
# Create the object
7581
smote = SMOTEENN(random_state=RND_SEED)
76-
assert_raises(RuntimeError, smote.transform, X, Y)
82+
assert_raises(RuntimeError, smote.sample, X, Y)
7783

7884

79-
def test_transform_regular():
80-
"""Test transform function with regular SMOTE."""
85+
def test_sample_regular():
86+
"""Test sample function with regular SMOTE."""
8187

8288
# Create the object
8389
smote = SMOTEENN(random_state=RND_SEED)
8490
# Fit the data
8591
smote.fit(X, Y)
8692

87-
X_resampled, y_resampled = smote.fit_transform(X, Y)
93+
X_resampled, y_resampled = smote.fit_sample(X, Y)
8894

8995
currdir = os.path.dirname(os.path.abspath(__file__))
9096
X_gt = np.load(os.path.join(currdir, 'data', 'smote_enn_reg_x.npy'))
@@ -93,16 +99,16 @@ def test_transform_regular():
9399
assert_array_equal(y_resampled, y_gt)
94100

95101

96-
def test_transform_regular_half():
97-
"""Test transform function with regular SMOTE and a ratio of 0.5."""
102+
def test_sample_regular_half():
103+
"""Test sample function with regular SMOTE and a ratio of 0.5."""
98104

99105
# Create the object
100106
ratio = 0.5
101107
smote = SMOTEENN(ratio=ratio, random_state=RND_SEED)
102108
# Fit the data
103109
smote.fit(X, Y)
104110

105-
X_resampled, y_resampled = smote.fit_transform(X, Y)
111+
X_resampled, y_resampled = smote.fit_sample(X, Y)
106112

107113
currdir = os.path.dirname(os.path.abspath(__file__))
108114
X_gt = np.load(os.path.join(currdir, 'data', 'smote_enn_reg_x_05.npy'))

unbalanced_dataset/combine/tests/test_smote_tomek.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
from numpy.testing import assert_raises
88
from numpy.testing import assert_equal
99
from numpy.testing import assert_array_equal
10+
from numpy.testing import assert_warns
1011

1112
from sklearn.datasets import make_classification
13+
from sklearn.utils.estimator_checks import check_estimator
1214

1315
from unbalanced_dataset.combine import SMOTETomek
1416

@@ -20,6 +22,10 @@
2022
n_samples=5000, random_state=RND_SEED)
2123

2224

25+
def test_smote_sk_estimator():
26+
"""Test the sklearn estimator compatibility"""
27+
check_estimator(SMOTETomek)
28+
2329
def test_smote_bad_ratio():
2430
"""Test either if an error is raised with a wrong decimal value for
2531
the ratio"""
@@ -49,7 +55,7 @@ def test_smote_fit_single_class():
4955
# Resample the data
5056
# Create a wrong y
5157
y_single_class = np.zeros((X.shape[0], ))
52-
assert_raises(RuntimeError, smote.fit, X, y_single_class)
58+
assert_warns(RuntimeWarning, smote.fit, X, y_single_class)
5359

5460

5561
def test_smote_fit():
@@ -67,24 +73,24 @@ def test_smote_fit():
6773
assert_equal(smote.stats_c_[1], 4500)
6874

6975

70-
def test_smote_transform_wt_fit():
71-
"""Test either if an error is raised when transform is called before
76+
def test_smote_sample_wt_fit():
77+
"""Test either if an error is raised when sample is called before
7278
fitting"""
7379

7480
# Create the object
7581
smote = SMOTETomek(random_state=RND_SEED)
76-
assert_raises(RuntimeError, smote.transform, X, Y)
82+
assert_raises(RuntimeError, smote.sample, X, Y)
7783

7884

79-
def test_transform_regular():
80-
"""Test transform function with regular SMOTE."""
85+
def test_sample_regular():
86+
"""Test sample function with regular SMOTE."""
8187

8288
# Create the object
8389
smote = SMOTETomek(random_state=RND_SEED)
8490
# Fit the data
8591
smote.fit(X, Y)
8692

87-
X_resampled, y_resampled = smote.fit_transform(X, Y)
93+
X_resampled, y_resampled = smote.fit_sample(X, Y)
8894

8995
currdir = os.path.dirname(os.path.abspath(__file__))
9096
X_gt = np.load(os.path.join(currdir, 'data', 'smote_tomek_reg_x.npy'))
@@ -93,16 +99,16 @@ def test_transform_regular():
9399
assert_array_equal(y_resampled, y_gt)
94100

95101

96-
def test_transform_regular_half():
97-
"""Test transform function with regular SMOTE and a ratio of 0.5."""
102+
def test_sample_regular_half():
103+
"""Test sample function with regular SMOTE and a ratio of 0.5."""
98104

99105
# Create the object
100106
ratio = 0.5
101107
smote = SMOTETomek(ratio=ratio, random_state=RND_SEED)
102108
# Fit the data
103109
smote.fit(X, Y)
104110

105-
X_resampled, y_resampled = smote.fit_transform(X, Y)
111+
X_resampled, y_resampled = smote.fit_sample(X, Y)
106112

107113
currdir = os.path.dirname(os.path.abspath(__file__))
108114
X_gt = np.load(os.path.join(currdir, 'data', 'smote_tomek_reg_x_05.npy'))

0 commit comments

Comments
 (0)