Skip to content

Commit a67a2f0

Browse files
committed
fix for multiple versions of scikit-learn
1 parent f917b39 commit a67a2f0

File tree

1 file changed

+30
-10
lines changed

1 file changed

+30
-10
lines changed

mlinsights/mlbatch/pipeline_cache.py

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,16 @@ def _get_fit_params_steps(self, fit_params):
5555
fit_params_steps[step][param] = pval
5656
return fit_params_steps
5757

58-
def _fit(self, X, y=None, routed_params=None):
58+
def _fit(self, X, y=None, *args, **fit_params_steps):
59+
if "routed_params" in fit_params_steps:
60+
# scikit-learn>=1.4
61+
routed_params = fit_params_steps["routed_params"]
62+
elif len(args) == 1:
63+
# scikit-learn>=1.4
64+
routed_params = args[0]
65+
else:
66+
# scikit-learn<1.4
67+
routed_params = None
5968
self.steps = list(self.steps)
6069
self._validate_steps()
6170
memory = check_memory(self.memory)
@@ -82,15 +91,26 @@ def _fit(self, X, y=None, routed_params=None):
8291
cached = self.cache_.get(params)
8392
if cached is None:
8493
cloned_transformer = clone(transformer)
85-
Xt, fitted_transformer = fit_transform_one_cached(
86-
cloned_transformer,
87-
Xt,
88-
y,
89-
None,
90-
message_clsname="PipelineCache",
91-
message=self._log_message(step_idx),
92-
params=routed_params[name],
93-
)
94+
if routed_params is None:
95+
Xt, fitted_transformer = fit_transform_one_cached(
96+
cloned_transformer,
97+
Xt,
98+
y,
99+
None,
100+
message_clsname="PipelineCache",
101+
message=self._log_message(step_idx),
102+
**fit_params_steps[name],
103+
)
104+
else:
105+
Xt, fitted_transformer = fit_transform_one_cached(
106+
cloned_transformer,
107+
Xt,
108+
y,
109+
None,
110+
message_clsname="PipelineCache",
111+
message=self._log_message(step_idx),
112+
params=routed_params[name],
113+
)
94114
self.cache_.cache(params, fitted_transformer)
95115
else:
96116
fitted_transformer = cached

0 commit comments

Comments
 (0)