Skip to content

Commit c1514dc

Browse files
committed
fix _more_tags
1 parent fa206e4 commit c1514dc

File tree

18 files changed

+103
-18
lines changed

18 files changed

+103
-18
lines changed

imblearn/base.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,12 @@
1212
from sklearn.preprocessing import label_binarize
1313
from sklearn.utils.metaestimators import available_if
1414
from sklearn.utils.multiclass import check_classification_targets
15-
from sklearn.utils.fixes import parse_version
1615

1716
from .utils import check_sampling_strategy, check_target_type
18-
from .utils.fixes import validate_data
17+
from .utils.fixes import check_version_package, validate_data
1918
from .utils._param_validation import validate_parameter_constraints
2019
from .utils._validation import ArraysTransformer
2120

22-
23-
def check_version(estimator):
24-
return parse_version(
25-
parse_version(sklearn.__version__).base_version
26-
) < parse_version("1.6")
27-
28-
2921
class _ParamsValidationMixin:
3022
"""Mixin class to validate parameters."""
3123

@@ -206,10 +198,11 @@ def fit_resample(self, X, y):
206198
self._validate_params()
207199
return super().fit_resample(X, y)
208200

209-
@available_if(check_version)
201+
@available_if(check_version_package("sklearn", "<", "1.6"))
210202
def _more_tags(self):
211203
return {"X_types": ["2darray", "sparse", "dataframe"]}
212204

205+
@available_if(check_version_package("sklearn", ">=", "1.6"))
213206
def __sklearn_tags__(self):
214207
tags = super().__sklearn_tags__()
215208

imblearn/ensemble/_bagging.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..utils import Substitution, check_sampling_strategy, check_target_type
2727
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
2828
from ..utils._param_validation import HasMethods, Interval, StrOptions
29-
from ..utils.fixes import _fit_context, validate_data
29+
from ..utils.fixes import _fit_context, check_version_package, validate_data
3030
from ._common import _bagging_parameter_constraints, _estimator_has
3131

3232
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)
@@ -420,6 +420,7 @@ def base_estimator_(self):
420420
)
421421
raise error
422422

423+
@available_if(check_version_package("sklearn", "<", "1.6"))
423424
def _more_tags(self):
424425
tags = super()._more_tags()
425426
tags_key = "_xfail_checks"

imblearn/ensemble/_easy_ensemble.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..utils import Substitution, check_sampling_strategy, check_target_type
2727
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
2828
from ..utils._param_validation import Interval, StrOptions
29-
from ..utils.fixes import _fit_context, get_tags, validate_data
29+
from ..utils.fixes import _fit_context, check_version_package, get_tags, validate_data
3030
from ._common import _bagging_parameter_constraints, _estimator_has
3131

3232
MAX_INT = np.iinfo(np.int32).max
@@ -354,6 +354,7 @@ def _get_estimator(self):
354354
return self.estimator
355355

356356
# TODO: remove when minimum supported version of scikit-learn is 1.5
357+
@available_if(check_version_package("sklearn", "<", "1.6"))
357358
def _more_tags(self):
358359
# This code should not be called for scikit-learn >= 1.6
359360
# Therefore, get_tags corresponds to _safe_tags that returns a dict

imblearn/ensemble/_forest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from sklearn.tree import DecisionTreeClassifier
2525
from sklearn.utils import _safe_indexing, check_random_state
2626
from sklearn.utils.fixes import parse_version
27+
from sklearn.utils.metaestimators import available_if
2728
from sklearn.utils.multiclass import type_of_target
2829
from sklearn.utils.parallel import Parallel, delayed
2930
from sklearn.utils.validation import _check_sample_weight
@@ -35,7 +36,7 @@
3536
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
3637
from ..utils._param_validation import Hidden, Interval, StrOptions
3738
from ..utils._validation import check_sampling_strategy
38-
from ..utils.fixes import _fit_context, validate_data
39+
from ..utils.fixes import _fit_context, check_version_package, validate_data
3940
from ._common import _random_forest_classifier_parameter_constraints
4041

4142
MAX_INT = np.iinfo(np.int32).max
@@ -884,5 +885,6 @@ def _compute_oob_predictions(self, X, y):
884885

885886
return oob_pred
886887

888+
@available_if(check_version_package("sklearn", "<", "1.6"))
887889
def _more_tags(self):
888890
return {"multioutput": False, "multilabel": False}

imblearn/metrics/pairwise.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,13 @@
99
from scipy.spatial import distance_matrix
1010
from sklearn.base import BaseEstimator
1111
from sklearn.utils import check_consistent_length
12+
from sklearn.utils.metaestimators import available_if
1213
from sklearn.utils.multiclass import unique_labels
1314
from sklearn.utils.validation import check_array, check_is_fitted
1415

1516
from ..base import _ParamsValidationMixin
1617
from ..utils._param_validation import StrOptions
17-
from ..utils.fixes import validate_data
18+
from ..utils.fixes import check_version_package, validate_data
1819

1920

2021
class ValueDifferenceMetric(_ParamsValidationMixin, BaseEstimator):
@@ -229,6 +230,7 @@ def pairwise(self, X, Y=None):
229230
)
230231
return distance
231232

233+
@available_if(check_version_package("sklearn", "<", "1.6"))
232234
def _more_tags(self):
233235
return {
234236
"requires_positive_X": True, # X should be encoded with OrdinalEncoder

imblearn/over_sampling/_adasyn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
import numpy as np
1111
from scipy import sparse
1212
from sklearn.utils import _safe_indexing, check_random_state
13+
from sklearn.utils.metaestimators import available_if
1314

1415
from ..utils import Substitution, check_neighbors_object
16+
from ..utils.fixes import check_version_package
1517
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
1618
from ..utils._param_validation import HasMethods, Interval
1719
from .base import BaseOverSampler
@@ -229,6 +231,7 @@ def _fit_resample(self, X, y):
229231

230232
return X_resampled, y_resampled
231233

234+
@available_if(check_version_package("sklearn", "<", "1.6"))
232235
def _more_tags(self):
233236
return {
234237
"X_types": ["2darray"],

imblearn/over_sampling/_random_over_sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
import numpy as np
1111
from scipy import sparse
1212
from sklearn.utils import _safe_indexing, check_array, check_random_state
13+
from sklearn.utils.metaestimators import available_if
1314
from sklearn.utils.sparsefuncs import mean_variance_axis
1415

1516
from ..utils import Substitution, check_target_type
1617
from ..utils._docstring import _random_state_docstring
1718
from ..utils._param_validation import Interval
18-
from ..utils.fixes import _check_n_features, _check_feature_names
19+
from ..utils.fixes import _check_n_features, _check_feature_names, check_version_package
1920
from ..utils._validation import _check_X
2021
from .base import BaseOverSampler
2122

@@ -250,6 +251,7 @@ def _fit_resample(self, X, y):
250251

251252
return X_resampled, y_resampled
252253

254+
@available_if(check_version_package("sklearn", "<", "1.6"))
253255
def _more_tags(self):
254256
return {
255257
"X_types": ["2darray", "string", "sparse", "dataframe"],

imblearn/over_sampling/_smote/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
check_array,
2222
check_random_state,
2323
)
24+
from sklearn.utils.metaestimators import available_if
2425
from sklearn.utils.fixes import parse_version
2526
from sklearn.utils.sparsefuncs_fast import (
2627
csr_mean_variance_axis0,
@@ -32,7 +33,14 @@
3233
from ...utils._docstring import _n_jobs_docstring, _random_state_docstring
3334
from ...utils._param_validation import HasMethods, Interval, StrOptions
3435
from ...utils._validation import _check_X
35-
from ...utils.fixes import _check_n_features, _check_feature_names, _is_pandas_df, _mode, validate_data
36+
from ...utils.fixes import (
37+
_check_n_features,
38+
_check_feature_names,
39+
_is_pandas_df,
40+
_mode,
41+
check_version_package,
42+
validate_data,
43+
)
3644
from ..base import BaseOverSampler
3745

3846
sklearn_version = parse_version(sklearn.__version__).base_version
@@ -1062,5 +1070,6 @@ def _fit_resample(self, X, y):
10621070
else:
10631071
return X_resampled, y_resampled
10641072

1073+
@available_if(check_version_package("sklearn", "<", "1.6"))
10651074
def _more_tags(self):
10661075
return {"X_types": ["2darray", "dataframe", "string"]}

imblearn/under_sampling/_prototype_generation/_cluster_centroids.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
from sklearn.cluster import KMeans
1313
from sklearn.neighbors import NearestNeighbors
1414
from sklearn.utils import _safe_indexing
15+
from sklearn.utils.metaestimators import available_if
1516

1617
from ...utils import Substitution
1718
from ...utils._docstring import _random_state_docstring
1819
from ...utils._param_validation import HasMethods, StrOptions
20+
from ...utils.fixes import check_version_package
1921
from ..base import BaseUnderSampler
2022

2123
VOTING_KIND = ("auto", "hard", "soft")
@@ -201,5 +203,6 @@ def _fit_resample(self, X, y):
201203

202204
return X_resampled, np.array(y_resampled, dtype=y.dtype)
203205

206+
@available_if(check_version_package("sklearn", "<", "1.6"))
204207
def _more_tags(self):
205208
return {"sample_indices": False}

imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414
from sklearn.base import clone
1515
from sklearn.neighbors import KNeighborsClassifier
1616
from sklearn.utils import _safe_indexing, check_random_state
17+
from sklearn.utils.metaestimators import available_if
1718

1819
from ...utils import Substitution
1920
from ...utils._docstring import _n_jobs_docstring, _random_state_docstring
21+
from ...utils.fixes import check_version_package
2022
from ...utils._param_validation import HasMethods, Interval
2123
from ..base import BaseCleaningSampler
2224

@@ -259,5 +261,6 @@ def estimator_(self):
259261
)
260262
return self.estimators_[-1]
261263

264+
@available_if(check_version_package("sklearn", "<", "1.6"))
262265
def _more_tags(self):
263266
return {"sample_indices": True}

0 commit comments

Comments
 (0)