diff --git a/vllm/config.py b/vllm/config.py index 22f740171369..80d7d4e17602 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -56,6 +56,7 @@ if TYPE_CHECKING: from _typeshed import DataclassInstance + from ray.runtime_env import RuntimeEnv from ray.util.placement_group import PlacementGroup from transformers.configuration_utils import PretrainedConfig @@ -73,6 +74,7 @@ else: DataclassInstance = Any PlacementGroup = Any + RuntimeEnv = Any PretrainedConfig = Any ExecutorBase = Any QuantizationConfig = Any @@ -1902,6 +1904,9 @@ class ParallelConfig: placement_group: Optional["PlacementGroup"] = None """ray distributed model workers placement group.""" + runtime_env: Optional["RuntimeEnv"] = None + """ray runtime environment for distributed workers""" + distributed_executor_backend: Optional[Union[DistributedExecutorBackend, type["ExecutorBase"]]] = None """Backend to use for distributed model diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index ae5eb46fa967..d1667e78b87f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1088,12 +1088,14 @@ def create_engine_config( # we are in a Ray actor. If so, then the placement group will be # passed to spawned processes. placement_group = None + runtime_env = None if is_in_ray_actor(): import ray # This call initializes Ray automatically if it is not initialized, # but we should not do this here. placement_group = ray.util.get_current_placement_group() + runtime_env = ray.get_runtime_context().runtime_env data_parallel_external_lb = self.data_parallel_rank is not None if data_parallel_external_lb: @@ -1170,6 +1172,7 @@ def create_engine_config( disable_custom_all_reduce=self.disable_custom_all_reduce, ray_workers_use_nsight=self.ray_workers_use_nsight, placement_group=placement_group, + runtime_env=runtime_env, distributed_executor_backend=self.distributed_executor_backend, worker_cls=self.worker_cls, worker_extension_cls=self.worker_extension_cls, diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index c222f1609096..64f736d525ab 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -288,14 +288,16 @@ def initialize_ray_cluster( elif current_platform.is_rocm() or current_platform.is_xpu(): # Try to connect existing ray instance and create a new one if not found try: - ray.init("auto") + ray.init("auto", runtime_env=parallel_config.runtime_env) except ConnectionError: logger.warning( "No existing RAY instance detected. " "A new instance will be launched with current node resources.") - ray.init(address=ray_address, num_gpus=parallel_config.world_size) + ray.init(address=ray_address, + num_gpus=parallel_config.world_size, + runtime_env=parallel_config.runtime_env) else: - ray.init(address=ray_address) + ray.init(address=ray_address, runtime_env=parallel_config.runtime_env) device_str = current_platform.ray_device_key if not device_str: