diff --git a/python/cudf_polars/cudf_polars/utils/config.py b/python/cudf_polars/cudf_polars/utils/config.py index d774b2a6624..6cc10e77049 100644 --- a/python/cudf_polars/cudf_polars/utils/config.py +++ b/python/cudf_polars/cudf_polars/utils/config.py @@ -1096,12 +1096,21 @@ def from_polars_engine( cuda_stream_policy = _convert_cuda_stream_policy(user_cuda_stream_policy) # Pool policy is only supported by the rapidsmpf runtime. - if isinstance(cuda_stream_policy, CUDAStreamPoolConfig) and ( + is_pool = isinstance(cuda_stream_policy, CUDAStreamPoolConfig) + if is_pool and ( (executor.name != "streaming") or (executor.name == "streaming" and executor.runtime != Runtime.RAPIDSMPF) ): raise ValueError( - "CUDAStreamPolicy.POOL is only supported by the rapidsmpf runtime." + "The rapidsmpf pool policy is only supported with 'runtime=\"rapidsmpf\"'." + ) + + elif not is_pool and ( + executor.name == "streaming" and executor.runtime == Runtime.RAPIDSMPF + ): + # Validate that we're using the rapidsmpf pool with the rapidsmpf runtime. + raise ValueError( + f"The rapidsmpf runtime must use the rapidsmpf pool policy, not {cuda_stream_policy}." ) kwargs["cuda_stream_policy"] = cuda_stream_policy diff --git a/python/cudf_polars/tests/test_config.py b/python/cudf_polars/tests/test_config.py index a1c35e08b79..ca31e4b9581 100644 --- a/python/cudf_polars/tests/test_config.py +++ b/python/cudf_polars/tests/test_config.py @@ -796,10 +796,31 @@ def test_cuda_stream_policy_default_rapidsmpf(monkeypatch: pytest.MonkeyPatch) - # "new" user argument monkeypatch.setenv("CUDF_POLARS__CUDA_STREAM_POLICY", "new") - config = ConfigOptions.from_polars_engine( - pl.GPUEngine(executor_options={"runtime": "rapidsmpf"}) - ) - assert config.cuda_stream_policy == CUDAStreamPolicy.NEW + with pytest.raises( + ValueError, + match="The rapidsmpf runtime must use the rapidsmpf pool policy, not CUDAStreamPolicy.NEW", + ): + config = ConfigOptions.from_polars_engine( + pl.GPUEngine(executor_options={"runtime": "rapidsmpf"}) + ) + + +@pytest.mark.parametrize( + "cuda_stream_policy", [CUDAStreamPolicy.NEW, CUDAStreamPolicy.DEFAULT] +) +def test_rapidsmpf_runtime_requires_pool_policy( + cuda_stream_policy: CUDAStreamPolicy, +) -> None: + with pytest.raises( + ValueError, + match=f"The rapidsmpf runtime must use the rapidsmpf pool policy, not {cuda_stream_policy}", + ): + ConfigOptions.from_polars_engine( + pl.GPUEngine( + executor_options={"runtime": "rapidsmpf"}, + cuda_stream_policy=cuda_stream_policy, + ) + ) @pytest.mark.parametrize( @@ -814,7 +835,7 @@ def test_cuda_stream_policy_pool_only_supported_by_rapidsmpf( ) -> None: with pytest.raises( ValueError, - match="CUDAStreamPolicy.POOL is only supported by the rapidsmpf runtime.", + match="The rapidsmpf pool policy is only supported with", ): ConfigOptions.from_polars_engine( pl.GPUEngine(