Skip to content

Commit b468f7f

Browse files
authored
FIX remove spurious warning raised when over-sampling the minority class (#1007)
1 parent f14033b commit b468f7f

File tree

3 files changed

+6
-17
lines changed

3 files changed

+6
-17
lines changed

doc/whats_new/v0.11.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ Bug fixes
1717
since it requires a conversion to dense matrices.
1818
:pr:`1003` by :user:`Guillaume Lemaitre <glemaitre>`.
1919

20+
- Remove spurious warning raised when minority class get over-sampled more than the
21+
number of sample in the majority class.
22+
:pr:`1007` by :user:`Guillaume Lemaitre <glemaitre>`.
23+
2024
Compatibility
2125
.............
2226

imblearn/utils/_validation.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,8 @@ def _sampling_strategy_dict(sampling_strategy, y, sampling_type):
307307
)
308308
sampling_strategy_ = {}
309309
if sampling_type == "over-sampling":
310-
n_samples_majority = max(target_stats.values())
311-
class_majority = max(target_stats, key=target_stats.get)
310+
max(target_stats.values())
311+
max(target_stats, key=target_stats.get)
312312
for class_sample, n_samples in sampling_strategy.items():
313313
if n_samples < target_stats[class_sample]:
314314
raise ValueError(
@@ -318,13 +318,6 @@ def _sampling_strategy_dict(sampling_strategy, y, sampling_type):
318318
f" Originally, there is {target_stats[class_sample]} "
319319
f"samples and {n_samples} samples are asked."
320320
)
321-
if n_samples > n_samples_majority:
322-
warnings.warn(
323-
f"After over-sampling, the number of samples ({n_samples})"
324-
f" in class {class_sample} will be larger than the number of"
325-
f" samples in the majority class (class #{class_majority} ->"
326-
f" {n_samples_majority})"
327-
)
328321
sampling_strategy_[class_sample] = n_samples - target_stats[class_sample]
329322
elif sampling_type == "under-sampling":
330323
for class_sample, n_samples in sampling_strategy.items():

imblearn/utils/tests/test_validation.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,6 @@ def test_check_sampling_strategy(
256256
assert sampling_strategy_ == expected_sampling_strategy
257257

258258

259-
def test_sampling_strategy_dict_over_sampling():
260-
y = np.array([1] * 50 + [2] * 100 + [3] * 25)
261-
sampling_strategy = {1: 70, 2: 140, 3: 70}
262-
expected_msg = "After over-sampling, the number of samples "
263-
with pytest.warns(UserWarning, match=expected_msg):
264-
check_sampling_strategy(sampling_strategy, y, "over-sampling")
265-
266-
267259
def test_sampling_strategy_callable_args():
268260
y = np.array([1] * 50 + [2] * 100 + [3] * 25)
269261
multiplier = {1: 1.5, 2: 1, 3: 3}

0 commit comments

Comments
 (0)