Skip to content

Commit 3047405

Browse files
authored
MAINT: rebase Pipeline with sklearn (#486)
1 parent ecab162 commit 3047405

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

imblearn/pipeline.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def fit_resample(self, X, y=None, **fit_params):
316316
return last_step.fit_resample(Xt, yt, **fit_params)
317317

318318
@if_delegate_has_method(delegate='_final_estimator')
319-
def predict(self, X):
319+
def predict(self, X, **predict_params):
320320
"""Apply transformers/samplers to the data, and predict with the final
321321
estimator
322322
@@ -326,6 +326,14 @@ def predict(self, X):
326326
Data to predict on. Must fulfill input requirements of first step
327327
of the pipeline.
328328
329+
**predict_params : dict of string -> object
330+
Parameters to the ``predict`` called at the end of all
331+
transformations in the pipeline. Note that while this may be
332+
used to return uncertainties from some models with return_std
333+
or return_cov, uncertainties that are generated by the
334+
transformations in the pipeline are not propagated to the
335+
final estimator.
336+
329337
Returns
330338
-------
331339
y_pred : array-like
@@ -339,7 +347,7 @@ def predict(self, X):
339347
pass
340348
else:
341349
Xt = transform.transform(Xt)
342-
return self.steps[-1][-1].predict(Xt)
350+
return self.steps[-1][-1].predict(Xt, **predict_params)
343351

344352
@if_delegate_has_method(delegate='_final_estimator')
345353
def fit_predict(self, X, y=None, **fit_params):

imblearn/tests/test_pipeline.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,16 @@ def fit(self, X, y):
139139
return self
140140

141141

142+
class DummyEstimatorParams(BaseEstimator):
143+
"""Mock classifier that takes params on predict"""
144+
def fit(self, X, y):
145+
return self
146+
147+
def predict(self, X, got_attribute=False):
148+
self.got_attribute = got_attribute
149+
return self
150+
151+
142152
class DummySampler(NoTrans):
143153
"""Samplers which returns a balanced number of samples"""
144154

@@ -1085,3 +1095,12 @@ def test_make_pipeline_memory():
10851095
assert pipeline.memory is None
10861096
finally:
10871097
shutil.rmtree(cachedir)
1098+
1099+
1100+
def test_predict_with_predict_params():
1101+
# tests that Pipeline passes predict_params to the final estimator
1102+
# when predict is invoked
1103+
pipe = Pipeline([('transf', Transf()), ('clf', DummyEstimatorParams())])
1104+
pipe.fit(None, None)
1105+
pipe.predict(X=None, got_attribute=True)
1106+
assert pipe.named_steps['clf'].got_attribute

0 commit comments

Comments
 (0)