Skip to content

Commit 09902a8

Browse files
Device support in Intel(R) Extension for Scikit-learn* (#772) (#787)
* policy support in onedal * add global config to sklearnex * last changes * refactor device offload * refactor onedal * fix pep8 onedal * refactor sklearnex * add patching for configs * enable devices for fit_proba, predict_proba is svm * add usm_ndarray support in daal4py * reset config in test_config * fix dispatch in SVR, NuSVR * remove sklearex deps in AdaBoost and GBT * fix pep * implement device dispatch for daal4py.sklearn * update support_usm_ndarray decorator usage * fix daal4py dispatch * pep fix * fix for no data parameters in device offload * deselect config_context test * add requirements file for dppy * fix daal4py dispatching * fix SVC pathing under multithreaded sklearn metaestimators * always patch configs in sklearnex * remove cyclic deps between sklearnex and daal4py * add host device in patch logging * fix pep * update deselected tests * fix device offload connection between sklearnex and daal4py * separate device utils in daal4py from common utils * fix config patching * fix pep * fix device offload in daal4py * debug print * fix patching for previous dal versions * unbound sklearnex from dpctl * pep fix * fix import of host binary * switch off onedal datatypes tests on GPU * pep fix * turn off sparse gpu support in svc * add debug print * fix verbose mode in case dpctl is not installed * switch off linear kernel tests on gpu (cherry picked from commit 9866376) Co-authored-by: Michael Smirnov <[email protected]>
1 parent 76e0cb7 commit 09902a8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

50 files changed

+1467
-479
lines changed

daal4py/oneapi/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
from daal4py._oneapi import (
3939
_get_sycl_ctxt,
4040
_get_device_name_sycl_ctxt,
41-
_get_sycl_ctxt_params
41+
_get_sycl_ctxt_params,
42+
_get_in_sycl_ctxt
4243
)
4344
except ModuleNotFoundError:
4445
raise

daal4py/sklearn/_device_offload.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#===============================================================================
2+
# Copyright 2014-2021 Intel Corporation
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#===============================================================================
16+
17+
from functools import wraps
18+
19+
try:
20+
from sklearnex._config import get_config
21+
from sklearnex._device_offload import (_get_global_queue,
22+
_transfer_to_host,
23+
_copy_to_usm)
24+
_sklearnex_available = True
25+
except ImportError:
26+
import logging
27+
logging.warning('Device support is limited in daal4py patching. '
28+
'Use Intel(R) Extension for Scikit-learn* '
29+
'for full experience.')
30+
_sklearnex_available = False
31+
32+
33+
def _get_host_inputs(*args, **kwargs):
34+
q = _get_global_queue()
35+
q, hostargs = _transfer_to_host(q, *args)
36+
q, hostvalues = _transfer_to_host(q, *kwargs.values())
37+
hostkwargs = dict(zip(kwargs.keys(), hostvalues))
38+
return q, hostargs, hostkwargs
39+
40+
41+
def _extract_usm_iface(*args, **kwargs):
42+
allargs = (*args, *kwargs.values())
43+
if len(allargs) == 0:
44+
return None
45+
return getattr(allargs[0],
46+
'__sycl_usm_array_interface__',
47+
None)
48+
49+
50+
def _run_on_device(func, queue, obj=None, *args, **kwargs):
51+
def dispatch_by_obj(obj, func, *args, **kwargs):
52+
if obj is not None:
53+
return func(obj, *args, **kwargs)
54+
return func(*args, **kwargs)
55+
56+
if queue is not None:
57+
from daal4py.oneapi import sycl_context, _get_in_sycl_ctxt
58+
59+
if _get_in_sycl_ctxt() is False:
60+
host_offload = get_config()['allow_fallback_to_host']
61+
62+
with sycl_context('gpu' if queue.sycl_device.is_gpu else 'cpu',
63+
host_offload_on_fail=host_offload):
64+
return dispatch_by_obj(obj, func, *args, **kwargs)
65+
return dispatch_by_obj(obj, func, *args, **kwargs)
66+
67+
68+
def support_usm_ndarray(freefunc=False):
69+
def decorator(func):
70+
def wrapper_impl(obj, *args, **kwargs):
71+
if _sklearnex_available:
72+
usm_iface = _extract_usm_iface(*args, **kwargs)
73+
q, hostargs, hostkwargs = _get_host_inputs(*args, **kwargs)
74+
result = _run_on_device(func, q, obj, *hostargs, **hostkwargs)
75+
if usm_iface is not None and hasattr(result, '__array_interface__'):
76+
return _copy_to_usm(q, result)
77+
return result
78+
return _run_on_device(func, None, obj, *args, **kwargs)
79+
80+
if freefunc:
81+
@wraps(func)
82+
def wrapper_free(*args, **kwargs):
83+
return wrapper_impl(None, *args, **kwargs)
84+
return wrapper_free
85+
else:
86+
@wraps(func)
87+
def wrapper_with_self(self, *args, **kwargs):
88+
return wrapper_impl(self, *args, **kwargs)
89+
return wrapper_with_self
90+
return decorator

daal4py/sklearn/_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#===============================================================================
1616

1717
import numpy as np
18+
import sys
1819

1920
from daal4py import _get__daal_link_version__ as dv
2021
from sklearn import __version__ as sklearn_version
@@ -25,7 +26,6 @@ def set_idp_sklearn_verbose():
2526
import logging
2627
import warnings
2728
import os
28-
import sys
2929
logLevel = os.environ.get("IDP_SKLEARN_VERBOSE")
3030
try:
3131
if logLevel is not None:
@@ -94,7 +94,6 @@ def make2d(X):
9494

9595

9696
def get_patch_message(s):
97-
import sys
9897
if s == "daal":
9998
message = "running accelerated version on "
10099
if 'daal4py.oneapi' in sys.modules:

daal4py/sklearn/cluster/_dbscan.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
from daal4py.sklearn._utils import (make2d, getFPType, get_patch_message)
2727
import logging
2828

29+
from .._device_offload import support_usm_ndarray
30+
2931

3032
def _daal_dbscan(X, eps=0.5, min_samples=5, sample_weight=None):
3133
ww = make2d(sample_weight) if sample_weight is not None else None
@@ -203,6 +205,7 @@ def __init__(
203205
self.p = p
204206
self.n_jobs = n_jobs
205207

208+
@support_usm_ndarray()
206209
def fit(self, X, y=None, sample_weight=None):
207210
"""Perform DBSCAN clustering from features, or distance matrix.
208211
@@ -258,3 +261,7 @@ def fit(self, X, y=None, sample_weight=None):
258261
"sklearn.cluster.DBSCAN."
259262
"fit: " + get_patch_message("sklearn"))
260263
return super().fit(X, y, sample_weight=sample_weight)
264+
265+
@support_usm_ndarray()
266+
def fit_predict(self, X, y=None, sample_weight=None):
267+
return super().fit_predict(X, y, sample_weight)

daal4py/sklearn/cluster/_k_means_0_22.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
import daal4py
3737
from .._utils import getFPType, get_patch_message, daal_check_version
38+
from .._device_offload import support_usm_ndarray
3839
import logging
3940

4041

@@ -320,8 +321,14 @@ def __init__(self, n_clusters=8, init='k-means++', n_init=10,
320321
n_init=n_init, verbose=verbose, random_state=random_state,
321322
copy_x=copy_x, n_jobs=n_jobs, algorithm=algorithm)
322323

324+
@support_usm_ndarray()
323325
def fit(self, X, y=None, sample_weight=None):
324326
return _fit(self, X, y=y, sample_weight=sample_weight)
325327

328+
@support_usm_ndarray()
326329
def predict(self, X, sample_weight=None):
327330
return _predict(self, X, sample_weight=sample_weight)
331+
332+
@support_usm_ndarray()
333+
def fit_predict(self, X, y=None, sample_weight=None):
334+
return super().fit_predict(X, y, sample_weight)

daal4py/sklearn/cluster/_k_means_0_23.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
getFPType,
4040
get_patch_message,
4141
sklearn_check_version)
42+
from .._device_offload import support_usm_ndarray
4243
import logging
4344

4445

@@ -430,8 +431,14 @@ def __init__(
430431
algorithm=algorithm,
431432
)
432433

434+
@support_usm_ndarray()
433435
def fit(self, X, y=None, sample_weight=None):
434436
return _fit(self, X, y=y, sample_weight=sample_weight)
435437

438+
@support_usm_ndarray()
436439
def predict(self, X, sample_weight=None):
437440
return _predict(self, X, sample_weight=sample_weight)
441+
442+
@support_usm_ndarray()
443+
def fit_predict(self, X, y=None, sample_weight=None):
444+
return super().fit_predict(X, y, sample_weight)

daal4py/sklearn/decomposition/_pca.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import daal4py
2727
from .._utils import getFPType, get_patch_message, sklearn_check_version
28+
from .._device_offload import support_usm_ndarray
2829
import logging
2930

3031
if sklearn_check_version('0.22'):
@@ -278,6 +279,7 @@ def _transform_daal4py(self, X, whiten=False, scale_eigenvalues=True, check_X=Tr
278279

279280
return tr_res.transformedData
280281

282+
@support_usm_ndarray()
281283
def transform(self, X):
282284
if self.n_components_ > 0:
283285
logging.info(
@@ -291,6 +293,7 @@ def transform(self, X):
291293
"transform: " + get_patch_message("sklearn"))
292294
return PCA_original.transform(self, X)
293295

296+
@support_usm_ndarray()
294297
def fit_transform(self, X, y=None):
295298
U, S, _ = self._fit(X)
296299

daal4py/sklearn/ensemble/_forest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import daal4py
2323
from .._utils import (getFPType, get_patch_message)
24+
from .._device_offload import support_usm_ndarray
2425
from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
2526
import logging
2627

@@ -615,6 +616,7 @@ def __init__(self,
615616
self.maxBins = maxBins
616617
self.minBinSize = minBinSize
617618

619+
@support_usm_ndarray()
618620
def fit(self, X, y, sample_weight=None):
619621
"""
620622
Build a forest of trees from the training set (X, y).
@@ -643,6 +645,7 @@ def fit(self, X, y, sample_weight=None):
643645
"""
644646
return _fit_classifier(self, X, y, sample_weight=sample_weight)
645647

648+
@support_usm_ndarray()
646649
def predict(self, X):
647650
"""
648651
Predict class for X.
@@ -682,6 +685,7 @@ def predict(self, X):
682685
"predict: " + get_patch_message("daal"))
683686
return _daal_predict_classifier(self, X)
684687

688+
@support_usm_ndarray()
685689
def predict_proba(self, X):
686690
"""
687691
Predict class probabilities for X.
@@ -904,6 +908,7 @@ def __init__(self,
904908
self.maxBins = maxBins
905909
self.minBinSize = minBinSize
906910

911+
@support_usm_ndarray()
907912
def fit(self, X, y, sample_weight=None):
908913
"""
909914
Build a forest of trees from the training set (X, y).
@@ -932,6 +937,7 @@ def fit(self, X, y, sample_weight=None):
932937
"""
933938
return _fit_regressor(self, X, y, sample_weight=sample_weight)
934939

940+
@support_usm_ndarray()
935941
def predict(self, X):
936942
"""
937943
Predict class for X.

daal4py/sklearn/linear_model/_coordinate_descent.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from sklearn.exceptions import ConvergenceWarning
3434
from sklearn.preprocessing import normalize
3535

36+
from .._device_offload import support_usm_ndarray
37+
3638

3739
def _daal4py_check(self, X, y, check_input):
3840
_fptype = getFPType(X)
@@ -502,12 +504,15 @@ def __init__(
502504
)
503505

504506
if sklearn_check_version('0.23'):
507+
@support_usm_ndarray()
505508
def fit(self, X, y, sample_weight=None, check_input=True):
506509
return _fit(self, X, y, sample_weight=sample_weight, check_input=check_input)
507510
else:
511+
@support_usm_ndarray()
508512
def fit(self, X, y, check_input=True):
509513
return _fit(self, X, y, check_input=check_input)
510514

515+
@support_usm_ndarray()
511516
def predict(self, X):
512517
"""Predict using the linear model
513518
@@ -642,12 +647,15 @@ def __init__(
642647
)
643648

644649
if sklearn_check_version('0.23'):
650+
@support_usm_ndarray()
645651
def fit(self, X, y, sample_weight=None, check_input=True):
646652
return _fit(self, X, y, sample_weight, check_input)
647653
else:
654+
@support_usm_ndarray()
648655
def fit(self, X, y, check_input=True):
649656
return _fit(self, X, y, check_input)
650657

658+
@support_usm_ndarray()
651659
def predict(self, X):
652660
"""Predict using the linear model
653661

daal4py/sklearn/linear_model/_linear.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ..utils.validation import _daal_check_array, _daal_check_X_y
2323
from ..utils.base import _daal_validate_data
2424
from .._utils import sklearn_check_version
25+
from .._device_offload import support_usm_ndarray
2526
from sklearn.utils.fixes import sparse_lsqr
2627
from sklearn.utils.validation import _check_sample_weight
2728
from sklearn.utils import check_array
@@ -253,6 +254,7 @@ def __init__(
253254
n_jobs=n_jobs
254255
)
255256

257+
@support_usm_ndarray()
256258
def fit(self, X, y, sample_weight=None):
257259
if sklearn_check_version('1.0'):
258260
self._normalize = _deprecate_normalize(
@@ -271,5 +273,6 @@ def fit(self, X, y, sample_weight=None):
271273
)
272274
return _fit_linear(self, X, y, sample_weight=sample_weight)
273275

276+
@support_usm_ndarray()
274277
def predict(self, X):
275278
return _predict_linear(self, X)

0 commit comments

Comments
 (0)