Skip to content

Commit 72d939a

Browse files
authored
MNNT synchronize pipeline with scikit-learn implementation (#795)
1 parent 1130324 commit 72d939a

File tree

1 file changed

+25
-33
lines changed

1 file changed

+25
-33
lines changed

imblearn/pipeline.py

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class Pipeline(pipeline.Pipeline):
6363
6464
Attributes
6565
----------
66-
named_steps : bunch object, a dictionary with attribute access
66+
named_steps : :class:`~sklearn.utils.Bunch`
6767
Read-only attribute to access any step parameter by user given name.
6868
Keys are step names and values are steps parameters.
6969
@@ -179,7 +179,7 @@ def _iter(self, with_final=True, filter_passthrough=True, filter_resample=True):
179179

180180
# Estimator interface
181181

182-
def _fit(self, X, y=None, **fit_params):
182+
def _fit(self, X, y=None, **fit_params_steps):
183183
self.steps = list(self.steps)
184184
self._validate_steps()
185185
# Setup the memory
@@ -188,18 +188,6 @@ def _fit(self, X, y=None, **fit_params):
188188
fit_transform_one_cached = memory.cache(pipeline._fit_transform_one)
189189
fit_resample_one_cached = memory.cache(_fit_resample_one)
190190

191-
fit_params_steps = {name: {} for name, step in self.steps if step is not None}
192-
for pname, pval in fit_params.items():
193-
if "__" not in pname:
194-
raise ValueError(
195-
f"Pipeline.fit does not accept the {pname} parameter. "
196-
"You can pass parameters to specific steps of your "
197-
"pipeline using the stepname__parameter format, e.g. "
198-
"`Pipeline.fit(X, y, logisticregression__sample_weight"
199-
"=sample_weight)`."
200-
)
201-
step, param = pname.split("__", 1)
202-
fit_params_steps[step][param] = pval
203191
for (step_idx, name, transformer) in self._iter(
204192
with_final=False, filter_passthrough=False, filter_resample=False
205193
):
@@ -241,9 +229,7 @@ def _fit(self, X, y=None, **fit_params):
241229
# transformer. This is necessary when loading the transformer
242230
# from the cache.
243231
self.steps[step_idx] = (name, fitted_transformer)
244-
if self._final_estimator == "passthrough":
245-
return X, y, {}
246-
return X, y, fit_params_steps[self.steps[-1][0]]
232+
return X, y
247233

248234
def fit(self, X, y=None, **fit_params):
249235
"""Fit the model.
@@ -272,10 +258,12 @@ def fit(self, X, y=None, **fit_params):
272258
self : Pipeline
273259
This estimator.
274260
"""
275-
Xt, yt, fit_params = self._fit(X, y, **fit_params)
261+
fit_params_steps = self._check_fit_params(**fit_params)
262+
Xt, yt = self._fit(X, y, **fit_params_steps)
276263
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
277264
if self._final_estimator != "passthrough":
278-
self._final_estimator.fit(Xt, yt, **fit_params)
265+
fit_params_last_step = fit_params_steps[self.steps[-1][0]]
266+
self._final_estimator.fit(Xt, yt, **fit_params_last_step)
279267
return self
280268

281269
def fit_transform(self, X, y=None, **fit_params):
@@ -305,15 +293,18 @@ def fit_transform(self, X, y=None, **fit_params):
305293
Xt : array-like of shape (n_samples, n_transformed_features)
306294
Transformed samples.
307295
"""
296+
fit_params_steps = self._check_fit_params(**fit_params)
297+
Xt, yt = self._fit(X, y, **fit_params_steps)
298+
308299
last_step = self._final_estimator
309-
Xt, yt, fit_params = self._fit(X, y, **fit_params)
310300
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
311301
if last_step == "passthrough":
312302
return Xt
313-
elif hasattr(last_step, "fit_transform"):
314-
return last_step.fit_transform(Xt, yt, **fit_params)
303+
fit_params_last_step = fit_params_steps[self.steps[-1][0]]
304+
if hasattr(last_step, "fit_transform"):
305+
return last_step.fit_transform(Xt, yt, **fit_params_last_step)
315306
else:
316-
return last_step.fit(Xt, yt, **fit_params).transform(Xt)
307+
return last_step.fit(Xt, yt, **fit_params_last_step).transform(Xt)
317308

318309
def fit_resample(self, X, y=None, **fit_params):
319310
"""Fit the model and sample with the final estimator.
@@ -345,13 +336,15 @@ def fit_resample(self, X, y=None, **fit_params):
345336
yt : array-like of shape (n_samples, n_transformed_features)
346337
Transformed target.
347338
"""
339+
fit_params_steps = self._check_fit_params(**fit_params)
340+
Xt, yt = self._fit(X, y, **fit_params_steps)
348341
last_step = self._final_estimator
349-
Xt, yt, fit_params = self._fit(X, y, **fit_params)
350342
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
351343
if last_step == "passthrough":
352344
return Xt
353-
elif hasattr(last_step, "fit_resample"):
354-
return last_step.fit_resample(Xt, yt, **fit_params)
345+
fit_params_last_step = fit_params_steps[self.steps[-1][0]]
346+
if hasattr(last_step, "fit_resample"):
347+
return last_step.fit_resample(Xt, yt, **fit_params_last_step)
355348

356349
@if_delegate_has_method(delegate="_final_estimator")
357350
def fit_predict(self, X, y=None, **fit_params):
@@ -381,9 +374,12 @@ def fit_predict(self, X, y=None, **fit_params):
381374
y_pred : ndarray of shape (n_samples,)
382375
The predicted target.
383376
"""
384-
Xt, yt, fit_params = self._fit(X, y, **fit_params)
377+
fit_params_steps = self._check_fit_params(**fit_params)
378+
Xt, yt = self._fit(X, y, **fit_params_steps)
379+
380+
fit_params_last_step = fit_params_steps[self.steps[-1][0]]
385381
with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)):
386-
y_pred = self.steps[-1][-1].fit_predict(Xt, yt, **fit_params)
382+
y_pred = self.steps[-1][-1].fit_predict(Xt, yt, **fit_params_last_step)
387383
return y_pred
388384

389385

@@ -394,7 +390,7 @@ def _fit_resample_one(sampler, X, y, message_clsname="", message=None, **fit_par
394390
return X_res, y_res, sampler
395391

396392

397-
def make_pipeline(*steps, **kwargs):
393+
def make_pipeline(*steps, memory=None, verbose=False):
398394
"""Construct a Pipeline from the given estimators.
399395
400396
This is a shorthand for the Pipeline constructor; it does not require, and
@@ -438,8 +434,4 @@ def make_pipeline(*steps, **kwargs):
438434
Pipeline(steps=[('standardscaler', StandardScaler()),
439435
('gaussiannb', GaussianNB())])
440436
"""
441-
memory = kwargs.pop("memory", None)
442-
verbose = kwargs.pop("verbose", False)
443-
if kwargs:
444-
raise TypeError(f'Unknown keyword arguments: "{list(kwargs.keys())[0]}"')
445437
return Pipeline(pipeline._name_estimators(steps), memory=memory, verbose=verbose)

0 commit comments

Comments
 (0)