Skip to content

Commit c0181f5

Browse files
ENH: functional support for Array API (#1861)
* updates on `_device_offload` modules both of `onedal4py` and `sklearne`x. * extensions for `sklearnex._device_offload.wrap_output_data` decorator and `sklearnex._device_offload.dispatch` function. * `onedal._device_offload._transfer_to_host` extended for the Array API inputs. * `onedal._device_offload.support_usm_ndarray` renamed to `onedal._device_offload.support_input_format` extend for Array API inputs handling. (several files changed for this this renaming) * onedal4py's testing utility `get_dataframes_and_queues` extended for the array_api inputs _array_api modules added in the both onedal4py and sklearnex.\ * `onedal.utils._array_api` module includes several array api utilities and `_get_sycl_namespace` that is used in the `get_namespace` (from sklearnex.utils._arrray_api), that returns supported array api namespace.
1 parent a92177f commit c0181f5

File tree

30 files changed

+243
-141
lines changed

30 files changed

+243
-141
lines changed

onedal/_device_offload.py

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
from functools import wraps
2020

2121
import numpy as np
22+
from sklearn import get_config
2223

2324
from ._config import _get_config
25+
from .utils._array_api import _asarray, _is_numpy_namespace
2426

2527
try:
2628
from dpctl import SyclQueue
@@ -34,6 +36,8 @@
3436
try:
3537
import dpnp
3638

39+
from .utils._array_api import _convert_to_dpnp
40+
3741
dpnp_available = True
3842
except ImportError:
3943
dpnp_available = False
@@ -94,6 +98,7 @@ def _transfer_to_host(queue, *data):
9498
host_data = []
9599
for item in data:
96100
usm_iface = getattr(item, "__sycl_usm_array_interface__", None)
101+
array_api = getattr(item, "__array_namespace__", lambda: None)()
97102
if usm_iface is not None:
98103
if not dpctl_available:
99104
raise RuntimeError(
@@ -120,6 +125,11 @@ def _transfer_to_host(queue, *data):
120125
order=order,
121126
)
122127
has_usm_data = True
128+
elif array_api and not _is_numpy_namespace(array_api):
129+
# `copy`` param for the `asarray`` is not setted.
130+
# The object is copied only if needed.
131+
item = np.asarray(item)
132+
has_host_data = True
123133
else:
124134
has_host_data = True
125135

@@ -153,34 +163,17 @@ def _get_host_inputs(*args, **kwargs):
153163
return q, hostargs, hostkwargs
154164

155165

156-
def _extract_usm_iface(*args, **kwargs):
157-
allargs = (*args, *kwargs.values())
158-
if len(allargs) == 0:
159-
return None
160-
return getattr(allargs[0], "__sycl_usm_array_interface__", None)
161-
162-
163166
def _run_on_device(func, obj=None, *args, **kwargs):
164167
if obj is not None:
165168
return func(obj, *args, **kwargs)
166169
return func(*args, **kwargs)
167170

168171

169-
if dpnp_available:
170-
171-
def _convert_to_dpnp(array):
172-
if isinstance(array, usm_ndarray):
173-
return dpnp.array(array, copy=False)
174-
elif isinstance(array, Iterable):
175-
for i in range(len(array)):
176-
array[i] = _convert_to_dpnp(array[i])
177-
return array
178-
179-
180-
def support_usm_ndarray(freefunc=False, queue_param=True):
172+
def support_input_format(freefunc=False, queue_param=True):
181173
"""
182-
Handles USMArray input. Puts SYCLQueue from data to decorated function arguments.
183-
Converts output of decorated function to dpctl.tensor/dpnp.ndarray if input was of this type.
174+
Converts and moves the output arrays of the decorated function
175+
to match the input array type and device.
176+
Puts SYCLQueue from data to decorated function arguments.
184177
185178
Parameters
186179
----------
@@ -194,17 +187,29 @@ def support_usm_ndarray(freefunc=False, queue_param=True):
194187

195188
def decorator(func):
196189
def wrapper_impl(obj, *args, **kwargs):
197-
usm_iface = _extract_usm_iface(*args, **kwargs)
190+
if len(args) == 0 and len(kwargs) == 0:
191+
return _run_on_device(func, obj, *args, **kwargs)
192+
data = (*args, *kwargs.values())
198193
data_queue, hostargs, hostkwargs = _get_host_inputs(*args, **kwargs)
199194
if queue_param and not (
200195
"queue" in hostkwargs and hostkwargs["queue"] is not None
201196
):
202197
hostkwargs["queue"] = data_queue
203198
result = _run_on_device(func, obj, *hostargs, **hostkwargs)
204-
if usm_iface is not None and hasattr(result, "__array_interface__"):
199+
usm_iface = getattr(data[0], "__sycl_usm_array_interface__", None)
200+
if usm_iface is not None:
205201
result = _copy_to_usm(data_queue, result)
206-
if dpnp_available and len(args) > 0 and isinstance(args[0], dpnp.ndarray):
202+
if dpnp_available and isinstance(data[0], dpnp.ndarray):
207203
result = _convert_to_dpnp(result)
204+
return result
205+
config = get_config()
206+
if not ("transform_output" in config and config["transform_output"]):
207+
input_array_api = getattr(data[0], "__array_namespace__", lambda: None)()
208+
if input_array_api:
209+
input_array_api_device = data[0].device
210+
result = _asarray(
211+
result, input_array_api, device=input_array_api_device
212+
)
208213
return result
209214

210215
if freefunc:

onedal/spmd/basic_statistics/basic_statistics.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616

1717
from onedal.basic_statistics import BasicStatistics as BasicStatistics_Batch
1818

19-
from ..._device_offload import support_usm_ndarray
19+
from ..._device_offload import support_input_format
2020
from .._base import BaseEstimatorSPMD
2121

2222

2323
class BasicStatistics(BaseEstimatorSPMD, BasicStatistics_Batch):
24-
@support_usm_ndarray()
24+
@support_input_format()
2525
def compute(self, data, weights=None, queue=None):
2626
return super().compute(data, weights=weights, queue=queue)
2727

28-
@support_usm_ndarray()
28+
@support_input_format()
2929
def fit(self, data, sample_weight=None, queue=None):
3030
return super().fit(data, sample_weight=sample_weight, queue=queue)

onedal/spmd/cluster/kmeans.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from onedal.cluster import KMeansInit as KMeansInit_Batch
1919
from onedal.spmd.basic_statistics import BasicStatistics
2020

21-
from ..._device_offload import support_usm_ndarray
21+
from ..._device_offload import support_input_format
2222
from .._base import BaseEstimatorSPMD
2323

2424

@@ -37,15 +37,15 @@ def _get_basic_statistics_backend(self, result_options):
3737
def _get_kmeans_init(self, cluster_count, seed, algorithm):
3838
return KMeansInit(cluster_count=cluster_count, seed=seed, algorithm=algorithm)
3939

40-
@support_usm_ndarray()
40+
@support_input_format()
4141
def fit(self, X, y=None, queue=None):
4242
return super().fit(X, queue=queue)
4343

44-
@support_usm_ndarray()
44+
@support_input_format()
4545
def predict(self, X, queue=None):
4646
return super().predict(X, queue=queue)
4747

48-
@support_usm_ndarray()
48+
@support_input_format()
4949
def fit_predict(self, X, y=None, queue=None):
5050
return super().fit_predict(X, queue=queue)
5151

onedal/spmd/covariance/covariance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616

1717
from onedal.covariance import EmpiricalCovariance as EmpiricalCovariance_Batch
1818

19-
from ..._device_offload import support_usm_ndarray
19+
from ..._device_offload import support_input_format
2020
from .._base import BaseEstimatorSPMD
2121

2222

2323
class EmpiricalCovariance(BaseEstimatorSPMD, EmpiricalCovariance_Batch):
24-
@support_usm_ndarray()
24+
@support_input_format()
2525
def fit(self, X, y=None, queue=None):
2626
return super().fit(X, queue=queue)

onedal/spmd/decomposition/pca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616

1717
from onedal.decomposition.pca import PCA as PCABatch
1818

19-
from ..._device_offload import support_usm_ndarray
19+
from ..._device_offload import support_input_format
2020
from .._base import BaseEstimatorSPMD
2121

2222

2323
class PCA(BaseEstimatorSPMD, PCABatch):
24-
@support_usm_ndarray()
24+
@support_input_format()
2525
def fit(self, X, y=None, queue=None):
2626
return super().fit(X, queue=queue)

onedal/spmd/linear_model/linear_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616

1717
from onedal.linear_model import LinearRegression as LinearRegression_Batch
1818

19-
from ..._device_offload import support_usm_ndarray
19+
from ..._device_offload import support_input_format
2020
from .._base import BaseEstimatorSPMD
2121

2222

2323
class LinearRegression(BaseEstimatorSPMD, LinearRegression_Batch):
24-
@support_usm_ndarray()
24+
@support_input_format()
2525
def fit(self, X, y, queue=None):
2626
return super().fit(X, y, queue=queue)
2727

28-
@support_usm_ndarray()
28+
@support_input_format()
2929
def predict(self, X, queue=None):
3030
return super().predict(X, queue=queue)

onedal/spmd/linear_model/logistic_regression.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,23 +16,23 @@
1616

1717
from onedal.linear_model import LogisticRegression as LogisticRegression_Batch
1818

19-
from ..._device_offload import support_usm_ndarray
19+
from ..._device_offload import support_input_format
2020
from .._base import BaseEstimatorSPMD
2121

2222

2323
class LogisticRegression(BaseEstimatorSPMD, LogisticRegression_Batch):
24-
@support_usm_ndarray()
24+
@support_input_format()
2525
def fit(self, X, y, queue=None):
2626
return super().fit(X, y, queue=queue)
2727

28-
@support_usm_ndarray()
28+
@support_input_format()
2929
def predict(self, X, queue=None):
3030
return super().predict(X, queue=queue)
3131

32-
@support_usm_ndarray()
32+
@support_input_format()
3333
def predict_proba(self, X, queue=None):
3434
return super().predict_proba(X, queue=queue)
3535

36-
@support_usm_ndarray()
36+
@support_input_format()
3737
def predict_log_proba(self, X, queue=None):
3838
return super().predict_log_proba(X, queue=queue)

onedal/spmd/neighbors/neighbors.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,30 @@
1717
from onedal.neighbors import KNeighborsClassifier as KNeighborsClassifier_Batch
1818
from onedal.neighbors import KNeighborsRegressor as KNeighborsRegressor_Batch
1919

20-
from ..._device_offload import support_usm_ndarray
20+
from ..._device_offload import support_input_format
2121
from .._base import BaseEstimatorSPMD
2222

2323

2424
class KNeighborsClassifier(BaseEstimatorSPMD, KNeighborsClassifier_Batch):
25-
@support_usm_ndarray()
25+
@support_input_format()
2626
def fit(self, X, y, queue=None):
2727
return super().fit(X, y, queue=queue)
2828

29-
@support_usm_ndarray()
29+
@support_input_format()
3030
def predict(self, X, queue=None):
3131
return super().predict(X, queue=queue)
3232

33-
@support_usm_ndarray()
33+
@support_input_format()
3434
def predict_proba(self, X, queue=None):
3535
raise NotImplementedError("predict_proba not supported in distributed mode.")
3636

37-
@support_usm_ndarray()
37+
@support_input_format()
3838
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None):
3939
return super().kneighbors(X, n_neighbors, return_distance, queue=queue)
4040

4141

4242
class KNeighborsRegressor(BaseEstimatorSPMD, KNeighborsRegressor_Batch):
43-
@support_usm_ndarray()
43+
@support_input_format()
4444
def fit(self, X, y, queue=None):
4545
if queue is not None and queue.sycl_device.is_gpu:
4646
return super()._fit(X, y, queue=queue)
@@ -50,11 +50,11 @@ def fit(self, X, y, queue=None):
5050
"CPU. Consider running on it on GPU."
5151
)
5252

53-
@support_usm_ndarray()
53+
@support_input_format()
5454
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None):
5555
return super().kneighbors(X, n_neighbors, return_distance, queue=queue)
5656

57-
@support_usm_ndarray()
57+
@support_input_format()
5858
def predict(self, X, queue=None):
5959
return self._predict_gpu(X, queue=queue)
6060

@@ -66,10 +66,10 @@ def _get_onedal_params(self, X, y=None):
6666

6767

6868
class NearestNeighbors(BaseEstimatorSPMD):
69-
@support_usm_ndarray()
69+
@support_input_format()
7070
def fit(self, X, y, queue=None):
7171
return super().fit(X, y, queue=queue)
7272

73-
@support_usm_ndarray()
73+
@support_input_format()
7474
def kneighbors(self, X=None, n_neighbors=None, return_distance=True, queue=None):
7575
return super().kneighbors(X, n_neighbors, return_distance, queue=queue)

onedal/tests/utils/_dataframes_support.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
import pytest
1818
import scipy.sparse as sp
19-
from sklearn import get_config
19+
20+
from sklearnex import get_config
2021

2122
try:
22-
import dpctl
2323
import dpctl.tensor as dpt
2424

2525
dpctl_available = True
@@ -40,7 +40,6 @@
4040
# GPU-no-copy.
4141
import array_api_strict
4242

43-
# Run check if "array_api_dispatch" is configurable
4443
array_api_enabled = lambda: get_config()["array_api_dispatch"]
4544
array_api_enabled()
4645
array_api_modules = {"array_api": array_api_strict}
@@ -58,7 +57,7 @@
5857

5958

6059
def get_dataframes_and_queues(
61-
dataframe_filter_="numpy,pandas,dpnp,dpctl", device_filter_="cpu,gpu"
60+
dataframe_filter_="numpy,pandas,dpnp,dpctl,array_api", device_filter_="cpu,gpu"
6261
):
6362
"""Get supported dataframes for testing.
6463
@@ -107,13 +106,18 @@ def get_df_and_q(dataframe: str):
107106
dataframes_and_queues.extend(get_df_and_q("dpctl"))
108107
if dpnp_available and "dpnp" in dataframe_filter_:
109108
dataframes_and_queues.extend(get_df_and_q("dpnp"))
110-
if "array_api" in dataframe_filter_ or array_api_enabled():
109+
if (
110+
"array_api" in dataframe_filter_
111+
and "array_api" in array_api_modules
112+
or array_api_enabled()
113+
):
111114
dataframes_and_queues.append(pytest.param("array_api", None, id="array_api"))
112115

113116
return dataframes_and_queues
114117

115118

116119
def _as_numpy(obj, *args, **kwargs):
120+
"""Converted input object to numpy.ndarray format."""
117121
if dpnp_available and isinstance(obj, dpnp.ndarray):
118122
return obj.asnumpy(*args, **kwargs)
119123
if dpctl_available and isinstance(obj, dpt.usm_ndarray):
@@ -155,17 +159,10 @@ def _convert_to_dataframe(obj, sycl_queue=None, target_df=None, *args, **kwargs)
155159
# DPCtl tensor.
156160
return dpt.asarray(obj, usm_type="device", sycl_queue=sycl_queue, *args, **kwargs)
157161
elif target_df in array_api_modules:
158-
# use dpctl to define gpu devices via queues and
159-
# move data to the device. This is necessary as
160-
# the standard for defining devices is
161-
# purposefully not defined in the array_api
162-
# standard, but maintaining data on a device
163-
# using the method `from_dlpack` is.
162+
# Array API input other than DPNP ndarray, DPCtl tensor or
163+
# Numpy ndarray.
164+
164165
xp = array_api_modules[target_df]
165-
return xp.from_dlpack(
166-
_convert_to_dataframe(
167-
obj, sycl_queue=sycl_queue, target_df="dpctl", *args, **kwargs
168-
)
169-
)
166+
return xp.asarray(obj)
170167

171168
raise RuntimeError("Unsupported dataframe conversion")

0 commit comments

Comments
 (0)