Skip to content

Commit 4aeb927

Browse files
committed
more clean up and coverage
1 parent 42fa8e5 commit 4aeb927

File tree

10 files changed

+164
-40
lines changed

10 files changed

+164
-40
lines changed

imblearn/ensemble/_bagging.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import numbers
99

1010
import numpy as np
11-
import sklearn
1211
from sklearn.base import clone
1312
from sklearn.ensemble import BaggingClassifier
1413
from sklearn.ensemble._bagging import _parallel_decision_function
@@ -25,11 +24,9 @@
2524
from ..under_sampling.base import BaseUnderSampler
2625
from ..utils import Substitution, check_sampling_strategy, check_target_type
2726
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
28-
from ..utils._sklearn_compat import _fit_context, validate_data
27+
from ..utils._sklearn_compat import _fit_context, sklearn_version, validate_data
2928
from ._common import _bagging_parameter_constraints, _estimator_has
3029

31-
sklearn_version = parse_version(sklearn.__version__)
32-
3330

3431
@Substitution(
3532
sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,

imblearn/ensemble/_easy_ensemble.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import warnings
1010

1111
import numpy as np
12-
import sklearn
1312
from sklearn.base import clone
1413
from sklearn.ensemble import AdaBoostClassifier, BaggingClassifier
1514
from sklearn.ensemble._bagging import _parallel_decision_function
@@ -26,11 +25,15 @@
2625
from ..under_sampling.base import BaseUnderSampler
2726
from ..utils import Substitution, check_sampling_strategy, check_target_type
2827
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
29-
from ..utils._sklearn_compat import _fit_context, get_tags, validate_data
28+
from ..utils._sklearn_compat import (
29+
_fit_context,
30+
get_tags,
31+
sklearn_version,
32+
validate_data,
33+
)
3034
from ._common import _bagging_parameter_constraints, _estimator_has
3135

3236
MAX_INT = np.iinfo(np.int32).max
33-
sklearn_version = parse_version(sklearn.__version__)
3437

3538

3639
@Substitution(

imblearn/ensemble/_forest.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from warnings import warn
99

1010
import numpy as np
11-
import sklearn
1211
from numpy import float32 as DTYPE
1312
from numpy import float64 as DOUBLE
1413
from scipy.sparse import issparse
@@ -33,12 +32,11 @@
3332
from ..under_sampling import RandomUnderSampler
3433
from ..utils import Substitution
3534
from ..utils._docstring import _n_jobs_docstring, _random_state_docstring
36-
from ..utils._sklearn_compat import _fit_context, validate_data
35+
from ..utils._sklearn_compat import _fit_context, sklearn_version, validate_data
3736
from ..utils._validation import check_sampling_strategy
3837
from ._common import _random_forest_classifier_parameter_constraints
3938

4039
MAX_INT = np.iinfo(np.int32).max
41-
sklearn_version = parse_version(parse_version(sklearn.__version__).base_version)
4240

4341

4442
def _local_parallel_build_trees(

imblearn/ensemble/_weight_boosting.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from copy import deepcopy
55

66
import numpy as np
7-
import sklearn
87
from sklearn.base import clone
98
from sklearn.ensemble import AdaBoostClassifier
109
from sklearn.ensemble._base import _set_random_states
@@ -19,11 +18,9 @@
1918
from ..under_sampling.base import BaseUnderSampler
2019
from ..utils import Substitution, check_target_type
2120
from ..utils._docstring import _random_state_docstring
22-
from ..utils._sklearn_compat import _fit_context
21+
from ..utils._sklearn_compat import _fit_context, sklearn_version
2322
from ._common import _adaboost_classifier_parameter_constraints
2423

25-
sklearn_version = parse_version(sklearn.__version__)
26-
2724

2825
@Substitution(
2926
sampling_strategy=BaseUnderSampler._sampling_strategy_docstring,

imblearn/ensemble/tests/test_forest.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
import numpy as np
22
import pytest
3-
import sklearn
43
from sklearn.datasets import make_classification
54
from sklearn.model_selection import GridSearchCV, train_test_split
65
from sklearn.utils._testing import assert_allclose, assert_array_equal
76
from sklearn.utils.fixes import parse_version
87

98
from imblearn.ensemble import BalancedRandomForestClassifier
10-
11-
sklearn_version = parse_version(sklearn.__version__)
9+
from imblearn.utils._sklearn_compat import sklearn_version
1210

1311

1412
@pytest.fixture

imblearn/ensemble/tests/test_weight_boosting.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,13 @@ def test_rusboost_sample_weight(imbalanced_dataset):
8282

8383
with pytest.raises(AssertionError):
8484
assert_array_equal(y_pred_no_sample_weight, y_pred_sample_weight)
85+
86+
87+
@pytest.mark.parametrize("algorithm", ["SAMME", "SAMME.R"])
88+
def test_rusboost_algorithm(imbalanced_dataset, algorithm):
89+
X, y = imbalanced_dataset
90+
91+
rusboost = RUSBoostClassifier(algorithm=algorithm)
92+
warn_msg = "`algorithm` parameter is deprecated in 0.12 and will be removed"
93+
with pytest.warns(FutureWarning, match=warn_msg):
94+
rusboost.fit(X, y)

imblearn/over_sampling/_smote/tests/test_smote_nc.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,6 @@
1515

1616
from imblearn.over_sampling import SMOTENC
1717

18-
# from imblearn.utils.estimator_checks import (
19-
# _set_checking_parameters,
20-
# check_param_validation,
21-
# )
22-
2318

2419
def data_heterogneous_ordered():
2520
rng = np.random.RandomState(42)
@@ -293,17 +288,6 @@ def test_smotenc_deprecation_ohe_():
293288
smote.ohe_
294289

295290

296-
# """ def test_smotenc_param_validation():
297-
# """Check that we validate the parameters correctly since this estimator requires
298-
# a specific parameter.
299-
# """
300-
# categorical_features = [0]
301-
# smote = SMOTENC(categorical_features=categorical_features, random_state=0)
302-
# name = smote.__class__.__name__
303-
# _set_checking_parameters(smote)
304-
# check_param_validation(name, smote) """
305-
306-
307291
def test_smotenc_bool_categorical():
308292
"""Check that we don't try to early convert the full input data to numeric when
309293
handling a pandas dataframe.

imblearn/pipeline.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
get_routing_for_object,
3030
)
3131
from sklearn.utils._param_validation import HasMethods
32+
from sklearn.utils.fixes import parse_version
3233
from sklearn.utils.metaestimators import available_if
3334
from sklearn.utils.validation import check_is_fitted, check_memory
3435

@@ -38,6 +39,7 @@
3839
_raise_for_params,
3940
get_tags,
4041
process_routing,
42+
sklearn_version,
4143
validate_params,
4244
)
4345

@@ -55,7 +57,7 @@ def _raise_or_warn_if_not_fitted(estimator):
5557
"""A context manager to make sure a NotFittedError is raised, if a sub-estimator
5658
raises the error.
5759
Otherwise, we raise a warning if the pipeline is not fitted, with the deprecation.
58-
TODO(1.8): remove this context manager and replace with check_is_fitted.
60+
TODO(0.15): remove this context manager and replace with check_is_fitted.
5961
"""
6062
try:
6163
yield
@@ -70,7 +72,7 @@ def _raise_or_warn_if_not_fitted(estimator):
7072
(
7173
"This Pipeline instance is not fitted yet. Call 'fit' with "
7274
"appropriate arguments before using other methods such as transform, "
73-
"predict, etc. This will raise an error in 1.8 instead of the current "
75+
"predict, etc. This will raise an error in 0.15 instead of the current "
7476
"warning."
7577
),
7678
FutureWarning,
@@ -511,6 +513,13 @@ def fit(self, X, y=None, **params):
511513
"`sklearn.set_config(enable_metadata_routing=True)`."
512514
)
513515

516+
if sklearn_version < parse_version("1.4") and self.transform_input is not None:
517+
raise ValueError(
518+
"The `transform_input` parameter is not supported in scikit-learn "
519+
"versions prior to 1.4. Please upgrade to scikit-learn 1.4 or "
520+
"later."
521+
)
522+
514523
routed_params = self._check_method_params(method="fit", props=params)
515524
Xt, yt = self._fit(X, y, routed_params, raw_params=params)
516525
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):

imblearn/tests/test_pipeline.py

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
Test the pipeline module.
33
"""
4+
45
# Authors: Guillaume Lemaitre <[email protected]>
56
# Christos Aridas
67
# License: MIT
@@ -15,7 +16,8 @@
1516
import pytest
1617
from joblib import Memory
1718
from pytest import raises
18-
from sklearn.base import BaseEstimator, clone
19+
from sklearn import config_context
20+
from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, clone
1921
from sklearn.cluster import KMeans
2022
from sklearn.datasets import load_iris, make_classification
2123
from sklearn.decomposition import PCA
@@ -30,11 +32,13 @@
3032
assert_array_almost_equal,
3133
assert_array_equal,
3234
)
35+
from sklearn.utils.fixes import parse_version
3336

3437
from imblearn.datasets import make_imbalance
3538
from imblearn.pipeline import Pipeline, make_pipeline
3639
from imblearn.under_sampling import EditedNearestNeighbours as ENN
3740
from imblearn.under_sampling import RandomUnderSampler
41+
from imblearn.utils._sklearn_compat import sklearn_version
3842
from imblearn.utils.estimator_checks import check_param_validation
3943

4044
JUNK_FOOD_DOCS = (
@@ -1365,3 +1369,129 @@ def test_pipeline_with_set_output():
13651369
assert isinstance(X_res, pd.DataFrame)
13661370
# transformer will not change `y` and sampler will always preserve the type of `y`
13671371
assert isinstance(y_res, type(y))
1372+
1373+
1374+
# TODO(0.15): change warning to checking for NotFittedError
1375+
@pytest.mark.parametrize(
1376+
"method",
1377+
[
1378+
"predict",
1379+
"predict_proba",
1380+
"predict_log_proba",
1381+
"decision_function",
1382+
"score",
1383+
"score_samples",
1384+
"transform",
1385+
"inverse_transform",
1386+
],
1387+
)
1388+
def test_pipeline_warns_not_fitted(method):
1389+
class StatelessEstimator(BaseEstimator):
1390+
"""Stateless estimator that doesn't check if it's fitted.
1391+
Stateless estimators that don't require fit, should properly set the
1392+
`requires_fit` flag and implement a `__sklearn_check_is_fitted__` returning
1393+
`True`.
1394+
"""
1395+
1396+
def fit(self, X, y):
1397+
return self # pragma: no cover
1398+
1399+
def transform(self, X):
1400+
return X
1401+
1402+
def predict(self, X):
1403+
return np.ones(len(X))
1404+
1405+
def predict_proba(self, X):
1406+
return np.ones(len(X))
1407+
1408+
def predict_log_proba(self, X):
1409+
return np.zeros(len(X))
1410+
1411+
def decision_function(self, X):
1412+
return np.ones(len(X))
1413+
1414+
def score(self, X, y):
1415+
return 1
1416+
1417+
def score_samples(self, X):
1418+
return np.ones(len(X))
1419+
1420+
def inverse_transform(self, X):
1421+
return X
1422+
1423+
pipe = Pipeline([("estimator", StatelessEstimator())])
1424+
with pytest.warns(FutureWarning, match="This Pipeline instance is not fitted yet."):
1425+
getattr(pipe, method)([[1]])
1426+
1427+
1428+
# transform_input tests
1429+
# =====================
1430+
1431+
1432+
@pytest.mark.skipif(
1433+
sklearn_version < parse_version("1.4"),
1434+
reason="scikit-learn < 1.4 does not support transform_input",
1435+
)
1436+
@config_context(enable_metadata_routing=True)
1437+
def test_transform_input_explicit_value_check():
1438+
"""Test that the right transformed values are passed to `fit`."""
1439+
1440+
class Transformer(TransformerMixin, BaseEstimator):
1441+
def fit(self, X, y):
1442+
self.fitted_ = True
1443+
return self
1444+
1445+
def transform(self, X):
1446+
return X + 1
1447+
1448+
class Estimator(ClassifierMixin, BaseEstimator):
1449+
def fit(self, X, y, X_val=None, y_val=None):
1450+
assert_array_equal(X, np.array([[1, 2]]))
1451+
assert_array_equal(y, np.array([0, 1]))
1452+
assert_array_equal(X_val, np.array([[2, 3]]))
1453+
assert_array_equal(y_val, np.array([0, 1]))
1454+
return self
1455+
1456+
X = np.array([[0, 1]])
1457+
y = np.array([0, 1])
1458+
X_val = np.array([[1, 2]])
1459+
y_val = np.array([0, 1])
1460+
pipe = Pipeline(
1461+
[
1462+
("transformer", Transformer()),
1463+
("estimator", Estimator().set_fit_request(X_val=True, y_val=True)),
1464+
],
1465+
transform_input=["X_val"],
1466+
)
1467+
pipe.fit(X, y, X_val=X_val, y_val=y_val)
1468+
1469+
1470+
def test_transform_input_no_slep6():
1471+
"""Make sure the right error is raised if slep6 is not enabled."""
1472+
X = np.array([[1, 2], [3, 4]])
1473+
y = np.array([0, 1])
1474+
msg = "The `transform_input` parameter can only be set if metadata"
1475+
with pytest.raises(ValueError, match=msg):
1476+
make_pipeline(DummyTransf(), transform_input=["blah"]).fit(X, y)
1477+
1478+
1479+
@pytest.mark.skipif(
1480+
sklearn_version >= parse_version("1.4"),
1481+
reason="scikit-learn >= 1.4 supports transform_input",
1482+
)
1483+
@config_context(enable_metadata_routing=True)
1484+
def test_transform_input_sklearn_version():
1485+
"""Test that transform_input raises error with sklearn < 1.4."""
1486+
X = np.array([[1, 2], [3, 4]])
1487+
y = np.array([0, 1])
1488+
msg = (
1489+
"The `transform_input` parameter is not supported in scikit-learn versions "
1490+
"prior to 1.4"
1491+
)
1492+
with pytest.raises(ValueError, match=msg):
1493+
make_pipeline(DummyTransf(), transform_input=["blah"]).fit(X, y)
1494+
1495+
1496+
# end of transform_input tests
1497+
# =============================

imblearn/utils/_test_common/instance_generator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from functools import partial
99
from inspect import isfunction
1010

11-
import sklearn
1211
from sklearn import clone, config_context
1312
from sklearn.exceptions import SkipTestWarning
1413
from sklearn.linear_model import LogisticRegression
@@ -42,10 +41,9 @@
4241
OneSidedSelection,
4342
RandomUnderSampler,
4443
)
44+
from imblearn.utils._sklearn_compat import sklearn_version
4545
from imblearn.utils.testing import all_estimators
4646

47-
sklearn_version = parse_version(sklearn.__version__).base_version
48-
4947
# The following dictionary is to indicate constructor arguments suitable for the test
5048
# suite, which uses very small datasets, and is intended to run rather quickly.
5149
INIT_PARAMS = {
@@ -232,7 +230,7 @@ def _yield_instances_for_check(check, estimator_orig):
232230
},
233231
}
234232

235-
if sklearn_version < "1.4":
233+
if sklearn_version < parse_version("1.4"):
236234
for _, Estimator in all_estimators():
237235
if Estimator in PER_ESTIMATOR_XFAIL_CHECKS:
238236
PER_ESTIMATOR_XFAIL_CHECKS[Estimator]["check_estimators_pickle"] = "FIXME"

0 commit comments

Comments
 (0)