Skip to content

Commit 6f3201a

Browse files
[maintenance] change logic in sklearnex get_namespace to follow sklearn array API expected behavior (#2747)
* Update _array_api.py * Update _array_api.py * Update _array_api.py * Update _array_api.py --------- Co-authored-by: David Cortes <[email protected]>
1 parent 59675a6 commit 6f3201a

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

sklearnex/utils/_array_api.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515
# ==============================================================================
1616

17-
"""Tools to support array_api."""
17+
"""Tools to support array API."""
1818

1919
import math
2020
from collections.abc import Callable
@@ -27,6 +27,7 @@
2727
from daal4py.sklearn._utils import sklearn_check_version
2828
from onedal.utils._array_api import _get_sycl_namespace, _is_numpy_namespace
2929

30+
from .._config import get_config
3031
from ..base import oneDALEstimator
3132

3233
if sklearn_check_version("1.6"):
@@ -83,11 +84,19 @@ def get_namespace(*arrays):
8384
True of the arrays are containers that implement the Array API spec.
8485
"""
8586

86-
sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays)
87-
88-
if sycl_type:
89-
return xp, is_array_api_compliant
90-
elif sklearn_check_version("1.2"):
87+
# check required because _get_sycl_namespace only verifies that *arrays
88+
# are of the same sycl namespace, not of the same array namespace.
89+
# When array_api_dispatch is enabled, then sklearn's version is required
90+
# for the additional array namespace check. This is now possible with
91+
# dpnp and dpctl as they both support `__array_namespace__`.
92+
if not get_config().get("array_api_dispatch", False):
93+
sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays)
94+
if sycl_type:
95+
return xp, is_array_api_compliant
96+
97+
# sklearn contains a specially patched numpy wrapper that should be
98+
# reused which is yielded from sklearn's get_namespace.
99+
if sklearn_check_version("1.2"):
91100
return sklearn_get_namespace(*arrays)
92101
else:
93102
return np, False

0 commit comments

Comments
 (0)