Skip to content

Commit acb8234

Browse files
committed
iter
1 parent c1514dc commit acb8234

21 files changed

+257
-42
lines changed

imblearn/base.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class attribute, which is a dictionary `param_name: list of constraints`. See
3737
)
3838

3939

40-
class SamplerMixin(_ParamsValidationMixin, BaseEstimator, metaclass=ABCMeta):
40+
class SamplerMixin(_ParamsValidationMixin, metaclass=ABCMeta):
4141
"""Mixin class for samplers with abstract method.
4242
4343
Warning: This class should not be used directly. Use the derive classes
@@ -135,7 +135,7 @@ def _fit_resample(self, X, y):
135135
pass
136136

137137

138-
class BaseSampler(SamplerMixin, OneToOneFeatureMixin):
138+
class BaseSampler(SamplerMixin, OneToOneFeatureMixin, BaseEstimator):
139139
"""Base class for sampling algorithms.
140140
141141
Warning: This class should not be used directly. Use the derive classes
@@ -204,9 +204,15 @@ def _more_tags(self):
204204

205205
@available_if(check_version_package("sklearn", ">=", "1.6"))
206206
def __sklearn_tags__(self):
207-
tags = super().__sklearn_tags__()
208-
209-
from .utils._tags import InputTags
207+
from .utils._tags import Tags, SamplerTags, TargetTags, InputTags
208+
tags = Tags(
209+
estimator_type="sampler",
210+
target_tags=TargetTags(required=True),
211+
transformer_tags=None,
212+
regressor_tags=None,
213+
classifier_tags=None,
214+
sampler_tags=SamplerTags(),
215+
)
210216
tags.input_tags = InputTags()
211217
tags.input_tags.two_d_array = True
212218
tags.input_tags.sparse = True

imblearn/ensemble/_bagging.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -382,17 +382,13 @@ def decision_function(self, X):
382382
check_is_fitted(self)
383383

384384
# Check data
385-
if sklearn_version < parse_version("1.6"):
386-
kwargs = {"force_all_finite": False}
387-
else:
388-
kwargs = {"ensure_all_finite": False}
389385
X = validate_data(
390386
self,
391387
X=X,
392388
accept_sparse=["csr", "csc"],
393389
dtype=None,
394390
reset=False,
395-
**kwargs
391+
ensure_all_finite=False,
396392
)
397393

398394
# Parallel loop

imblearn/ensemble/_easy_ensemble.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,17 +310,13 @@ def decision_function(self, X):
310310
check_is_fitted(self)
311311

312312
# Check data
313-
if sklearn_version < parse_version("1.6"):
314-
kwargs = {"force_all_finite": False}
315-
else:
316-
kwargs = {"ensure_all_finite": False}
317313
X = validate_data(
318314
self,
319315
X=X,
320316
accept_sparse=["csr", "csc"],
321317
dtype=None,
322318
reset=False,
323-
**kwargs,
319+
ensure_all_finite=False,
324320
)
325321

326322
# Parallel loop

imblearn/ensemble/_forest.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -598,12 +598,9 @@ def fit(self, X, y, sample_weight=None):
598598
# TODO: remove when the minimum supported version of scipy will be 1.4
599599
# Support for missing values
600600
if parse_version(sklearn_version.base_version) >= parse_version("1.4"):
601-
if sklearn_version >= parse_version("1.6"):
602-
kwargs = {"ensure_all_finite": False}
603-
else:
604-
kwargs = {"force_all_finite": False}
601+
ensure_all_finite = False
605602
else:
606-
kwargs = {"force_all_finite": False}
603+
ensure_all_finite = False
607604

608605
X, y = validate_data(
609606
self,
@@ -612,7 +609,7 @@ def fit(self, X, y, sample_weight=None):
612609
multi_output=True,
613610
accept_sparse="csc",
614611
dtype=DTYPE,
615-
**kwargs,
612+
ensure_all_finite=ensure_all_finite,
616613
)
617614

618615
# TODO: remove when the minimum supported version of scikit-learn will be 1.4

imblearn/metrics/tests/test_classification.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Christos Aridas
55
# License: MIT
66

7+
import warnings
78
from functools import partial
89

910
import numpy as np
@@ -23,7 +24,6 @@
2324
from sklearn.utils._testing import (
2425
assert_allclose,
2526
assert_array_equal,
26-
assert_no_warnings,
2727
)
2828
from sklearn.utils.validation import check_random_state
2929

@@ -105,11 +105,13 @@ def test_sensitivity_specificity_score_binary():
105105
# binary class case the score is the value of the measure for the positive
106106
# class (e.g. label == 1). This is deprecated for average != 'binary'.
107107
for kwargs in ({}, {"average": "binary"}):
108-
sen = assert_no_warnings(sensitivity_score, y_true, y_pred, **kwargs)
109-
assert sen == pytest.approx(0.68, rel=R_TOL)
108+
with warnings.catch_warnings():
109+
warnings.simplefilter("error")
110+
sen = sensitivity_score(y_true, y_pred, **kwargs)
111+
assert sen == pytest.approx(0.68, rel=R_TOL)
110112

111-
spe = assert_no_warnings(specificity_score, y_true, y_pred, **kwargs)
112-
assert spe == pytest.approx(0.88, rel=R_TOL)
113+
spe = specificity_score(y_true, y_pred, **kwargs)
114+
assert spe == pytest.approx(0.88, rel=R_TOL)
113115

114116

115117
@pytest.mark.filterwarnings("ignore:Specificity is ill-defined")

imblearn/over_sampling/_random_over_sampler.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,3 +261,11 @@ def _more_tags(self):
261261
"check_complex_data": "Robust to this type of data.",
262262
},
263263
}
264+
265+
@available_if(check_version_package("sklearn", ">=", "1.6"))
266+
def __sklearn_tags__(self):
267+
tags = super().__sklearn_tags__()
268+
tags.input_tags.allow_nan = True
269+
tags.input_tags.string = True
270+
tags.sampler_tags.sample_indices = True
271+
return tags

imblearn/over_sampling/_smote/base.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,17 @@ def ohe_(self):
829829
)
830830
return self.categorical_encoder_
831831

832+
@available_if(check_version_package("sklearn", "<", "1.6"))
833+
def _more_tags(self):
834+
return {"X_types": ["2darray", "dataframe", "string"]}
835+
836+
@available_if(check_version_package("sklearn", ">=", "1.6"))
837+
def __sklearn_tags__(self):
838+
tags = super().__sklearn_tags__()
839+
tags.input_tags.sparse = False
840+
tags.input_tags.string = True
841+
return tags
842+
832843

833844
@Substitution(
834845
sampling_strategy=BaseOverSampler._sampling_strategy_docstring,
@@ -1073,3 +1084,10 @@ def _fit_resample(self, X, y):
10731084
@available_if(check_version_package("sklearn", "<", "1.6"))
10741085
def _more_tags(self):
10751086
return {"X_types": ["2darray", "dataframe", "string"]}
1087+
1088+
@available_if(check_version_package("sklearn", ">=", "1.6"))
1089+
def __sklearn_tags__(self):
1090+
tags = super().__sklearn_tags__()
1091+
tags.input_tags.sparse = False
1092+
tags.input_tags.string = True
1093+
return tags

imblearn/under_sampling/_prototype_selection/_condensed_nearest_neighbour.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,9 @@ def estimator_(self):
264264
@available_if(check_version_package("sklearn", "<", "1.6"))
265265
def _more_tags(self):
266266
return {"sample_indices": True}
267+
268+
@available_if(check_version_package("sklearn", ">=", "1.6"))
269+
def __sklearn_tags__(self):
270+
tags = super().__sklearn_tags__()
271+
tags.sampler_tags.sample_indices = True
272+
return tags

imblearn/under_sampling/_prototype_selection/_edited_nearest_neighbours.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,12 @@ def _fit_resample(self, X, y):
194194
def _more_tags(self):
195195
return {"sample_indices": True}
196196

197+
@available_if(check_version_package("sklearn", ">=", "1.6"))
198+
def __sklearn_tags__(self):
199+
tags = super().__sklearn_tags__()
200+
tags.sampler_tags.sample_indices = True
201+
return tags
202+
197203

198204
@Substitution(
199205
sampling_strategy=BaseCleaningSampler._sampling_strategy_docstring,
@@ -416,6 +422,12 @@ def _fit_resample(self, X, y):
416422
def _more_tags(self):
417423
return {"sample_indices": True}
418424

425+
@available_if(check_version_package("sklearn", ">=", "1.6"))
426+
def __sklearn_tags__(self):
427+
tags = super().__sklearn_tags__()
428+
tags.sampler_tags.sample_indices = True
429+
return tags
430+
419431

420432
@Substitution(
421433
sampling_strategy=BaseCleaningSampler._sampling_strategy_docstring,
@@ -625,3 +637,9 @@ def _fit_resample(self, X, y):
625637
@available_if(check_version_package("sklearn", "<", "1.6"))
626638
def _more_tags(self):
627639
return {"sample_indices": True}
640+
641+
@available_if(check_version_package("sklearn", ">=", "1.6"))
642+
def __sklearn_tags__(self):
643+
tags = super().__sklearn_tags__()
644+
tags.sampler_tags.sample_indices = True
645+
return tags

imblearn/under_sampling/_prototype_selection/_instance_hardness_threshold.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,9 @@ def _fit_resample(self, X, y):
205205
@available_if(check_version_package("sklearn", "<", "1.6"))
206206
def _more_tags(self):
207207
return {"sample_indices": True}
208+
209+
@available_if(check_version_package("sklearn", ">=", "1.6"))
210+
def __sklearn_tags__(self):
211+
tags = super().__sklearn_tags__()
212+
tags.sampler_tags.sample_indices = True
213+
return tags

0 commit comments

Comments
 (0)