@@ -49,7 +49,7 @@ def _get_policy(self, queue, *data):
4949 def _get_onedal_params (self , dtype = np .float32 ):
5050 intercept = 'intercept|' if self .fit_intercept else ''
5151 return {
52- 'fptype' : 'float' if dtype is np .float32 else 'double' ,
52+ 'fptype' : 'float' if dtype == np .float32 else 'double' ,
5353 'method' : self .algorithm , 'intercept' : self .fit_intercept ,
5454 'result_option' : (intercept + 'coefficients' ),
5555 }
@@ -63,20 +63,19 @@ def _fit(self, X, y, module, queue):
6363
6464 dtype = get_dtype (X_loc )
6565 if dtype not in [np .float32 , np .float64 ]:
66- X_loc = X_loc .astype (np .float64 , copy = self .copy_X )
6766 dtype = np .float64
67+ X_loc = X_loc .astype (dtype , copy = self .copy_X )
6868
6969 y_loc = np .asarray (y_loc ).astype (dtype = dtype )
7070
7171 # Finiteness is checked in the sklearnex wrapper
7272 X_loc , y_loc = _check_X_y (
7373 X_loc , y_loc , force_all_finite = False , accept_2d_y = True )
7474
75- params = self ._get_onedal_params (dtype )
76-
7775 self .n_features_in_ = _num_features (X_loc , fallback_1d = True )
7876
7977 X_loc , y_loc = _convert_to_supported (policy , X_loc , y_loc )
78+ params = self ._get_onedal_params (get_dtype (X_loc ))
8079 X_table , y_table = to_table (X_loc , y_loc )
8180
8281 result = module .train (policy , params , X_table , y_table )
@@ -92,7 +91,7 @@ def _fit(self, X, y, module, queue):
9291
9392 return self
9493
95- def _create_model (self , module ):
94+ def _create_model (self , module , policy ):
9695 m = module .model ()
9796
9897 coefficients = self .coef_
@@ -137,6 +136,8 @@ def _create_model(self, module):
137136 if self .fit_intercept :
138137 packed_coefficients [:, 0 ][:, np .newaxis ] = intercept
139138
139+ packed_coefficients = _convert_to_supported (policy , packed_coefficients )
140+
140141 m .packed_coefficients = to_table (packed_coefficients )
141142
142143 self ._onedal_model = m
@@ -158,15 +159,14 @@ def _predict(self, X, module, queue):
158159 force_all_finite = False , ensure_2d = False )
159160 _check_n_features (self , X_loc , False )
160161
161- params = self ._get_onedal_params (X_loc )
162-
163162 if hasattr (self , '_onedal_model' ):
164163 model = self ._onedal_model
165164 else :
166- model = self ._create_model (module )
165+ model = self ._create_model (module , policy )
167166
168167 X_loc = make2d (X_loc )
169168 X_loc = _convert_to_supported (policy , X_loc )
169+ params = self ._get_onedal_params (get_dtype (X_loc ))
170170
171171 X_table = to_table (X_loc )
172172 result = module .infer (policy , params , model , X_table )
0 commit comments