11from sklearn .base import clone
22from sklearn .pipeline import Pipeline , _fit_transform_one
33from sklearn .utils import _print_elapsed_time
4+ from sklearn .utils .validation import check_memory
45from .cache_model import MLCache
56
67
@@ -54,10 +55,11 @@ def _get_fit_params_steps(self, fit_params):
5455 fit_params_steps [step ][param ] = pval
5556 return fit_params_steps
5657
57- def _fit (self , X , y = None , ** fit_params ):
58+ def _fit (self , X , y = None , routed_params = None ):
5859 self .steps = list (self .steps )
5960 self ._validate_steps ()
60- fit_params_steps = self ._get_fit_params_steps (fit_params )
61+ memory = check_memory (self .memory )
62+ fit_transform_one_cached = memory .cache (_fit_transform_one )
6163 if not MLCache .has_cache (self .cache_name ):
6264 self .cache_ = MLCache .create_cache (self .cache_name )
6365 else :
@@ -80,14 +82,14 @@ def _fit(self, X, y=None, **fit_params):
8082 cached = self .cache_ .get (params )
8183 if cached is None :
8284 cloned_transformer = clone (transformer )
83- Xt , fitted_transformer = _fit_transform_one (
85+ Xt , fitted_transformer = fit_transform_one_cached (
8486 cloned_transformer ,
8587 Xt ,
8688 y ,
8789 None ,
8890 message_clsname = "PipelineCache" ,
8991 message = self ._log_message (step_idx ),
90- ** fit_params_steps [name ],
92+ params = routed_params [name ],
9193 )
9294 self .cache_ .cache (params , fitted_transformer )
9395 else :
0 commit comments