Skip to content

Commit b12c4f5

Browse files
authored
Fix recursive imports for daal4py module (#978)
1 parent 90d1e45 commit b12c4f5

File tree

1 file changed

+27
-25
lines changed

1 file changed

+27
-25
lines changed

sklearnex/dispatcher.py

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,37 +21,39 @@
2121
from functools import lru_cache
2222
from daal4py.sklearn._utils import daal_check_version, sklearn_check_version
2323

24-
# Classes for patching
25-
if os.environ.get('OFF_ONEDAL_IFACE') is None and daal_check_version((2021, 'P', 300)):
26-
from ._config import set_config as set_config_sklearnex
27-
from ._config import get_config as get_config_sklearnex
28-
from ._config import config_context as config_context_sklearnex
2924

30-
from .svm import SVR as SVR_sklearnex
31-
from .svm import SVC as SVC_sklearnex
32-
from .svm import NuSVR as NuSVR_sklearnex
33-
from .svm import NuSVC as NuSVC_sklearnex
34-
35-
from .neighbors import KNeighborsClassifier as KNeighborsClassifier_sklearnex
36-
from .neighbors import NearestNeighbors as NearestNeighbors_sklearnex
37-
38-
new_patching_available = True
39-
else:
40-
new_patching_available = False
41-
42-
# Scikit-learn* modules
43-
44-
import sklearn as base_module
45-
import sklearn.svm as svm_module
46-
import sklearn.neighbors as neighbors_module
25+
def _is_new_patching_available():
26+
return os.environ.get('OFF_ONEDAL_IFACE') is None \
27+
and daal_check_version((2021, 'P', 300))
4728

4829

4930
@lru_cache(maxsize=None)
5031
def get_patch_map():
5132
from daal4py.sklearn.monkeypatch.dispatcher import _get_map_of_algorithms
5233
mapping = _get_map_of_algorithms().copy()
5334

54-
if new_patching_available:
35+
if _is_new_patching_available():
36+
# Classes for patching
37+
38+
from ._config import set_config as set_config_sklearnex
39+
from ._config import get_config as get_config_sklearnex
40+
from ._config import config_context as config_context_sklearnex
41+
42+
from .svm import SVR as SVR_sklearnex
43+
from .svm import SVC as SVC_sklearnex
44+
from .svm import NuSVR as NuSVR_sklearnex
45+
from .svm import NuSVC as NuSVC_sklearnex
46+
47+
from .neighbors import KNeighborsClassifier as KNeighborsClassifier_sklearnex
48+
from .neighbors import NearestNeighbors as NearestNeighbors_sklearnex
49+
50+
# Scikit-learn* modules
51+
52+
import sklearn as base_module
53+
import sklearn.svm as svm_module
54+
import sklearn.neighbors as neighbors_module
55+
56+
# Patch for mapping
5557
# Algorithms
5658
# SVM
5759
mapping.pop('svm')
@@ -101,7 +103,7 @@ def patch_sklearn(name=None, verbose=True, global_patch=False):
101103

102104
from daal4py.sklearn import patch_sklearn as patch_sklearn_orig
103105

104-
if new_patching_available:
106+
if _is_new_patching_available():
105107
for config in ['set_config', 'get_config', 'config_context']:
106108
patch_sklearn_orig(config, verbose=False, deprecation=False,
107109
get_map=get_patch_map)
@@ -129,7 +131,7 @@ def unpatch_sklearn(name=None, global_unpatch=False):
129131
for algorithm in name:
130132
unpatch_sklearn_orig(algorithm, get_map=get_patch_map)
131133
else:
132-
if new_patching_available:
134+
if _is_new_patching_available():
133135
for config in ['set_config', 'get_config', 'config_context']:
134136
unpatch_sklearn_orig(config, get_map=get_patch_map)
135137
unpatch_sklearn_orig(name, get_map=get_patch_map)

0 commit comments

Comments
 (0)