|
17 | 17 | assert_true, assert_warns_message)
|
18 | 18 |
|
19 | 19 | from imblearn.pipeline import Pipeline, make_pipeline
|
20 |
| -from imblearn.under_sampling import RandomUnderSampler |
| 20 | +from imblearn.under_sampling import RandomUnderSampler, EditedNearestNeighbours as ENN |
21 | 21 |
|
22 | 22 | JUNK_FOOD_DOCS = (
|
23 | 23 | "the pizza pizza beer copyright",
|
@@ -473,4 +473,36 @@ def test_pipeline_with_step_that_it_is_pipeline():
|
473 | 473 | filter1 = SelectKBest(f_classif, k=2)
|
474 | 474 | pipe1 = Pipeline([('rus', rus), ('anova', filter1)])
|
475 | 475 | assert_raises(TypeError, Pipeline, [('pipe1', pipe1), ('logistic', clf)])
|
476 |
| - |
| 476 | + |
| 477 | +def test_pipeline_fit_then_sample_with_sampler_last_estimator(): |
| 478 | + X, y = make_classification(n_classes=2, class_sep=2, weights=[0.1, 0.9], |
| 479 | + n_informative=3, n_redundant=1, flip_y=0, |
| 480 | + n_features=20, n_clusters_per_class=1, |
| 481 | + n_samples=50000, random_state=0) |
| 482 | + |
| 483 | + rus = RandomUnderSampler(random_state=42) |
| 484 | + enn = ENN() |
| 485 | + pipeline = make_pipeline(rus, enn) |
| 486 | + X_fit_sample_resampled, y_fit_sample_resampled = pipeline.fit_sample(X,y) |
| 487 | + pipeline = make_pipeline(rus, enn) |
| 488 | + pipeline.fit(X,y) |
| 489 | + X_fit_then_sample_resampled, y_fit_then_sample_resampled = pipeline.sample(X,y) |
| 490 | + assert_array_equal(X_fit_sample_resampled, X_fit_then_sample_resampled) |
| 491 | + assert_array_equal(y_fit_sample_resampled, y_fit_then_sample_resampled) |
| 492 | + |
| 493 | + |
| 494 | +def test_pipeline_fit_then_sample_of_three_samplers_with_sampler_last_estimator(): |
| 495 | + X, y = make_classification(n_classes=2, class_sep=2, weights=[0.1, 0.9], |
| 496 | + n_informative=3, n_redundant=1, flip_y=0, |
| 497 | + n_features=20, n_clusters_per_class=1, |
| 498 | + n_samples=50000, random_state=0) |
| 499 | + |
| 500 | + rus = RandomUnderSampler(random_state=42) |
| 501 | + enn = ENN() |
| 502 | + pipeline = make_pipeline(rus, enn, rus) |
| 503 | + X_fit_sample_resampled, y_fit_sample_resampled = pipeline.fit_sample(X,y) |
| 504 | + pipeline = make_pipeline(rus, enn, rus) |
| 505 | + pipeline.fit(X,y) |
| 506 | + X_fit_then_sample_resampled, y_fit_then_sample_resampled = pipeline.sample(X,y) |
| 507 | + assert_array_equal(X_fit_sample_resampled, X_fit_then_sample_resampled) |
| 508 | + assert_array_equal(y_fit_sample_resampled, y_fit_then_sample_resampled) |
0 commit comments