Skip to content

Commit e2115d6

Browse files
author
Guillaume Lemaitre
committed
Raise an error at fitting time if the ratio do not make sense.
1 parent bd25c8b commit e2115d6

11 files changed

+112
-3
lines changed

unbalanced_dataset/base_sampler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,14 @@ def fit(self, X, y):
115115
print('{} classes detected: {}'.format(uniques.size,
116116
self.stats_c_))
117117

118+
# Check if the ratio provided at initialisation make sense
119+
if isinstance(self.ratio_, float):
120+
if self.ratio_ < (self.stats_c_[self.min_c_] /
121+
self.stats_c_[self.maj_c_]):
122+
raise RuntimeError('The ratio requested at initialisation'
123+
' should be greater or equal than the'
124+
' balancing ratio of the current data.')
125+
118126
return self
119127

120128
@abstractmethod

unbalanced_dataset/ensemble/tests/test_balance_cascade.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,17 @@ def test_bc_fit_single_class():
7373
assert_raises(RuntimeError, bc.fit, X, y_single_class)
7474

7575

76+
def test_bc_fit_invalid_ratio():
77+
"""Test either if an error is raised when the balancing ratio to fit is
78+
smaller than the one of the data"""
79+
80+
# Create the object
81+
ratio = 1. / 10000.
82+
bc = BalanceCascade(ratio=ratio, random_state=RND_SEED)
83+
# Fit the data
84+
assert_raises(RuntimeError, bc.fit, X, Y)
85+
86+
7687
def test_bc_fit():
7788
"""Test the fitting method"""
7889

unbalanced_dataset/ensemble/tests/test_easy_ensemble.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,17 @@ def test_ee_fit_single_class():
7373
assert_raises(RuntimeError, ee.fit, X, y_single_class)
7474

7575

76+
def test_ee_fit_invalid_ratio():
77+
"""Test either if an error is raised when the balancing ratio to fit is
78+
smaller than the one of the data"""
79+
80+
# Create the object
81+
ratio = 1. / 10000.
82+
ee = EasyEnsemble(ratio=ratio, random_state=RND_SEED)
83+
# Fit the data
84+
assert_raises(RuntimeError, ee.fit, X, Y)
85+
86+
7687
def test_ee_fit():
7788
"""Test the fitting method"""
7889

unbalanced_dataset/over_sampling/random_over_sampler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,11 @@ def transform(self, X, y):
147147

148148
# Define the number of sample to create
149149
if self.ratio_ == 'auto':
150-
num_samples = self.stats_c_[self.maj_c_] - self.stats_c_[key]
150+
num_samples = int(self.stats_c_[self.maj_c_] -
151+
self.stats_c_[key])
151152
else:
152-
num_samples = ((self.ratio_ * self.stats_c_[self.maj_c_]) -
153-
self.stats_c_[key])
153+
num_samples = int((self.ratio_ * self.stats_c_[self.maj_c_]) -
154+
self.stats_c_[key])
154155

155156
# Pick some elements at random
156157
np.random.seed(self.rs_)

unbalanced_dataset/over_sampling/tests/test_random_over_sampler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,17 @@ def test_ros_fit_single_class():
6868
assert_raises(RuntimeError, ros.fit, X, y_single_class)
6969

7070

71+
def test_ros_fit_invalid_ratio():
72+
"""Test either if an error is raised when the balancing ratio to fit is
73+
smaller than the one of the data"""
74+
75+
# Create the object
76+
ratio = 1. / 10000.
77+
ros = RandomOverSampler(ratio=ratio, random_state=RND_SEED)
78+
# Fit the data
79+
assert_raises(RuntimeError, ros.fit, X, Y)
80+
81+
7182
def test_ros_fit():
7283
"""Test the fitting method"""
7384

unbalanced_dataset/under_sampling/tests/test_cluster_centroids.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,17 @@ def test_cc_fit_single_class():
7171
assert_raises(RuntimeError, cc.fit, X, y_single_class)
7272

7373

74+
def test_cc_fit_invalid_ratio():
75+
"""Test either if an error is raised when the balancing ratio to fit is
76+
smaller than the one of the data"""
77+
78+
# Create the object
79+
ratio = 1. / 10000.
80+
cc = ClusterCentroids(ratio=ratio, random_state=RND_SEED)
81+
# Fit the data
82+
assert_raises(RuntimeError, cc.fit, X, Y)
83+
84+
7485
def test_cc_fit():
7586
"""Test the fitting method"""
7687

unbalanced_dataset/under_sampling/tests/test_instance_hardness_threshold.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,18 @@ def test_iht_fit_single_class():
8787
assert_raises(RuntimeError, iht.fit, X, y_single_class)
8888

8989

90+
def test_iht_fit_invalid_ratio():
91+
"""Test either if an error is raised when the balancing ratio to fit is
92+
smaller than the one of the data"""
93+
94+
# Create the object
95+
ratio = 1. / 10000.
96+
iht = InstanceHardnessThreshold(ESTIMATOR, ratio=ratio,
97+
random_state=RND_SEED)
98+
# Fit the data
99+
assert_raises(RuntimeError, iht.fit, X, Y)
100+
101+
90102
def test_iht_fit():
91103
"""Test the fitting method"""
92104

unbalanced_dataset/under_sampling/tests/test_nearmiss_1.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,17 @@ def test_nearmiss_fit_single_class():
8383
assert_raises(RuntimeError, nm1.fit, X, y_single_class)
8484

8585

86+
def test_nm_fit_invalid_ratio():
87+
"""Test either if an error is raised when the balancing ratio to fit is
88+
smaller than the one of the data"""
89+
90+
# Create the object
91+
ratio = 1. / 10000.
92+
nm = NearMiss(ratio=ratio, random_state=RND_SEED)
93+
# Fit the data
94+
assert_raises(RuntimeError, nm.fit, X, Y)
95+
96+
8697
def test_nm1_fit():
8798
"""Test the fitting method"""
8899

unbalanced_dataset/under_sampling/tests/test_nearmiss_2.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,17 @@ def test_nearmiss_fit_single_class():
8383
assert_raises(RuntimeError, nm2.fit, X, y_single_class)
8484

8585

86+
def test_nm_fit_invalid_ratio():
87+
"""Test either if an error is raised when the balancing ratio to fit is
88+
smaller than the one of the data"""
89+
90+
# Create the object
91+
ratio = 1. / 10000.
92+
nm = NearMiss(ratio=ratio, random_state=RND_SEED)
93+
# Fit the data
94+
assert_raises(RuntimeError, nm.fit, X, Y)
95+
96+
8697
def test_nm2_fit():
8798
"""Test the fitting method"""
8899

unbalanced_dataset/under_sampling/tests/test_nearmiss_3.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,17 @@ def test_nearmiss_fit_single_class():
8383
assert_raises(RuntimeError, nm3.fit, X, y_single_class)
8484

8585

86+
def test_nm_fit_invalid_ratio():
87+
"""Test either if an error is raised when the balancing ratio to fit is
88+
smaller than the one of the data"""
89+
90+
# Create the object
91+
ratio = 1. / 10000.
92+
nm = NearMiss(ratio=ratio, random_state=RND_SEED)
93+
# Fit the data
94+
assert_raises(RuntimeError, nm.fit, X, Y)
95+
96+
8697
def test_nm3_fit():
8798
"""Test the fitting method"""
8899

0 commit comments

Comments
 (0)