Skip to content

Commit 0ff1824

Browse files
authored
Bump vllm to v0.5.0.post1 (#547)
1 parent 6c7c924 commit 0ff1824

File tree

4 files changed

+91
-136
lines changed

4 files changed

+91
-136
lines changed
Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,3 @@
1-
#################### BASE BUILD IMAGE ####################
2-
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
3-
RUN apt-get update -y \
4-
&& apt-get install -y python3-pip git
5-
# Workaround for https://github.com/openai/triton/issues/2507 and
6-
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
7-
# this won't be needed for future versions of this docker image
8-
# or future versions of triton.
9-
RUN ldconfig /usr/local/cuda-12.1/compat/
10-
WORKDIR /workspace
11-
12-
COPY requirements-build.txt requirements-build.txt
13-
RUN --mount=type=cache,target=/root/.cache/pip \
14-
pip install -r requirements-build.txt
15-
#################### BASE BUILD IMAGE ####################
16-
17-
#################### FLASH_ATTENTION Build IMAGE ####################
18-
FROM dev as flash-attn-builder
19-
# max jobs used for build
20-
ARG max_jobs=2
21-
ENV MAX_JOBS=${max_jobs}
22-
# flash attention version
23-
ARG flash_attn_version=v2.4.2
24-
ENV FLASH_ATTN_VERSION=${flash_attn_version}
25-
26-
WORKDIR /usr/src/flash-attention-v2
27-
28-
# Download the wheel or build it if a pre-compiled release doesn't exist
29-
RUN pip --verbose wheel flash-attn==${FLASH_ATTN_VERSION} \
30-
--no-build-isolation --no-deps --no-cache-dir
31-
32-
#################### FLASH_ATTENTION Build IMAGE ####################
33-
34-
#################### Runtime IMAGE ####################
351
FROM nvcr.io/nvidia/pytorch:23.09-py3
362

373
RUN apt-get update \
@@ -41,10 +7,6 @@ RUN apt-get update \
417
&& apt-get autoremove -y \
428
&& rm -rf /var/lib/apt/lists/*
439

44-
# Install flash attention (from pre-built wheel)
45-
RUN --mount=type=bind,from=flash-attn-builder,src=/usr/src/flash-attention-v2,target=/usr/src/flash-attention-v2 \
46-
pip install /usr/src/flash-attention-v2/*.whl --no-cache-dir
47-
4810
RUN pip uninstall torch -y
4911
COPY requirements.txt /workspace/requirements.txt
5012
RUN pip install -r requirements.txt
@@ -53,5 +15,3 @@ RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linu
5315
RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz
5416

5517
COPY vllm_server.py /workspace/vllm_server.py
56-
57-
#################### Runtime IMAGE ####################

model-engine/model_engine_server/inference/vllm/requirements-build.txt

Lines changed: 0 additions & 8 deletions
This file was deleted.
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
vllm==0.4.2
1+
vllm==0.5.0.post1
22
pydantic>=2.0

model-engine/model_engine_server/inference/vllm/vllm_server.py

Lines changed: 90 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from fastapi import BackgroundTasks, FastAPI, HTTPException, Request
1212
from fastapi.responses import Response, StreamingResponse
1313
from vllm.engine.arg_utils import AsyncEngineArgs
14-
from vllm.engine.async_llm_engine import AsyncLLMEngine
14+
from vllm.engine.async_llm_engine import AsyncEngineDeadError, AsyncLLMEngine
1515
from vllm.entrypoints.openai.protocol import CompletionRequest as OpenAICompletionRequest
1616
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
1717
from vllm.outputs import CompletionOutput
@@ -43,97 +43,101 @@ async def generate(request: Request) -> Response:
4343
# check health before accepting request and fail fast if engine isn't healthy
4444
try:
4545
await engine.check_health()
46-
except Exception as e:
47-
print(f"The vllm engine is dead, exiting the pod: {e}")
48-
os.kill(os.getpid(), signal.SIGINT)
4946

50-
request_dict = await request.json()
51-
prompt = request_dict.pop("prompt")
52-
stream = request_dict.pop("stream", False)
53-
guided_json = request_dict.pop("guided_json", None)
54-
guided_regex = request_dict.pop("guided_regex", None)
55-
guided_choice = request_dict.pop("guided_choice", None)
56-
guided_grammar = request_dict.pop("guided_grammar", None)
57-
sampling_params = SamplingParams(**request_dict)
47+
request_dict = await request.json()
48+
prompt = request_dict.pop("prompt")
49+
stream = request_dict.pop("stream", False)
50+
guided_json = request_dict.pop("guided_json", None)
51+
guided_regex = request_dict.pop("guided_regex", None)
52+
guided_choice = request_dict.pop("guided_choice", None)
53+
guided_grammar = request_dict.pop("guided_grammar", None)
54+
sampling_params = SamplingParams(**request_dict)
55+
56+
# Dummy request to get guided decode logit processor
57+
try:
58+
partial_openai_request = OpenAICompletionRequest.model_validate(
59+
{
60+
"model": "",
61+
"prompt": "",
62+
"guided_json": guided_json,
63+
"guided_regex": guided_regex,
64+
"guided_choice": guided_choice,
65+
"guided_grammar": guided_grammar,
66+
}
67+
)
68+
except Exception:
69+
raise HTTPException(
70+
status_code=400, detail="Bad request: failed to parse guided decoding parameters."
71+
)
5872

59-
# Dummy request to get guided decode logit processor
60-
try:
61-
partial_openai_request = OpenAICompletionRequest.model_validate(
62-
{
63-
"model": "",
64-
"prompt": "",
65-
"guided_json": guided_json,
66-
"guided_regex": guided_regex,
67-
"guided_choice": guided_choice,
68-
"guided_grammar": guided_grammar,
69-
}
70-
)
71-
except Exception:
72-
raise HTTPException(
73-
status_code=400, detail="Bad request: failed to parse guided decoding parameters."
73+
guided_decoding_backend = engine.engine.decoding_config.guided_decoding_backend
74+
guided_decode_logit_processor = await get_guided_decoding_logits_processor(
75+
guided_decoding_backend, partial_openai_request, await engine.get_tokenizer()
7476
)
77+
if guided_decode_logit_processor is not None:
78+
if sampling_params.logits_processors is None:
79+
sampling_params.logits_processors = []
80+
sampling_params.logits_processors.append(guided_decode_logit_processor)
7581

76-
guided_decoding_backend = engine.engine.decoding_config.guided_decoding_backend
77-
guided_decode_logit_processor = await get_guided_decoding_logits_processor(
78-
guided_decoding_backend, partial_openai_request, await engine.get_tokenizer()
79-
)
80-
if guided_decode_logit_processor is not None:
81-
if sampling_params.logits_processors is None:
82-
sampling_params.logits_processors = []
83-
sampling_params.logits_processors.append(guided_decode_logit_processor)
84-
85-
request_id = random_uuid()
86-
87-
results_generator = engine.generate(prompt, sampling_params, request_id)
88-
89-
async def abort_request() -> None:
90-
await engine.abort(request_id)
91-
92-
if stream:
93-
# Streaming case
94-
async def stream_results() -> AsyncGenerator[str, None]:
95-
last_output_text = ""
96-
async for request_output in results_generator:
97-
log_probs = format_logprobs(request_output)
98-
ret = {
99-
"text": request_output.outputs[-1].text[len(last_output_text) :],
100-
"count_prompt_tokens": len(request_output.prompt_token_ids),
101-
"count_output_tokens": len(request_output.outputs[0].token_ids),
102-
"log_probs": log_probs[-1] if log_probs and sampling_params.logprobs else None,
103-
"finished": request_output.finished,
104-
}
105-
last_output_text = request_output.outputs[-1].text
106-
yield f"data:{json.dumps(ret)}\n\n"
107-
108-
background_tasks = BackgroundTasks()
109-
# Abort the request if the client disconnects.
110-
background_tasks.add_task(abort_request)
111-
112-
return StreamingResponse(stream_results(), background=background_tasks)
113-
114-
# Non-streaming case
115-
final_output = None
116-
tokens = []
117-
last_output_text = ""
118-
async for request_output in results_generator:
119-
tokens.append(request_output.outputs[-1].text[len(last_output_text) :])
120-
last_output_text = request_output.outputs[-1].text
121-
if await request.is_disconnected():
122-
# Abort the request if the client disconnects.
82+
request_id = random_uuid()
83+
84+
results_generator = engine.generate(prompt, sampling_params, request_id)
85+
86+
async def abort_request() -> None:
12387
await engine.abort(request_id)
124-
return Response(status_code=499)
125-
final_output = request_output
12688

127-
assert final_output is not None
128-
prompt = final_output.prompt
129-
ret = {
130-
"text": final_output.outputs[0].text,
131-
"count_prompt_tokens": len(final_output.prompt_token_ids),
132-
"count_output_tokens": len(final_output.outputs[0].token_ids),
133-
"log_probs": format_logprobs(final_output),
134-
"tokens": tokens,
135-
}
136-
return Response(content=json.dumps(ret))
89+
if stream:
90+
# Streaming case
91+
async def stream_results() -> AsyncGenerator[str, None]:
92+
last_output_text = ""
93+
async for request_output in results_generator:
94+
log_probs = format_logprobs(request_output)
95+
ret = {
96+
"text": request_output.outputs[-1].text[len(last_output_text) :],
97+
"count_prompt_tokens": len(request_output.prompt_token_ids),
98+
"count_output_tokens": len(request_output.outputs[0].token_ids),
99+
"log_probs": log_probs[-1]
100+
if log_probs and sampling_params.logprobs
101+
else None,
102+
"finished": request_output.finished,
103+
}
104+
last_output_text = request_output.outputs[-1].text
105+
yield f"data:{json.dumps(ret)}\n\n"
106+
107+
background_tasks = BackgroundTasks()
108+
# Abort the request if the client disconnects.
109+
background_tasks.add_task(abort_request)
110+
111+
return StreamingResponse(stream_results(), background=background_tasks)
112+
113+
# Non-streaming case
114+
final_output = None
115+
tokens = []
116+
last_output_text = ""
117+
async for request_output in results_generator:
118+
tokens.append(request_output.outputs[-1].text[len(last_output_text) :])
119+
last_output_text = request_output.outputs[-1].text
120+
if await request.is_disconnected():
121+
# Abort the request if the client disconnects.
122+
await engine.abort(request_id)
123+
return Response(status_code=499)
124+
final_output = request_output
125+
126+
assert final_output is not None
127+
prompt = final_output.prompt
128+
ret = {
129+
"text": final_output.outputs[0].text,
130+
"count_prompt_tokens": len(final_output.prompt_token_ids),
131+
"count_output_tokens": len(final_output.outputs[0].token_ids),
132+
"log_probs": format_logprobs(final_output),
133+
"tokens": tokens,
134+
}
135+
return Response(content=json.dumps(ret))
136+
137+
except AsyncEngineDeadError as e:
138+
print(f"The vllm engine is dead, exiting the pod: {e}")
139+
os.kill(os.getpid(), signal.SIGINT)
140+
raise e
137141

138142

139143
def get_gpu_free_memory():
@@ -206,7 +210,6 @@ def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]:
206210

207211
engine_args = AsyncEngineArgs.from_cli_args(args)
208212
engine = AsyncLLMEngine.from_engine_args(engine_args)
209-
engine.check_health()
210213

211214
signal.signal(signal.SIGUSR1, debug)
212215

0 commit comments

Comments
 (0)