|
1 | 1 | """ |
2 | 2 | Test the pipeline module. |
3 | 3 | """ |
| 4 | + |
4 | 5 | # Authors: Guillaume Lemaitre <[email protected]> |
5 | 6 | # Christos Aridas |
6 | 7 | # License: MIT |
|
15 | 16 | import pytest |
16 | 17 | from joblib import Memory |
17 | 18 | from pytest import raises |
18 | | -from sklearn.base import BaseEstimator, clone |
| 19 | +from sklearn import config_context |
| 20 | +from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin, clone |
19 | 21 | from sklearn.cluster import KMeans |
20 | 22 | from sklearn.datasets import load_iris, make_classification |
21 | 23 | from sklearn.decomposition import PCA |
|
30 | 32 | assert_array_almost_equal, |
31 | 33 | assert_array_equal, |
32 | 34 | ) |
| 35 | +from sklearn.utils.fixes import parse_version |
33 | 36 |
|
34 | 37 | from imblearn.datasets import make_imbalance |
35 | 38 | from imblearn.pipeline import Pipeline, make_pipeline |
36 | 39 | from imblearn.under_sampling import EditedNearestNeighbours as ENN |
37 | 40 | from imblearn.under_sampling import RandomUnderSampler |
| 41 | +from imblearn.utils._sklearn_compat import sklearn_version |
38 | 42 | from imblearn.utils.estimator_checks import check_param_validation |
39 | 43 |
|
40 | 44 | JUNK_FOOD_DOCS = ( |
@@ -1365,3 +1369,129 @@ def test_pipeline_with_set_output(): |
1365 | 1369 | assert isinstance(X_res, pd.DataFrame) |
1366 | 1370 | # transformer will not change `y` and sampler will always preserve the type of `y` |
1367 | 1371 | assert isinstance(y_res, type(y)) |
| 1372 | + |
| 1373 | + |
| 1374 | +# TODO(0.15): change warning to checking for NotFittedError |
| 1375 | +@pytest.mark.parametrize( |
| 1376 | + "method", |
| 1377 | + [ |
| 1378 | + "predict", |
| 1379 | + "predict_proba", |
| 1380 | + "predict_log_proba", |
| 1381 | + "decision_function", |
| 1382 | + "score", |
| 1383 | + "score_samples", |
| 1384 | + "transform", |
| 1385 | + "inverse_transform", |
| 1386 | + ], |
| 1387 | +) |
| 1388 | +def test_pipeline_warns_not_fitted(method): |
| 1389 | + class StatelessEstimator(BaseEstimator): |
| 1390 | + """Stateless estimator that doesn't check if it's fitted. |
| 1391 | + Stateless estimators that don't require fit, should properly set the |
| 1392 | + `requires_fit` flag and implement a `__sklearn_check_is_fitted__` returning |
| 1393 | + `True`. |
| 1394 | + """ |
| 1395 | + |
| 1396 | + def fit(self, X, y): |
| 1397 | + return self # pragma: no cover |
| 1398 | + |
| 1399 | + def transform(self, X): |
| 1400 | + return X |
| 1401 | + |
| 1402 | + def predict(self, X): |
| 1403 | + return np.ones(len(X)) |
| 1404 | + |
| 1405 | + def predict_proba(self, X): |
| 1406 | + return np.ones(len(X)) |
| 1407 | + |
| 1408 | + def predict_log_proba(self, X): |
| 1409 | + return np.zeros(len(X)) |
| 1410 | + |
| 1411 | + def decision_function(self, X): |
| 1412 | + return np.ones(len(X)) |
| 1413 | + |
| 1414 | + def score(self, X, y): |
| 1415 | + return 1 |
| 1416 | + |
| 1417 | + def score_samples(self, X): |
| 1418 | + return np.ones(len(X)) |
| 1419 | + |
| 1420 | + def inverse_transform(self, X): |
| 1421 | + return X |
| 1422 | + |
| 1423 | + pipe = Pipeline([("estimator", StatelessEstimator())]) |
| 1424 | + with pytest.warns(FutureWarning, match="This Pipeline instance is not fitted yet."): |
| 1425 | + getattr(pipe, method)([[1]]) |
| 1426 | + |
| 1427 | + |
| 1428 | +# transform_input tests |
| 1429 | +# ===================== |
| 1430 | + |
| 1431 | + |
| 1432 | +@pytest.mark.skipif( |
| 1433 | + sklearn_version < parse_version("1.4"), |
| 1434 | + reason="scikit-learn < 1.4 does not support transform_input", |
| 1435 | +) |
| 1436 | +@config_context(enable_metadata_routing=True) |
| 1437 | +def test_transform_input_explicit_value_check(): |
| 1438 | + """Test that the right transformed values are passed to `fit`.""" |
| 1439 | + |
| 1440 | + class Transformer(TransformerMixin, BaseEstimator): |
| 1441 | + def fit(self, X, y): |
| 1442 | + self.fitted_ = True |
| 1443 | + return self |
| 1444 | + |
| 1445 | + def transform(self, X): |
| 1446 | + return X + 1 |
| 1447 | + |
| 1448 | + class Estimator(ClassifierMixin, BaseEstimator): |
| 1449 | + def fit(self, X, y, X_val=None, y_val=None): |
| 1450 | + assert_array_equal(X, np.array([[1, 2]])) |
| 1451 | + assert_array_equal(y, np.array([0, 1])) |
| 1452 | + assert_array_equal(X_val, np.array([[2, 3]])) |
| 1453 | + assert_array_equal(y_val, np.array([0, 1])) |
| 1454 | + return self |
| 1455 | + |
| 1456 | + X = np.array([[0, 1]]) |
| 1457 | + y = np.array([0, 1]) |
| 1458 | + X_val = np.array([[1, 2]]) |
| 1459 | + y_val = np.array([0, 1]) |
| 1460 | + pipe = Pipeline( |
| 1461 | + [ |
| 1462 | + ("transformer", Transformer()), |
| 1463 | + ("estimator", Estimator().set_fit_request(X_val=True, y_val=True)), |
| 1464 | + ], |
| 1465 | + transform_input=["X_val"], |
| 1466 | + ) |
| 1467 | + pipe.fit(X, y, X_val=X_val, y_val=y_val) |
| 1468 | + |
| 1469 | + |
| 1470 | +def test_transform_input_no_slep6(): |
| 1471 | + """Make sure the right error is raised if slep6 is not enabled.""" |
| 1472 | + X = np.array([[1, 2], [3, 4]]) |
| 1473 | + y = np.array([0, 1]) |
| 1474 | + msg = "The `transform_input` parameter can only be set if metadata" |
| 1475 | + with pytest.raises(ValueError, match=msg): |
| 1476 | + make_pipeline(DummyTransf(), transform_input=["blah"]).fit(X, y) |
| 1477 | + |
| 1478 | + |
| 1479 | +@pytest.mark.skipif( |
| 1480 | + sklearn_version >= parse_version("1.4"), |
| 1481 | + reason="scikit-learn >= 1.4 supports transform_input", |
| 1482 | +) |
| 1483 | +@config_context(enable_metadata_routing=True) |
| 1484 | +def test_transform_input_sklearn_version(): |
| 1485 | + """Test that transform_input raises error with sklearn < 1.4.""" |
| 1486 | + X = np.array([[1, 2], [3, 4]]) |
| 1487 | + y = np.array([0, 1]) |
| 1488 | + msg = ( |
| 1489 | + "The `transform_input` parameter is not supported in scikit-learn versions " |
| 1490 | + "prior to 1.4" |
| 1491 | + ) |
| 1492 | + with pytest.raises(ValueError, match=msg): |
| 1493 | + make_pipeline(DummyTransf(), transform_input=["blah"]).fit(X, y) |
| 1494 | + |
| 1495 | + |
| 1496 | +# end of transform_input tests |
| 1497 | +# ============================= |
0 commit comments