Skip to content

Commit 326119b

Browse files
authored
TST add common test for string and nan (#804)
1 parent d9ba4af commit 326119b

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

imblearn/utils/estimator_checks.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from sklearn.utils.estimator_checks import _maybe_mark_xfail
2929
from sklearn.utils.estimator_checks import _get_check_estimator_ids
3030
from sklearn.utils._testing import assert_allclose
31+
from sklearn.utils._testing import assert_array_equal
3132
from sklearn.utils._testing import assert_raises_regex
3233
from sklearn.utils.multiclass import type_of_target
3334

@@ -61,6 +62,10 @@ def _yield_sampler_checks(sampler):
6162
yield check_samplers_sparse
6263
if "dataframe" in tags["X_types"]:
6364
yield check_samplers_pandas
65+
if "string" in tags["X_types"]:
66+
yield check_samplers_string
67+
if tags["allow_nan"]:
68+
yield check_samplers_nan
6469
yield check_samplers_list
6570
yield check_samplers_multiclass_ova
6671
yield check_samplers_preserve_dtype
@@ -399,6 +404,36 @@ def check_samplers_sample_indices(name, sampler_orig):
399404
assert not hasattr(sampler, "sample_indices_")
400405

401406

407+
def check_samplers_string(name, sampler_orig):
408+
rng = np.random.RandomState(0)
409+
sampler = clone(sampler_orig)
410+
categories = np.array(["A", "B", "C"], dtype=object)
411+
n_samples = 30
412+
X = rng.randint(low=0, high=3, size=n_samples).reshape(-1, 1)
413+
X = categories[X]
414+
y = rng.permutation([0] * 10 + [1] * 20)
415+
416+
X_res, y_res = sampler.fit_resample(X, y)
417+
assert X_res.dtype == object
418+
assert X_res.shape[0] == y_res.shape[0]
419+
assert_array_equal(np.unique(X_res.ravel()), categories)
420+
421+
422+
def check_samplers_nan(name, sampler_orig):
423+
rng = np.random.RandomState(0)
424+
sampler = clone(sampler_orig)
425+
categories = np.array([0, 1, np.nan], dtype=np.float64)
426+
n_samples = 100
427+
X = rng.randint(low=0, high=3, size=n_samples).reshape(-1, 1)
428+
X = categories[X]
429+
y = rng.permutation([0] * 40 + [1] * 60)
430+
431+
X_res, y_res = sampler.fit_resample(X, y)
432+
assert X_res.dtype == np.float64
433+
assert X_res.shape[0] == y_res.shape[0]
434+
assert np.any(np.isnan(X_res.ravel()))
435+
436+
402437
def check_classifier_on_multilabel_or_multioutput_targets(name, estimator_orig):
403438
estimator = clone(estimator_orig)
404439
X, y = make_multilabel_classification(n_samples=30)

imblearn/utils/tests/test_estimator_checks.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
import numpy as np
33

44
from sklearn.base import BaseEstimator
5-
from sklearn.utils import check_X_y
65
from sklearn.utils.multiclass import check_classification_targets
76

87
from imblearn.base import BaseSampler
9-
8+
from imblearn.over_sampling.base import BaseOverSampler
9+
from imblearn.utils import check_target_type as target_check
1010
from imblearn.utils.estimator_checks import check_target_type
1111
from imblearn.utils.estimator_checks import check_samplers_one_label
1212
from imblearn.utils.estimator_checks import check_samplers_fit
1313
from imblearn.utils.estimator_checks import check_samplers_sparse
1414
from imblearn.utils.estimator_checks import check_samplers_preserve_dtype
15+
from imblearn.utils.estimator_checks import check_samplers_string
16+
from imblearn.utils.estimator_checks import check_samplers_nan
1517

1618

1719
class BaseBadSampler(BaseEstimator):
@@ -64,6 +66,34 @@ def _fit_resample(self, X, y):
6466
return X.astype(np.float64), y.astype(np.int64)
6567

6668

69+
class IndicesSampler(BaseOverSampler):
70+
def _check_X_y(self, X, y):
71+
y, binarize_y = target_check(y, indicate_one_vs_all=True)
72+
X, y = self._validate_data(
73+
X,
74+
y,
75+
reset=True,
76+
dtype=None,
77+
force_all_finite=False,
78+
)
79+
return X, y, binarize_y
80+
81+
def _fit_resample(self, X, y):
82+
n_max_count_class = np.bincount(y).max()
83+
indices = np.random.choice(np.arange(X.shape[0]), size=n_max_count_class * 2)
84+
return X[indices], y[indices]
85+
86+
87+
def test_check_samplers_string():
88+
sampler = IndicesSampler()
89+
check_samplers_string(sampler.__class__.__name__, sampler)
90+
91+
92+
def test_check_samplers_nan():
93+
sampler = IndicesSampler()
94+
check_samplers_nan(sampler.__class__.__name__, sampler)
95+
96+
6797
mapping_estimator_error = {
6898
"BaseBadSampler": (AssertionError, "ValueError not raised by fit"),
6999
"SamplerSingleClass": (AssertionError, "Sampler can't balance when only"),

0 commit comments

Comments
 (0)