Skip to content

Commit 76a9148

Browse files
author
Guillaume Lemaitre
committed
Enforce to get same data at fitting and sampling
1 parent 52ea6fd commit 76a9148

34 files changed

+268
-2
lines changed

unbalanced_dataset/base.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(self, ratio='auto', random_state=None, verbose=True):
7373
self.min_c_ = None
7474
self.maj_c_ = None
7575
self.stats_c_ = {}
76+
self.X_shape_ = None
7677

7778
@abstractmethod
7879
def fit(self, X, y):
@@ -110,6 +111,10 @@ def fit(self, X, y):
110111
warnings.warn('Only one class detected, something will get wrong',
111112
RuntimeWarning)
112113

114+
# Store the size of X to check at sampling time if we have the
115+
# same data
116+
self.X_shape_ = X.shape
117+
113118
# Create a dictionary containing the class statistics
114119
self.stats_c_ = Counter(y)
115120

@@ -157,6 +162,12 @@ def sample(self, X, y):
157162
if not self.stats_c_:
158163
raise RuntimeError('You need to fit the data, first!!!')
159164

165+
# Check if the size of the data is identical than at fitting
166+
if X.shape != self.X_shape_:
167+
raise RuntimeError('The data that you attempt to resample do not'
168+
' seem to be the one earlier fitted. Use the'
169+
' fitted data.')
170+
160171
return self
161172

162173
def fit_sample(self, X, y):

unbalanced_dataset/combine/smote_enn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ class SMOTEENN(SamplerMixin):
7979
A dictionary in which the number of occurences of each class is
8080
reported.
8181
82+
X_shape_ : tuple of int
83+
Shape of the data `X` during fitting.
84+
8285
Notes
8386
-----
8487
The method is presented in [1]_.

unbalanced_dataset/combine/smote_tomek.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ class SMOTETomek(SamplerMixin):
8080
A dictionary in which the number of occurences of each class is
8181
reported.
8282
83+
X_shape_ : tuple of int
84+
Shape of the data `X` during fitting.
85+
8386
Notes
8487
-----
8588
The methos is presented in [1]_.

unbalanced_dataset/combine/tests/test_smote_enn.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,14 @@ def test_sample_regular_half():
116116
y_gt = np.load(os.path.join(currdir, 'data', 'smote_enn_reg_y_05.npy'))
117117
assert_array_equal(X_resampled, X_gt)
118118
assert_array_equal(y_resampled, y_gt)
119+
120+
121+
def test_sample_wrong_X():
122+
"""Test either if an error is raised when X is different at fitting
123+
and sampling"""
124+
125+
# Create the object
126+
sm = SMOTEENN(random_state=RND_SEED)
127+
sm.fit(X, Y)
128+
assert_raises(RuntimeError, sm.sample, np.random.random((100, 40)),
129+
np.array([0] * 50 + [1] * 50))

unbalanced_dataset/combine/tests/test_smote_tomek.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,14 @@ def test_sample_regular_half():
116116
y_gt = np.load(os.path.join(currdir, 'data', 'smote_tomek_reg_y_05.npy'))
117117
assert_array_equal(X_resampled, X_gt)
118118
assert_array_equal(y_resampled, y_gt)
119+
120+
121+
def test_sample_wrong_X():
122+
"""Test either if an error is raised when X is different at fitting
123+
and sampling"""
124+
125+
# Create the object
126+
sm = SMOTETomek(random_state=RND_SEED)
127+
sm.fit(X, Y)
128+
assert_raises(RuntimeError, sm.sample, np.random.random((100, 40)),
129+
np.array([0] * 50 + [1] * 50))

unbalanced_dataset/ensemble/balance_cascade.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ class BalanceCascade(EnsembleSampler):
7171
A dictionary in which the number of occurences of each class is
7272
reported.
7373
74+
X_shape_ : tuple of int
75+
Shape of the data `X` during fitting.
76+
7477
Notes
7578
-----
7679
The method is described in [1]_.

unbalanced_dataset/ensemble/easy_ensemble.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@ class EasyEnsemble(EnsembleSampler):
6060
A dictionary in which the number of occurences of each class is
6161
reported.
6262
63+
X_shape_ : tuple of int
64+
Shape of the data `X` during fitting.
65+
6366
Notes
6467
-----
6568
The method is described in [1]_.

unbalanced_dataset/ensemble/tests/test_balance_cascade.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,3 +331,14 @@ def test_fit_sample_auto_early_stop():
331331
assert_array_equal(X_resampled[idx], X_gt[idx])
332332
assert_array_equal(y_resampled[idx], y_gt[idx])
333333
assert_array_equal(idx_under[idx], idx_gt[idx])
334+
335+
336+
def test_sample_wrong_X():
337+
"""Test either if an error is raised when X is different at fitting
338+
and sampling"""
339+
340+
# Create the object
341+
bc = BalanceCascade(random_state=RND_SEED)
342+
bc.fit(X, Y)
343+
assert_raises(RuntimeError, bc.sample, np.random.random((100, 40)),
344+
np.array([0] * 50 + [1] * 50))

unbalanced_dataset/ensemble/tests/test_easy_ensemble.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,14 @@ def test_fit_sample_half():
160160
y_gt = np.load(os.path.join(currdir, 'data', 'ee_y_05.npy'))
161161
assert_array_equal(X_resampled, X_gt)
162162
assert_array_equal(y_resampled, y_gt)
163+
164+
165+
def test_sample_wrong_X():
166+
"""Test either if an error is raised when X is different at fitting
167+
and sampling"""
168+
169+
# Create the object
170+
ee = EasyEnsemble(random_state=RND_SEED)
171+
ee.fit(X, Y)
172+
assert_raises(RuntimeError, ee.sample, np.random.random((100, 40)),
173+
np.array([0] * 50 + [1] * 50))

unbalanced_dataset/over_sampling/random_over_sampler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ class RandomOverSampler(OverSampler):
5252
A dictionary in which the number of occurences of each class is
5353
reported.
5454
55+
X_shape_ : tuple of int
56+
Shape of the data `X` during fitting.
57+
5558
Notes
5659
-----
5760
Supports multiple classes.

0 commit comments

Comments
 (0)