Skip to content

Commit 9e10a90

Browse files
author
KulikovNikita
authored
FP64 fix (#1225)
1 parent 97f0d47 commit 9e10a90

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

onedal/neighbors/neighbors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,7 @@ def _onedal_fit(self, X, y, queue):
557557
params = self._get_onedal_params(X, y)
558558
train_alg_regr = _backend.neighbors.regression.train
559559
train_alg_srch = _backend.neighbors.search.train
560+
560561
if gpu_device:
561562
return train_alg_regr(policy, params, *to_table(X, y)).model
562563
return train_alg_srch(policy, params, to_table(X)).model
@@ -581,6 +582,7 @@ def _onedal_predict(self, model, X, params, queue):
581582
model = self._onedal_model
582583
else:
583584
model = self._create_model(backend)
585+
params['fptype'] = 'float' if X.dtype == np.float32 else 'double'
584586
result = backend.infer(policy, params, model, to_table(X))
585587

586588
return result
@@ -708,6 +710,8 @@ def _onedal_predict(self, model, X, params, queue):
708710
model = self._onedal_model
709711
else:
710712
model = self._create_model(_backend.neighbors.search)
713+
714+
params['fptype'] = 'float' if X.dtype == np.float32 else 'double'
711715
result = _backend.neighbors.search.infer(policy, params, model, to_table(X))
712716

713717
return result

0 commit comments

Comments
 (0)