Skip to content

Commit 3a9dcbd

Browse files
committed
iter
1 parent a5cb58b commit 3a9dcbd

File tree

8 files changed

+222
-168
lines changed

8 files changed

+222
-168
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ jobs:
1818
ci-py310-min-optional-dependencies,
1919
ci-py310-min-keras,
2020
ci-py310-min-tensorflow,
21-
ci-py311-sklearn-1-3,
2221
ci-py311-sklearn-1-4,
22+
ci-py311-sklearn-1-5,
2323
ci-py311-latest-keras,
2424
ci-py311-latest-tensorflow,
2525
ci-py313-latest-dependencies,

imblearn/ensemble/_easy_ensemble.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,9 @@ def decision_function(self, X):
315315
X=X,
316316
accept_sparse=["csr", "csc"],
317317
dtype=None,
318-
ensure_all_finite=False,
318+
ensure_all_finite=(
319+
"allow_nan" if get_tags(self).input_tags.allow_nan else True
320+
),
319321
reset=False,
320322
)
321323

@@ -352,7 +354,6 @@ def _get_estimator(self):
352354
return AdaBoostClassifier()
353355
return self.estimator
354356

355-
# TODO: remove when minimum supported version of scikit-learn is 1.5
356357
def _more_tags(self):
357358
return {"allow_nan": _safe_tags(self._get_estimator(), "allow_nan")}
358359

imblearn/ensemble/tests/test_easy_ensemble.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import pytest
88
from sklearn.datasets import load_iris, make_hastie_10_2
9-
from sklearn.ensemble import AdaBoostClassifier
9+
from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier
1010
from sklearn.feature_selection import SelectKBest
1111
from sklearn.model_selection import GridSearchCV, train_test_split
1212
from sklearn.utils._testing import assert_allclose, assert_array_equal
@@ -41,8 +41,8 @@
4141
@pytest.mark.parametrize(
4242
"estimator",
4343
[
44-
AdaBoostClassifier(n_estimators=5),
45-
AdaBoostClassifier(n_estimators=10),
44+
GradientBoostingClassifier(n_estimators=5),
45+
GradientBoostingClassifier(n_estimators=10),
4646
],
4747
)
4848
def test_easy_ensemble_classifier(n_estimators, estimator):
@@ -89,10 +89,10 @@ def test_estimator():
8989
assert isinstance(ensemble.estimator_.steps[-1][1], AdaBoostClassifier)
9090

9191
ensemble = EasyEnsembleClassifier(
92-
2, AdaBoostClassifier(), n_jobs=-1, random_state=0
92+
2, GradientBoostingClassifier(), n_jobs=-1, random_state=0
9393
).fit(X_train, y_train)
9494

95-
assert isinstance(ensemble.estimator_.steps[-1][1], AdaBoostClassifier)
95+
assert isinstance(ensemble.estimator_.steps[-1][1], GradientBoostingClassifier)
9696

9797

9898
def test_bagging_with_pipeline():
@@ -104,7 +104,7 @@ def test_bagging_with_pipeline():
104104
)
105105
estimator = EasyEnsembleClassifier(
106106
n_estimators=2,
107-
estimator=make_pipeline(SelectKBest(k=1), AdaBoostClassifier()),
107+
estimator=make_pipeline(SelectKBest(k=1), GradientBoostingClassifier()),
108108
)
109109
estimator.fit(X, y).predict(X)
110110

@@ -196,7 +196,7 @@ def test_easy_ensemble_classifier_single_estimator():
196196
clf1 = EasyEnsembleClassifier(n_estimators=1, random_state=0).fit(X_train, y_train)
197197
clf2 = make_pipeline(
198198
RandomUnderSampler(random_state=0),
199-
AdaBoostClassifier(random_state=0),
199+
GradientBoostingClassifier(random_state=0),
200200
).fit(X_train, y_train)
201201

202202
assert_array_equal(clf1.predict(X_test), clf2.predict(X_test))
@@ -215,7 +215,7 @@ def test_easy_ensemble_classifier_grid_search():
215215
"estimator__n_estimators": [3, 4],
216216
}
217217
grid_search = GridSearchCV(
218-
EasyEnsembleClassifier(estimator=AdaBoostClassifier()),
218+
EasyEnsembleClassifier(estimator=GradientBoostingClassifier()),
219219
parameters,
220220
cv=5,
221221
)

imblearn/pipeline.py

Lines changed: 111 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# License: BSD
1616
import warnings
1717
from contextlib import contextmanager
18+
from copy import deepcopy
1819

1920
import sklearn
2021
from sklearn import pipeline
@@ -25,10 +26,8 @@
2526
METHODS,
2627
MetadataRouter,
2728
MethodMapping,
28-
_raise_for_params,
2929
_routing_enabled,
3030
get_routing_for_object,
31-
process_routing,
3231
)
3332
from sklearn.utils._param_validation import HasMethods
3433
from sklearn.utils.fixes import parse_version
@@ -38,9 +37,14 @@
3837
from .utils._sklearn_compat import (
3938
_fit_context,
4039
_print_elapsed_time,
40+
_raise_for_params,
41+
get_tags,
42+
process_routing,
4143
validate_params,
4244
)
4345

46+
if "fit_predict" not in METHODS:
47+
METHODS.append("fit_predict")
4448
METHODS.append("fit_resample")
4549

4650
__all__ = ["Pipeline", "make_pipeline"]
@@ -245,6 +249,12 @@ class Pipeline(pipeline.Pipeline):
245249
"verbose": ["boolean"],
246250
}
247251

252+
def __init__(self, steps, *, transform_input=None, memory=None, verbose=False):
253+
self.steps = steps
254+
self.transform_input = transform_input
255+
self.memory = memory
256+
self.verbose = verbose
257+
248258
# BaseEstimator interface
249259

250260
def _validate_steps(self):
@@ -1162,35 +1172,29 @@ def get_metadata_routing(self):
11621172
# fit, fit_predict, and fit_transform call fit_transform if it
11631173
# exists, or else fit and transform
11641174
if hasattr(trans, "fit_transform"):
1165-
(
1166-
method_mapping.add(caller="fit", callee="fit_transform")
1167-
.add(caller="fit_transform", callee="fit_transform")
1168-
.add(caller="fit_predict", callee="fit_transform")
1169-
.add(caller="fit_resample", callee="fit_transform")
1170-
)
1175+
method_mapping.add(caller="fit", callee="fit_transform")
1176+
method_mapping.add(caller="fit_transform", callee="fit_transform")
1177+
method_mapping.add(caller="fit_predict", callee="fit_transform")
1178+
method_mapping.add(caller="fit_resample", callee="fit_transform")
11711179
else:
1172-
(
1173-
method_mapping.add(caller="fit", callee="fit")
1174-
.add(caller="fit", callee="transform")
1175-
.add(caller="fit_transform", callee="fit")
1176-
.add(caller="fit_transform", callee="transform")
1177-
.add(caller="fit_predict", callee="fit")
1178-
.add(caller="fit_predict", callee="transform")
1179-
.add(caller="fit_resample", callee="fit")
1180-
.add(caller="fit_resample", callee="transform")
1181-
)
1182-
1183-
(
1184-
method_mapping.add(caller="predict", callee="transform")
1185-
.add(caller="predict", callee="transform")
1186-
.add(caller="predict_proba", callee="transform")
1187-
.add(caller="decision_function", callee="transform")
1188-
.add(caller="predict_log_proba", callee="transform")
1189-
.add(caller="transform", callee="transform")
1190-
.add(caller="inverse_transform", callee="inverse_transform")
1191-
.add(caller="score", callee="transform")
1192-
.add(caller="fit_resample", callee="transform")
1193-
)
1180+
method_mapping.add(caller="fit", callee="fit")
1181+
method_mapping.add(caller="fit", callee="transform")
1182+
method_mapping.add(caller="fit_transform", callee="fit")
1183+
method_mapping.add(caller="fit_transform", callee="transform")
1184+
method_mapping.add(caller="fit_predict", callee="fit")
1185+
method_mapping.add(caller="fit_predict", callee="transform")
1186+
method_mapping.add(caller="fit_resample", callee="fit")
1187+
method_mapping.add(caller="fit_resample", callee="transform")
1188+
1189+
method_mapping.add(caller="predict", callee="transform")
1190+
method_mapping.add(caller="predict", callee="transform")
1191+
method_mapping.add(caller="predict_proba", callee="transform")
1192+
method_mapping.add(caller="decision_function", callee="transform")
1193+
method_mapping.add(caller="predict_log_proba", callee="transform")
1194+
method_mapping.add(caller="transform", callee="transform")
1195+
method_mapping.add(caller="inverse_transform", callee="inverse_transform")
1196+
method_mapping.add(caller="score", callee="transform")
1197+
method_mapping.add(caller="fit_resample", callee="transform")
11941198

11951199
router.add(method_mapping=method_mapping, **{name: trans})
11961200

@@ -1201,30 +1205,24 @@ def get_metadata_routing(self):
12011205
# then we add the last step
12021206
method_mapping = MethodMapping()
12031207
if hasattr(final_est, "fit_transform"):
1204-
(
1205-
method_mapping.add(caller="fit_transform", callee="fit_transform").add(
1206-
caller="fit_resample", callee="fit_transform"
1207-
)
1208-
)
1208+
method_mapping.add(caller="fit_transform", callee="fit_transform")
1209+
method_mapping.add(caller="fit_resample", callee="fit_transform")
12091210
else:
1210-
(
1211-
method_mapping.add(caller="fit", callee="fit")
1212-
.add(caller="fit", callee="transform")
1213-
.add(caller="fit_resample", callee="fit")
1214-
.add(caller="fit_resample", callee="transform")
1215-
)
1216-
(
12171211
method_mapping.add(caller="fit", callee="fit")
1218-
.add(caller="predict", callee="predict")
1219-
.add(caller="fit_predict", callee="fit_predict")
1220-
.add(caller="predict_proba", callee="predict_proba")
1221-
.add(caller="decision_function", callee="decision_function")
1222-
.add(caller="predict_log_proba", callee="predict_log_proba")
1223-
.add(caller="transform", callee="transform")
1224-
.add(caller="inverse_transform", callee="inverse_transform")
1225-
.add(caller="score", callee="score")
1226-
.add(caller="fit_resample", callee="fit_resample")
1227-
)
1212+
method_mapping.add(caller="fit", callee="transform")
1213+
method_mapping.add(caller="fit_resample", callee="fit")
1214+
method_mapping.add(caller="fit_resample", callee="transform")
1215+
1216+
method_mapping.add(caller="fit", callee="fit")
1217+
method_mapping.add(caller="predict", callee="predict")
1218+
method_mapping.add(caller="fit_predict", callee="fit_predict")
1219+
method_mapping.add(caller="predict_proba", callee="predict_proba")
1220+
method_mapping.add(caller="decision_function", callee="decision_function")
1221+
method_mapping.add(caller="predict_log_proba", callee="predict_log_proba")
1222+
method_mapping.add(caller="transform", callee="transform")
1223+
method_mapping.add(caller="inverse_transform", callee="inverse_transform")
1224+
method_mapping.add(caller="score", callee="score")
1225+
method_mapping.add(caller="fit_resample", callee="fit_resample")
12281226

12291227
router.add(method_mapping=method_mapping, **{final_name: final_est})
12301228
return router
@@ -1258,6 +1256,67 @@ def _check_method_params(self, method, props, **kwargs):
12581256
fit_params_steps[step]["fit_predict"][param] = pval
12591257
return fit_params_steps
12601258

1259+
def __sklearn_is_fitted__(self):
1260+
"""Indicate whether pipeline has been fit.
1261+
1262+
This is done by checking whether the last non-`passthrough` step of the
1263+
pipeline is fitted.
1264+
1265+
An empty pipeline is considered fitted.
1266+
"""
1267+
1268+
# First find the last step that is not 'passthrough'
1269+
last_step = None
1270+
for _, estimator in reversed(self.steps):
1271+
if estimator != "passthrough":
1272+
last_step = estimator
1273+
break
1274+
1275+
if last_step is None:
1276+
# All steps are 'passthrough', so the pipeline is considered fitted
1277+
return True
1278+
1279+
try:
1280+
# check if the last step of the pipeline is fitted
1281+
# we only check the last step since if the last step is fit, it
1282+
# means the previous steps should also be fit. This is faster than
1283+
# checking if every step of the pipeline is fit.
1284+
check_is_fitted(last_step)
1285+
return True
1286+
except NotFittedError:
1287+
return False
1288+
1289+
def __sklearn_tags__(self):
1290+
tags = super().__sklearn_tags__()
1291+
1292+
if not self.steps:
1293+
return tags
1294+
1295+
try:
1296+
if self.steps[0][1] is not None and self.steps[0][1] != "passthrough":
1297+
tags.input_tags.pairwise = get_tags(
1298+
self.steps[0][1]
1299+
).input_tags.pairwise
1300+
except (ValueError, AttributeError, TypeError):
1301+
# This happens when the `steps` is not a list of (name, estimator)
1302+
# tuples and `fit` is not called yet to validate the steps.
1303+
pass
1304+
1305+
try:
1306+
if self.steps[-1][1] is not None and self.steps[-1][1] != "passthrough":
1307+
last_step_tags = get_tags(self.steps[-1][1])
1308+
tags.estimator_type = last_step_tags.estimator_type
1309+
tags.target_tags.multi_output = last_step_tags.target_tags.multi_output
1310+
tags.classifier_tags = deepcopy(last_step_tags.classifier_tags)
1311+
tags.regressor_tags = deepcopy(last_step_tags.regressor_tags)
1312+
tags.transformer_tags = deepcopy(last_step_tags.transformer_tags)
1313+
except (ValueError, AttributeError, TypeError):
1314+
# This happens when the `steps` is not a list of (name, estimator)
1315+
# tuples and `fit` is not called yet to validate the steps.
1316+
pass
1317+
1318+
return tags
1319+
12611320

12621321
def _fit_resample_one(sampler, X, y, message_clsname="", message=None, params=None):
12631322
with _print_elapsed_time(message_clsname, message):

imblearn/tests/test_common.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@
99

1010
import numpy as np
1111
import pytest
12-
import sklearn
1312
from sklearn.exceptions import ConvergenceWarning
1413
from sklearn.utils._testing import ignore_warnings
15-
from sklearn.utils.fixes import parse_version
1614

1715
from imblearn.over_sampling import RandomOverSampler
1816
from imblearn.under_sampling import RandomUnderSampler
@@ -32,14 +30,6 @@
3230
)
3331
from imblearn.utils.testing import all_estimators
3432

35-
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)
36-
if sklearn_version >= parse_version("1.6"):
37-
kwargs_parametrize_with_checks = {
38-
"expected_failed_checks": _get_expected_failed_checks
39-
}
40-
else:
41-
kwargs_parametrize_with_checks = {}
42-
4333

4434
@pytest.mark.parametrize("name, Estimator", all_estimators())
4535
def test_all_estimator_no_base_class(name, Estimator):
@@ -49,7 +39,7 @@ def test_all_estimator_no_base_class(name, Estimator):
4939

5040

5141
@parametrize_with_checks_sklearn(
52-
list(_tested_estimators()), **kwargs_parametrize_with_checks
42+
list(_tested_estimators()), expected_failed_checks=_get_expected_failed_checks
5343
)
5444
def test_estimators_compatibility_sklearn(estimator, check, request):
5545
_set_checking_parameters(estimator)

0 commit comments

Comments
 (0)