Skip to content

Commit 6632523

Browse files
authored
[maintenance] lazy load dpnp.tensor/dpnp and prepare for array_api lazy importing (#2509)
* starting point * first cut * rename * fix various testing imports * don't get ahead of my skis * attempt to further move things apart * remove get_unique_values_with_dpep * remove actually * Update _array_api.py * try to fix * Update _device_offload.py * Update _device_offload.py * Update _device_offload.py * Update _device_offload.py * Update _sycl_usm.py * Update _third_party.py * Update _device_offload.py * Update _device_offload.py * Update _device_offload.py * Update _device_offload.py * Update _sycl_usm.py * Update _sycl_usm.py * Update _third_party.py * Update _third_party.py * Update _sycl_usm.py * Update _third_party.py * Update _third_party.py * Update _array_api.py * Update _array_api.py * formatting * Update setup.py * Update _third_party.py * Update _third_party.py * Update _third_party.py * Update _data_conversion.py * Update __init__.py * Update __init__.py * Update _data_conversion.py * Update _device_offload.py * Update _third_party.py * add requested comments to code * add requested comments to code * fix codespell hits * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update _data_conversion.py * Update _data_conversion.py * Update test_common.py * Update _data_conversion.py * Update _data_conversion.py
1 parent c8fb137 commit 6632523

File tree

20 files changed

+398
-265
lines changed

20 files changed

+398
-265
lines changed

onedal/_device_offload.py

Lines changed: 17 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import inspect
1818
import logging
19-
from collections.abc import Iterable
2019
from functools import wraps
2120

2221
import numpy as np
@@ -25,14 +24,10 @@
2524
from onedal import _default_backend as backend
2625

2726
from ._config import _get_config
27+
from .datatypes import copy_to_dpnp, copy_to_usm, usm_to_numpy
2828
from .utils import _sycl_queue_manager as QM
2929
from .utils._array_api import _asarray, _is_numpy_namespace
30-
from .utils._dpep_helpers import dpctl_available, dpnp_available
31-
32-
if dpctl_available:
33-
from dpctl.memory import MemoryUSMDevice, as_usm_memory
34-
from dpctl.tensor import usm_ndarray
35-
30+
from .utils._third_party import is_dpnp_ndarray
3631

3732
logger = logging.getLogger("sklearnex")
3833
cpu_dlpack_device = (backend.kDLCPU, 0)
@@ -69,61 +64,13 @@ def wrapper(self, *args, **kwargs):
6964
return wrapper
7065

7166

72-
if dpnp_available:
73-
import dpnp
74-
75-
from .utils._array_api import _convert_to_dpnp
76-
77-
78-
def _copy_to_usm(queue, array):
79-
if not dpctl_available:
80-
raise RuntimeError(
81-
"dpctl need to be installed to work " "with __sycl_usm_array_interface__"
82-
)
83-
84-
if hasattr(array, "__array__"):
85-
86-
try:
87-
mem = MemoryUSMDevice(array.nbytes, queue=queue)
88-
mem.copy_from_host(array.tobytes())
89-
return usm_ndarray(array.shape, array.dtype, buffer=mem)
90-
except ValueError as e:
91-
# ValueError will raise if device does not support the dtype
92-
# retry with float32 (needed for fp16 and fp64 support issues)
93-
# try again as float32, if it is a float32 just raise the error.
94-
if array.dtype == np.float32:
95-
raise e
96-
return _copy_to_usm(queue, array.astype(np.float32))
97-
else:
98-
if isinstance(array, Iterable):
99-
array = [_copy_to_usm(queue, i) for i in array]
100-
return array
101-
102-
10367
def _transfer_to_host(*data):
10468
has_usm_data, has_host_data = False, False
10569

10670
host_data = []
10771
for item in data:
108-
usm_iface = getattr(item, "__sycl_usm_array_interface__", None)
109-
if usm_iface is not None:
110-
if not dpctl_available:
111-
raise RuntimeError(
112-
"dpctl need to be installed to work "
113-
"with __sycl_usm_array_interface__"
114-
)
115-
116-
buffer = as_usm_memory(item).copy_to_host()
117-
order = "C"
118-
if usm_iface["strides"] is not None and len(usm_iface["strides"]) > 1:
119-
if usm_iface["strides"][0] < usm_iface["strides"][1]:
120-
order = "F"
121-
item = np.ndarray(
122-
shape=usm_iface["shape"],
123-
dtype=usm_iface["typestr"],
124-
buffer=buffer,
125-
order=order,
126-
)
72+
if usm_iface := getattr(item, "__sycl_usm_array_interface__", None):
73+
item = usm_to_numpy(item, usm_iface)
12774
has_usm_data = True
12875
elif not isinstance(item, np.ndarray) and (
12976
device := getattr(item, "__dlpack_device__", None)
@@ -215,15 +162,16 @@ def wrapper_impl(*args, **kwargs):
215162
)
216163
if _get_config()["use_raw_input"] is True and not override_raw_input:
217164
if "queue" not in kwargs:
218-
usm_iface = getattr(args[0], "__sycl_usm_array_interface__", None)
219-
data_queue = usm_iface["syclobj"] if usm_iface is not None else None
220-
kwargs["queue"] = data_queue
165+
if usm_iface := getattr(args[0], "__sycl_usm_array_interface__", None):
166+
kwargs["queue"] = usm_iface["syclobj"]
167+
else:
168+
kwargs["queue"] = None
221169
return invoke_func(self, *args, **kwargs)
222170
elif len(args) == 0 and len(kwargs) == 0:
223171
# no arguments, there's nothing we can deduce from them -> just call the function
224172
return invoke_func(self, *args, **kwargs)
225173

226-
data = (*args, *kwargs.values())
174+
data = (*args, *kwargs.values())[0]
227175
# get and set the global queue from the kwarg or data
228176
with QM.manage_global_queue(kwargs.get("queue"), *args) as queue:
229177
hostargs, hostkwargs = _get_host_inputs(*args, **kwargs)
@@ -232,17 +180,17 @@ def wrapper_impl(*args, **kwargs):
232180
hostkwargs["queue"] = queue
233181
result = invoke_func(self, *hostargs, **hostkwargs)
234182

235-
usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None)
236-
if queue is not None and usm_iface is not None:
237-
result = _copy_to_usm(queue, result)
238-
if dpnp_available and isinstance(data[0], dpnp.ndarray):
239-
result = _convert_to_dpnp(result)
240-
return result
183+
if queue and hasattr(data, "__sycl_usm_array_interface__"):
184+
return (
185+
copy_to_dpnp(queue, result)
186+
if is_dpnp_ndarray(data)
187+
else copy_to_usm(queue, result)
188+
)
241189

242190
if get_config().get("transform_output") in ("default", None):
243-
input_array_api = getattr(data[0], "__array_namespace__", lambda: None)()
191+
input_array_api = getattr(data, "__array_namespace__", lambda: None)()
244192
if input_array_api and not _is_numpy_namespace(input_array_api):
245-
input_array_api_device = data[0].device
193+
input_array_api_device = data.device
246194
result = _asarray(result, input_array_api, device=input_array_api_device)
247195
return result
248196

onedal/common/tests/test_sycl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from onedal import _default_backend as backend
2020
from onedal.tests.utils._device_selection import get_queues
21-
from onedal.utils._dpep_helpers import dpctl_available
21+
from onedal.utils._third_party import dpctl_available
2222

2323

2424
@pytest.mark.skipif(

onedal/datatypes/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,12 @@
1515
# ==============================================================================
1616

1717
from ._data_conversion import from_table, to_table
18+
from ._sycl_usm import copy_to_dpnp, copy_to_usm, usm_to_numpy
1819

19-
__all__ = ["from_table", "to_table"]
20+
__all__ = [
21+
"copy_to_dpnp",
22+
"copy_to_usm",
23+
"from_table",
24+
"to_table",
25+
"usm_to_numpy",
26+
]

onedal/datatypes/_data_conversion.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17+
from types import ModuleType
18+
1719
import numpy as np
1820

1921
from onedal import _default_backend as backend
@@ -61,29 +63,6 @@ def to_table(*args, queue=None):
6163

6264
if backend.is_dpc:
6365

64-
try:
65-
# try/catch is used here instead of dpep_helpers because
66-
# of circular import issues of _data_conversion.py and
67-
# utils/validation.py. This is a temporary fix until the
68-
# issue with dpnp is addressed, at which point this can
69-
# be removed entirely.
70-
import dpnp
71-
72-
def _table_to_array(table, xp=None):
73-
# By default DPNP ndarray created with a copy.
74-
# TODO:
75-
# investigate why dpnp.array(table, copy=False) doesn't work.
76-
# Work around with using dpctl.tensor.asarray.
77-
if xp == dpnp:
78-
return dpnp.array(dpnp.dpctl.tensor.asarray(table), copy=False)
79-
else:
80-
return xp.asarray(table)
81-
82-
except ImportError:
83-
84-
def _table_to_array(table, xp=None):
85-
return xp.asarray(table)
86-
8766
def convert_one_from_table(table, sycl_queue=None, sua_iface=None, xp=None):
8867
# Currently only `__sycl_usm_array_interface__` protocol used to
8968
# convert into dpnp/dpctl tensors.
@@ -102,7 +81,14 @@ def convert_one_from_table(table, sycl_queue=None, sua_iface=None, xp=None):
10281
backend.from_table(table), usm_type="device", sycl_queue=sycl_queue
10382
)
10483
else:
105-
return _table_to_array(table, xp=xp)
84+
# By default DPNP ndarray created with a copy.
85+
# TODO:
86+
# investigate why dpnp.array(table, copy=False) doesn't work.
87+
# Work around with using dpctl.tensor.asarray.
88+
if isinstance(xp, ModuleType) and xp.__name__ == "dpnp":
89+
return xp.array(xp.dpctl.tensor.asarray(table), copy=False)
90+
else:
91+
return xp.asarray(table)
10692

10793
return backend.from_table(table)
10894

onedal/datatypes/_sycl_usm.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# ==============================================================================
2+
# Copyright Contributors to the oneDAL Project
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
# ==============================================================================
16+
17+
from collections.abc import Iterable
18+
19+
import numpy as np
20+
21+
from ..utils._third_party import lazy_import
22+
23+
24+
@lazy_import("dpctl.memory", "dpctl.tensor")
25+
def array_to_usm(memory, tensor, queue, array):
26+
try:
27+
mem = memory.MemoryUSMDevice(array.nbytes, queue=queue)
28+
mem.copy_from_host(array.tobytes())
29+
return tensor.usm_ndarray(array.shape, array.dtype, buffer=mem)
30+
except ValueError as e:
31+
# ValueError will raise if device does not support the dtype
32+
# retry with float32 (needed for fp16 and fp64 support issues)
33+
# try again as float32, if it is a float32 just raise the error.
34+
if array.dtype == np.float32:
35+
raise e
36+
return _array_to_usm(queue, array.astype(np.float32))
37+
38+
39+
@lazy_import("dpnp", "dpctl.tensor")
40+
def to_dpnp(dpnp, tensor, array):
41+
if isinstance(array, tensor.usm_ndarray):
42+
return dpnp.array(array, copy=False)
43+
else:
44+
return array
45+
46+
47+
def copy_to_usm(queue, array):
48+
if hasattr(array, "__array__"):
49+
return array_to_usm(queue, array)
50+
else:
51+
if isinstance(array, Iterable):
52+
array = [copy_to_usm(queue, i) for i in array]
53+
return array
54+
55+
56+
def copy_to_dpnp(queue, array):
57+
if hasattr(array, "__array__"):
58+
return to_dpnp(array_to_usm(queue, array))
59+
else:
60+
if isinstance(array, Iterable):
61+
array = [copy_to_dpnp(queue, i) for i in array]
62+
return array
63+
64+
65+
@lazy_import("dpctl.memory")
66+
def usm_to_numpy(memorymod, item, usm_iface):
67+
buffer = memorymod.as_usm_memory(item).copy_to_host()
68+
order = "C"
69+
if usm_iface["strides"] is not None and len(usm_iface["strides"]) > 1:
70+
if usm_iface["strides"][0] < usm_iface["strides"][1]:
71+
order = "F"
72+
item = np.ndarray(
73+
shape=usm_iface["shape"],
74+
dtype=usm_iface["typestr"],
75+
buffer=buffer,
76+
order=order,
77+
)
78+
return item

onedal/datatypes/tests/common.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,16 @@
1414
# limitations under the License.
1515
# ===============================================================================
1616

17-
from onedal.utils._dpep_helpers import dpctl_available, dpnp_available
17+
from onedal.utils._third_party import dpctl_available
1818

19-
if dpnp_available:
19+
try:
2020
import dpnp
2121

22+
dpnp_available = True
23+
except ImportError:
24+
dpnp_available = False
25+
26+
2227
if dpctl_available:
2328
import dpctl
2429
from dpctl.tensor import usm_ndarray

onedal/datatypes/tests/test_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from onedal import _default_backend, _dpc_backend
2525
from onedal._device_offload import supports_queue
2626
from onedal.datatypes import from_table, to_table
27-
from onedal.utils._dpep_helpers import dpctl_available
27+
from onedal.utils._third_party import dpctl_available
2828

2929
backend = _dpc_backend or _default_backend
3030

onedal/ensemble/forest.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from ..common._mixin import ClassifierMixin, RegressorMixin
3535
from ..datatypes import from_table, to_table
3636
from ..utils._array_api import _get_sycl_namespace
37-
from ..utils._dpep_helpers import get_unique_values_with_dpep
3837
from ..utils.validation import (
3938
_check_array,
4039
_check_n_features,
@@ -315,7 +314,13 @@ def _fit(self, X, y, sample_weight):
315314
else:
316315
if sua_iface is not None:
317316
queue = X.sycl_queue
318-
self.classes_ = get_unique_values_with_dpep(y)
317+
# try catch needed for raw_inputs + array_api data where unlike
318+
# numpy the way to yield unique values is via `unique_values`
319+
# This should be removed when refactored for gpu zero-copy
320+
try:
321+
self.classes_ = xp.unique(y)
322+
except AttributeError:
323+
self.classes_ = xp.unique_values(y)
319324

320325
self.n_features_in_ = X.shape[1]
321326

onedal/linear_model/logistic_regression.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from ..common._mixin import ClassifierMixin
3030
from ..datatypes import from_table, to_table
3131
from ..utils._array_api import _get_sycl_namespace
32-
from ..utils._dpep_helpers import get_unique_values_with_dpep
3332
from ..utils.validation import (
3433
_check_array,
3534
_check_n_features,
@@ -96,7 +95,15 @@ def _fit(self, X, y):
9695
self.classes_, y = np.unique(y, return_inverse=True)
9796
y = y.astype(dtype=np.int32)
9897
else:
99-
self.classes_ = get_unique_values_with_dpep(y)
98+
_, xp, _ = _get_sycl_namespace(X)
99+
# try catch needed for raw_inputs + array_api data where unlike
100+
# numpy the way to yield unique values is via `unique_values`
101+
# This should be removed when refactored for gpu zero-copy
102+
try:
103+
self.classes_ = xp.unique(y)
104+
except AttributeError:
105+
self.classes_ = xp.unique_values(y)
106+
100107
n_classes = len(self.classes_)
101108
if n_classes != 2:
102109
raise ValueError("Only binary classification is supported")

onedal/tests/utils/_dataframes_support.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,18 @@
2020

2121
from sklearnex import get_config
2222

23-
from ...utils._dpep_helpers import dpctl_available, dpnp_available
23+
from ...utils._third_party import dpctl_available
2424

2525
if dpctl_available:
2626
import dpctl.tensor as dpt
2727

28-
if dpnp_available:
28+
try:
2929
import dpnp
3030

31+
dpnp_available = True
32+
except ImportError:
33+
dpnp_available = False
34+
3135
try:
3236
# This should be lazy imported in the
3337
# future along with other popular

0 commit comments

Comments
 (0)