Skip to content
Open
21 changes: 15 additions & 6 deletions sklearnex/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
# ==============================================================================

"""Tools to support array_api."""
"""Tools to support array API."""

import math
from collections.abc import Callable
Expand All @@ -27,6 +27,7 @@
from daal4py.sklearn._utils import sklearn_check_version
from onedal.utils._array_api import _get_sycl_namespace, _is_numpy_namespace

from .._config import get_config
from ..base import oneDALEstimator

if sklearn_check_version("1.6"):
Expand Down Expand Up @@ -83,11 +84,19 @@ def get_namespace(*arrays):
True of the arrays are containers that implement the Array API spec.
"""

sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays)

if sycl_type:
return xp, is_array_api_compliant
elif sklearn_check_version("1.2"):
# check required because _get_sycl_namespace only verifies that *arrays
# are of the same sycl namespace, not of the same array namespace.
# When array_api_dispatch is enabled, then sklearn's version is required
# for the additional array namespace check. This is now possible with
# dpnp and dpctl as they both support `__array_namespace__`.
if not get_config().get("array_api_dispatch", False):
sycl_type, xp, is_array_api_compliant = _get_sycl_namespace(*arrays)
if sycl_type:
return xp, is_array_api_compliant

# sklearn contains a specially patched numpy wrapper that should be
# reused which is yielded from sklearn's get_namespace.
if sklearn_check_version("1.2"):
return sklearn_get_namespace(*arrays)
else:
return np, False
Expand Down
Loading