1616
1717import inspect
1818import logging
19- from collections .abc import Iterable
2019from functools import wraps
2120
2221import numpy as np
2524from onedal import _default_backend as backend
2625
2726from ._config import _get_config
27+ from .datatypes import copy_to_dpnp , copy_to_usm , usm_to_numpy
2828from .utils import _sycl_queue_manager as QM
2929from .utils ._array_api import _asarray , _is_numpy_namespace
30- from .utils ._dpep_helpers import dpctl_available , dpnp_available
31-
32- if dpctl_available :
33- from dpctl .memory import MemoryUSMDevice , as_usm_memory
34- from dpctl .tensor import usm_ndarray
35-
30+ from .utils ._third_party import is_dpnp_ndarray
3631
3732logger = logging .getLogger ("sklearnex" )
3833cpu_dlpack_device = (backend .kDLCPU , 0 )
@@ -69,61 +64,13 @@ def wrapper(self, *args, **kwargs):
6964 return wrapper
7065
7166
72- if dpnp_available :
73- import dpnp
74-
75- from .utils ._array_api import _convert_to_dpnp
76-
77-
78- def _copy_to_usm (queue , array ):
79- if not dpctl_available :
80- raise RuntimeError (
81- "dpctl need to be installed to work " "with __sycl_usm_array_interface__"
82- )
83-
84- if hasattr (array , "__array__" ):
85-
86- try :
87- mem = MemoryUSMDevice (array .nbytes , queue = queue )
88- mem .copy_from_host (array .tobytes ())
89- return usm_ndarray (array .shape , array .dtype , buffer = mem )
90- except ValueError as e :
91- # ValueError will raise if device does not support the dtype
92- # retry with float32 (needed for fp16 and fp64 support issues)
93- # try again as float32, if it is a float32 just raise the error.
94- if array .dtype == np .float32 :
95- raise e
96- return _copy_to_usm (queue , array .astype (np .float32 ))
97- else :
98- if isinstance (array , Iterable ):
99- array = [_copy_to_usm (queue , i ) for i in array ]
100- return array
101-
102-
10367def _transfer_to_host (* data ):
10468 has_usm_data , has_host_data = False , False
10569
10670 host_data = []
10771 for item in data :
108- usm_iface = getattr (item , "__sycl_usm_array_interface__" , None )
109- if usm_iface is not None :
110- if not dpctl_available :
111- raise RuntimeError (
112- "dpctl need to be installed to work "
113- "with __sycl_usm_array_interface__"
114- )
115-
116- buffer = as_usm_memory (item ).copy_to_host ()
117- order = "C"
118- if usm_iface ["strides" ] is not None and len (usm_iface ["strides" ]) > 1 :
119- if usm_iface ["strides" ][0 ] < usm_iface ["strides" ][1 ]:
120- order = "F"
121- item = np .ndarray (
122- shape = usm_iface ["shape" ],
123- dtype = usm_iface ["typestr" ],
124- buffer = buffer ,
125- order = order ,
126- )
72+ if usm_iface := getattr (item , "__sycl_usm_array_interface__" , None ):
73+ item = usm_to_numpy (item , usm_iface )
12774 has_usm_data = True
12875 elif not isinstance (item , np .ndarray ) and (
12976 device := getattr (item , "__dlpack_device__" , None )
@@ -215,15 +162,16 @@ def wrapper_impl(*args, **kwargs):
215162 )
216163 if _get_config ()["use_raw_input" ] is True and not override_raw_input :
217164 if "queue" not in kwargs :
218- usm_iface = getattr (args [0 ], "__sycl_usm_array_interface__" , None )
219- data_queue = usm_iface ["syclobj" ] if usm_iface is not None else None
220- kwargs ["queue" ] = data_queue
165+ if usm_iface := getattr (args [0 ], "__sycl_usm_array_interface__" , None ):
166+ kwargs ["queue" ] = usm_iface ["syclobj" ]
167+ else :
168+ kwargs ["queue" ] = None
221169 return invoke_func (self , * args , ** kwargs )
222170 elif len (args ) == 0 and len (kwargs ) == 0 :
223171 # no arguments, there's nothing we can deduce from them -> just call the function
224172 return invoke_func (self , * args , ** kwargs )
225173
226- data = (* args , * kwargs .values ())
174+ data = (* args , * kwargs .values ())[ 0 ]
227175 # get and set the global queue from the kwarg or data
228176 with QM .manage_global_queue (kwargs .get ("queue" ), * args ) as queue :
229177 hostargs , hostkwargs = _get_host_inputs (* args , ** kwargs )
@@ -232,17 +180,17 @@ def wrapper_impl(*args, **kwargs):
232180 hostkwargs ["queue" ] = queue
233181 result = invoke_func (self , * hostargs , ** hostkwargs )
234182
235- usm_iface = getattr (data [ 0 ] , "__sycl_usm_array_interface__" , None )
236- if queue is not None and usm_iface is not None :
237- result = _copy_to_usm (queue , result )
238- if dpnp_available and isinstance (data [ 0 ], dpnp . ndarray ):
239- result = _convert_to_dpnp ( result )
240- return result
183+ if queue and hasattr (data , "__sycl_usm_array_interface__" ):
184+ return (
185+ copy_to_dpnp (queue , result )
186+ if is_dpnp_ndarray (data )
187+ else copy_to_usm ( queue , result )
188+ )
241189
242190 if get_config ().get ("transform_output" ) in ("default" , None ):
243- input_array_api = getattr (data [ 0 ] , "__array_namespace__" , lambda : None )()
191+ input_array_api = getattr (data , "__array_namespace__" , lambda : None )()
244192 if input_array_api and not _is_numpy_namespace (input_array_api ):
245- input_array_api_device = data [ 0 ] .device
193+ input_array_api_device = data .device
246194 result = _asarray (result , input_array_api , device = input_array_api_device )
247195 return result
248196
0 commit comments