Skip to content

Commit 0573d83

Browse files
chkoarglemaitre
authored andcommitted
Adress #176 - Fix "fit then sample" bug in pipeline (#181)
1 parent ee33364 commit 0573d83

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

imblearn/pipeline.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,14 @@ def sample(self, X, y):
230230
Xt = X
231231
for _, transform in self.steps[:-1]:
232232
if hasattr(transform, "fit_sample"):
233-
pass
233+
# XXX: Calling sample in pipeline it means that the
234+
# last estimator is a sampler. Samplers don't carry
235+
# the sampled data. So, call 'fit_sample' in all intermediate
236+
# steps to get the sampled data for the last estimator.
237+
Xt, y = transform.fit_sample(Xt, y)
234238
else:
235239
Xt = transform.transform(Xt)
236-
return self.steps[-1][-1].sample(Xt, y)
240+
return self.steps[-1][-1].fit_sample(Xt, y)
237241

238242
@if_delegate_has_method(delegate='_final_estimator')
239243
def predict(self, X):

imblearn/tests/test_pipeline.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
assert_true, assert_warns_message)
1818

1919
from imblearn.pipeline import Pipeline, make_pipeline
20-
from imblearn.under_sampling import RandomUnderSampler
20+
from imblearn.under_sampling import RandomUnderSampler, EditedNearestNeighbours as ENN
2121

2222
JUNK_FOOD_DOCS = (
2323
"the pizza pizza beer copyright",
@@ -473,4 +473,36 @@ def test_pipeline_with_step_that_it_is_pipeline():
473473
filter1 = SelectKBest(f_classif, k=2)
474474
pipe1 = Pipeline([('rus', rus), ('anova', filter1)])
475475
assert_raises(TypeError, Pipeline, [('pipe1', pipe1), ('logistic', clf)])
476-
476+
477+
def test_pipeline_fit_then_sample_with_sampler_last_estimator():
478+
X, y = make_classification(n_classes=2, class_sep=2, weights=[0.1, 0.9],
479+
n_informative=3, n_redundant=1, flip_y=0,
480+
n_features=20, n_clusters_per_class=1,
481+
n_samples=50000, random_state=0)
482+
483+
rus = RandomUnderSampler(random_state=42)
484+
enn = ENN()
485+
pipeline = make_pipeline(rus, enn)
486+
X_fit_sample_resampled, y_fit_sample_resampled = pipeline.fit_sample(X,y)
487+
pipeline = make_pipeline(rus, enn)
488+
pipeline.fit(X,y)
489+
X_fit_then_sample_resampled, y_fit_then_sample_resampled = pipeline.sample(X,y)
490+
assert_array_equal(X_fit_sample_resampled, X_fit_then_sample_resampled)
491+
assert_array_equal(y_fit_sample_resampled, y_fit_then_sample_resampled)
492+
493+
494+
def test_pipeline_fit_then_sample_of_three_samplers_with_sampler_last_estimator():
495+
X, y = make_classification(n_classes=2, class_sep=2, weights=[0.1, 0.9],
496+
n_informative=3, n_redundant=1, flip_y=0,
497+
n_features=20, n_clusters_per_class=1,
498+
n_samples=50000, random_state=0)
499+
500+
rus = RandomUnderSampler(random_state=42)
501+
enn = ENN()
502+
pipeline = make_pipeline(rus, enn, rus)
503+
X_fit_sample_resampled, y_fit_sample_resampled = pipeline.fit_sample(X,y)
504+
pipeline = make_pipeline(rus, enn, rus)
505+
pipeline.fit(X,y)
506+
X_fit_then_sample_resampled, y_fit_then_sample_resampled = pipeline.sample(X,y)
507+
assert_array_equal(X_fit_sample_resampled, X_fit_then_sample_resampled)
508+
assert_array_equal(y_fit_sample_resampled, y_fit_then_sample_resampled)

0 commit comments

Comments
 (0)