1616
1717# daal4py Scikit-Learn examples for GPU
1818# run like this:
19- # python -m daal4py ./sklearn_sycl.py
19+ # python -m sklearnex ./sklearn_sycl.py
2020
2121import numpy as np
2222
2727
2828from sklearn .datasets import load_iris
2929
30- dpctx_available = False
30+ dpctl_available = False
3131try :
32- from dpctx import device_context , device_type
33- dpctx_available = True
32+ import dpctl
33+ from sklearnex ._config import config_context
34+ dpctl_available = True
3435except ImportError :
3536 try :
3637 from daal4py .oneapi import sycl_context
37- sycl_extention_available = True
38- except :
39- sycl_extention_available = False
38+ print ("*" * 80 )
39+ print ("\n dpctl package not found, switched to daal4py package\n " )
40+ print ("*" * 80 )
41+ except ImportError :
42+ print ("\n Required packages not found, aborting...\n " )
43+ exit ()
4044
41- gpu_available = False
42- if dpctx_available :
43- try :
44- with device_context (device_type .gpu , 0 ):
45- gpu_available = True
46- except :
47- gpu_available = False
4845
49- elif sycl_extention_available :
46+ gpu_available = False
47+ if not dpctl_available :
5048 try :
5149 with sycl_context ('gpu' ):
5250 gpu_available = True
53- except :
51+ except Exception :
5452 gpu_available = False
5553
5654
@@ -136,11 +134,23 @@ def dbscan():
136134
137135
138136def get_context (device ):
139- if dpctx_available :
140- return device_context (device , 0 )
141- if sycl_extention_available :
142- return sycl_context (device )
143- return None
137+ if dpctl_available :
138+ return config_context (target_offload = device )
139+ return sycl_context (device )
140+
141+
142+ def device_type_to_str (queue ):
143+ if queue is None :
144+ return 'host'
145+
146+ from dpctl import device_type
147+ if queue .sycl_device .device_type == device_type .cpu :
148+ return 'cpu'
149+ if queue .sycl_device .device_type == device_type .gpu :
150+ return 'gpu'
151+ if queue .sycl_device .device_type == device_type .host :
152+ return 'host'
153+ return 'unknown'
144154
145155
146156if __name__ == "__main__" :
@@ -154,13 +164,13 @@ def get_context(device):
154164 ]
155165 devices = []
156166
157- if dpctx_available :
158- devices .append (device_type . host )
159- devices .append (device_type . cpu )
160- if gpu_available :
161- devices .append (device_type . gpu )
167+ if dpctl_available :
168+ devices .append (None )
169+ devices .append (dpctl . SyclQueue ( ' cpu' ) )
170+ if dpctl . has_gpu_devices :
171+ devices .append (dpctl . SyclQueue ( ' gpu' ) )
162172
163- elif sycl_extention_available :
173+ else :
164174 devices .append ('host' )
165175 devices .append ('cpu' )
166176 if gpu_available :
@@ -169,7 +179,10 @@ def get_context(device):
169179 for device in devices :
170180 for e in examples :
171181 print ("*" * 80 )
172- print ("device context:" , device )
182+ if (dpctl_available ):
183+ print ("device context:" , device_type_to_str (device ))
184+ else :
185+ print ("device context:" , device )
173186 with get_context (device ):
174187 e ()
175188 print ("*" * 80 )
0 commit comments