Skip to content

Commit c07ece5

Browse files
Yard1avnishn
andauthored
Make AsyncLLMEngine more robust & fix batched abort (#969)
Signed-off-by: Antoni Baum <[email protected]> Co-authored-by: Avnish Narayan <[email protected]>
1 parent 7a9c20c commit c07ece5

File tree

7 files changed

+345
-55
lines changed

7 files changed

+345
-55
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""vllm.entrypoints.api_server with some extra logging for testing."""
2+
import argparse
3+
from typing import Any, Dict
4+
5+
import uvicorn
6+
from fastapi.responses import JSONResponse, Response
7+
8+
import vllm.entrypoints.api_server
9+
from vllm.engine.arg_utils import AsyncEngineArgs
10+
from vllm.engine.async_llm_engine import AsyncLLMEngine
11+
12+
app = vllm.entrypoints.api_server.app
13+
14+
15+
class AsyncLLMEngineWithStats(AsyncLLMEngine):
16+
17+
def __init__(self, *args, **kwargs):
18+
super().__init__(*args, **kwargs)
19+
self._num_aborts = 0
20+
21+
async def abort(self, request_id: str) -> None:
22+
await super().abort(request_id)
23+
self._num_aborts += 1
24+
25+
def testing_stats(self) -> Dict[str, Any]:
26+
return {"num_aborted_requests": self._num_aborts}
27+
28+
29+
@app.get("/stats")
30+
def stats() -> Response:
31+
"""Get the statistics of the engine."""
32+
return JSONResponse(engine.testing_stats())
33+
34+
35+
if __name__ == "__main__":
36+
parser = argparse.ArgumentParser()
37+
parser.add_argument("--host", type=str, default="localhost")
38+
parser.add_argument("--port", type=int, default=8000)
39+
parser = AsyncEngineArgs.add_cli_args(parser)
40+
args = parser.parse_args()
41+
42+
engine_args = AsyncEngineArgs.from_cli_args(args)
43+
engine = AsyncLLMEngineWithStats.from_engine_args(engine_args,
44+
start_engine_loop=False)
45+
vllm.entrypoints.api_server.engine = engine
46+
uvicorn.run(
47+
app,
48+
host=args.host,
49+
port=args.port,
50+
log_level="debug",
51+
timeout_keep_alive=vllm.entrypoints.api_server.TIMEOUT_KEEP_ALIVE)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import subprocess
2+
import sys
3+
import time
4+
from multiprocessing import Pool
5+
from pathlib import Path
6+
7+
import pytest
8+
import requests
9+
10+
11+
def _query_server(prompt: str) -> dict:
12+
response = requests.post("http://localhost:8000/generate",
13+
json={
14+
"prompt": prompt,
15+
"max_tokens": 100,
16+
"temperature": 0,
17+
"ignore_eos": True
18+
})
19+
response.raise_for_status()
20+
return response.json()
21+
22+
23+
@pytest.fixture
24+
def api_server():
25+
script_path = Path(__file__).parent.joinpath(
26+
"api_server_async_engine.py").absolute()
27+
uvicorn_process = subprocess.Popen([
28+
sys.executable, "-u",
29+
str(script_path), "--model", "facebook/opt-125m"
30+
])
31+
yield
32+
uvicorn_process.terminate()
33+
34+
35+
def test_api_server(api_server):
36+
"""
37+
Run the API server and test it.
38+
39+
We run both the server and requests in separate processes.
40+
41+
We test that the server can handle incoming requests, including
42+
multiple requests at the same time, and that it can handle requests
43+
being cancelled without crashing.
44+
"""
45+
with Pool(32) as pool:
46+
# Wait until the server is ready
47+
prompts = ["Hello world"] * 1
48+
result = None
49+
while not result:
50+
try:
51+
for result in pool.map(_query_server, prompts):
52+
break
53+
except:
54+
time.sleep(1)
55+
56+
# Actual tests start here
57+
# Try with 1 prompt
58+
for result in pool.map(_query_server, prompts):
59+
assert result
60+
61+
num_aborted_requests = requests.get(
62+
"http://localhost:8000/stats").json()["num_aborted_requests"]
63+
assert num_aborted_requests == 0
64+
65+
# Try with 100 prompts
66+
prompts = ["Hello world"] * 100
67+
for result in pool.map(_query_server, prompts):
68+
assert result
69+
70+
# Cancel requests
71+
pool.map_async(_query_server, prompts)
72+
time.sleep(0.01)
73+
pool.terminate()
74+
pool.join()
75+
76+
# check cancellation stats
77+
num_aborted_requests = requests.get(
78+
"http://localhost:8000/stats").json()["num_aborted_requests"]
79+
assert num_aborted_requests > 0
80+
81+
# check that server still runs after cancellations
82+
with Pool(32) as pool:
83+
# Try with 100 prompts
84+
prompts = ["Hello world"] * 100
85+
for result in pool.map(_query_server, prompts):
86+
assert result
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import pytest
2+
3+
from vllm.engine.async_llm_engine import RequestTracker
4+
from vllm.outputs import RequestOutput
5+
6+
7+
def test_request_tracker():
8+
tracker = RequestTracker()
9+
stream_1 = tracker.add_request("1")
10+
new, finished = tracker.get_new_and_finished_requests()
11+
assert len(new) == 1
12+
assert new[0]["request_id"] == "1"
13+
assert not finished
14+
assert not stream_1.finished
15+
16+
stream_2 = tracker.add_request("2")
17+
stream_3 = tracker.add_request("3")
18+
new, finished = tracker.get_new_and_finished_requests()
19+
assert len(new) == 2
20+
assert new[0]["request_id"] == "2"
21+
assert new[1]["request_id"] == "3"
22+
assert not finished
23+
assert not stream_2.finished
24+
assert not stream_3.finished
25+
26+
# request_ids must be unique
27+
with pytest.raises(KeyError):
28+
tracker.add_request("1")
29+
30+
tracker.abort_request("1")
31+
new, finished = tracker.get_new_and_finished_requests()
32+
assert len(finished) == 1
33+
assert "1" in finished
34+
assert not new
35+
assert stream_1.finished
36+
37+
stream_4 = tracker.add_request("4")
38+
tracker.abort_request("4")
39+
new, finished = tracker.get_new_and_finished_requests()
40+
assert len(finished) == 1
41+
assert "4" in finished
42+
assert not new
43+
assert stream_4.finished
44+
45+
stream_5 = tracker.add_request("5")
46+
tracker.process_request_output(
47+
RequestOutput("2", "output", [], [], finished=True))
48+
new, finished = tracker.get_new_and_finished_requests()
49+
assert len(finished) == 1
50+
assert "2" in finished
51+
assert len(new) == 1
52+
assert new[0]["request_id"] == "5"
53+
assert stream_2.finished
54+
assert not stream_5.finished

vllm/core/scheduler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,10 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
9292
request_id = (request_id, )
9393
request_ids = set(request_id)
9494
for state_queue in [self.waiting, self.running, self.swapped]:
95-
for seq_group in state_queue:
95+
# We need to reverse the list as we are removing elements
96+
# from it as we iterate over it. If we don't do it,
97+
# indices will get messed up and we will skip over elements.
98+
for seq_group in reversed(state_queue):
9699
if seq_group.request_id in request_ids:
97100
# Remove the sequence group from the state queue.
98101
state_queue.remove(seq_group)

0 commit comments

Comments
 (0)