@@ -557,6 +557,7 @@ def _onedal_fit(self, X, y, queue):
557557 params = self ._get_onedal_params (X , y )
558558 train_alg_regr = _backend .neighbors .regression .train
559559 train_alg_srch = _backend .neighbors .search .train
560+
560561 if gpu_device :
561562 return train_alg_regr (policy , params , * to_table (X , y )).model
562563 return train_alg_srch (policy , params , to_table (X )).model
@@ -581,6 +582,7 @@ def _onedal_predict(self, model, X, params, queue):
581582 model = self ._onedal_model
582583 else :
583584 model = self ._create_model (backend )
585+ params ['fptype' ] = 'float' if X .dtype == np .float32 else 'double'
584586 result = backend .infer (policy , params , model , to_table (X ))
585587
586588 return result
@@ -708,6 +710,8 @@ def _onedal_predict(self, model, X, params, queue):
708710 model = self ._onedal_model
709711 else :
710712 model = self ._create_model (_backend .neighbors .search )
713+
714+ params ['fptype' ] = 'float' if X .dtype == np .float32 else 'double'
711715 result = _backend .neighbors .search .infer (policy , params , model , to_table (X ))
712716
713717 return result
0 commit comments