|
12 | 12 | from sklearn.preprocessing import label_binarize |
13 | 13 | from sklearn.utils.metaestimators import available_if |
14 | 14 | from sklearn.utils.multiclass import check_classification_targets |
15 | | -from sklearn.utils.fixes import parse_version |
16 | 15 |
|
17 | 16 | from .utils import check_sampling_strategy, check_target_type |
18 | | -from .utils.fixes import validate_data |
| 17 | +from .utils.fixes import check_version_package, validate_data |
19 | 18 | from .utils._param_validation import validate_parameter_constraints |
20 | 19 | from .utils._validation import ArraysTransformer |
21 | 20 |
|
22 | | - |
23 | | -def check_version(estimator): |
24 | | - return parse_version( |
25 | | - parse_version(sklearn.__version__).base_version |
26 | | - ) < parse_version("1.6") |
27 | | - |
28 | | - |
29 | 21 | class _ParamsValidationMixin: |
30 | 22 | """Mixin class to validate parameters.""" |
31 | 23 |
|
@@ -206,10 +198,11 @@ def fit_resample(self, X, y): |
206 | 198 | self._validate_params() |
207 | 199 | return super().fit_resample(X, y) |
208 | 200 |
|
209 | | - @available_if(check_version) |
| 201 | + @available_if(check_version_package("sklearn", "<", "1.6")) |
210 | 202 | def _more_tags(self): |
211 | 203 | return {"X_types": ["2darray", "sparse", "dataframe"]} |
212 | 204 |
|
| 205 | + @available_if(check_version_package("sklearn", ">=", "1.6")) |
213 | 206 | def __sklearn_tags__(self): |
214 | 207 | tags = super().__sklearn_tags__() |
215 | 208 |
|
|
0 commit comments