diff --git a/tests/config/test_config_generation.py b/tests/config/test_config_generation.py index 024e81fccc5..e37b6b95941 100644 --- a/tests/config/test_config_generation.py +++ b/tests/config/test_config_generation.py @@ -36,3 +36,36 @@ def create_config(): assert deep_compare(normal_config_dict, empty_config_dict), ( "Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\"" " should be equivalent") + + +def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch): + # In testing, this method needs to be nested inside as ray does not + # see the test module. + def create_config(): + engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite", + trust_remote_code=True) + return engine_args.create_engine_config() + + config = create_config() + parallel_config = config.parallel_config + assert parallel_config.ray_runtime_env is None + + import ray + ray.init() + + runtime_env = { + "env_vars": { + "TEST_ENV_VAR": "test_value", + }, + } + + config_ref = ray.remote(create_config).options( + runtime_env=runtime_env).remote() + + config = ray.get(config_ref) + parallel_config = config.parallel_config + assert parallel_config.ray_runtime_env is not None + assert parallel_config.ray_runtime_env.env_vars().get( + "TEST_ENV_VAR") == "test_value" + + ray.shutdown() diff --git a/vllm/config.py b/vllm/config.py index f038cdd64c6..08bd59be7c1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -56,6 +56,7 @@ if TYPE_CHECKING: from _typeshed import DataclassInstance + from ray.runtime_env import RuntimeEnv from ray.util.placement_group import PlacementGroup from transformers.configuration_utils import PretrainedConfig @@ -73,6 +74,7 @@ else: DataclassInstance = Any PlacementGroup = Any + RuntimeEnv = Any PretrainedConfig = Any ExecutorBase = Any QuantizationConfig = Any @@ -1950,6 +1952,9 @@ class ParallelConfig: ray_workers_use_nsight: bool = False """Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.""" + ray_runtime_env: Optional["RuntimeEnv"] = None + """Ray runtime environment to pass to distributed workers.""" + placement_group: Optional["PlacementGroup"] = None """ray distributed model workers placement group.""" diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index aec75f82631..ba40262e538 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -36,6 +36,7 @@ from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform from vllm.plugins import load_general_plugins +from vllm.ray.lazy_utils import is_ray_initialized from vllm.reasoning import ReasoningParserManager from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.transformers_utils.utils import check_gguf_file @@ -1060,6 +1061,15 @@ def create_engine_config( calculate_kv_scales=self.calculate_kv_scales, ) + ray_runtime_env = None + if is_ray_initialized(): + # Ray Serve LLM calls `create_engine_config` in the context + # of a Ray task, therefore we check is_ray_initialized() + # as opposed to is_in_ray_actor(). + import ray + ray_runtime_env = ray.get_runtime_context().runtime_env + logger.info("Using ray runtime env: %s", ray_runtime_env) + # Get the current placement group if Ray is initialized and # we are in a Ray actor. If so, then the placement group will be # passed to spawned processes. @@ -1172,6 +1182,7 @@ def create_engine_config( max_parallel_loading_workers=self.max_parallel_loading_workers, disable_custom_all_reduce=self.disable_custom_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, + ray_runtime_env=ray_runtime_env, placement_group=placement_group, distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 033ecc00853..7abaffa54c0 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -295,9 +295,12 @@ def initialize_ray_cluster( logger.warning( "No existing RAY instance detected. " "A new instance will be launched with current node resources.") - ray.init(address=ray_address, num_gpus=parallel_config.world_size) + ray.init(address=ray_address, + num_gpus=parallel_config.world_size, + runtime_env=parallel_config.ray_runtime_env) else: - ray.init(address=ray_address) + ray.init(address=ray_address, + runtime_env=parallel_config.ray_runtime_env) device_str = current_platform.ray_device_key if not device_str: diff --git a/vllm/ray/lazy_utils.py b/vllm/ray/lazy_utils.py new file mode 100644 index 00000000000..bb3535579cf --- /dev/null +++ b/vllm/ray/lazy_utils.py @@ -0,0 +1,22 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +def is_ray_initialized(): + """Check if Ray is initialized.""" + try: + import ray + return ray.is_initialized() + except ImportError: + return False + + +def is_in_ray_actor(): + """Check if we are in a Ray actor.""" + + try: + import ray + return (ray.is_initialized() + and ray.get_runtime_context().get_actor_id() is not None) + except ImportError: + return False diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 5b9c3b6a50c..afacea1eac6 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -71,6 +71,7 @@ import vllm.envs as envs from vllm.logger import enable_trace_function_call, init_logger +from vllm.ray.lazy_utils import is_in_ray_actor if TYPE_CHECKING: from argparse import Namespace @@ -2864,17 +2865,6 @@ def zmq_socket_ctx( ctx.destroy(linger=linger) -def is_in_ray_actor(): - """Check if we are in a Ray actor.""" - - try: - import ray - return (ray.is_initialized() - and ray.get_runtime_context().get_actor_id() is not None) - except ImportError: - return False - - def _maybe_force_spawn(): """Check if we need to force the use of the `spawn` multiprocessing start method.