Skip to content

Commit 02606d9

Browse files
committed
[Misc] Getting and passing ray runtime_env to workers
Signed-off-by: Rui Qiao <[email protected]>
1 parent 6d8d0a2 commit 02606d9

File tree

4 files changed

+22
-13
lines changed

4 files changed

+22
-13
lines changed

vllm/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
if TYPE_CHECKING:
5858
from _typeshed import DataclassInstance
5959
from ray.util.placement_group import PlacementGroup
60+
from ray.runtime_env import RuntimeEnv
6061
from transformers.configuration_utils import PretrainedConfig
6162

6263
import vllm.model_executor.layers.quantization as me_quant
@@ -73,6 +74,7 @@
7374
else:
7475
DataclassInstance = Any
7576
PlacementGroup = Any
77+
RuntimeEnv = Any
7678
PretrainedConfig = Any
7779
ExecutorBase = Any
7880
QuantizationConfig = Any
@@ -1950,6 +1952,9 @@ class ParallelConfig:
19501952
ray_workers_use_nsight: bool = False
19511953
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
19521954

1955+
ray_runtime_env: Optional["RuntimeEnv"] = None
1956+
"""Ray runtime environment to pass to distributed workers."""
1957+
19531958
placement_group: Optional["PlacementGroup"] = None
19541959
"""ray distributed model workers placement group."""
19551960

vllm/engine/arg_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from vllm.logger import init_logger
3737
from vllm.platforms import CpuArchEnum, current_platform
3838
from vllm.plugins import load_general_plugins
39+
from vllm.ray.lazy_utils import is_ray_initialized
3940
from vllm.reasoning import ReasoningParserManager
4041
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
4142
from vllm.transformers_utils.utils import check_gguf_file
@@ -1060,6 +1061,15 @@ def create_engine_config(
10601061
calculate_kv_scales=self.calculate_kv_scales,
10611062
)
10621063

1064+
ray_runtime_env = None
1065+
if is_ray_initialized():
1066+
# Ray Serve LLM calls `create_engine_config` in the context
1067+
# of a Ray task, therefore we check is_ray_initialized()
1068+
# as opposed to is_in_ray_actor().
1069+
import ray
1070+
ray_runtime_env = ray.get_runtime_context().runtime_env
1071+
logger.info(f"Using ray runtime env: {ray_runtime_env}")
1072+
10631073
# Get the current placement group if Ray is initialized and
10641074
# we are in a Ray actor. If so, then the placement group will be
10651075
# passed to spawned processes.
@@ -1172,6 +1182,7 @@ def create_engine_config(
11721182
max_parallel_loading_workers=self.max_parallel_loading_workers,
11731183
disable_custom_all_reduce=self.disable_custom_all_reduce,
11741184
ray_workers_use_nsight=self.ray_workers_use_nsight,
1185+
ray_runtime_env=ray_runtime_env,
11751186
placement_group=placement_group,
11761187
distributed_executor_backend=self.distributed_executor_backend,
11771188
worker_cls=self.worker_cls,

vllm/executor/ray_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,12 @@ def initialize_ray_cluster(
295295
logger.warning(
296296
"No existing RAY instance detected. "
297297
"A new instance will be launched with current node resources.")
298-
ray.init(address=ray_address, num_gpus=parallel_config.world_size)
298+
ray.init(address=ray_address,
299+
num_gpus=parallel_config.world_size,
300+
runtime_env=parallel_config.ray_runtime_env)
299301
else:
300-
ray.init(address=ray_address)
302+
ray.init(address=ray_address,
303+
runtime_env=parallel_config.ray_runtime_env)
301304

302305
device_str = current_platform.ray_device_key
303306
if not device_str:

vllm/utils/__init__.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171

7272
import vllm.envs as envs
7373
from vllm.logger import enable_trace_function_call, init_logger
74+
from vllm.ray.utils import is_in_ray_actor
7475

7576
if TYPE_CHECKING:
7677
from argparse import Namespace
@@ -2864,17 +2865,6 @@ def zmq_socket_ctx(
28642865
ctx.destroy(linger=linger)
28652866

28662867

2867-
def is_in_ray_actor():
2868-
"""Check if we are in a Ray actor."""
2869-
2870-
try:
2871-
import ray
2872-
return (ray.is_initialized()
2873-
and ray.get_runtime_context().get_actor_id() is not None)
2874-
except ImportError:
2875-
return False
2876-
2877-
28782868
def _maybe_force_spawn():
28792869
"""Check if we need to force the use of the `spawn` multiprocessing start
28802870
method.

0 commit comments

Comments
 (0)