@@ -23,17 +23,16 @@ Overview
2323
2424Many estimators from the |sklearnex | support passing data classes that conform to the
2525`Array API <https://data-apis.org/array-api/ >`_ specification as inputs to methods like ``.fit() ``
26- and ``.predict() ``, such as :external+dpnp:doc: `dpnp.ndarray <reference/ndarray >` or
27- `torch.tensor <https://docs.pytorch.org/docs/stable/tensors.html >`__. This is particularly
28- useful for GPU computations, as it allows performing operations on inputs that are already
26+ and ``.predict() ``, such as |dpnp_array | or `torch.tensor <https://docs.pytorch.org/docs/stable/tensors.html >`__.
27+ This is particularly useful for GPU computations, as it allows performing operations on inputs that are already
2928on GPU without moving the data from host to device.
3029
3130.. important ::
3231 Array API is disabled by default in |sklearn |. In order to get array API support in the |sklearnex |, it must
3332 be :external+sklearn:doc: `enabled in scikit-learn <modules/array_api >`, which requires either changing
3433 global settings or using a ``config_context ``, plus installing additional dependencies such as ``array-api-compat ``.
3534
36- When passing array API inputs whose data is on a SyCL -enabled device (e.g. an Intel GPU), as
35+ When passing array API inputs whose data is on a SYCL -enabled device (e.g. an Intel GPU), as
3736supported for example by `PyTorch <https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html >`__
3837and |dpnp |, if array API support is enabled and the requested operation (e.g. call to ``.fit() `` / ``.predict() ``
3938on the estimator class being used) is :ref: `supported on device/GPU <sklearn_algorithms_gpu >`, computations
@@ -51,10 +50,10 @@ through options ``allow_sklearn_after_onedal`` (default is ``True``) and ``allow
5150
5251If array API is enabled for |sklearn | and the estimator being used has array API support on |sklearn | (which can be
5352verified by attribute ``array_api_support `` from :obj: `sklearn.utils.get_tags `), then array API inputs whose data
54- is allocated neither on CPU nor on a SyCL device will be forwarded directly to the unpatched methods from |sklearn |,
53+ is allocated neither on CPU nor on a SYCL device will be forwarded directly to the unpatched methods from |sklearn |,
5554without using the accelerated versions from this library, regardless of option ``allow_sklearn_after_onedal ``.
5655
57- While other array API inputs (e.g. torch arrays with data allocated on a non-SyCL device) might be supported
56+ While other array API inputs (e.g. torch arrays with data allocated on a non-SYCL device) might be supported
5857by the |sklearnex | in cases where the same class from |sklearn | doesn't support array API, note that the data will
5958be transferred to host if it isn't already, and the computations will happen on CPU.
6059
@@ -80,6 +79,7 @@ in many cases they are.
8079 classes that have :external+dpctl:doc: `USM data <api_reference/dpctl/memory >`. In order to ensure that computations
8180 happen on the intended device under array API, make sure that the data is already on the desired device.
8281
82+ .. _array_api_estimators :
8383
8484Supported classes
8585=================
@@ -98,11 +98,10 @@ The following patched classes have support for array API inputs:
9898- :obj: `sklearnex.linear_model.IncrementalRidge `
9999
100100.. note ::
101- While full array API support is currently not implemented for all classes, :external+dpnp:doc: `dpnp.ndarray <reference/ndarray >`
102- and :external+dpctl:doc: `dpctl.tensor <api_reference/dpctl/tensor >` inputs are supported by all the classes
103- that have :ref: `GPU support <oneapi_gpu >`. Note however that if array API support is not enabled in |sklearn |,
104- when passing these classes as inputs, data will be transferred to host and then back to device instead of being
105- used directly.
101+ While full array API support is currently not implemented for all classes, |dpnp_array | inputs are supported
102+ by all the classes that have :ref: `GPU support <oneapi_gpu >`. Note however that if array API support is not
103+ enabled in |sklearn |, when passing these classes as inputs, data will be transferred to host and then back to
104+ device instead of being used directly.
106105
107106
108107Example usage
@@ -111,52 +110,102 @@ Example usage
111110GPU operations on GPU arrays
112111----------------------------
113112
114- .. code-block :: python
115-
116- # Array API support from sklearn requires enabling it on SciPy too
117- import os
118- os.environ[" SCIPY_ARRAY_API" ] = " 1"
119-
120- import numpy as np
121- import dpnp
122- from sklearnex import config_context
123- from sklearnex.linear_model import LinearRegression
124-
125- # Random data for a regression problem
126- rng = np.random.default_rng(seed = 123 )
127- X_np = rng.standard_normal(size = (100 , 10 ), dtype = np.float32)
128- y_np = rng.standard_normal(size = 100 , dtype = np.float32)
129-
130- # DPNP offers an array-API-compliant class where data can be on GPU
131- X = dpnp.array(X_np, device = " gpu" )
132- y = dpnp.array(y_np, device = " gpu" )
133-
134- # Important to note again that array API must be enabled on scikit-learn
135- model = LinearRegression()
136- with config_context(array_api_dispatch = True ):
137- model.fit(X, y)
138-
139- # Fitted attributes are now of the same class as inputs
140- assert isinstance (model.coef_, X.__class__ )
141-
142- # Predictions are also of the same class
143- with config_context(array_api_dispatch = True ):
144- pred = model.predict(X[:5 ])
145- assert isinstance (pred, X.__class__ )
146-
147- # Fitted models can be passed array API inputs of a different class
148- # than the training data, as long as their data resides in the same
149- # device. This now fits a model using a non-NumPy class whose data is on CPU.
150- X_cpu = dpnp.array(X_np, device = " cpu" )
151- y_cpu = dpnp.array(y_np, device = " cpu" )
152- model_cpu = LinearRegression()
153- with config_context(array_api_dispatch = True ):
154- model_cpu.fit(X_cpu, y_cpu)
155- pred_dpnp = model_cpu.predict(X_cpu[:5 ])
156- pred_np = model_cpu.predict(X_cpu[:5 ].asnumpy())
157- assert isinstance (pred_dpnp, X_cpu.__class__ )
158- assert isinstance (pred_np, np.ndarray)
159- assert pred_dpnp.__class__ != pred_np.__class__
113+ .. tabs ::
114+ .. tab :: With Torch tensors
115+ .. code-block :: python
116+
117+ # Array API support from sklearn requires enabling it on SciPy too
118+ import os
119+ os.environ[" SCIPY_ARRAY_API" ] = " 1"
120+
121+ import numpy as np
122+ import torch
123+ from sklearnex import config_context
124+ from sklearnex.linear_model import LinearRegression
125+
126+ # Random data for a regression problem
127+ rng = np.random.default_rng(seed = 123 )
128+ X_np = rng.standard_normal(size = (100 , 10 ), dtype = np.float32)
129+ y_np = rng.standard_normal(size = 100 , dtype = np.float32)
130+
131+ # Torch offers an array-API-compliant class where data can be on GPU (referred to as 'xpu')
132+ X = torch.tensor(X_np, device = " xpu" )
133+ y = torch.tensor(y_np, device = " xpu" )
134+
135+ # Important to note again that array API must be enabled on scikit-learn
136+ model = LinearRegression()
137+ with config_context(array_api_dispatch = True ):
138+ model.fit(X, y)
139+
140+ # Fitted attributes are now of the same class as inputs
141+ assert isinstance (model.coef_, torch.Tensor)
142+
143+ # Predictions are also of the same class
144+ with config_context(array_api_dispatch = True ):
145+ pred = model.predict(X[:5 ])
146+ assert isinstance (pred, torch.Tensor)
147+
148+ # Fitted models can be passed array API inputs of a different class
149+ # than the training data, as long as their data resides in the same
150+ # device. This now fits a model using a non-NumPy class whose data is on CPU.
151+ X_cpu = torch.tensor(X_np, device = " cpu" )
152+ y_cpu = torch.tensor(y_np, device = " cpu" )
153+ model_cpu = LinearRegression()
154+ with config_context(array_api_dispatch = True ):
155+ model_cpu.fit(X_cpu, y_cpu)
156+ pred_torch = model_cpu.predict(X_cpu[:5 ])
157+ pred_np = model_cpu.predict(X_cpu[:5 ].numpy())
158+ assert isinstance (pred_torch, X_cpu.__class__ )
159+ assert isinstance (pred_np, np.ndarray)
160+ assert pred_torch.__class__ != pred_np.__class__
161+
162+ .. tab :: With DPNP arrays
163+ .. code-block :: python
164+
165+ # Array API support from sklearn requires enabling it on SciPy too
166+ import os
167+ os.environ[" SCIPY_ARRAY_API" ] = " 1"
168+
169+ import numpy as np
170+ import dpnp
171+ from sklearnex import config_context
172+ from sklearnex.linear_model import LinearRegression
173+
174+ # Random data for a regression problem
175+ rng = np.random.default_rng(seed = 123 )
176+ X_np = rng.standard_normal(size = (100 , 10 ), dtype = np.float32)
177+ y_np = rng.standard_normal(size = 100 , dtype = np.float32)
178+
179+ # DPNP offers an array-API-compliant class where data can be on GPU
180+ X = dpnp.array(X_np, device = " gpu" )
181+ y = dpnp.array(y_np, device = " gpu" )
182+
183+ # Important to note again that array API must be enabled on scikit-learn
184+ model = LinearRegression()
185+ with config_context(array_api_dispatch = True ):
186+ model.fit(X, y)
187+
188+ # Fitted attributes are now of the same class as inputs
189+ assert isinstance (model.coef_, X.__class__ )
190+
191+ # Predictions are also of the same class
192+ with config_context(array_api_dispatch = True ):
193+ pred = model.predict(X[:5 ])
194+ assert isinstance (pred, X.__class__ )
195+
196+ # Fitted models can be passed array API inputs of a different class
197+ # than the training data, as long as their data resides in the same
198+ # device. This now fits a model using a non-NumPy class whose data is on CPU.
199+ X_cpu = dpnp.array(X_np, device = " cpu" )
200+ y_cpu = dpnp.array(y_np, device = " cpu" )
201+ model_cpu = LinearRegression()
202+ with config_context(array_api_dispatch = True ):
203+ model_cpu.fit(X_cpu, y_cpu)
204+ pred_dpnp = model_cpu.predict(X_cpu[:5 ])
205+ pred_np = model_cpu.predict(X_cpu[:5 ].asnumpy())
206+ assert isinstance (pred_dpnp, X_cpu.__class__ )
207+ assert isinstance (pred_np, np.ndarray)
208+ assert pred_dpnp.__class__ != pred_np.__class__
160209
161210
162211 ``array-api-strict ``
0 commit comments