diff --git a/onedal/tests/utils/_device_selection.py b/onedal/tests/utils/_device_selection.py index 4b1eb3c5ed..e84f52931a 100644 --- a/onedal/tests/utils/_device_selection.py +++ b/onedal/tests/utils/_device_selection.py @@ -15,17 +15,23 @@ # ============================================================================== import functools +from collections.abc import Iterable import pytest -from ...utils._third_party import dpctl_available +from onedal.utils._third_party import SyclQueue, dpctl_available if dpctl_available: import dpctl - from dpctl.memory import MemoryUSMDevice, MemoryUSMShared + queue_creation_err = dpctl._sycl_queue.SyclQueueCreationError +else: + queue_creation_err = (RuntimeError, ValueError) -def get_queues(filter_="cpu,gpu"): + +# lru_cache is used to limit the number of SyclQueues generated +# @functools.lru_cache() +def get_queues(filter_: str = "cpu,gpu") -> list[SyclQueue]: """Get available dpctl.SycQueues for testing. This is meant to be used for testing purposes only. @@ -33,12 +39,14 @@ def get_queues(filter_="cpu,gpu"): Parameters ---------- filter_ : str, default="cpu,gpu" - Configure output list with available dpctl.SycQueues for testing. + Configure output list with availabe SyclQueues for testing. + SyclQueues are generated from a comma-separated string with + each element conforming to SYCL's ``filter_selector``. Returns ------- - list[dpctl.SycQueue] - The list of dpctl.SycQueue. + list[SyclQueue] + The list of SyclQueues. Notes ----- @@ -47,32 +55,41 @@ def get_queues(filter_="cpu,gpu"): """ queues = [None] if "cpu" in filter_ else [] - if dpctl_available: - if dpctl.has_cpu_devices() and "cpu" in filter_: - queues.append(pytest.param(dpctl.SyclQueue("cpu"), id="SyclQueue_CPU")) - if dpctl.has_gpu_devices() and "gpu" in filter_: - queues.append(pytest.param(dpctl.SyclQueue("gpu"), id="SyclQueue_GPU")) + for i in filter_.split(","): + try: + queues.append(pytest.param(SyclQueue(i), id=f"SyclQueue_{i.upper()}")) + except queue_creation_err: + pass return queues -def get_memory_usm(): - if dpctl_available: - return [MemoryUSMDevice, MemoryUSMShared] - return [] +def is_sycl_device_available(targets: Iterable[str]) -> bool: + """Check if a SYCL device is available. + + This is meant to be used for testing purposes only. + The check succeeds if all SYCL devices in targets are + available. + Parameters + ---------- + targets : Iterable[str] + SYCL filter strings of possible devices. -def is_dpctl_device_available(targets): - if not isinstance(targets, (list, tuple)): - raise TypeError("`targets` should be a list or tuple of strings.") - if dpctl_available: - for device in targets: - if device == "cpu" and not dpctl.has_cpu_devices(): - return False - if device == "gpu" and not dpctl.has_gpu_devices(): - return False - return True - return False + Returns + ------- + bool + Flag if all of the SYCL targets are available. + + """ + if not isinstance(targets, Iterable): + raise TypeError("`targets` should be an iterable of strings.") + for device in targets: + try: + SyclQueue(device) + except queue_creation_err: + return False + return True def pass_if_not_implemented_for_gpu(reason=""): diff --git a/sklearnex/tests/test_config.py b/sklearnex/tests/test_config.py index 317b7531ed..a93d920260 100644 --- a/sklearnex/tests/test_config.py +++ b/sklearnex/tests/test_config.py @@ -24,7 +24,7 @@ import onedal import sklearnex -from onedal.tests.utils._device_selection import is_dpctl_device_available +from onedal.tests.utils._device_selection import is_sycl_device_available def test_get_config_contains_sklearn_params(): @@ -152,7 +152,7 @@ def test_host_backend_target_offload(target): @pytest.mark.skipif( - not is_dpctl_device_available(["gpu"]), reason="Requires a gpu for fallback testing" + not is_sycl_device_available(["gpu"]), reason="Requires a gpu for fallback testing" ) def test_fallback_to_host(caplog): # force a fallback to cpu with direct use of dispatch and PatchingConditionsChain diff --git a/sklearnex/tests/test_memory_usage.py b/sklearnex/tests/test_memory_usage.py index f6f84b45ae..e8586f2e52 100644 --- a/sklearnex/tests/test_memory_usage.py +++ b/sklearnex/tests/test_memory_usage.py @@ -33,7 +33,7 @@ _convert_to_dataframe, get_dataframes_and_queues, ) -from onedal.tests.utils._device_selection import get_queues, is_dpctl_device_available +from onedal.tests.utils._device_selection import get_queues, is_sycl_device_available from onedal.utils._array_api import _get_sycl_namespace from sklearnex import config_context from sklearnex.tests.utils import ( @@ -288,7 +288,7 @@ def test_memory_leaks(estimator, dataframe, queue, order, data_shape): @pytest.mark.skipif( - os.getenv("ZES_ENABLE_SYSMAN") is None or not is_dpctl_device_available(["gpu"]), + os.getenv("ZES_ENABLE_SYSMAN") is None or not is_sycl_device_available(["gpu"]), reason="SYCL device memory leak check requires the level zero sysman", ) @pytest.mark.parametrize("queue", get_queues("gpu"))