Skip to content

Commit 1d85db2

Browse files
Update sklearn_sycl example (#813)
* update example * fix codefactor * fix codestyle
1 parent 9ee4400 commit 1d85db2

File tree

1 file changed

+41
-28
lines changed

1 file changed

+41
-28
lines changed

examples/daal4py/sycl/sklearn_sycl.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
# daal4py Scikit-Learn examples for GPU
1818
# run like this:
19-
# python -m daal4py ./sklearn_sycl.py
19+
# python -m sklearnex ./sklearn_sycl.py
2020

2121
import numpy as np
2222

@@ -27,30 +27,28 @@
2727

2828
from sklearn.datasets import load_iris
2929

30-
dpctx_available = False
30+
dpctl_available = False
3131
try:
32-
from dpctx import device_context, device_type
33-
dpctx_available = True
32+
import dpctl
33+
from sklearnex._config import config_context
34+
dpctl_available = True
3435
except ImportError:
3536
try:
3637
from daal4py.oneapi import sycl_context
37-
sycl_extention_available = True
38-
except:
39-
sycl_extention_available = False
38+
print("*" * 80)
39+
print("\ndpctl package not found, switched to daal4py package\n")
40+
print("*" * 80)
41+
except ImportError:
42+
print("\nRequired packages not found, aborting...\n")
43+
exit()
4044

41-
gpu_available = False
42-
if dpctx_available:
43-
try:
44-
with device_context(device_type.gpu, 0):
45-
gpu_available = True
46-
except:
47-
gpu_available = False
4845

49-
elif sycl_extention_available:
46+
gpu_available = False
47+
if not dpctl_available:
5048
try:
5149
with sycl_context('gpu'):
5250
gpu_available = True
53-
except:
51+
except Exception:
5452
gpu_available = False
5553

5654

@@ -136,11 +134,23 @@ def dbscan():
136134

137135

138136
def get_context(device):
139-
if dpctx_available:
140-
return device_context(device, 0)
141-
if sycl_extention_available:
142-
return sycl_context(device)
143-
return None
137+
if dpctl_available:
138+
return config_context(target_offload=device)
139+
return sycl_context(device)
140+
141+
142+
def device_type_to_str(queue):
143+
if queue is None:
144+
return 'host'
145+
146+
from dpctl import device_type
147+
if queue.sycl_device.device_type == device_type.cpu:
148+
return 'cpu'
149+
if queue.sycl_device.device_type == device_type.gpu:
150+
return 'gpu'
151+
if queue.sycl_device.device_type == device_type.host:
152+
return 'host'
153+
return 'unknown'
144154

145155

146156
if __name__ == "__main__":
@@ -154,13 +164,13 @@ def get_context(device):
154164
]
155165
devices = []
156166

157-
if dpctx_available:
158-
devices.append(device_type.host)
159-
devices.append(device_type.cpu)
160-
if gpu_available:
161-
devices.append(device_type.gpu)
167+
if dpctl_available:
168+
devices.append(None)
169+
devices.append(dpctl.SyclQueue('cpu'))
170+
if dpctl.has_gpu_devices:
171+
devices.append(dpctl.SyclQueue('gpu'))
162172

163-
elif sycl_extention_available:
173+
else:
164174
devices.append('host')
165175
devices.append('cpu')
166176
if gpu_available:
@@ -169,7 +179,10 @@ def get_context(device):
169179
for device in devices:
170180
for e in examples:
171181
print("*" * 80)
172-
print("device context:", device)
182+
if(dpctl_available):
183+
print("device context:", device_type_to_str(device))
184+
else:
185+
print("device context:", device)
173186
with get_context(device):
174187
e()
175188
print("*" * 80)

0 commit comments

Comments
 (0)