From 74804baa94cbbcfa485bf7cb8b3742e229b5c637 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Fri, 10 Oct 2025 14:34:21 +0200 Subject: [PATCH 1/8] clarify details on gpu support, remove references to dpctl arrays --- doc/sources/array_api.rst | 15 ++- doc/sources/config-contexts.rst | 91 ++++++++++++++++ doc/sources/distributed-mode.rst | 2 +- doc/sources/index.rst | 11 +- doc/sources/input-types.rst | 10 +- doc/sources/oneapi-gpu.rst | 174 ++++++++++++++++++------------- doc/sources/substitutions.rst | 1 + sklearnex/_config.py | 141 ++++++++++++------------- 8 files changed, 282 insertions(+), 163 deletions(-) create mode 100644 doc/sources/config-contexts.rst diff --git a/doc/sources/array_api.rst b/doc/sources/array_api.rst index b2eb7a8bee..2923254d08 100644 --- a/doc/sources/array_api.rst +++ b/doc/sources/array_api.rst @@ -23,9 +23,8 @@ Overview Many estimators from the |sklearnex| support passing data classes that conform to the `Array API `_ specification as inputs to methods like ``.fit()`` -and ``.predict()``, such as :external+dpnp:doc:`dpnp.ndarray ` or -`torch.tensor `__. This is particularly -useful for GPU computations, as it allows performing operations on inputs that are already +and ``.predict()``, such as |dpnp_array| or `torch.tensor `__. +This is particularly useful for GPU computations, as it allows performing operations on inputs that are already on GPU without moving the data from host to device. .. important:: @@ -80,6 +79,7 @@ in many cases they are. classes that have :external+dpctl:doc:`USM data `. In order to ensure that computations happen on the intended device under array API, make sure that the data is already on the desired device. +.. _array_api_estimators: Supported classes ================= @@ -98,11 +98,10 @@ The following patched classes have support for array API inputs: - :obj:`sklearnex.linear_model.IncrementalRidge` .. note:: - While full array API support is currently not implemented for all classes, :external+dpnp:doc:`dpnp.ndarray ` - and :external+dpctl:doc:`dpctl.tensor ` inputs are supported by all the classes - that have :ref:`GPU support `. Note however that if array API support is not enabled in |sklearn|, - when passing these classes as inputs, data will be transferred to host and then back to device instead of being - used directly. + While full array API support is currently not implemented for all classes, |dpnp_array| inputs are supported + by all the classes that have :ref:`GPU support `. Note however that if array API support is not + enabled in |sklearn|, when passing these classes as inputs, data will be transferred to host and then back to + device instead of being used directly. Example usage diff --git a/doc/sources/config-contexts.rst b/doc/sources/config-contexts.rst new file mode 100644 index 0000000000..be11343f08 --- /dev/null +++ b/doc/sources/config-contexts.rst @@ -0,0 +1,91 @@ +.. Copyright contributors to the oneDAL project +.. +.. Licensed under the Apache License, Version 2.0 (the "License"); +.. you may not use this file except in compliance with the License. +.. You may obtain a copy of the License at +.. +.. http://www.apache.org/licenses/LICENSE-2.0 +.. +.. Unless required by applicable law or agreed to in writing, software +.. distributed under the License is distributed on an "AS IS" BASIS, +.. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +.. See the License for the specific language governing permissions and +.. limitations under the License. +.. include:: substitutions.rst +.. _config_contexts: + +========================================= +Configuration Contexts and Global Options +========================================= + +Overview +======== + +Just like |sklearn|, the |sklearnex| offers configurable options which can be managed +locally through a configuration context, or globally through process-wide settings, +by extending the configuration-related functions from |sklearn| (see :obj:`sklearn.config_context` +for details). + +Configurations in the |sklearnex| are particularly useful for :ref:`GPU functionalities ` +and :ref:`SMPD mode `, and are necessary to modify for enabling :ref:`array API `. + +Configuration context and global options manager for the |sklearnex| can either be imported directly +from the module ``sklearnex``, or can be imported from the ``sklearn`` module after applying patching. + +Note that options in the |sklearnex| are a superset of options from |sklearn|, and options passed to +the configuration contexts and global settings of the |sklearnex| will also affect |sklearn| if the +option is supported by it - meaning: the same context manager or global option setter is used for +both libraries. + +Example usage +============= + +Example using the ``target_offload`` option to make computations run on a GPU: + +With a local context +-------------------- + +Here, only the operations from |sklearn| and from the |sklearnex| that happen within the 'with' +block will be affected by the options: + +.. code:: python + + import numpy as np + from sklearnex import config_context + from sklearnex.cluster import DBSCAN + + X = np.array([[1., 2.], [2., 2.], [2., 3.], + [8., 7.], [8., 8.], [25., 80.]], dtype=np.float32) + with config_context(target_offload="gpu"): + clustering = DBSCAN(eps=3, min_samples=2).fit(X) + +As a global option +------------------ + +Here, all computations from |sklearn| and from the |sklearnex| that happen after the option +is modified are affected: + +.. code:: python + + import numpy as np + from sklearnex import set_config + from sklearnex.cluster import DBSCAN + + X = np.array([[1., 2.], [2., 2.], [2., 3.], + [8., 7.], [8., 8.], [25., 80.]], dtype=np.float32) + + set_config(target_offload="gpu") # set it globally + clustering = DBSCAN(eps=3, min_samples=2).fit(X) + set_config(target_offload="auto") # restore it back + +API Reference +============= + +Note that all of the options accepted by these functions in |sklearn| are also accepted +here - these just list the additional options offered by the |sklearnex|. + +.. autofunction:: sklearnex.config_context + +.. autofunction:: sklearnex.get_config + +.. autofunction:: sklearnex.set_config diff --git a/doc/sources/distributed-mode.rst b/doc/sources/distributed-mode.rst index 73e55839c9..924500a340 100644 --- a/doc/sources/distributed-mode.rst +++ b/doc/sources/distributed-mode.rst @@ -85,7 +85,7 @@ data on device without this may lead to a runtime error): :: export I_MPI_OFFLOAD=1 SMPD-aware versions of estimators can be imported from the ``sklearnex.spmd`` module. Data should be distributed across multiple nodes as -desired, and should be transferred to a |dpctl| or `dpnp `__ array before being passed to the estimator. +desired, and should be transferred to a |dpnp_array| before being passed to the estimator. Note that SPMD estimators allow an additional argument ``queue`` in their ``.fit`` / ``.predict`` methods, which accept :obj:`dpctl.SyclQueue` objects. For example, while the signature for :obj:`sklearn.linear_model.LinearRegression.predict` would be diff --git a/doc/sources/index.rst b/doc/sources/index.rst index 2b405c8d10..dc02dbdff3 100755 --- a/doc/sources/index.rst +++ b/doc/sources/index.rst @@ -41,16 +41,16 @@ These performance charts use benchmarks that you can find in the `scikit-learn b Supported Algorithms ---------------------- +-------------------- See all of the :ref:`sklearn_algorithms`. Optimizations ----------------------------------- +------------- Enable CPU Optimizations -********************************* +************************ .. tabs:: .. tab:: By patching @@ -78,7 +78,7 @@ Enable CPU Optimizations Enable GPU optimizations -********************************* +************************ Note: executing on GPU has `additional system software requirements `__ - see :doc:`oneapi-gpu`. @@ -168,6 +168,8 @@ See :ref:`oneapi_gpu` for other ways of executing on GPU. algorithms.rst oneapi-gpu.rst + config-contexts.rst + array_api.rst distributed-mode.rst distributed_daal4py.rst non-scikit-algorithms.rst @@ -175,7 +177,6 @@ See :ref:`oneapi_gpu` for other ways of executing on GPU. model_builders.rst logistic_model_builder.rst input-types.rst - array_api.rst verbose.rst preview.rst deprecation.rst diff --git a/doc/sources/input-types.rst b/doc/sources/input-types.rst index 790080e6bf..28cf7f45c9 100644 --- a/doc/sources/input-types.rst +++ b/doc/sources/input-types.rst @@ -29,10 +29,7 @@ and work with different classes of input data, including: - SciPy :external+scipy:doc:`sparse arrays and sparse matrices ` (depending on the estimator). - Pandas :external+pandas:doc:`DataFrame and Series ` classes. -In addition, |sklearnex| also supports: - -- :external+dpnp:doc:`dpnp.ndarray `. -- :external+dpctl:doc:`dpctl.tensor `. +In addition, |sklearnex| also supports |dpnp_array| arrays, which are particularly useful for GPU computations. Stock Scikit-Learn estimators, depending on the version, might offer support for additional input types beyond this list, such as ``DataFrame`` and ``Series`` classes from other libraries @@ -50,8 +47,9 @@ enabled the input is unsupported). The affected cases are listed below. - Non-contiguous NumPy array - i.e. where strides are wider than one element across both rows and columns - - For SciPy CSR matrix / array, index arrays are always copied. + - For SciPy CSR matrix / array, index arrays are always copied. Note that sparse matrices in formats other than CSR + will be converted to CSR, which implies more than just data copying. - Heterogeneous NumPy array - - If SYCL queue is provided for device without ``float64`` support but data are ``float64``, data are copied with reduced precision. + - If SyCL queue is provided for device without ``float64`` support but data are ``float64``, data are copied with reduced precision. - If :ref:`Array API ` is not enabled then data from GPU devices are always copied to the host device and then result table (for applicable methods) is copied to the source device. diff --git a/doc/sources/oneapi-gpu.rst b/doc/sources/oneapi-gpu.rst index fa54437a77..5a83289851 100644 --- a/doc/sources/oneapi-gpu.rst +++ b/doc/sources/oneapi-gpu.rst @@ -15,20 +15,25 @@ .. include:: substitutions.rst .. _oneapi_gpu: -############################################################## -oneAPI and GPU support in |sklearnex| -############################################################## +########### +GPU support +########### -|sklearnex| can execute computations on different devices (CPUs and GPUs, including integrated GPUs from laptops and desktops) through the SYCL framework in oneAPI. +Overview +-------- -The device used for computations can be easily controlled through the target offloading functionality (e.g. through ``sklearnex.config_context(target_offload="gpu")``, which moves data to GPU if it's not already there - see rest of this page for more details), but for finer-grained controlled (e.g. operating on arrays that are already in a given device's memory), it can also interact with objects from package |dpctl|, which offers a Python interface over SYCL concepts such as devices, queues, and USM (unified shared memory) arrays. +|sklearnex| can execute computations on different devices (CPUs and GPUs, including integrated GPUs from laptops and desktops) supported by the SyCL framework. -While not strictly required, package |dpctl| is recommended for a better experience on GPUs - for example, it can provide GPU-allocated arrays that enable compute-follows-data execution models (i.e. so that ``target_offload`` wouldn't need to move the data from CPU to GPU). +The device used for computations can be easily controlled through the ``target_offload`` option in config contexts, which moves data to GPU if it's not already there - see :ref:`config_contexts` and rest of this page for more details). + +For finer-grained controlled (e.g. operating on arrays that are already in a given device's memory), it can also interact with on-device :ref:`array API classes ` like |dpnp_array|, and with SyCL-related objects from package |dpctl| such as :obj:`dpctl.SyclQueue`. + +.. Note:: Note that not every operation from every estimator is supported on GPU - see the :ref:`GPU support table ` for more information. .. important:: Be aware that GPU usage requires non-Python dependencies on your system, such as the `Intel(R) Compute Runtime `_ (see below). -Prerequisites -------------- +Software Requirements +--------------------- For execution on GPUs, DPC++ runtime and Intel Compute Runtime (also referred to elsewhere as 'GPGPU drivers') are required. @@ -76,93 +81,116 @@ Be aware that datacenter-grade devices, such as 'Flex' and 'Max', require differ For more details, see the `DPC++ requirements page `__. -Device offloading ------------------ +Running operations on GPU +------------------------- -|sklearnex| offers two options for running an algorithm on a specified device: +|sklearnex| offers different options for running an algorithm on a specified device (e.g. a GPU): -- Use global configurations of |sklearnex|: +Target offload option +~~~~~~~~~~~~~~~~~~~~~ - 1. The :code:`target_offload` argument (in ``config_context`` and in ``set_config`` / ``get_config``) - can be used to set the device primarily used to perform computations. Accepted data types are - :code:`str` and :obj:`dpctl.SyclQueue`. Strings must match to device names recognized by - the SYCL* device filter selector - for example, ``"gpu"``. If passing ``"auto"``, - the device will be deduced from the location of the input data. Examples: +Just like |sklearn|, the |sklearnex| can use configuration contexts and global options to modify how it interacts with different inputs - see :ref:`config_contexts` for details. - .. code-block:: python - - from sklearnex import config_context - from sklearnex.linear_model import LinearRegression - - with config_context(target_offload="gpu"): - model = LinearRegression().fit(X, y) +In particular, the |sklearnex| allows an option ``target_offload`` which can be passed a SyCL device name like ``"gpu"`` indicating where the operations should be performed, moving the data to that device in the process if it's not already there; or a :obj:`dpctl.SyclQueue` object from an already-existing queue on a device. - .. code-block:: python - - from sklearnex import set_config - from sklearnex.linear_model import LinearRegression - - set_config(target_offload="gpu") - model = LinearRegression().fit(X, y) +Example: +.. tabs:: + .. tab:: Passing a device name + .. code-block:: python - If passing a string different than ``"auto"``, - it must be a device + from sklearnex import config_context + from sklearnex.linear_model import LinearRegression + from sklearn.datasets import make_regression + X, y = make_regression() + model = LinearRegression() - 2. The :code:`allow_fallback_to_host` argument in those same configuration functions - is a Boolean flag. If set to :code:`True`, the computation is allowed - to fallback to the host device when a particular estimator does not support - the selected device. The default value is :code:`False`. + with config_context(target_offload="gpu"): + model.fit(X, y) + pred = model.predict(X) -These options can be set using :code:`sklearnex.set_config()` function or -:code:`sklearnex.config_context`. To obtain the current values of these options, -call :code:`sklearnex.get_config()`. + .. tab:: Passing a SyCL queue + .. code-block:: python -.. note:: - Functions :code:`set_config`, :code:`get_config` and :code:`config_context` - are always patched after the :code:`sklearnex.patch_sklearn()` call. + import dpctl + from sklearnex import config_context + from sklearnex.linear_model import LinearRegression + from sklearn.datasets import make_regression + X, y = make_regression() + model = LinearRegression() -- Pass input data as :obj:`dpctl.tensor.usm_ndarray` to the algorithm. + queue = dpctl.SyclQueue("gpu") + with config_context(target_offload=queue): + model.fit(X, y) + pred = model.predict(X) - The computation will run on the device where the input data is - located, and the result will be returned as :code:`usm_ndarray` to the same - device. - .. important:: - In order to enable zero-copy operations on GPU arrays, it's necessary to enable - :ref:`array API support ` for scikit-learn. Otherwise, if passing a GPU - array and array API support is not enabled, GPU arrays will first be transferred to - host and then back to GPU. +.. warning:: + When using ``target_offload``, operations on a fitted model must be executed under a context or global option with the same device or queue where the model was fitted - meaning: a model fitted on GPU cannot make predictions on CPU, and vice-versa. Note that upon serialization and subsequent deserialization of models, data is moved to the CPU. - .. note:: - All the input data for an algorithm must reside on the same device. +GPU arrays through array API +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +As another option, computations can also be performed on data that is already on a SyCL device without moving it there if it belongs to an array API-compatible class, such as |dpnp_array| or `torch.tensor `__. + +This is particularly useful when multiple operations are performed on the same data (e.g. cross validators, stacked ensembles, etc.), or when the data is meant to interact with other libraries besides the |sklearnex|. Be aware that it requires enabling array API support in |sklearn|, which comes with additional dependencies. + +See :ref:`array_api` for details, instructions, and limitations. Example: + +.. code-block:: python - .. warning:: - The :code:`usm_ndarray` can only be consumed by the base methods - like :code:`fit`, :code:`predict`, and :code:`transform`. - Note that only the algorithms in |sklearnex| support - :code:`usm_ndarray`. The algorithms from the stock version of |sklearn| - do not support this feature. + # Array API support from sklearn requires enabling it on SciPy too + import os + os.environ["SCIPY_ARRAY_API"] = "1" + import numpy as np + import dpnp + from sklearnex import config_context + from sklearnex.linear_model import LinearRegression -Example -------- + # Random data for a regression problem + rng = np.random.default_rng(seed=123) + X_np = rng.standard_normal(size=(100, 10), dtype=np.float32) + y_np = rng.standard_normal(size=100, dtype=np.float32) -A full example of how to patch your code with Intel CPU/GPU optimizations: + # DPNP offers an array-API-compliant class where data can be on GPU + X = dpnp.array(X_np, device="gpu") + y = dpnp.array(y_np, device="gpu") + + # Important to note again that array API must be enabled on scikit-learn + model = LinearRegression() + with config_context(array_api_dispatch=True): + model.fit(X, y) + +.. note:: + Not all estimator classes in the |sklearnex| support array API objects - see the list of :ref:`estimators with array API support ` for details. + +DPNP Arrays +~~~~~~~~~~~ + +As a special case, GPU arrays from |dpnp| can be used without enabling array API, even for estimators in the |sklearnex| that do not currently support array API, but note that it involves data movement to host and back and is thus not the most efficient route in computational terms. + +Example: .. code-block:: python - from sklearnex import patch_sklearn, config_context - patch_sklearn() + import numpy as np + import dpnp + from sklearnex import config_context + from sklearnex.linear_model import LinearRegression + + rng = np.random.default_rng(seed=123) + X_np = rng.standard_normal(size=(100, 10), dtype=np.float32) + y_np = rng.standard_normal(size=100, dtype=np.float32) + + X = dpnp.array(X_np, device="gpu") + y = dpnp.array(y_np, device="gpu") - from sklearn.cluster import DBSCAN + model = LinearRegression() + model.fit(X, y) - X = np.array([[1., 2.], [2., 2.], [2., 3.], - [8., 7.], [8., 8.], [25., 80.]], dtype=np.float32) - with config_context(target_offload="gpu:0"): - clustering = DBSCAN(eps=3, min_samples=2).fit(X) +Note that, if array API had been enabled, the snippet above would use the data as-is on the device where it resides, but without array API, it implies data movements using the SyCL queue contained by those objects. -.. note:: Current offloading behavior restricts fitting and predictions (a.k.a. inference) of any models to be - in the same context or absence of context. For example, a model whose ``.fit()`` method was called in a GPU context with - ``target_offload="gpu:0"`` will throw an error if a ``.predict()`` call is then made outside the same GPU context. +.. note:: + All the input data for an algorithm must reside on the same device. diff --git a/doc/sources/substitutions.rst b/doc/sources/substitutions.rst index ded175b164..205484f848 100644 --- a/doc/sources/substitutions.rst +++ b/doc/sources/substitutions.rst @@ -14,6 +14,7 @@ .. |dpctl| replace:: :external+dpctl:doc:`dpctl ` .. |dpnp| replace:: :external+dpnp:doc:`dpnp ` +.. |dpnp_array| replace:: :external+dpnp:doc:`dpnp.ndarray ` .. |sklearn| replace:: :external+sklearn:doc:`scikit-learn ` .. |intelex_repo| replace:: |sklearnex| repository .. _intelex_repo: https://github.com/uxlfoundation/scikit-learn-intelex diff --git a/sklearnex/_config.py b/sklearnex/_config.py index 793fc295c9..123ffd690f 100644 --- a/sklearnex/_config.py +++ b/sklearnex/_config.py @@ -14,6 +14,7 @@ # limitations under the License. # ============================================================================== +import sys from contextlib import contextmanager from sklearn import get_config as skl_get_config @@ -22,6 +23,61 @@ from daal4py.sklearn._utils import sklearn_check_version from onedal._config import _get_config as onedal_get_config +__all__ = ["get_config", "set_config", "config_context"] + +tab = " " if (sys.version_info.major == 3 and sys.version_info.minor < 13) else "" +_options_docstring = f"""Parameters +{tab}---------- +{tab}target_offload : str or dpctl.SyclQueue or None +{tab} The device used to perform computations, either as a string indicating a name +{tab} recognized by the SyCL runtime, such as ``"gpu"``, ``"gpu:0"``, or as a +{tab} :obj:`dpctl.SyclQueue` object indicating where to move the data. +{tab} +{tab} Assuming SyCL-related dependencies are installed, the list of devices recognized +{tab} by SyCL can be retrieved through the CLI tool ``sycl-ls`` in a shell, or through +{tab} :obj:`dpctl.get_devices` in a Python process. +{tab} +{tab} String ``"auto"`` is also accepted. +{tab} +{tab} Global default: ``"auto"``. +{tab} +{tab}allow_fallback_to_host : bool or None +{tab} If ``True``, allows computations to fall back to host device (CPU) when an unsupported +{tab} operation is attempted on GPU through ``target_offload``. +{tab} +{tab} Global default: ``False``. +{tab} +{tab}allow_sklearn_after_onedal : bool or None, default=None +{tab} If ``True``, allows computations to fall back to stock scikit-learn when no +{tab} accelered version of the operation is available (see :ref:`algorithms`). +{tab} +{tab} Global default: ``True.`` +{tab} +{tab}use_raw_input : bool or None +{tab} If ``True``, uses the raw input data in some SPMD onedal backend computations +{tab} without any checks on data consistency or validity. Note that this can be +{tab} better achieved through usage of :ref:`array API classes ` without +{tab} ``target_offload``. Not recommended for general use. +{tab} +{tab} Global default: ``False``. +{tab} +{tab} .. deprecated:: 2026.0 +{tab} +{tab}sklearn_configs : kwargs +{tab} Other settings accepted by scikit-learn. See :obj:`sklearn.set_config` for +{tab} details. +{tab} +{tab}Warnings +{tab}-------- +{tab}Using ``use_raw_input=True`` is not recommended for general use as it +{tab}bypasses data consistency checks, which may lead to unexpected behavior. It is +{tab}recommended to use the newer :ref:`array API ` instead. +{tab} +{tab}Note +{tab}---- +{tab}Usage of ``target_offload`` requires additional dependencies - see +{tab}:ref:`GPU support ` for more information.""" + def get_config(): """Retrieve current values for configuration set by :func:`set_config`. @@ -47,52 +103,15 @@ def set_config( allow_sklearn_after_onedal=None, use_raw_input=None, **sklearn_configs, -): +): # numpydoc ignore=PR01,PR07 """Set global configuration. - Parameters - ---------- - target_offload : str or SyclQueue or None, default=None - The device primarily used to perform computations. - If string, expected to be "auto" (the execution context - is deduced from input data location), - or SYCL* filter selector string. Global default: "auto". - - allow_fallback_to_host : bool or None, default=None - If True, allows to fallback computation to host device - in case particular estimator does not support the selected one. - Global default: False. - - allow_sklearn_after_onedal : bool or None, default=None - If True, allows to fallback computation to sklearn after onedal - backend in case of runtime error on onedal backend computations. - Global default: True. - - use_raw_input : bool or None, default=None - If True, uses the raw input data in some SPMD onedal backend computations - without any checks on data consistency or validity. - Not recommended for general use. - Global default: False. - - .. deprecated:: 2026.0 - - **sklearn_configs : kwargs - Scikit-learn configuration settings dependent on the installed version - of scikit-learn. + %_options_docstring% See Also -------- config_context : Context manager for global configuration. get_config : Retrieve current values of the global configuration. - - Warnings - -------- - Using ``use_raw_input=True`` is not recommended for general use as it - bypasses data consistency checks, which may lead to unexpected behavior. - - Use of ``target_offload`` requires the DPC++ backend. Setting a - non-default value (e.g ``cpu`` or ``gpu``) without this backend active - will raise an error. """ skl_set_config(**sklearn_configs) @@ -109,34 +128,16 @@ def set_config( local_config["use_raw_input"] = use_raw_input +set_config.__doc__ = set_config.__doc__.replace( + "%_options_docstring%", _options_docstring +) + + @contextmanager def config_context(**new_config): # numpydoc ignore=PR01,PR07 - """Context manager for global scikit-learn configuration. - - Parameters - ---------- - target_offload : str or SyclQueue or None, default=None - The device primarily used to perform computations. - If string, expected to be "auto" (the execution context - is deduced from input data location), - or SYCL* filter selector string. Global default: "auto". - - allow_fallback_to_host : bool or None, default=None - If True, allows to fallback computation to host device - in case particular estimator does not support the selected one. - Global default: False. - - allow_sklearn_after_onedal : bool or None, default=None - If True, allows to fallback computation to sklearn after onedal - backend in case of runtime error on onedal backend computations. - Global default: True. - - use_raw_input : bool or None, default=None - .. deprecated:: 2026.0 - If True, uses the raw input data in some SPMD onedal backend computations - without any checks on data consistency or validity. - Not recommended for general use. - Global default: False. + """Context manager for local scikit-learn-intelex configurations. + + %_options_docstring% Notes ----- @@ -147,11 +148,6 @@ def config_context(**new_config): # numpydoc ignore=PR01,PR07 -------- set_config : Set global scikit-learn configuration. get_config : Retrieve current values of the global configuration. - - Warnings - -------- - Using ``use_raw_input=True`` is not recommended for general use as it - bypasses data consistency checks, which may lead to unexpected behavior. """ old_config = get_config() set_config(**new_config) @@ -160,3 +156,8 @@ def config_context(**new_config): # numpydoc ignore=PR01,PR07 yield finally: set_config(**old_config) + + +config_context.__doc__ = config_context.__doc__.replace( + "%_options_docstring%", _options_docstring +) From 5db1c51ee7eb8f329a88b9b91b9aff075db9d128 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Fri, 10 Oct 2025 15:07:22 +0200 Subject: [PATCH 2/8] formatting --- sklearnex/_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearnex/_config.py b/sklearnex/_config.py index 123ffd690f..46e3bf5975 100644 --- a/sklearnex/_config.py +++ b/sklearnex/_config.py @@ -139,8 +139,8 @@ def config_context(**new_config): # numpydoc ignore=PR01,PR07 %_options_docstring% - Notes - ----- + Note + ---- All settings, not just those presently modified, will be returned to their previous values when the context manager is exited. From 947b872eb46e3764a021699cfd61ad10f3d2b14e Mon Sep 17 00:00:00 2001 From: David Cortes Date: Mon, 13 Oct 2025 11:09:44 +0200 Subject: [PATCH 3/8] more links --- doc/sources/algorithms.rst | 2 +- doc/sources/oneapi-gpu.rst | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/sources/algorithms.rst b/doc/sources/algorithms.rst index a5428f78a0..855940b46d 100755 --- a/doc/sources/algorithms.rst +++ b/doc/sources/algorithms.rst @@ -21,7 +21,7 @@ Supported Algorithms .. note:: To verify that oneDAL is being used for these algorithms, you can enable verbose mode. - See :ref:`verbose mode documentation ` for details. + See :ref:`verbose` for details. Applying |sklearnex| impacts the following |sklearn| estimators: diff --git a/doc/sources/oneapi-gpu.rst b/doc/sources/oneapi-gpu.rst index 5a83289851..029bc45809 100644 --- a/doc/sources/oneapi-gpu.rst +++ b/doc/sources/oneapi-gpu.rst @@ -28,7 +28,7 @@ The device used for computations can be easily controlled through the ``target_o For finer-grained controlled (e.g. operating on arrays that are already in a given device's memory), it can also interact with on-device :ref:`array API classes ` like |dpnp_array|, and with SyCL-related objects from package |dpctl| such as :obj:`dpctl.SyclQueue`. -.. Note:: Note that not every operation from every estimator is supported on GPU - see the :ref:`GPU support table ` for more information. +.. Note:: Note that not every operation from every estimator is supported on GPU - see the :ref:`GPU support table ` for more information. See also :ref:`verbose` to verify where computations are performed. .. important:: Be aware that GPU usage requires non-Python dependencies on your system, such as the `Intel(R) Compute Runtime `_ (see below). From f21f7411d564fca72eba54966e51103307e85e98 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Mon, 13 Oct 2025 12:33:31 +0200 Subject: [PATCH 4/8] more hints --- doc/sources/oneapi-gpu.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/doc/sources/oneapi-gpu.rst b/doc/sources/oneapi-gpu.rst index 029bc45809..a1ef978d7c 100644 --- a/doc/sources/oneapi-gpu.rst +++ b/doc/sources/oneapi-gpu.rst @@ -93,6 +93,8 @@ Just like |sklearn|, the |sklearnex| can use configuration contexts and global o In particular, the |sklearnex| allows an option ``target_offload`` which can be passed a SyCL device name like ``"gpu"`` indicating where the operations should be performed, moving the data to that device in the process if it's not already there; or a :obj:`dpctl.SyclQueue` object from an already-existing queue on a device. +.. hint:: If repeated operations are going to be performed on the same data (e.g. cross-validators, resamplers, missing data imputers, etc.), it's recommended to use the array API option instead - see the next section for details. + Example: .. tabs:: From dbbc3e0577ab2162aa114822adcb39626d1885e3 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Tue, 14 Oct 2025 16:18:12 +0200 Subject: [PATCH 5/8] add torch examples --- doc/sources/array_api.rst | 142 +++++++++++++++++++++++++------------ doc/sources/index.rst | 12 ++-- doc/sources/oneapi-gpu.rst | 66 ++++++++++++----- 3 files changed, 149 insertions(+), 71 deletions(-) diff --git a/doc/sources/array_api.rst b/doc/sources/array_api.rst index 2923254d08..3901ae87a1 100644 --- a/doc/sources/array_api.rst +++ b/doc/sources/array_api.rst @@ -110,52 +110,102 @@ Example usage GPU operations on GPU arrays ---------------------------- -.. code-block:: python - - # Array API support from sklearn requires enabling it on SciPy too - import os - os.environ["SCIPY_ARRAY_API"] = "1" - - import numpy as np - import dpnp - from sklearnex import config_context - from sklearnex.linear_model import LinearRegression - - # Random data for a regression problem - rng = np.random.default_rng(seed=123) - X_np = rng.standard_normal(size=(100, 10), dtype=np.float32) - y_np = rng.standard_normal(size=100, dtype=np.float32) - - # DPNP offers an array-API-compliant class where data can be on GPU - X = dpnp.array(X_np, device="gpu") - y = dpnp.array(y_np, device="gpu") - - # Important to note again that array API must be enabled on scikit-learn - model = LinearRegression() - with config_context(array_api_dispatch=True): - model.fit(X, y) - - # Fitted attributes are now of the same class as inputs - assert isinstance(model.coef_, X.__class__) - - # Predictions are also of the same class - with config_context(array_api_dispatch=True): - pred = model.predict(X[:5]) - assert isinstance(pred, X.__class__) - - # Fitted models can be passed array API inputs of a different class - # than the training data, as long as their data resides in the same - # device. This now fits a model using a non-NumPy class whose data is on CPU. - X_cpu = dpnp.array(X_np, device="cpu") - y_cpu = dpnp.array(y_np, device="cpu") - model_cpu = LinearRegression() - with config_context(array_api_dispatch=True): - model_cpu.fit(X_cpu, y_cpu) - pred_dpnp = model_cpu.predict(X_cpu[:5]) - pred_np = model_cpu.predict(X_cpu[:5].asnumpy()) - assert isinstance(pred_dpnp, X_cpu.__class__) - assert isinstance(pred_np, np.ndarray) - assert pred_dpnp.__class__ != pred_np.__class__ +.. tabs:: + .. tab:: With Torch tensors + .. code-block:: python + + # Array API support from sklearn requires enabling it on SciPy too + import os + os.environ["SCIPY_ARRAY_API"] = "1" + + import numpy as np + import torch + from sklearnex import config_context + from sklearnex.linear_model import LinearRegression + + # Random data for a regression problem + rng = np.random.default_rng(seed=123) + X_np = rng.standard_normal(size=(100, 10), dtype=np.float32) + y_np = rng.standard_normal(size=100, dtype=np.float32) + + # Torch offers an array-API-compliant class where data can be on GPU (referred to as 'xpu') + X = torch.tensor(X_np, device="xpu") + y = torch.tensor(y_np, device="xpu") + + # Important to note again that array API must be enabled on scikit-learn + model = LinearRegression() + with config_context(array_api_dispatch=True): + model.fit(X, y) + + # Fitted attributes are now of the same class as inputs + assert isinstance(model.coef_, torch.Tensor) + + # Predictions are also of the same class + with config_context(array_api_dispatch=True): + pred = model.predict(X[:5]) + assert isinstance(pred, torch.Tensor) + + # Fitted models can be passed array API inputs of a different class + # than the training data, as long as their data resides in the same + # device. This now fits a model using a non-NumPy class whose data is on CPU. + X_cpu = torch.tensor(X_np, device="cpu") + y_cpu = torch.tensor(y_np, device="cpu") + model_cpu = LinearRegression() + with config_context(array_api_dispatch=True): + model_cpu.fit(X_cpu, y_cpu) + pred_torch = model_cpu.predict(X_cpu[:5]) + pred_np = model_cpu.predict(X_cpu[:5].numpy()) + assert isinstance(pred_torch, X_cpu.__class__) + assert isinstance(pred_np, np.ndarray) + assert pred_torch.__class__ != pred_np.__class__ + + .. tab:: With DPNP arrays + .. code-block:: python + + # Array API support from sklearn requires enabling it on SciPy too + import os + os.environ["SCIPY_ARRAY_API"] = "1" + + import numpy as np + import dpnp + from sklearnex import config_context + from sklearnex.linear_model import LinearRegression + + # Random data for a regression problem + rng = np.random.default_rng(seed=123) + X_np = rng.standard_normal(size=(100, 10), dtype=np.float32) + y_np = rng.standard_normal(size=100, dtype=np.float32) + + # DPNP offers an array-API-compliant class where data can be on GPU + X = dpnp.array(X_np, device="gpu") + y = dpnp.array(y_np, device="gpu") + + # Important to note again that array API must be enabled on scikit-learn + model = LinearRegression() + with config_context(array_api_dispatch=True): + model.fit(X, y) + + # Fitted attributes are now of the same class as inputs + assert isinstance(model.coef_, X.__class__) + + # Predictions are also of the same class + with config_context(array_api_dispatch=True): + pred = model.predict(X[:5]) + assert isinstance(pred, X.__class__) + + # Fitted models can be passed array API inputs of a different class + # than the training data, as long as their data resides in the same + # device. This now fits a model using a non-NumPy class whose data is on CPU. + X_cpu = dpnp.array(X_np, device="cpu") + y_cpu = dpnp.array(y_np, device="cpu") + model_cpu = LinearRegression() + with config_context(array_api_dispatch=True): + model_cpu.fit(X_cpu, y_cpu) + pred_dpnp = model_cpu.predict(X_cpu[:5]) + pred_np = model_cpu.predict(X_cpu[:5].asnumpy()) + assert isinstance(pred_dpnp, X_cpu.__class__) + assert isinstance(pred_np, np.ndarray) + assert pred_dpnp.__class__ != pred_np.__class__ ``array-api-strict`` diff --git a/doc/sources/index.rst b/doc/sources/index.rst index dc02dbdff3..7d13456527 100755 --- a/doc/sources/index.rst +++ b/doc/sources/index.rst @@ -105,7 +105,7 @@ Note: executing on GPU has `additional system software requirements ` for details. From 21856873907e538a877d5d870d295748e8c7e6a0 Mon Sep 17 00:00:00 2001 From: David Cortes Date: Tue, 14 Oct 2025 16:38:29 +0200 Subject: [PATCH 6/8] add docs about serialization --- doc/sources/index.rst | 1 + doc/sources/oneapi-gpu.rst | 10 +++ doc/sources/serialization.rst | 116 ++++++++++++++++++++++++++++++++++ 3 files changed, 127 insertions(+) create mode 100644 doc/sources/serialization.rst diff --git a/doc/sources/index.rst b/doc/sources/index.rst index 7d13456527..8fca9c14dd 100755 --- a/doc/sources/index.rst +++ b/doc/sources/index.rst @@ -170,6 +170,7 @@ See :ref:`oneapi_gpu` for other ways of executing on GPU. oneapi-gpu.rst config-contexts.rst array_api.rst + serialization.rst distributed-mode.rst distributed_daal4py.rst non-scikit-algorithms.rst diff --git a/doc/sources/oneapi-gpu.rst b/doc/sources/oneapi-gpu.rst index cf3d2614d9..f106d80391 100644 --- a/doc/sources/oneapi-gpu.rst +++ b/doc/sources/oneapi-gpu.rst @@ -86,6 +86,9 @@ Running operations on GPU |sklearnex| offers different options for running an algorithm on a specified device (e.g. a GPU): + +.. _target_offload: + Target offload option ~~~~~~~~~~~~~~~~~~~~~ @@ -130,6 +133,9 @@ Example: .. warning:: When using ``target_offload``, operations on a fitted model must be executed under a context or global option with the same device or queue where the model was fitted - meaning: a model fitted on GPU cannot make predictions on CPU, and vice-versa. Note that upon serialization and subsequent deserialization of models, data is moved to the CPU. +.. hint:: + Serialization of model objects that used target offload will move data to CPU upon deserialization. See :doc:`serialization` for detail about serializing GPU models. + GPU arrays through array API ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -192,6 +198,10 @@ See :ref:`array_api` for details, instructions, and limitations. Example: with config_context(array_api_dispatch=True): model.fit(X, y) + .. hint:: + If serialization of a GPU model is desired, use Torch tensors instead of DPNP arrays. + See :doc:`serialization` for more information. + .. note:: Not all estimator classes in the |sklearnex| support array API objects - see the list of :ref:`estimators with array API support ` for details. diff --git a/doc/sources/serialization.rst b/doc/sources/serialization.rst new file mode 100644 index 0000000000..3d835339af --- /dev/null +++ b/doc/sources/serialization.rst @@ -0,0 +1,116 @@ +.. Copyright contributors to the oneDAL project +.. +.. Licensed under the Apache License, Version 2.0 (the "License"); +.. you may not use this file except in compliance with the License. +.. You may obtain a copy of the License at +.. +.. http://www.apache.org/licenses/LICENSE-2.0 +.. +.. Unless required by applicable law or agreed to in writing, software +.. distributed under the License is distributed on an "AS IS" BASIS, +.. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +.. See the License for the specific language governing permissions and +.. limitations under the License. +.. include:: substitutions.rst + +============================== +Model serialization (pickling) +============================== + +Serializing objects +------------------- + +Objects in Python are bound to the process that creates them. Usually, when it comes to statistical or machine learning models, one typically wants to save fitted models for later usage - for example, by fitting a model on a large machine, saving it to disk storage, and then serving it (making predictions on new data) on other smaller machines. + +Just like other objects in Python, estimator objects from the |sklearnex| can be serialized / persisted / pickled using the built-in ``pickle`` module - for example: + +.. code-block:: python + + import pickle + import numpy as np + from sklearn.datasets import make_regression + from sklearnex.linear_model import LinearRegression + + X, y = make_regression() + model = LinearRegression().fit(X, y) + + model_file = "linear_model.pkl" + with open(model_file, "wb") as output_file: + pickle.dump(model, output_file) + + with open(model_file, "rb") as input_file: + model_deserialized = pickle.load(input_file) + + np.testing.assert_array_equal( + model_deserialized.predict(X), + model.predict(X), + ) + +.. hint:: Note that, while operations performed on CPU are usually deterministic and procedures involving random numbers allow controlling the seed, upon deserializing a model in a different machine, it is not guaranteed that outputs such as predictions on new data will be byte-by-byte reproducible due to differences in instructions sets supported by different CPUs, runtimes of backend libraries, and similar such nuances. Nevertheless, results from predictions of the same model on different machines with compatible environments (see next section) should be within numerical roundoff error. + +Serialization requirements +-------------------------- + +All estimator classes in the |sklearnex| that have a counterpart in |sklearn| (and thus participate in :ref:`patching `) inherit from that respective class from |sklearn|, and expose the same public attributes. Hence, in order to successfully serialize and deserialize a model from the |sklearnex|, it is necessary to satisfy all the requirements for serialization of |sklearn| objects, such as using the same |sklearn| version for serializing and deserializing the object - see :ref:`sklearn:pickle_persistence` for more details. + +In addition to those requirements, additional conditions need to be met in order to ensure that serialization and deserialization of objects belonging to classes from the |sklearnex| will work correctly: + +- The versions of both |sklearn| and the |sklearnex| must be the same for deserializing a given object as the versions used for serializing it. +- The version of the :external+onedal:doc:`oneDAL ` backend used for the |sklearnex| (through Python package ``dal`` or ``daal`` depending on the installation medium) must be either the same or a higher minor version within the same major version series - for example, |onedal| version 2025.10 can deserialize models saved with 2025.8, but not the other way around, and version 2026.0 might not be able to deserialize models from 2025.x versions. +- Other dependencies providing data classes that constitute object attributes, such as NumPy's arrays, must also be able to successfully serialize and deserialize in that same environment. Note that :ref:`array API classes `, which might be used as object attributes when enabling this mode, might have tighter serialization requirements than NumPy. +- The Python major version must be the same, and the minor version must be either the same or higher. + +Just like in |sklearn|, in order to ensure that deserialized models work correctly, it is highly recommended to recreate the same environment that created the serialized model in terms of python versions, package versions, and configurations of packages (e.g. build variants in the case of conda-managed environments). + +.. warning:: Note that, unlike objects from |sklearn|, objects from the |sklearnex| will not necessarily issue a warning when deserializing them with an incompatible library version. + +Serialization of GPU models +--------------------------- + +Be aware that if using the :ref:`target offload option ` to fit models on GPU or on another SyCL device, upon deserialization of those models, the internal data behind them will be re-created on host (CPU), hence the deserialized models will become CPU/host ones and will not be able to make predictions on GPU data. + +If persistence of GPU-only models is desired, one can instead use :ref:`array API classes with GPU support `, which might have a different logic for serialization that preserves the device. + +Currently, the only array API library with SyCL support known to provide serializable GPU arrays is `PyTorch `__. + +.. warning:: If serialization of models is desired, avoid usage of |dpnp| GPU arrays as they are not serilizable. + +Example: + +.. code-block:: python + + import os + os.environ["SCIPY_ARRAY_API"] = "1" + + import pickle + import torch + import numpy as np + from sklearn.datasets import make_regression + from sklearnex import config_context + from sklearnex.linear_model import LinearRegression + + X_np, y_np = make_regression() + X = torch.tensor(X_np, dtype=torch.float32, device="xpu") + y = torch.tensor(y_np, dtype=torch.float32, device="xpu") + + with config_context(array_api_dispatch=True): + model = LinearRegression().fit(X, y) + pred_fresh = model.predict(X) + + assert isinstance(pred_fresh, torch.Tensor) + + model_deserialized = pickle.loads( pickle.dumps(model) ) + with config_context(array_api_dispatch=True): + pred_deserialized = model_deserialized.predict(X) + + np.testing.assert_allclose( + pred_fresh.cpu().numpy(), + pred_deserialized.cpu().numpy(), + ) + +Configurations are not serializable +----------------------------------- + +Be aware that serialization of model objects does not imply saving of global or local configurations. For example, a model that was fitted to :ref:`array API classes ` will have those same array API classes as attributes, but array API mode is not enabled by default in |sklearn| (and by extension, not in the |sklearnex| either). Hence, if the :ref:`global configuration ` was modified to enable array API support, the deserialized model might not be usable in a new Python process until that setting (array API) is enabled. + +Likewise, other process-level internal settings, such as efficiency parameters that are modifiable through static class methods of estimators (currently undocumented), are not saved along with a model object, since they are not managed by it. From b7b0c672d5a9f7a2de5c5e9df96a068661490d1b Mon Sep 17 00:00:00 2001 From: David Cortes Date: Tue, 14 Oct 2025 17:07:32 +0200 Subject: [PATCH 7/8] standardize references to sycl --- doc/sources/array_api.rst | 6 +++--- doc/sources/input-types.rst | 2 +- doc/sources/oneapi-gpu.rst | 12 ++++++------ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/doc/sources/array_api.rst b/doc/sources/array_api.rst index 3901ae87a1..f493878b36 100644 --- a/doc/sources/array_api.rst +++ b/doc/sources/array_api.rst @@ -32,7 +32,7 @@ on GPU without moving the data from host to device. be :external+sklearn:doc:`enabled in scikit-learn `, which requires either changing global settings or using a ``config_context``, plus installing additional dependencies such as ``array-api-compat``. -When passing array API inputs whose data is on a SyCL-enabled device (e.g. an Intel GPU), as +When passing array API inputs whose data is on a SYCL-enabled device (e.g. an Intel GPU), as supported for example by `PyTorch `__ and |dpnp|, if array API support is enabled and the requested operation (e.g. call to ``.fit()`` / ``.predict()`` on the estimator class being used) is :ref:`supported on device/GPU `, computations @@ -50,10 +50,10 @@ through options ``allow_sklearn_after_onedal`` (default is ``True``) and ``allow If array API is enabled for |sklearn| and the estimator being used has array API support on |sklearn| (which can be verified by attribute ``array_api_support`` from :obj:`sklearn.utils.get_tags`), then array API inputs whose data -is allocated neither on CPU nor on a SyCL device will be forwarded directly to the unpatched methods from |sklearn|, +is allocated neither on CPU nor on a SYCL device will be forwarded directly to the unpatched methods from |sklearn|, without using the accelerated versions from this library, regardless of option ``allow_sklearn_after_onedal``. -While other array API inputs (e.g. torch arrays with data allocated on a non-SyCL device) might be supported +While other array API inputs (e.g. torch arrays with data allocated on a non-SYCL device) might be supported by the |sklearnex| in cases where the same class from |sklearn| doesn't support array API, note that the data will be transferred to host if it isn't already, and the computations will happen on CPU. diff --git a/doc/sources/input-types.rst b/doc/sources/input-types.rst index 28cf7f45c9..ceb61c5f80 100644 --- a/doc/sources/input-types.rst +++ b/doc/sources/input-types.rst @@ -50,6 +50,6 @@ enabled the input is unsupported). - For SciPy CSR matrix / array, index arrays are always copied. Note that sparse matrices in formats other than CSR will be converted to CSR, which implies more than just data copying. - Heterogeneous NumPy array - - If SyCL queue is provided for device without ``float64`` support but data are ``float64``, data are copied with reduced precision. + - If SYCL queue is provided for device without ``float64`` support but data are ``float64``, data are copied with reduced precision. - If :ref:`Array API ` is not enabled then data from GPU devices are always copied to the host device and then result table (for applicable methods) is copied to the source device. diff --git a/doc/sources/oneapi-gpu.rst b/doc/sources/oneapi-gpu.rst index cf3d2614d9..6cf81f4a4a 100644 --- a/doc/sources/oneapi-gpu.rst +++ b/doc/sources/oneapi-gpu.rst @@ -22,11 +22,11 @@ GPU support Overview -------- -|sklearnex| can execute computations on different devices (CPUs and GPUs, including integrated GPUs from laptops and desktops) supported by the SyCL framework. +|sklearnex| can execute computations on different devices (CPUs and GPUs, including integrated GPUs from laptops and desktops) supported by the SYCL framework. The device used for computations can be easily controlled through the ``target_offload`` option in config contexts, which moves data to GPU if it's not already there - see :ref:`config_contexts` and rest of this page for more details). -For finer-grained controlled (e.g. operating on arrays that are already in a given device's memory), it can also interact with on-device :ref:`array API classes ` like |dpnp_array|, and with SyCL-related objects from package |dpctl| such as :obj:`dpctl.SyclQueue`. +For finer-grained controlled (e.g. operating on arrays that are already in a given device's memory), it can also interact with on-device :ref:`array API classes ` like |dpnp_array|, and with SYCL-related objects from package |dpctl| such as :obj:`dpctl.SyclQueue`. .. Note:: Note that not every operation from every estimator is supported on GPU - see the :ref:`GPU support table ` for more information. See also :ref:`verbose` to verify where computations are performed. @@ -91,7 +91,7 @@ Target offload option Just like |sklearn|, the |sklearnex| can use configuration contexts and global options to modify how it interacts with different inputs - see :ref:`config_contexts` for details. -In particular, the |sklearnex| allows an option ``target_offload`` which can be passed a SyCL device name like ``"gpu"`` indicating where the operations should be performed, moving the data to that device in the process if it's not already there; or a :obj:`dpctl.SyclQueue` object from an already-existing queue on a device. +In particular, the |sklearnex| allows an option ``target_offload`` which can be passed a SYCL device name like ``"gpu"`` indicating where the operations should be performed, moving the data to that device in the process if it's not already there; or a :obj:`dpctl.SyclQueue` object from an already-existing queue on a device. .. hint:: If repeated operations are going to be performed on the same data (e.g. cross-validators, resamplers, missing data imputers, etc.), it's recommended to use the array API option instead - see the next section for details. @@ -111,7 +111,7 @@ Example: model.fit(X, y) pred = model.predict(X) - .. tab:: Passing a SyCL queue + .. tab:: Passing a SYCL queue .. code-block:: python import dpctl @@ -133,7 +133,7 @@ Example: GPU arrays through array API ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -As another option, computations can also be performed on data that is already on a SyCL device without moving it there if it belongs to an array API-compatible class, such as |dpnp_array| or `torch.tensor `__. +As another option, computations can also be performed on data that is already on a SYCL device without moving it there if it belongs to an array API-compatible class, such as |dpnp_array| or `torch.tensor `__. This is particularly useful when multiple operations are performed on the same data (e.g. cross validators, stacked ensembles, etc.), or when the data is meant to interact with other libraries besides the |sklearnex|. Be aware that it requires enabling array API support in |sklearn|, which comes with additional dependencies. @@ -220,7 +220,7 @@ Example: model.fit(X, y) -Note that, if array API had been enabled, the snippet above would use the data as-is on the device where it resides, but without array API, it implies data movements using the SyCL queue contained by those objects. +Note that, if array API had been enabled, the snippet above would use the data as-is on the device where it resides, but without array API, it implies data movements using the SYCL queue contained by those objects. .. note:: All the input data for an algorithm must reside on the same device. From 03e714260f6e41d4a596f8320a20088cfe45346a Mon Sep 17 00:00:00 2001 From: David Cortes Date: Tue, 14 Oct 2025 17:08:24 +0200 Subject: [PATCH 8/8] standardize references to sycl --- doc/sources/serialization.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/sources/serialization.rst b/doc/sources/serialization.rst index 3d835339af..fc5964b41c 100644 --- a/doc/sources/serialization.rst +++ b/doc/sources/serialization.rst @@ -67,11 +67,11 @@ Just like in |sklearn|, in order to ensure that deserialized models work correct Serialization of GPU models --------------------------- -Be aware that if using the :ref:`target offload option ` to fit models on GPU or on another SyCL device, upon deserialization of those models, the internal data behind them will be re-created on host (CPU), hence the deserialized models will become CPU/host ones and will not be able to make predictions on GPU data. +Be aware that if using the :ref:`target offload option ` to fit models on GPU or on another SYCL device, upon deserialization of those models, the internal data behind them will be re-created on host (CPU), hence the deserialized models will become CPU/host ones and will not be able to make predictions on GPU data. If persistence of GPU-only models is desired, one can instead use :ref:`array API classes with GPU support `, which might have a different logic for serialization that preserves the device. -Currently, the only array API library with SyCL support known to provide serializable GPU arrays is `PyTorch `__. +Currently, the only array API library with SYCL support known to provide serializable GPU arrays is `PyTorch `__. .. warning:: If serialization of models is desired, avoid usage of |dpnp| GPU arrays as they are not serilizable.