Skip to content

Commit e405447

Browse files
committed
fix grid tests
1 parent ea6e3e1 commit e405447

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

mlinsights/mlbatch/pipeline_cache.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from sklearn.base import clone
22
from sklearn.pipeline import Pipeline, _fit_transform_one
33
from sklearn.utils import _print_elapsed_time
4+
from sklearn.utils.validation import check_memory
45
from .cache_model import MLCache
56

67

@@ -54,10 +55,11 @@ def _get_fit_params_steps(self, fit_params):
5455
fit_params_steps[step][param] = pval
5556
return fit_params_steps
5657

57-
def _fit(self, X, y=None, **fit_params):
58+
def _fit(self, X, y=None, routed_params=None):
5859
self.steps = list(self.steps)
5960
self._validate_steps()
60-
fit_params_steps = self._get_fit_params_steps(fit_params)
61+
memory = check_memory(self.memory)
62+
fit_transform_one_cached = memory.cache(_fit_transform_one)
6163
if not MLCache.has_cache(self.cache_name):
6264
self.cache_ = MLCache.create_cache(self.cache_name)
6365
else:
@@ -80,14 +82,14 @@ def _fit(self, X, y=None, **fit_params):
8082
cached = self.cache_.get(params)
8183
if cached is None:
8284
cloned_transformer = clone(transformer)
83-
Xt, fitted_transformer = _fit_transform_one(
85+
Xt, fitted_transformer = fit_transform_one_cached(
8486
cloned_transformer,
8587
Xt,
8688
y,
8789
None,
8890
message_clsname="PipelineCache",
8991
message=self._log_message(step_idx),
90-
**fit_params_steps[name],
92+
params=routed_params[name],
9193
)
9294
self.cache_.cache(params, fitted_transformer)
9395
else:

0 commit comments

Comments
 (0)