|
14 | 14 | # limitations under the License. |
15 | 15 | # ============================================================================== |
16 | 16 |
|
| 17 | +from collections.abc import Callable |
17 | 18 | from functools import wraps |
| 19 | +from typing import Any, Union |
18 | 20 |
|
19 | 21 | from onedal._device_offload import _copy_to_usm, _transfer_to_host |
20 | 22 | from onedal.utils import _sycl_queue_manager as QM |
21 | | -from onedal.utils._array_api import _asarray |
| 23 | +from onedal.utils._array_api import _asarray, _is_numpy_namespace |
22 | 24 | from onedal.utils._dpep_helpers import dpnp_available |
23 | 25 |
|
24 | 26 | if dpnp_available: |
25 | 27 | import dpnp |
26 | 28 | from onedal.utils._array_api import _convert_to_dpnp |
27 | 29 |
|
28 | | -from ._config import get_config |
29 | | - |
30 | | - |
31 | | -def _get_backend(obj, queue, method_name, *data): |
32 | | - with QM.manage_global_queue(queue, *data) as queue: |
33 | | - cpu_device = queue is None or getattr(queue.sycl_device, "is_cpu", True) |
34 | | - gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False) |
35 | | - |
36 | | - if cpu_device: |
37 | | - patching_status = obj._onedal_cpu_supported(method_name, *data) |
38 | | - if patching_status.get_status(): |
39 | | - return "onedal", patching_status |
40 | | - else: |
41 | | - return "sklearn", patching_status |
| 30 | +from ._config import config_context, get_config, set_config |
| 31 | +from ._utils import PatchingConditionsChain, get_tags |
| 32 | +from .base import oneDALEstimator |
| 33 | + |
| 34 | + |
| 35 | +def _get_backend( |
| 36 | + obj: type[oneDALEstimator], method_name: str, *data |
| 37 | +) -> tuple[Union[bool, None], PatchingConditionsChain]: |
| 38 | + """This function verifies the hardware conditions, data characteristics, and |
| 39 | + estimator parameters necessary for offloading computation to oneDAL. The status |
| 40 | + of this patching is returned as a PatchingConditionsChain object along with a |
| 41 | + boolean flag signaling whether the computation can be offloaded to oneDAL or not. |
| 42 | + It is assumed that the queue (which determined what hardware to possibly use for |
| 43 | + oneDAL) has been previously and extensively collected (i.e. the data has already |
| 44 | + been checked using onedal's SyclQueueManager for queues).""" |
| 45 | + queue = QM.get_global_queue() |
| 46 | + cpu_device = queue is None or getattr(queue.sycl_device, "is_cpu", True) |
| 47 | + gpu_device = queue is not None and getattr(queue.sycl_device, "is_gpu", False) |
| 48 | + |
| 49 | + if cpu_device: |
| 50 | + patching_status = obj._onedal_cpu_supported(method_name, *data) |
| 51 | + return patching_status.get_status(), patching_status |
| 52 | + |
| 53 | + if gpu_device: |
| 54 | + patching_status = obj._onedal_gpu_supported(method_name, *data) |
| 55 | + if ( |
| 56 | + not patching_status.get_status() |
| 57 | + and (config := get_config())["allow_fallback_to_host"] |
| 58 | + ): |
| 59 | + QM.fallback_to_host() |
| 60 | + return None, patching_status |
| 61 | + return patching_status.get_status(), patching_status |
| 62 | + |
| 63 | + raise RuntimeError("Device support is not implemented for the supplied data type.") |
| 64 | + |
| 65 | + |
| 66 | +if "array_api_dispatch" in get_config(): |
| 67 | + _array_api_offload = lambda: get_config()["array_api_dispatch"] |
| 68 | +else: |
| 69 | + _array_api_offload = lambda: False |
| 70 | + |
| 71 | + |
| 72 | +def dispatch( |
| 73 | + obj: type[oneDALEstimator], |
| 74 | + method_name: str, |
| 75 | + branches: dict[Callable, Callable], |
| 76 | + *args, |
| 77 | + **kwargs, |
| 78 | +) -> Any: |
| 79 | + """Dispatch object method call to oneDAL if conditionally possible. |
| 80 | + Depending on support conditions, oneDAL will be called, otherwise it will |
| 81 | + fall back to calling scikit-learn. Dispatching to oneDAL can be influenced |
| 82 | + by the 'use_raw_input' or 'allow_fallback_to_host' config parameters. |
| 83 | +
|
| 84 | + Parameters |
| 85 | + ---------- |
| 86 | + obj : object |
| 87 | + sklearnex object which inherits from oneDALEstimator and contains |
| 88 | + ``onedal_cpu_supported`` and ``onedal_gpu_supported`` methods which |
| 89 | + evaluate oneDAL support. |
| 90 | +
|
| 91 | + method_name : string |
| 92 | + name of method to be evaluated for oneDAL support |
| 93 | +
|
| 94 | + branches : dict |
| 95 | + dictionary containing functions to be called. Only keys 'sklearn' and |
| 96 | + 'onedal' are used which should contain the relevant scikit-learn and |
| 97 | + onedal object methods respectively. All functions should accept the |
| 98 | + inputs from *args and **kwargs. Additionally, the onedal object method |
| 99 | + must accept a 'queue' keyword. |
| 100 | +
|
| 101 | + *args : tuple |
| 102 | + arguments to be supplied to the dispatched method |
| 103 | +
|
| 104 | + **kwargs : dict |
| 105 | + keyword arguments to be supplied to the dispatched method |
| 106 | +
|
| 107 | + Returns |
| 108 | + ------- |
| 109 | + unknown : object |
| 110 | + Returned object dependent on the supplied branches. Implicitly the returned |
| 111 | + object types should match for the sklearn and onedal object methods. |
| 112 | + """ |
42 | 113 |
|
43 | | - allow_fallback_to_host = get_config()["allow_fallback_to_host"] |
| 114 | + if get_config()["use_raw_input"]: |
| 115 | + return branches["onedal"](obj, *args, **kwargs) |
44 | 116 |
|
45 | | - if gpu_device: |
46 | | - patching_status = obj._onedal_gpu_supported(method_name, *data) |
47 | | - if patching_status.get_status(): |
48 | | - return "onedal", patching_status |
49 | | - else: |
50 | | - QM.remove_global_queue() |
51 | | - if allow_fallback_to_host: |
52 | | - patching_status = obj._onedal_cpu_supported(method_name, *data) |
53 | | - if patching_status.get_status(): |
54 | | - return "onedal", patching_status |
55 | | - else: |
56 | | - return "sklearn", patching_status |
57 | | - else: |
58 | | - return "sklearn", patching_status |
59 | | - |
60 | | - raise RuntimeError("Device support is not implemented") |
61 | | - |
62 | | - |
63 | | -def get_array_api_support_tag(estimator): |
64 | | - """Gets the value of the 'array_api_support' tag from the estimator |
65 | | - using correct code path depending on the scikit-learn version.""" |
66 | | - if hasattr(estimator, "__sklearn_tags__"): |
67 | | - return estimator.__sklearn_tags__().array_api_support |
68 | | - elif hasattr(estimator, "_get_tags"): |
69 | | - tags = estimator._get_tags() |
70 | | - if "array_api_support" in tags: |
71 | | - return tags["array_api_support"] |
72 | | - return False |
73 | | - |
74 | | - |
75 | | -def dispatch(obj, method_name, branches, *args, **kwargs): |
76 | | - if get_config()["use_raw_input"] is False: |
77 | | - with QM.manage_global_queue(None, *args) as queue: |
78 | | - has_usm_data_for_args, hostargs = _transfer_to_host(*args) |
79 | | - has_usm_data_for_kwargs, hostvalues = _transfer_to_host(*kwargs.values()) |
80 | | - hostkwargs = dict(zip(kwargs.keys(), hostvalues)) |
81 | | - |
82 | | - backend, patching_status = _get_backend(obj, queue, method_name, *hostargs) |
83 | | - has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs |
84 | | - if backend == "onedal": |
85 | | - # Host args only used before onedal backend call. |
86 | | - # Device will be offloaded when onedal backend will be called. |
| 117 | + # Determine if array_api dispatching is enabled, and if estimator is capable |
| 118 | + onedal_array_api = _array_api_offload() and get_tags(obj).onedal_array_api |
| 119 | + sklearn_array_api = _array_api_offload() and get_tags(obj).array_api_support |
| 120 | + |
| 121 | + # backend can only be a boolean or None, None signifies an unverified backend |
| 122 | + backend: "bool | None" = None |
| 123 | + |
| 124 | + # config context needs to be saved, as the sycl_queue_manager interacts with |
| 125 | + # target_offload, which can regenerate a GPU queue later on. Therefore if a |
| 126 | + # fallback occurs, then the state of target_offload must be set to default |
| 127 | + # so that later use of get_global_queue only sends to host. We must modify |
| 128 | + # the target offload settings, but we must also set the original value at the |
| 129 | + # end, hence the need of a contextmanager. |
| 130 | + with QM.manage_global_queue(None, *args): |
| 131 | + if onedal_array_api: |
| 132 | + backend, patching_status = _get_backend(obj, method_name, *args) |
| 133 | + if backend: |
| 134 | + queue = QM.get_global_queue() |
87 | 135 | patching_status.write_log(queue=queue, transferred_to_host=False) |
88 | | - return branches[backend](obj, *hostargs, **hostkwargs, queue=queue) |
89 | | - if backend == "sklearn": |
90 | | - if ( |
91 | | - "array_api_dispatch" in get_config() |
92 | | - and get_config()["array_api_dispatch"] |
93 | | - and get_array_api_support_tag(obj) |
94 | | - and not has_usm_data |
95 | | - ): |
96 | | - # USM ndarrays are also excluded for the fallback Array API. Currently, DPNP.ndarray is |
97 | | - # not compliant with the Array API standard, and DPCTL usm_ndarray Array API is compliant, |
98 | | - # except for the linalg module. There is no guarantee that stock scikit-learn will |
99 | | - # work with such input data. The condition will be updated after DPNP.ndarray and |
100 | | - # DPCTL usm_ndarray enabling for conformance testing and these arrays supportance |
101 | | - # of the fallback cases. |
102 | | - # If `array_api_dispatch` enabled and array api is supported for the stock scikit-learn, |
103 | | - # then raw inputs are used for the fallback. |
104 | | - patching_status.write_log(transferred_to_host=False) |
105 | | - return branches[backend](obj, *args, **kwargs) |
106 | | - else: |
107 | | - patching_status.write_log() |
108 | | - return branches[backend](obj, *hostargs, **hostkwargs) |
109 | | - raise RuntimeError( |
110 | | - f"Undefined backend {backend} in " f"{obj.__class__.__name__}.{method_name}" |
111 | | - ) |
112 | | - else: |
113 | | - return branches["onedal"](obj, *args, **kwargs) |
| 136 | + return branches["onedal"](obj, *args, **kwargs, queue=queue) |
| 137 | + elif sklearn_array_api and backend is False: |
| 138 | + patching_status.write_log(transferred_to_host=False) |
| 139 | + return branches["sklearn"](obj, *args, **kwargs) |
| 140 | + |
| 141 | + # move data to host because of multiple reasons: array_api fallback to host, |
| 142 | + # non array_api supporing oneDAL code, issues with usm support in sklearn. |
| 143 | + has_usm_data_for_args, hostargs = _transfer_to_host(*args) |
| 144 | + has_usm_data_for_kwargs, hostvalues = _transfer_to_host(*kwargs.values()) |
| 145 | + |
| 146 | + hostkwargs = dict(zip(kwargs.keys(), hostvalues)) |
| 147 | + has_usm_data = has_usm_data_for_args or has_usm_data_for_kwargs |
| 148 | + |
| 149 | + while backend is None: |
| 150 | + backend, patching_status = _get_backend(obj, method_name, *hostargs) |
| 151 | + |
| 152 | + if backend: |
| 153 | + queue = QM.get_global_queue() |
| 154 | + patching_status.write_log(queue=queue, transferred_to_host=False) |
| 155 | + return branches["onedal"](obj, *hostargs, **hostkwargs, queue=queue) |
| 156 | + else: |
| 157 | + if sklearn_array_api and not has_usm_data: |
| 158 | + # dpnp fallback is not handled properly yet. |
| 159 | + patching_status.write_log(transferred_to_host=False) |
| 160 | + return branches["sklearn"](obj, *args, **kwargs) |
| 161 | + else: |
| 162 | + patching_status.write_log() |
| 163 | + return branches["sklearn"](obj, *hostargs, **hostkwargs) |
114 | 164 |
|
115 | 165 |
|
116 | | -def wrap_output_data(func): |
| 166 | +def wrap_output_data(func: Callable) -> Callable: |
117 | 167 | """ |
118 | 168 | Converts and moves the output arrays of the decorated function |
119 | 169 | to match the input array type and device. |
120 | 170 | """ |
121 | 171 |
|
122 | 172 | @wraps(func) |
123 | | - def wrapper(self, *args, **kwargs): |
| 173 | + def wrapper(self, *args, **kwargs) -> Any: |
124 | 174 | result = func(self, *args, **kwargs) |
125 | 175 | if not (len(args) == 0 and len(kwargs) == 0): |
126 | 176 | data = (*args, *kwargs.values()) |
| 177 | + |
127 | 178 | usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None) |
128 | 179 | if usm_iface is not None: |
129 | 180 | result = _copy_to_usm(usm_iface["syclobj"], result) |
130 | 181 | if dpnp_available and isinstance(data[0], dpnp.ndarray): |
131 | 182 | result = _convert_to_dpnp(result) |
132 | 183 | return result |
133 | | - config = get_config() |
134 | | - if not ("transform_output" in config and config["transform_output"]): |
| 184 | + |
| 185 | + if get_config().get("transform_output") in ("default", None): |
135 | 186 | input_array_api = getattr(data[0], "__array_namespace__", lambda: None)() |
136 | | - if input_array_api: |
| 187 | + if input_array_api and not _is_numpy_namespace(input_array_api): |
137 | 188 | input_array_api_device = data[0].device |
138 | 189 | result = _asarray( |
139 | 190 | result, input_array_api, device=input_array_api_device |
|
0 commit comments