|
28 | 28 | from sklearn.utils.estimator_checks import _maybe_mark_xfail
|
29 | 29 | from sklearn.utils.estimator_checks import _get_check_estimator_ids
|
30 | 30 | from sklearn.utils._testing import assert_allclose
|
| 31 | +from sklearn.utils._testing import assert_array_equal |
31 | 32 | from sklearn.utils._testing import assert_raises_regex
|
32 | 33 | from sklearn.utils.multiclass import type_of_target
|
33 | 34 |
|
@@ -61,6 +62,10 @@ def _yield_sampler_checks(sampler):
|
61 | 62 | yield check_samplers_sparse
|
62 | 63 | if "dataframe" in tags["X_types"]:
|
63 | 64 | yield check_samplers_pandas
|
| 65 | + if "string" in tags["X_types"]: |
| 66 | + yield check_samplers_string |
| 67 | + if tags["allow_nan"]: |
| 68 | + yield check_samplers_nan |
64 | 69 | yield check_samplers_list
|
65 | 70 | yield check_samplers_multiclass_ova
|
66 | 71 | yield check_samplers_preserve_dtype
|
@@ -399,6 +404,36 @@ def check_samplers_sample_indices(name, sampler_orig):
|
399 | 404 | assert not hasattr(sampler, "sample_indices_")
|
400 | 405 |
|
401 | 406 |
|
| 407 | +def check_samplers_string(name, sampler_orig): |
| 408 | + rng = np.random.RandomState(0) |
| 409 | + sampler = clone(sampler_orig) |
| 410 | + categories = np.array(["A", "B", "C"], dtype=object) |
| 411 | + n_samples = 30 |
| 412 | + X = rng.randint(low=0, high=3, size=n_samples).reshape(-1, 1) |
| 413 | + X = categories[X] |
| 414 | + y = rng.permutation([0] * 10 + [1] * 20) |
| 415 | + |
| 416 | + X_res, y_res = sampler.fit_resample(X, y) |
| 417 | + assert X_res.dtype == object |
| 418 | + assert X_res.shape[0] == y_res.shape[0] |
| 419 | + assert_array_equal(np.unique(X_res.ravel()), categories) |
| 420 | + |
| 421 | + |
| 422 | +def check_samplers_nan(name, sampler_orig): |
| 423 | + rng = np.random.RandomState(0) |
| 424 | + sampler = clone(sampler_orig) |
| 425 | + categories = np.array([0, 1, np.nan], dtype=np.float64) |
| 426 | + n_samples = 100 |
| 427 | + X = rng.randint(low=0, high=3, size=n_samples).reshape(-1, 1) |
| 428 | + X = categories[X] |
| 429 | + y = rng.permutation([0] * 40 + [1] * 60) |
| 430 | + |
| 431 | + X_res, y_res = sampler.fit_resample(X, y) |
| 432 | + assert X_res.dtype == np.float64 |
| 433 | + assert X_res.shape[0] == y_res.shape[0] |
| 434 | + assert np.any(np.isnan(X_res.ravel())) |
| 435 | + |
| 436 | + |
402 | 437 | def check_classifier_on_multilabel_or_multioutput_targets(name, estimator_orig):
|
403 | 438 | estimator = clone(estimator_orig)
|
404 | 439 | X, y = make_multilabel_classification(n_samples=30)
|
|
0 commit comments