|
14 | 14 | # limitations under the License. |
15 | 15 | # ============================================================================== |
16 | 16 |
|
17 | | -"""Tools to support array_api.""" |
| 17 | +"""Tools to support array API.""" |
18 | 18 |
|
19 | 19 | import math |
20 | 20 | from collections.abc import Callable |
|
27 | 27 | from daal4py.sklearn._utils import sklearn_check_version |
28 | 28 | from onedal.utils._array_api import _get_sycl_namespace, _is_numpy_namespace |
29 | 29 |
|
| 30 | +from .._config import get_config |
30 | 31 | from ..base import oneDALEstimator |
31 | 32 |
|
32 | 33 | if sklearn_check_version("1.6"): |
@@ -83,11 +84,19 @@ def get_namespace(*arrays): |
83 | 84 | True of the arrays are containers that implement the Array API spec. |
84 | 85 | """ |
85 | 86 |
|
86 | | - sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays) |
87 | | - |
88 | | - if sycl_type: |
89 | | - return xp, is_array_api_compliant |
90 | | - elif sklearn_check_version("1.2"): |
| 87 | + # check required because _get_sycl_namespace only verifies that *arrays |
| 88 | + # are of the same sycl namespace, not of the same array namespace. |
| 89 | + # When array_api_dispatch is enabled, then sklearn's version is required |
| 90 | + # for the additional array namespace check. This is now possible with |
| 91 | + # dpnp and dpctl as they both support `__array_namespace__`. |
| 92 | + if not get_config().get("array_api_dispatch", False): |
| 93 | + sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays) |
| 94 | + if sycl_type: |
| 95 | + return xp, is_array_api_compliant |
| 96 | + |
| 97 | + # sklearn contains a specially patched numpy wrapper that should be |
| 98 | + # reused which is yielded from sklearn's get_namespace. |
| 99 | + if sklearn_check_version("1.2"): |
91 | 100 | return sklearn_get_namespace(*arrays) |
92 | 101 | else: |
93 | 102 | return np, False |
|
0 commit comments