Skip to content

Commit 99ca962

Browse files
authored
vllm batch mitigation (#721)
* debug * wip * checkpoint * checkpoint - workaround * checkpoint - workaround * cleanup
1 parent 3f8705a commit 99ca962

File tree

5 files changed

+263
-123
lines changed

5 files changed

+263
-123
lines changed

model-engine/model_engine_server/inference/vllm/Dockerfile.vllm

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,17 @@ COPY model-engine/model_engine_server/inference/vllm/vllm_batch.py /workspace/vl
4949
COPY model-engine/model_engine_server/inference/vllm/init_ray_batch_inf_v2.py /workspace/init_ray_batch_inf_v2.py
5050

5151
# Need to override entrypoint from parent image
52-
ENTRYPOINT ["/bin/env"]
52+
ENTRYPOINT ["/bin/env"]
53+
54+
FROM vllm_batch_v2 AS vllm_batch_debug
55+
56+
COPY model-engine/model_engine_server/inference/vllm/ray_overrides/_version.py /usr/local/lib/python3.12/dist-packages/ray/_version.py
57+
COPY model-engine/model_engine_server/inference/vllm/ray_overrides/proxier.py /usr/local/lib/python3.12/dist-packages/ray/util/client/server/proxier.py
58+
COPY model-engine/model_engine_server/inference/vllm/ray_overrides/worker.py /usr/local/lib/python3.12/dist-packages/ray/_private/worker.py
59+
COPY model-engine/model_engine_server/inference/vllm/ray_overrides/state.py /usr/local/lib/python3.12/dist-packages/ray/_private/state.py
60+
COPY model-engine/model_engine_server/inference/vllm/ray_overrides/client_model_hook.py /usr/local/lib/python3.12/dist-packages/ray/_private/client_model_hook.py$
61+
62+
COPY model-engine/model_engine_server/inference/vllm/vllm_overrides/ray_utils.py /usr/local/lib/python3.12/dist-packages/vllm/executor/ray_utils.py
63+
64+
# Need to override entrypoint from parent image
65+
ENTRYPOINT ["/bin/env"]

model-engine/model_engine_server/inference/vllm/init_ray_batch_inf_v2.py

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@
99
RETRY_INTERVAL_SEC = 5
1010

1111

12-
def get_node_ip_address(leader_addr: str) -> str:
12+
def get_node_fqdn(leader_addr: str) -> str:
1313
# Assumes we're on a K8s cluster where the leader address is
1414
# <leader pod name>.<the rest of the FQDN>
1515
# e.g. if we're using a JobSet
1616
# Kinda of a dumb hack to get an externally addressable DNS name
17-
node_ip_address = socket.gethostname() + "." + leader_addr.split(".", 1)[1]
18-
return node_ip_address
17+
node_fqdn = socket.gethostname() + "." + leader_addr.split(".", 1)[1]
18+
return node_fqdn
1919

2020

2121
def wait_for_dns(hostname: str, timeout: int = 300, interval: int = 5):
@@ -48,7 +48,7 @@ def wait_for_cluster_nodes(
4848
bool: True if cluster reached expected size, False if timeout occurred
4949
"""
5050
# Since we've subprocess.run for starting ray, need to connect in the cluster right here.
51-
ray.init()
51+
ray.init(log_to_driver=False)
5252
start_time = time.time()
5353
while time.time() - start_time < timeout:
5454
try:
@@ -114,7 +114,7 @@ def wait_for_head_node_to_exit_process():
114114
# This will run in the subprocess spawned and will conveniently error out
115115
# when the head node is no longer reachable
116116
# The exit gets caught by the `wait_for_head_node_to_exit` function
117-
ray.init()
117+
ray.init(log_to_driver=False)
118118
while True:
119119
nodes = ray.nodes()
120120
print(f"Able to get nodes list {len(nodes)}", flush=True)
@@ -137,6 +137,20 @@ def start_leader(
137137
return False
138138

139139

140+
def is_ipv6_address(ip_address: str) -> bool:
141+
try:
142+
socket.inet_pton(socket.AF_INET6, ip_address)
143+
return True
144+
except socket.error:
145+
return False
146+
147+
148+
def format_ip_address(ip_address: str) -> str:
149+
if is_ipv6_address(ip_address):
150+
return f"[{ip_address}]"
151+
return ip_address
152+
153+
140154
def start_worker(
141155
ray_port: int,
142156
node_ip_address: str,
@@ -147,17 +161,22 @@ def start_worker(
147161
# node ip address in this case is actually a DNS name for the pod
148162
start_time = time.time()
149163
while time.time() - start_time < timeout:
164+
print(
165+
f"Starting ray worker with head address {format_ip_address(leader_addr)}:{ray_port} and node ip address {node_ip_address}",
166+
flush=True,
167+
)
150168
result = subprocess.run(
151169
[
152170
"ray",
153171
"start",
154172
"--address",
155-
f"{leader_addr}:{ray_port}",
173+
f"{format_ip_address(leader_addr)}:{ray_port}",
156174
"--node-ip-address",
157175
node_ip_address,
158176
],
159177
capture_output=True,
160178
)
179+
print(f"result: {result}", flush=True)
161180
if result.returncode == 0:
162181
print(
163182
f"Worker: Ray runtime started with head address {leader_addr}:{ray_port}",
@@ -175,6 +194,13 @@ def start_worker(
175194
return False
176195

177196

197+
def get_node_ip_address(node_fqdn: str, timeout: int = 300) -> str:
198+
node_ip_info = wait_for_dns(node_fqdn, timeout=timeout)
199+
if node_ip_info is None:
200+
raise RuntimeError(f"Timeout waiting for DNS resolution of {node_fqdn}")
201+
return node_ip_info[0][4][0]
202+
203+
178204
def init_ray(
179205
leader_addr: str,
180206
leader_port: int,
@@ -193,21 +219,30 @@ def init_ray(
193219
node_ip_address: IP address of the current node. If None, will be automatically detected
194220
timeout: Maximum time to wait for cluster to reach expected size
195221
"""
196-
node_ip_address = get_node_ip_address(leader_addr)
222+
import os
223+
224+
# export environment variable to disable ray logging
225+
os.environ["NCCL_DEBUG"] = "INFO"
226+
227+
# Get FQDN of the current node
228+
node_fqdn = get_node_fqdn(leader_addr)
229+
print(f"node fqdn: {node_fqdn}", flush=True)
197230

198231
print(f"Waiting for head node DNS ({leader_addr}) to be resolvable...", flush=True)
199-
head_ip_info = wait_for_dns(leader_addr, timeout=timeout)
200-
if head_ip_info is None:
201-
raise RuntimeError(f"Timeout waiting for DNS resolution of {leader_addr}")
232+
leader_ip_address = get_node_ip_address(leader_addr, timeout=timeout)
233+
print(f"leader ip: {leader_ip_address}", flush=True)
202234

203235
if is_leader:
204-
if not start_leader(leader_port, node_ip_address):
236+
if not start_leader(leader_port, leader_ip_address):
205237
raise RuntimeError("Failed to start Ray leader node")
206238
else:
207-
if not start_worker(leader_port, node_ip_address, leader_addr, timeout):
239+
print(f"Waiting for worker node DNS ({node_fqdn}) to be resolvable...", flush=True)
240+
worker_ip_address = get_node_ip_address(node_fqdn, timeout=timeout)
241+
print(f"worker ip: {worker_ip_address}", flush=True)
242+
if not start_worker(leader_port, worker_ip_address, leader_ip_address, timeout):
208243
raise RuntimeError("Failed to start Ray worker node")
209244
print(
210-
f"Successfully initialized Ray {'head' if is_leader else 'worker'} node at {node_ip_address}",
245+
f"Successfully initialized Ray {'head' if is_leader else 'worker'} node at {leader_ip_address if is_leader else worker_ip_address}",
211246
flush=True,
212247
)
213248

@@ -231,5 +266,6 @@ def main(mode: str):
231266
if __name__ == "__main__":
232267
parser = argparse.ArgumentParser()
233268
parser.add_argument("--mode", choices=["wait_for_head_node_to_exit"], required=True)
269+
234270
args = parser.parse_args()
235271
main(args.mode)
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
pydantic>=2.8
22
boto3==1.34.15
33
smart-open==6.4.0
4-
ddtrace==2.11.0
5-
datadog==0.49.1
4+
ddtrace==2.21.11
5+
wrapt>=1.15,<2
6+
datadog==0.52.1
67
dataclasses-json~=0.6.7
78
sse-starlette==2.1.3
8-
ray[client]==2.37.0
9+
ray[client]==2.49.2
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
vllm==0.10.2
1+
vllm==0.11.0

0 commit comments

Comments
 (0)