Skip to content

Commit 4ac8437

Browse files
authored
[Misc] Getting and passing ray runtime_env to workers (#22040)
Signed-off-by: Rui Qiao <[email protected]>
1 parent d3a6f21 commit 4ac8437

File tree

6 files changed

+77
-13
lines changed

6 files changed

+77
-13
lines changed

tests/config/test_config_generation.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,36 @@ def create_config():
3636
assert deep_compare(normal_config_dict, empty_config_dict), (
3737
"Configs with normal CUDA_VISIBLE_DEVICES and CUDA_VISIBLE_DEVICES=\"\""
3838
" should be equivalent")
39+
40+
41+
def test_ray_runtime_env(monkeypatch: pytest.MonkeyPatch):
42+
# In testing, this method needs to be nested inside as ray does not
43+
# see the test module.
44+
def create_config():
45+
engine_args = EngineArgs(model="deepseek-ai/DeepSeek-V2-Lite",
46+
trust_remote_code=True)
47+
return engine_args.create_engine_config()
48+
49+
config = create_config()
50+
parallel_config = config.parallel_config
51+
assert parallel_config.ray_runtime_env is None
52+
53+
import ray
54+
ray.init()
55+
56+
runtime_env = {
57+
"env_vars": {
58+
"TEST_ENV_VAR": "test_value",
59+
},
60+
}
61+
62+
config_ref = ray.remote(create_config).options(
63+
runtime_env=runtime_env).remote()
64+
65+
config = ray.get(config_ref)
66+
parallel_config = config.parallel_config
67+
assert parallel_config.ray_runtime_env is not None
68+
assert parallel_config.ray_runtime_env.env_vars().get(
69+
"TEST_ENV_VAR") == "test_value"
70+
71+
ray.shutdown()

vllm/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757

5858
if TYPE_CHECKING:
5959
from _typeshed import DataclassInstance
60+
from ray.runtime_env import RuntimeEnv
6061
from ray.util.placement_group import PlacementGroup
6162
from transformers.configuration_utils import PretrainedConfig
6263

@@ -74,6 +75,7 @@
7475
else:
7576
DataclassInstance = Any
7677
PlacementGroup = Any
78+
RuntimeEnv = Any
7779
PretrainedConfig = Any
7880
ExecutorBase = Any
7981
QuantizationConfig = Any
@@ -2098,6 +2100,9 @@ class ParallelConfig:
20982100
ray_workers_use_nsight: bool = False
20992101
"""Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler."""
21002102

2103+
ray_runtime_env: Optional["RuntimeEnv"] = None
2104+
"""Ray runtime environment to pass to distributed workers."""
2105+
21012106
placement_group: Optional["PlacementGroup"] = None
21022107
"""ray distributed model workers placement group."""
21032108

vllm/engine/arg_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from vllm.logger import init_logger
3737
from vllm.platforms import CpuArchEnum, current_platform
3838
from vllm.plugins import load_general_plugins
39+
from vllm.ray.lazy_utils import is_ray_initialized
3940
from vllm.reasoning import ReasoningParserManager
4041
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
4142
from vllm.transformers_utils.utils import check_gguf_file
@@ -1099,6 +1100,15 @@ def create_engine_config(
10991100
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
11001101
)
11011102

1103+
ray_runtime_env = None
1104+
if is_ray_initialized():
1105+
# Ray Serve LLM calls `create_engine_config` in the context
1106+
# of a Ray task, therefore we check is_ray_initialized()
1107+
# as opposed to is_in_ray_actor().
1108+
import ray
1109+
ray_runtime_env = ray.get_runtime_context().runtime_env
1110+
logger.info("Using ray runtime env: %s", ray_runtime_env)
1111+
11021112
# Get the current placement group if Ray is initialized and
11031113
# we are in a Ray actor. If so, then the placement group will be
11041114
# passed to spawned processes.
@@ -1211,6 +1221,7 @@ def create_engine_config(
12111221
max_parallel_loading_workers=self.max_parallel_loading_workers,
12121222
disable_custom_all_reduce=self.disable_custom_all_reduce,
12131223
ray_workers_use_nsight=self.ray_workers_use_nsight,
1224+
ray_runtime_env=ray_runtime_env,
12141225
placement_group=placement_group,
12151226
distributed_executor_backend=self.distributed_executor_backend,
12161227
worker_cls=self.worker_cls,

vllm/executor/ray_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,12 @@ def initialize_ray_cluster(
295295
logger.warning(
296296
"No existing RAY instance detected. "
297297
"A new instance will be launched with current node resources.")
298-
ray.init(address=ray_address, num_gpus=parallel_config.world_size)
298+
ray.init(address=ray_address,
299+
num_gpus=parallel_config.world_size,
300+
runtime_env=parallel_config.ray_runtime_env)
299301
else:
300-
ray.init(address=ray_address)
302+
ray.init(address=ray_address,
303+
runtime_env=parallel_config.ray_runtime_env)
301304

302305
device_str = current_platform.ray_device_key
303306
if not device_str:

vllm/ray/lazy_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
5+
def is_ray_initialized():
6+
"""Check if Ray is initialized."""
7+
try:
8+
import ray
9+
return ray.is_initialized()
10+
except ImportError:
11+
return False
12+
13+
14+
def is_in_ray_actor():
15+
"""Check if we are in a Ray actor."""
16+
17+
try:
18+
import ray
19+
return (ray.is_initialized()
20+
and ray.get_runtime_context().get_actor_id() is not None)
21+
except ImportError:
22+
return False

vllm/utils/__init__.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272

7373
import vllm.envs as envs
7474
from vllm.logger import enable_trace_function_call, init_logger
75+
from vllm.ray.lazy_utils import is_in_ray_actor
7576

7677
if TYPE_CHECKING:
7778
from argparse import Namespace
@@ -2835,17 +2836,6 @@ def zmq_socket_ctx(
28352836
ctx.destroy(linger=linger)
28362837

28372838

2838-
def is_in_ray_actor():
2839-
"""Check if we are in a Ray actor."""
2840-
2841-
try:
2842-
import ray
2843-
return (ray.is_initialized()
2844-
and ray.get_runtime_context().get_actor_id() is not None)
2845-
except ImportError:
2846-
return False
2847-
2848-
28492839
def _maybe_force_spawn():
28502840
"""Check if we need to force the use of the `spawn` multiprocessing start
28512841
method.

0 commit comments

Comments
 (0)