Skip to content

Commit f8c27ae

Browse files
authored
MAINT be more inclusive regarding dict (#958)
1 parent 628f4a4 commit f8c27ae

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

imblearn/over_sampling/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# License: MIT
77

88
import numbers
9+
from collections.abc import Mapping
910

1011
from ..base import BaseSampler
1112
from ..utils._param_validation import Interval, StrOptions
@@ -61,7 +62,7 @@ class BaseOverSampler(BaseSampler):
6162
"sampling_strategy": [
6263
Interval(numbers.Real, 0, 1, closed="right"),
6364
StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
64-
dict,
65+
Mapping,
6566
callable,
6667
],
6768
"random_state": ["random_state"],

imblearn/tests/test_common.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
# Christos Aridas
44
# License: MIT
55

6+
from collections import OrderedDict
7+
8+
import numpy as np
69
import pytest
710
from sklearn.base import clone
811
from sklearn.exceptions import ConvergenceWarning
@@ -12,7 +15,8 @@
1215
parametrize_with_checks as parametrize_with_checks_sklearn,
1316
)
1417

15-
from imblearn.under_sampling import NearMiss
18+
from imblearn.over_sampling import RandomOverSampler
19+
from imblearn.under_sampling import NearMiss, RandomUnderSampler
1620
from imblearn.utils.estimator_checks import (
1721
_set_checking_parameters,
1822
check_param_validation,
@@ -73,3 +77,19 @@ def test_check_param_validation(estimator):
7377
print(name)
7478
_set_checking_parameters(estimator)
7579
check_param_validation(name, estimator)
80+
81+
82+
@pytest.mark.parametrize("Sampler", [RandomOverSampler, RandomUnderSampler])
83+
def test_strategy_as_ordered_dict(Sampler):
84+
"""Check that it is possible to pass an `OrderedDict` as strategy."""
85+
rng = np.random.RandomState(42)
86+
X, y = rng.randn(30, 2), np.array([0] * 10 + [1] * 20)
87+
sampler = Sampler(random_state=42)
88+
if isinstance(sampler, RandomOverSampler):
89+
strategy = OrderedDict({0: 20, 1: 20})
90+
else:
91+
strategy = OrderedDict({0: 10, 1: 10})
92+
sampler.set_params(sampling_strategy=strategy)
93+
X_res, y_res = sampler.fit_resample(X, y)
94+
assert X_res.shape[0] == sum(strategy.values())
95+
assert y_res.shape[0] == sum(strategy.values())

imblearn/under_sampling/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# License: MIT
66

77
import numbers
8+
from collections.abc import Mapping
89

910
from ..base import BaseSampler
1011
from ..utils._param_validation import Interval, StrOptions
@@ -61,7 +62,7 @@ class BaseUnderSampler(BaseSampler):
6162
"sampling_strategy": [
6263
Interval(numbers.Real, 0, 1, closed="right"),
6364
StrOptions({"auto", "majority", "not minority", "not majority", "all"}),
64-
dict,
65+
Mapping,
6566
callable,
6667
],
6768
}

0 commit comments

Comments
 (0)