@@ -55,7 +55,16 @@ def _get_fit_params_steps(self, fit_params):
5555 fit_params_steps [step ][param ] = pval
5656 return fit_params_steps
5757
58- def _fit (self , X , y = None , routed_params = None ):
58+ def _fit (self , X , y = None , * args , ** fit_params_steps ):
59+ if "routed_params" in fit_params_steps :
60+ # scikit-learn>=1.4
61+ routed_params = fit_params_steps ["routed_params" ]
62+ elif len (args ) == 1 :
63+ # scikit-learn>=1.4
64+ routed_params = args [0 ]
65+ else :
66+ # scikit-learn<1.4
67+ routed_params = None
5968 self .steps = list (self .steps )
6069 self ._validate_steps ()
6170 memory = check_memory (self .memory )
@@ -82,15 +91,26 @@ def _fit(self, X, y=None, routed_params=None):
8291 cached = self .cache_ .get (params )
8392 if cached is None :
8493 cloned_transformer = clone (transformer )
85- Xt , fitted_transformer = fit_transform_one_cached (
86- cloned_transformer ,
87- Xt ,
88- y ,
89- None ,
90- message_clsname = "PipelineCache" ,
91- message = self ._log_message (step_idx ),
92- params = routed_params [name ],
93- )
94+ if routed_params is None :
95+ Xt , fitted_transformer = fit_transform_one_cached (
96+ cloned_transformer ,
97+ Xt ,
98+ y ,
99+ None ,
100+ message_clsname = "PipelineCache" ,
101+ message = self ._log_message (step_idx ),
102+ ** fit_params_steps [name ],
103+ )
104+ else :
105+ Xt , fitted_transformer = fit_transform_one_cached (
106+ cloned_transformer ,
107+ Xt ,
108+ y ,
109+ None ,
110+ message_clsname = "PipelineCache" ,
111+ message = self ._log_message (step_idx ),
112+ params = routed_params [name ],
113+ )
94114 self .cache_ .cache (params , fitted_transformer )
95115 else :
96116 fitted_transformer = cached
0 commit comments