Skip to content

Commit 43596ee

Browse files
samir-nasibliicfaustdavid-cortes-intelahuber21
authored
ENH: Array API dispatching (#2096)
* ENH: array api dispatching added array-api-compat to test env * Deselect some scikit-learn Array API tests * deselect more tests * deselect more tests * disabled tests for * fix the deselection comment * disabled test for Ridge regression * Disabled tests and added comment * ENH: Array API dispatching * Revert adding dpctl into Array PI conformance testing added versioning for the get_nnamespace * minor refactoring onedal _array_api * add tests * addressed memory usage tests * Address some array api test fails * linting * addressed test_get_namespace * adding test case for validate_data check with Array API inputs * minor refactoring * addressed test_patch_map_match fail * Added docstrings for get_namespace * docstrings for Array API tests * updated minimal scikit-learn version for Array API dispatching * updated minimal scikit-learn version for Array API dispatching in _device_offload.py _array_api.py * fix test test_get_namespace_with_config_context * refactor onedal/datatypes/_data_conversion.py * correction for array api * Update conftest.py * introduce tags * fix imports * see if this works * really lazy logic introduction * introduce IntelEstimator * missing change in knn * recofigure logic * strip out dpnp/dpctl special code, will come back to it later * switchover * Update __init__.py * Update test_array_api.py * merge main into PR * try to fix changes: * add test: * attempt to get things running * remove tsne failure * try again * add sklearn_check_version * oops * update * another try * try again * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * Update test_common.py * fixes * attempt to re-introduce changes * Update _device_offload.py * switch get_tags * merge main into branch * remove leftover code * Update base.py * formatting * correct bad sklearn recommendation * remove irrelevant new tests * fix mistake * lint fix * add fallback test * make test fixes * formatting * rejigger the logic again * Update test_config.py * interim fix * Update test_config.py * attempt to get test_config test to skip * fix test * interim fixes * fixes for IncrementalEmpiricalCovariance * fixes for wrap_output_data * solve issues related to older numpy * Update test_run_to_run_stability.py * Update __init__.py * Update __init__.py * Update test_run_to_run_stability.py * Update base.py * Update test_run_to_run_stability.py * Update test_config.py * Update test_patching.py * Update test_patching.py * Update test_patching.py * Update k_means.py * Update k_means.py * Update k_means.py * Update _array_api.py * Update _array_api.py * Update _device_offload.py * Update _device_offload.py * Update _device_offload.py * Update _device_offload.py * Update test_patching.py * Update test_patching.py * Update test_patching.py * local verify doclinks * fix doclink * fixes * remove print statement * fix changes from scikit-learn/scikit-learn#29774 * Update sklearnex/_device_offload.py Co-authored-by: david-cortes-intel <[email protected]> * add type hints and fix docs for dispatch * remove change from local testing * be more explicit with a type * python3.9 fix * change language which was bothering me * formatting * This should fix it (I think) * Apply suggestions from code review Co-authored-by: david-cortes-intel <[email protected]> * Update base.py * Update __init__.py * remove hack on config * Apply suggestions from code review Co-authored-by: david-cortes-intel <[email protected]> * make suggested changes * fix import in test_patching * Update test_patching.py * Update build-and-test-lnx.yml * Update build-and-test-lnx.yml * Update build-and-test-lnx.yml * Apply suggestions from code review Co-authored-by: david-cortes-intel <[email protected]> * Apply suggestions from code review Co-authored-by: Andreas Huber <[email protected]> Co-authored-by: david-cortes-intel <[email protected]> * formatting and ABC change * fix gpu test * remove workaround --------- Co-authored-by: Faust, Ian <[email protected]> Co-authored-by: Ian Faust <[email protected]> Co-authored-by: david-cortes-intel <[email protected]> Co-authored-by: Andreas Huber <[email protected]>
1 parent ae4febb commit 43596ee

34 files changed

+585
-293
lines changed

deselected_tests.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,14 @@ deselected_tests:
309309
# Failure occurs in python3.9 on windows CPU only - not easy to reproduce
310310
- ensemble/tests/test_weight_boosting.py::test_estimator >= 1.4 win32
311311

312+
# array_api dispatching for oneDAL is recognized using the sklearn tag system. This is an expected
313+
# break in sklearn conformance as additional non-default tags cause errors in this test. This
314+
# non-conformance does not impact the verification system of sklearn. A new system (sklearn >=1.6)
315+
# allows for public tag addition. Sklearn also acknowledges the abuse of the old system by others,
316+
# meaning its not too impactful to use it so long as the default keys and value types set by sklearn
317+
# are respected.
318+
- tests/test_common.py::test_valid_tag_types <1.6
319+
312320
# --------------------------------------------------------
313321
# No need to test daal4py patching
314322
reduced_tests:

onedal/_device_offload.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,9 @@ def wrapper_impl(*args, **kwargs):
198198
result = _convert_to_dpnp(result)
199199
return result
200200

201-
if not get_config().get("transform_output"):
201+
if get_config().get("transform_output") in ("default", None):
202202
input_array_api = getattr(data[0], "__array_namespace__", lambda: None)()
203-
if input_array_api:
203+
if input_array_api and not _is_numpy_namespace(input_array_api):
204204
input_array_api_device = data[0].device
205205
result = _asarray(result, input_array_api, device=input_array_api_device)
206206
return result

onedal/utils/_array_api.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,19 @@ def _convert_to_dpnp(array):
4040
return array
4141

4242

43+
def _supports_buffer_protocol(obj):
44+
# the array_api standard mandates conversion with the buffer protocol,
45+
# which can only be checked via a try-catch in native python
46+
try:
47+
memoryview(obj)
48+
except TypeError:
49+
return False
50+
return True
51+
52+
4353
def _asarray(data, xp, *args, **kwargs):
4454
"""Converted input object to array format of xp namespace provided."""
45-
if hasattr(data, "__array_namespace__"):
55+
if hasattr(data, "__array_namespace__") or _supports_buffer_protocol(data):
4656
return xp.asarray(data, *args, **kwargs)
4757
elif isinstance(data, Iterable):
4858
if isinstance(data, tuple):

onedal/utils/_sycl_queue_manager.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@
2626

2727
SyclQueue = getattr(_dpc_backend, "SyclQueue", None)
2828

29+
# This special object signifies that the queue system should be
30+
# disabled. It will force computation to host. This occurs when the
31+
# global queue is set to this value (and therefore should not be
32+
# modified).
33+
__fallback_queue = object()
2934
# single instance of global queue
3035
__global_queue = None
3136

@@ -46,8 +51,11 @@ def __create_sycl_queue(target):
4651
def get_global_queue():
4752
"""Get the global queue. Retrieve it from the config if not set."""
4853
if (queue := __global_queue) is not None:
49-
if SyclQueue and not isinstance(queue, SyclQueue):
50-
raise ValueError("Global queue is not a SyclQueue object.")
54+
if SyclQueue:
55+
if queue is __fallback_queue:
56+
return None
57+
elif not isinstance(queue, SyclQueue):
58+
raise ValueError("Global queue is not a SyclQueue object.")
5159
return queue
5260

5361
target = _get_config()["target_offload"]
@@ -73,6 +81,12 @@ def update_global_queue(queue):
7381
__global_queue = queue
7482

7583

84+
def fallback_to_host():
85+
"""Enforce a host queue."""
86+
global __global_queue
87+
__global_queue = __fallback_queue
88+
89+
7690
def from_data(*data):
7791
"""Extract the queue from provided data. This updates the global queue as well."""
7892
for item in data:

sklearnex/_device_offload.py

Lines changed: 140 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -14,126 +14,177 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17+
from collections.abc import Callable
1718
from functools import wraps
19+
from typing import Any, Union
1820

1921
from onedal._device_offload import _copy_to_usm, _transfer_to_host
2022
from onedal.utils import _sycl_queue_manager as QM
21-
from onedal.utils._array_api import _asarray
23+
from onedal.utils._array_api import _asarray, _is_numpy_namespace
2224
from onedal.utils._dpep_helpers import dpnp_available
2325

2426
if dpnp_available:
2527
import dpnp
2628
from onedal.utils._array_api import _convert_to_dpnp
2729

28-
from ._config import get_config
29-
30-
31-
def _get_backend(obj, queue, method_name, *data):
32-
with QM.manage_global_queue(queue, *data) as queue:
33-
cpu_device = queue is None or getattr(queue.sycl_device, "is_cpu", True)
34-
gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False)
35-
36-
if cpu_device:
37-
patching_status = obj._onedal_cpu_supported(method_name, *data)
38-
if patching_status.get_status():
39-
return "onedal", patching_status
40-
else:
41-
return "sklearn", patching_status
30+
from ._config import config_context, get_config, set_config
31+
from ._utils import PatchingConditionsChain, get_tags
32+
from .base import oneDALEstimator
33+
34+
35+
def _get_backend(
36+
obj: type[oneDALEstimator], method_name: str, *data
37+
) -> tuple[Union[bool, None], PatchingConditionsChain]:
38+
"""This function verifies the hardware conditions, data characteristics, and
39+
estimator parameters necessary for offloading computation to oneDAL. The status
40+
of this patching is returned as a PatchingConditionsChain object along with a
41+
boolean flag signaling whether the computation can be offloaded to oneDAL or not.
42+
It is assumed that the queue (which determined what hardware to possibly use for
43+
oneDAL) has been previously and extensively collected (i.e. the data has already
44+
been checked using onedal's SyclQueueManager for queues)."""
45+
queue = QM.get_global_queue()
46+
cpu_device = queue is None or getattr(queue.sycl_device, "is_cpu", True)
47+
gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False)
48+
49+
if cpu_device:
50+
patching_status = obj._onedal_cpu_supported(method_name, *data)
51+
return patching_status.get_status(), patching_status
52+
53+
if gpu_device:
54+
patching_status = obj._onedal_gpu_supported(method_name, *data)
55+
if (
56+
not patching_status.get_status()
57+
and (config := get_config())["allow_fallback_to_host"]
58+
):
59+
QM.fallback_to_host()
60+
return None, patching_status
61+
return patching_status.get_status(), patching_status
62+
63+
raise RuntimeError("Device support is not implemented for the supplied data type.")
64+
65+
66+
if "array_api_dispatch" in get_config():
67+
_array_api_offload = lambda: get_config()["array_api_dispatch"]
68+
else:
69+
_array_api_offload = lambda: False
70+
71+
72+
def dispatch(
73+
obj: type[oneDALEstimator],
74+
method_name: str,
75+
branches: dict[Callable, Callable],
76+
*args,
77+
**kwargs,
78+
) -> Any:
79+
"""Dispatch object method call to oneDAL if conditionally possible.
80+
Depending on support conditions, oneDAL will be called, otherwise it will
81+
fall back to calling scikit-learn. Dispatching to oneDAL can be influenced
82+
by the 'use_raw_input' or 'allow_fallback_to_host' config parameters.
83+
84+
Parameters
85+
----------
86+
obj : object
87+
sklearnex object which inherits from oneDALEstimator and contains
88+
``onedal_cpu_supported`` and ``onedal_gpu_supported`` methods which
89+
evaluate oneDAL support.
90+
91+
method_name : string
92+
name of method to be evaluated for oneDAL support
93+
94+
branches : dict
95+
dictionary containing functions to be called. Only keys 'sklearn' and
96+
'onedal' are used which should contain the relevant scikit-learn and
97+
onedal object methods respectively. All functions should accept the
98+
inputs from *args and **kwargs. Additionally, the onedal object method
99+
must accept a 'queue' keyword.
100+
101+
*args : tuple
102+
arguments to be supplied to the dispatched method
103+
104+
**kwargs : dict
105+
keyword arguments to be supplied to the dispatched method
106+
107+
Returns
108+
-------
109+
unknown : object
110+
Returned object dependent on the supplied branches. Implicitly the returned
111+
object types should match for the sklearn and onedal object methods.
112+
"""
42113

43-
allow_fallback_to_host = get_config()["allow_fallback_to_host"]
114+
if get_config()["use_raw_input"]:
115+
return branches["onedal"](obj, *args, **kwargs)
44116

45-
if gpu_device:
46-
patching_status = obj._onedal_gpu_supported(method_name, *data)
47-
if patching_status.get_status():
48-
return "onedal", patching_status
49-
else:
50-
QM.remove_global_queue()
51-
if allow_fallback_to_host:
52-
patching_status = obj._onedal_cpu_supported(method_name, *data)
53-
if patching_status.get_status():
54-
return "onedal", patching_status
55-
else:
56-
return "sklearn", patching_status
57-
else:
58-
return "sklearn", patching_status
59-
60-
raise RuntimeError("Device support is not implemented")
61-
62-
63-
def get_array_api_support_tag(estimator):
64-
"""Gets the value of the 'array_api_support' tag from the estimator
65-
using correct code path depending on the scikit-learn version."""
66-
if hasattr(estimator, "__sklearn_tags__"):
67-
return estimator.__sklearn_tags__().array_api_support
68-
elif hasattr(estimator, "_get_tags"):
69-
tags = estimator._get_tags()
70-
if "array_api_support" in tags:
71-
return tags["array_api_support"]
72-
return False
73-
74-
75-
def dispatch(obj, method_name, branches, *args, **kwargs):
76-
if get_config()["use_raw_input"] is False:
77-
with QM.manage_global_queue(None, *args) as queue:
78-
has_usm_data_for_args, hostargs = _transfer_to_host(*args)
79-
has_usm_data_for_kwargs, hostvalues = _transfer_to_host(*kwargs.values())
80-
hostkwargs = dict(zip(kwargs.keys(), hostvalues))
81-
82-
backend, patching_status = _get_backend(obj, queue, method_name, *hostargs)
83-
has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs
84-
if backend == "onedal":
85-
# Host args only used before onedal backend call.
86-
# Device will be offloaded when onedal backend will be called.
117+
# Determine if array_api dispatching is enabled, and if estimator is capable
118+
onedal_array_api = _array_api_offload() and get_tags(obj).onedal_array_api
119+
sklearn_array_api = _array_api_offload() and get_tags(obj).array_api_support
120+
121+
# backend can only be a boolean or None, None signifies an unverified backend
122+
backend: "bool | None" = None
123+
124+
# config context needs to be saved, as the sycl_queue_manager interacts with
125+
# target_offload, which can regenerate a GPU queue later on. Therefore if a
126+
# fallback occurs, then the state of target_offload must be set to default
127+
# so that later use of get_global_queue only sends to host. We must modify
128+
# the target offload settings, but we must also set the original value at the
129+
# end, hence the need of a contextmanager.
130+
with QM.manage_global_queue(None, *args):
131+
if onedal_array_api:
132+
backend, patching_status = _get_backend(obj, method_name, *args)
133+
if backend:
134+
queue = QM.get_global_queue()
87135
patching_status.write_log(queue=queue, transferred_to_host=False)
88-
return branches[backend](obj, *hostargs, **hostkwargs, queue=queue)
89-
if backend == "sklearn":
90-
if (
91-
"array_api_dispatch" in get_config()
92-
and get_config()["array_api_dispatch"]
93-
and get_array_api_support_tag(obj)
94-
and not has_usm_data
95-
):
96-
# USM ndarrays are also excluded for the fallback Array API. Currently, DPNP.ndarray is
97-
# not compliant with the Array API standard, and DPCTL usm_ndarray Array API is compliant,
98-
# except for the linalg module. There is no guarantee that stock scikit-learn will
99-
# work with such input data. The condition will be updated after DPNP.ndarray and
100-
# DPCTL usm_ndarray enabling for conformance testing and these arrays supportance
101-
# of the fallback cases.
102-
# If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn,
103-
# then raw inputs are used for the fallback.
104-
patching_status.write_log(transferred_to_host=False)
105-
return branches[backend](obj, *args, **kwargs)
106-
else:
107-
patching_status.write_log()
108-
return branches[backend](obj, *hostargs, **hostkwargs)
109-
raise RuntimeError(
110-
f"Undefined backend {backend} in " f"{obj.__class__.__name__}.{method_name}"
111-
)
112-
else:
113-
return branches["onedal"](obj, *args, **kwargs)
136+
return branches["onedal"](obj, *args, **kwargs, queue=queue)
137+
elif sklearn_array_api and backend is False:
138+
patching_status.write_log(transferred_to_host=False)
139+
return branches["sklearn"](obj, *args, **kwargs)
140+
141+
# move data to host because of multiple reasons: array_api fallback to host,
142+
# non array_api supporing oneDAL code, issues with usm support in sklearn.
143+
has_usm_data_for_args, hostargs = _transfer_to_host(*args)
144+
has_usm_data_for_kwargs, hostvalues = _transfer_to_host(*kwargs.values())
145+
146+
hostkwargs = dict(zip(kwargs.keys(), hostvalues))
147+
has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs
148+
149+
while backend is None:
150+
backend, patching_status = _get_backend(obj, method_name, *hostargs)
151+
152+
if backend:
153+
queue = QM.get_global_queue()
154+
patching_status.write_log(queue=queue, transferred_to_host=False)
155+
return branches["onedal"](obj, *hostargs, **hostkwargs, queue=queue)
156+
else:
157+
if sklearn_array_api and not has_usm_data:
158+
# dpnp fallback is not handled properly yet.
159+
patching_status.write_log(transferred_to_host=False)
160+
return branches["sklearn"](obj, *args, **kwargs)
161+
else:
162+
patching_status.write_log()
163+
return branches["sklearn"](obj, *hostargs, **hostkwargs)
114164

115165

116-
def wrap_output_data(func):
166+
def wrap_output_data(func: Callable) -> Callable:
117167
"""
118168
Converts and moves the output arrays of the decorated function
119169
to match the input array type and device.
120170
"""
121171

122172
@wraps(func)
123-
def wrapper(self, *args, **kwargs):
173+
def wrapper(self, *args, **kwargs) -> Any:
124174
result = func(self, *args, **kwargs)
125175
if not (len(args) == 0 and len(kwargs) == 0):
126176
data = (*args, *kwargs.values())
177+
127178
usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None)
128179
if usm_iface is not None:
129180
result = _copy_to_usm(usm_iface["syclobj"], result)
130181
if dpnp_available and isinstance(data[0], dpnp.ndarray):
131182
result = _convert_to_dpnp(result)
132183
return result
133-
config = get_config()
134-
if not ("transform_output" in config and config["transform_output"]):
184+
185+
if get_config().get("transform_output") in ("default", None):
135186
input_array_api = getattr(data[0], "__array_namespace__", lambda: None)()
136-
if input_array_api:
187+
if input_array_api and not _is_numpy_namespace(input_array_api):
137188
input_array_api_device = data[0].device
138189
result = _asarray(
139190
result, input_array_api, device=input_array_api_device

0 commit comments

Comments
 (0)