File tree Expand file tree Collapse file tree 3 files changed +11
-3
lines changed Expand file tree Collapse file tree 3 files changed +11
-3
lines changed Original file line number Diff line number Diff line change 57
57
if TYPE_CHECKING :
58
58
from _typeshed import DataclassInstance
59
59
from ray .util .placement_group import PlacementGroup
60
+ from ray .runtime_env import RuntimeEnv
60
61
from transformers .configuration_utils import PretrainedConfig
61
62
62
63
import vllm .model_executor .layers .quantization as me_quant
73
74
else :
74
75
DataclassInstance = Any
75
76
PlacementGroup = Any
77
+ RuntimeEnv = Any
76
78
PretrainedConfig = Any
77
79
ExecutorBase = Any
78
80
QuantizationConfig = Any
@@ -1902,6 +1904,9 @@ class ParallelConfig:
1902
1904
placement_group : Optional ["PlacementGroup" ] = None
1903
1905
"""ray distributed model workers placement group."""
1904
1906
1907
+ runtime_env : Optional ["RuntimeEnv" ] = None
1908
+ """ray runtime environment for distributed workers"""
1909
+
1905
1910
distributed_executor_backend : Optional [Union [DistributedExecutorBackend ,
1906
1911
type ["ExecutorBase" ]]] = None
1907
1912
"""Backend to use for distributed model
Original file line number Diff line number Diff line change @@ -1088,12 +1088,14 @@ def create_engine_config(
1088
1088
# we are in a Ray actor. If so, then the placement group will be
1089
1089
# passed to spawned processes.
1090
1090
placement_group = None
1091
+ runtime_env = None
1091
1092
if is_in_ray_actor ():
1092
1093
import ray
1093
1094
1094
1095
# This call initializes Ray automatically if it is not initialized,
1095
1096
# but we should not do this here.
1096
1097
placement_group = ray .util .get_current_placement_group ()
1098
+ runtime_env = ray .get_runtime_context ().runtime_env
1097
1099
1098
1100
data_parallel_external_lb = self .data_parallel_rank is not None
1099
1101
if data_parallel_external_lb :
@@ -1170,6 +1172,7 @@ def create_engine_config(
1170
1172
disable_custom_all_reduce = self .disable_custom_all_reduce ,
1171
1173
ray_workers_use_nsight = self .ray_workers_use_nsight ,
1172
1174
placement_group = placement_group ,
1175
+ runtime_env = runtime_env ,
1173
1176
distributed_executor_backend = self .distributed_executor_backend ,
1174
1177
worker_cls = self .worker_cls ,
1175
1178
worker_extension_cls = self .worker_extension_cls ,
Original file line number Diff line number Diff line change @@ -288,14 +288,14 @@ def initialize_ray_cluster(
288
288
elif current_platform .is_rocm () or current_platform .is_xpu ():
289
289
# Try to connect existing ray instance and create a new one if not found
290
290
try :
291
- ray .init ("auto" )
291
+ ray .init ("auto" , runtime_env = parallel_config . runtime_env )
292
292
except ConnectionError :
293
293
logger .warning (
294
294
"No existing RAY instance detected. "
295
295
"A new instance will be launched with current node resources." )
296
- ray .init (address = ray_address , num_gpus = parallel_config .world_size )
296
+ ray .init (address = ray_address , num_gpus = parallel_config .world_size , runtime_env = parallel_config . runtime_env )
297
297
else :
298
- ray .init (address = ray_address )
298
+ ray .init (address = ray_address , runtime_env = parallel_config . runtime_env )
299
299
300
300
device_str = current_platform .ray_device_key
301
301
if not device_str :
You can’t perform that action at this time.
0 commit comments