|
21 | 21 | from functools import lru_cache |
22 | 22 | from daal4py.sklearn._utils import daal_check_version, sklearn_check_version |
23 | 23 |
|
24 | | -# Classes for patching |
25 | | -if os.environ.get('OFF_ONEDAL_IFACE') is None and daal_check_version((2021, 'P', 300)): |
26 | | - from ._config import set_config as set_config_sklearnex |
27 | | - from ._config import get_config as get_config_sklearnex |
28 | | - from ._config import config_context as config_context_sklearnex |
29 | 24 |
|
30 | | - from .svm import SVR as SVR_sklearnex |
31 | | - from .svm import SVC as SVC_sklearnex |
32 | | - from .svm import NuSVR as NuSVR_sklearnex |
33 | | - from .svm import NuSVC as NuSVC_sklearnex |
34 | | - |
35 | | - from .neighbors import KNeighborsClassifier as KNeighborsClassifier_sklearnex |
36 | | - from .neighbors import NearestNeighbors as NearestNeighbors_sklearnex |
37 | | - |
38 | | - new_patching_available = True |
39 | | -else: |
40 | | - new_patching_available = False |
41 | | - |
42 | | -# Scikit-learn* modules |
43 | | - |
44 | | -import sklearn as base_module |
45 | | -import sklearn.svm as svm_module |
46 | | -import sklearn.neighbors as neighbors_module |
| 25 | +def _is_new_patching_available(): |
| 26 | + return os.environ.get('OFF_ONEDAL_IFACE') is None \ |
| 27 | + and daal_check_version((2021, 'P', 300)) |
47 | 28 |
|
48 | 29 |
|
49 | 30 | @lru_cache(maxsize=None) |
50 | 31 | def get_patch_map(): |
51 | 32 | from daal4py.sklearn.monkeypatch.dispatcher import _get_map_of_algorithms |
52 | 33 | mapping = _get_map_of_algorithms().copy() |
53 | 34 |
|
54 | | - if new_patching_available: |
| 35 | + if _is_new_patching_available(): |
| 36 | + # Classes for patching |
| 37 | + |
| 38 | + from ._config import set_config as set_config_sklearnex |
| 39 | + from ._config import get_config as get_config_sklearnex |
| 40 | + from ._config import config_context as config_context_sklearnex |
| 41 | + |
| 42 | + from .svm import SVR as SVR_sklearnex |
| 43 | + from .svm import SVC as SVC_sklearnex |
| 44 | + from .svm import NuSVR as NuSVR_sklearnex |
| 45 | + from .svm import NuSVC as NuSVC_sklearnex |
| 46 | + |
| 47 | + from .neighbors import KNeighborsClassifier as KNeighborsClassifier_sklearnex |
| 48 | + from .neighbors import NearestNeighbors as NearestNeighbors_sklearnex |
| 49 | + |
| 50 | + # Scikit-learn* modules |
| 51 | + |
| 52 | + import sklearn as base_module |
| 53 | + import sklearn.svm as svm_module |
| 54 | + import sklearn.neighbors as neighbors_module |
| 55 | + |
| 56 | + # Patch for mapping |
55 | 57 | # Algorithms |
56 | 58 | # SVM |
57 | 59 | mapping.pop('svm') |
@@ -101,7 +103,7 @@ def patch_sklearn(name=None, verbose=True, global_patch=False): |
101 | 103 |
|
102 | 104 | from daal4py.sklearn import patch_sklearn as patch_sklearn_orig |
103 | 105 |
|
104 | | - if new_patching_available: |
| 106 | + if _is_new_patching_available(): |
105 | 107 | for config in ['set_config', 'get_config', 'config_context']: |
106 | 108 | patch_sklearn_orig(config, verbose=False, deprecation=False, |
107 | 109 | get_map=get_patch_map) |
@@ -129,7 +131,7 @@ def unpatch_sklearn(name=None, global_unpatch=False): |
129 | 131 | for algorithm in name: |
130 | 132 | unpatch_sklearn_orig(algorithm, get_map=get_patch_map) |
131 | 133 | else: |
132 | | - if new_patching_available: |
| 134 | + if _is_new_patching_available(): |
133 | 135 | for config in ['set_config', 'get_config', 'config_context']: |
134 | 136 | unpatch_sklearn_orig(config, get_map=get_patch_map) |
135 | 137 | unpatch_sklearn_orig(name, get_map=get_patch_map) |
0 commit comments