|
35 | 35 | set_random_state, |
36 | 36 | ) |
37 | 37 | from sklearn.utils.estimator_checks import ( |
| 38 | + _enforce_estimator_tags_X, |
38 | 39 | _enforce_estimator_tags_y, |
39 | 40 | _get_check_estimator_ids, |
40 | 41 | _maybe_mark_xfail, |
41 | 42 | ) |
42 | | - |
43 | | -try: |
44 | | - from sklearn.utils.estimator_checks import _enforce_estimator_tags_x |
45 | | -except ImportError: |
46 | | - # scikit-learn >= 1.2 |
47 | | - from sklearn.utils.estimator_checks import ( |
48 | | - _enforce_estimator_tags_X as _enforce_estimator_tags_x, |
49 | | - ) |
50 | | - |
51 | 43 | from sklearn.utils.fixes import parse_version |
52 | 44 | from sklearn.utils.multiclass import type_of_target |
53 | 45 |
|
@@ -602,7 +594,7 @@ def check_dataframe_column_names_consistency(name, estimator_orig): |
602 | 594 |
|
603 | 595 | X_orig = rng.normal(size=(150, 8)) |
604 | 596 |
|
605 | | - X_orig = _enforce_estimator_tags_x(estimator, X_orig) |
| 597 | + X_orig = _enforce_estimator_tags_X(estimator, X_orig) |
606 | 598 | n_samples, n_features = X_orig.shape |
607 | 599 |
|
608 | 600 | names = np.array([f"col_{i}" for i in range(n_features)]) |
@@ -756,7 +748,7 @@ def check_sampler_get_feature_names_out(name, sampler_orig): |
756 | 748 | X = StandardScaler().fit_transform(X) |
757 | 749 |
|
758 | 750 | sampler = clone(sampler_orig) |
759 | | - X = _enforce_estimator_tags_x(sampler, X) |
| 751 | + X = _enforce_estimator_tags_X(sampler, X) |
760 | 752 |
|
761 | 753 | n_features = X.shape[1] |
762 | 754 | set_random_state(sampler) |
@@ -804,7 +796,7 @@ def check_sampler_get_feature_names_out_pandas(name, sampler_orig): |
804 | 796 | X = StandardScaler().fit_transform(X) |
805 | 797 |
|
806 | 798 | sampler = clone(sampler_orig) |
807 | | - X = _enforce_estimator_tags_x(sampler, X) |
| 799 | + X = _enforce_estimator_tags_X(sampler, X) |
808 | 800 |
|
809 | 801 | n_features = X.shape[1] |
810 | 802 | set_random_state(sampler) |
|
0 commit comments