Skip to content

Commit d56958e

Browse files
authored
fix: forest dpctl predict queue misalignment (#2507)
* fix: forest dpctl predict queue misalignment * black formatting * sycl_queue -> device
1 parent 6632523 commit d56958e

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

onedal/ensemble/forest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,8 @@ def predict(self, X, queue=None):
522522

523523
try:
524524
return xp.take(
525-
xp.asarray(self.classes_), xp.astype(xp.reshape(pred, (-1,)), xp.int64)
525+
xp.asarray(self.classes_, device=pred.sycl_queue),
526+
xp.astype(xp.reshape(pred, (-1,)), xp.int64),
526527
)
527528
except AttributeError:
528529
return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe"))

sklearnex/ensemble/_forest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,8 @@ def _onedal_predict(self, X, queue=None):
831831
res = self._onedal_estimator.predict(X, queue=queue)
832832
try:
833833
return xp.take(
834-
xp.asarray(self.classes_), xp.astype(xp.reshape(res, (-1,)), xp.int64)
834+
xp.asarray(self.classes_, device=res.sycl_queue),
835+
xp.astype(xp.reshape(res, (-1,)), xp.int64),
835836
)
836837
except AttributeError:
837838
return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe"))

0 commit comments

Comments
 (0)