Skip to content

Commit 49b26e2

Browse files
authored
feat: add ChatCompletion endpoint in OpenAI demo server. (#330)
1 parent dafd924 commit 49b26e2

File tree

2 files changed

+284
-6
lines changed

2 files changed

+284
-6
lines changed

vllm/entrypoints/openai/api_server.py

Lines changed: 236 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from http import HTTPStatus
55
import json
66
import time
7-
from typing import AsyncGenerator, Dict, List, Optional
7+
from typing import AsyncGenerator, Dict, List, Optional, Union, Any
88

99
import fastapi
1010
from fastapi import BackgroundTasks, Request
@@ -17,8 +17,12 @@
1717
from vllm.engine.async_llm_engine import AsyncLLMEngine
1818
from vllm.entrypoints.openai.protocol import (
1919
CompletionRequest, CompletionResponse, CompletionResponseChoice,
20-
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse,
21-
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
20+
CompletionResponseStreamChoice, CompletionStreamResponse,
21+
ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice,
22+
ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse,
23+
ChatMessage, DeltaMessage, ErrorResponse, LogProbs,
24+
ModelCard, ModelList, ModelPermission, UsageInfo)
25+
from fastchat.conversation import Conversation, SeparatorStyle, get_conv_template
2226
from vllm.logger import init_logger
2327
from vllm.outputs import RequestOutput
2428
from vllm.sampling_params import SamplingParams
@@ -55,6 +59,70 @@ async def check_model(request) -> Optional[JSONResponse]:
5559
return ret
5660

5761

62+
async def get_gen_prompt(request) -> str:
63+
conv = get_conv_template(request.model)
64+
conv = Conversation(
65+
name=conv.name,
66+
system=conv.system,
67+
roles=conv.roles,
68+
messages=list(conv.messages), # prevent in-place modification
69+
offset=conv.offset,
70+
sep_style=SeparatorStyle(conv.sep_style),
71+
sep=conv.sep,
72+
sep2=conv.sep2,
73+
stop_str=conv.stop_str,
74+
stop_token_ids=conv.stop_token_ids,
75+
)
76+
77+
if isinstance(request.messages, str):
78+
prompt = request.messages
79+
else:
80+
for message in request.messages:
81+
msg_role = message["role"]
82+
if msg_role == "system":
83+
conv.system = message["content"]
84+
elif msg_role == "user":
85+
conv.append_message(conv.roles[0], message["content"])
86+
elif msg_role == "assistant":
87+
conv.append_message(conv.roles[1], message["content"])
88+
else:
89+
raise ValueError(f"Unknown role: {msg_role}")
90+
91+
# Add a blank message for the assistant.
92+
conv.append_message(conv.roles[1], None)
93+
prompt = conv.get_prompt()
94+
95+
return prompt
96+
97+
98+
async def check_length(request, prompt, engine):
99+
if hasattr(engine.engine.model_config.hf_config, "max_sequence_length"):
100+
context_len = engine.engine.model_config.hf_config.max_sequence_length
101+
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
102+
context_len = engine.engine.model_config.hf_config.seq_length
103+
elif hasattr(engine.engine.model_config.hf_config, "max_position_embeddings"):
104+
context_len = engine.engine.model_config.hf_config.max_position_embeddings
105+
elif hasattr(engine.engine.model_config.hf_config, "seq_length"):
106+
context_len = engine.engine.model_config.hf_config.seq_length
107+
else:
108+
context_len = 2048
109+
110+
input_ids = tokenizer(prompt).input_ids
111+
token_num = len(input_ids)
112+
113+
if token_num + request.max_tokens > context_len:
114+
return create_error_response(
115+
HTTPStatus.BAD_REQUEST,
116+
f"This model's maximum context length is {context_len} tokens. "
117+
f"However, you requested {request.max_tokens + token_num} tokens "
118+
f"({token_num} in the messages, "
119+
f"{request.max_tokens} in the completion). "
120+
f"Please reduce the length of the messages or completion.",
121+
)
122+
else:
123+
return None
124+
125+
58126
@app.get("/v1/models")
59127
async def show_available_models():
60128
"""Show available models. Right now we only have one model."""
@@ -85,6 +153,171 @@ def create_logprobs(token_ids: List[int],
85153
return logprobs
86154

87155

156+
@app.post("/v1/chat/completions")
157+
async def create_chat_completion(raw_request: Request):
158+
"""Completion API similar to OpenAI's API.
159+
160+
See https://platform.openai.com/docs/api-reference/chat/create
161+
for the API specification. This API mimics the OpenAI ChatCompletion API.
162+
163+
NOTE: Currently we do not support the following features:
164+
- function_call (Users should implement this by themselves)
165+
- logit_bias (to be supported by vLLM engine)
166+
"""
167+
request = ChatCompletionRequest(**await raw_request.json())
168+
logger.info(f"Received chat completion request: {request}")
169+
170+
error_check_ret = await check_model(request)
171+
if error_check_ret is not None:
172+
return error_check_ret
173+
174+
if request.logit_bias is not None:
175+
# TODO: support logit_bias in vLLM engine.
176+
return create_error_response(HTTPStatus.BAD_REQUEST,
177+
"logit_bias is not currently supported")
178+
179+
prompt = await get_gen_prompt(request)
180+
error_check_ret = await check_length(request, prompt, engine)
181+
if error_check_ret is not None:
182+
return error_check_ret
183+
184+
model_name = request.model
185+
request_id = f"cmpl-{random_uuid()}"
186+
created_time = int(time.time())
187+
try:
188+
sampling_params = SamplingParams(
189+
n=request.n,
190+
presence_penalty=request.presence_penalty,
191+
frequency_penalty=request.frequency_penalty,
192+
temperature=request.temperature,
193+
top_p=request.top_p,
194+
stop=request.stop,
195+
max_tokens=request.max_tokens,
196+
best_of=request.best_of,
197+
top_k=request.top_k,
198+
ignore_eos=request.ignore_eos,
199+
use_beam_search=request.use_beam_search,
200+
)
201+
except ValueError as e:
202+
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))
203+
204+
result_generator = engine.generate(prompt, sampling_params,
205+
request_id)
206+
207+
async def abort_request() -> None:
208+
await engine.abort(request_id)
209+
210+
def create_stream_response_json(index: int,
211+
text: str,
212+
finish_reason: Optional[str] = None) -> str:
213+
choice_data = ChatCompletionResponseStreamChoice(
214+
index=index,
215+
delta=DeltaMessage(content=text),
216+
finish_reason=finish_reason,
217+
)
218+
response = ChatCompletionStreamResponse(
219+
id=request_id,
220+
created=created_time,
221+
model=model_name,
222+
choices=[choice_data],
223+
)
224+
response_json = response.json(ensure_ascii=False)
225+
226+
return response_json
227+
228+
async def completion_stream_generator() -> AsyncGenerator[str, None]:
229+
# First chunk with role
230+
for i in range(request.n):
231+
choice_data = ChatCompletionResponseStreamChoice(
232+
index=i,
233+
delta=DeltaMessage(role="assistant"),
234+
finish_reason=None,
235+
)
236+
chunk = ChatCompletionStreamResponse(
237+
id=request_id, choices=[choice_data], model=model_name
238+
)
239+
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
240+
241+
previous_texts = [""] * request.n
242+
previous_num_tokens = [0] * request.n
243+
async for res in result_generator:
244+
res: RequestOutput
245+
for output in res.outputs:
246+
i = output.index
247+
delta_text = output.text[len(previous_texts[i]):]
248+
previous_texts[i] = output.text
249+
previous_num_tokens[i] = len(output.token_ids)
250+
response_json = create_stream_response_json(
251+
index=i,
252+
text=delta_text,
253+
)
254+
yield f"data: {response_json}\n\n"
255+
if output.finish_reason is not None:
256+
response_json = create_stream_response_json(
257+
index=i,
258+
text="",
259+
finish_reason=output.finish_reason,
260+
)
261+
yield f"data: {response_json}\n\n"
262+
yield "data: [DONE]\n\n"
263+
264+
# Streaming response
265+
if request.stream:
266+
background_tasks = BackgroundTasks()
267+
# Abort the request if the client disconnects.
268+
background_tasks.add_task(abort_request)
269+
return StreamingResponse(completion_stream_generator(),
270+
media_type="text/event-stream",
271+
background=background_tasks)
272+
273+
# Non-streaming response
274+
final_res: RequestOutput = None
275+
async for res in result_generator:
276+
if await raw_request.is_disconnected():
277+
# Abort the request if the client disconnects.
278+
await abort_request()
279+
return create_error_response(HTTPStatus.BAD_REQUEST,
280+
"Client disconnected")
281+
final_res = res
282+
assert final_res is not None
283+
choices = []
284+
for output in final_res.outputs:
285+
choice_data = ChatCompletionResponseChoice(
286+
index=output.index,
287+
message=ChatMessage(role="assistant", content=output.text),
288+
finish_reason=output.finish_reason,
289+
)
290+
choices.append(choice_data)
291+
292+
num_prompt_tokens = len(final_res.prompt_token_ids)
293+
num_generated_tokens = sum(len(output.token_ids)
294+
for output in final_res.outputs)
295+
usage = UsageInfo(
296+
prompt_tokens=num_prompt_tokens,
297+
completion_tokens=num_generated_tokens,
298+
total_tokens=num_prompt_tokens + num_generated_tokens,
299+
)
300+
response = ChatCompletionResponse(
301+
id=request_id,
302+
created=created_time,
303+
model=model_name,
304+
choices=choices,
305+
usage=usage,
306+
)
307+
308+
if request.stream:
309+
# When user requests streaming but we don't stream, we still need to
310+
# return a streaming response with a single event.
311+
response_json = response.json(ensure_ascii=False)
312+
async def fake_stream_generator() -> AsyncGenerator[str, None]:
313+
yield f"data: {response_json}\n\n"
314+
yield "data: [DONE]\n\n"
315+
return StreamingResponse(fake_stream_generator(),
316+
media_type="text/event-stream")
317+
318+
return response
319+
320+
88321
@app.post("/v1/completions")
89322
async def create_completion(raw_request: Request):
90323
"""Completion API similar to OpenAI's API.

vllm/entrypoints/openai/protocol.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,22 @@ class UsageInfo(BaseModel):
5353

5454
class ChatCompletionRequest(BaseModel):
5555
model: str
56-
messages: List[Dict[str, str]]
56+
messages: Union[str, List[Dict[str, str]]]
5757
temperature: Optional[float] = 0.7
5858
top_p: Optional[float] = 1.0
5959
n: Optional[int] = 1
60-
max_tokens: Optional[int] = None
61-
stop: Optional[Union[str, List[str]]] = None
60+
max_tokens: Optional[int] = 16
61+
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
6262
stream: Optional[bool] = False
6363
presence_penalty: Optional[float] = 0.0
6464
frequency_penalty: Optional[float] = 0.0
65+
logit_bias: Optional[Dict[str, float]] = None
6566
user: Optional[str] = None
67+
# Additional parameters supported by vLLM
68+
best_of: Optional[int] = None
69+
top_k: Optional[int] = -1
70+
ignore_eos: Optional[bool] = False
71+
use_beam_search: Optional[bool] = False
6672

6773

6874
class CompletionRequest(BaseModel):
@@ -124,3 +130,42 @@ class CompletionStreamResponse(BaseModel):
124130
created: int = Field(default_factory=lambda: int(time.time()))
125131
model: str
126132
choices: List[CompletionResponseStreamChoice]
133+
134+
135+
class ChatMessage(BaseModel):
136+
role: str
137+
content: str
138+
139+
140+
class ChatCompletionResponseChoice(BaseModel):
141+
index: int
142+
message: ChatMessage
143+
finish_reason: Optional[Literal["stop", "length"]] = None
144+
145+
146+
class ChatCompletionResponse(BaseModel):
147+
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
148+
object: str = "chat.completion"
149+
created: int = Field(default_factory=lambda: int(time.time()))
150+
model: str
151+
choices: List[ChatCompletionResponseChoice]
152+
usage: UsageInfo
153+
154+
155+
class DeltaMessage(BaseModel):
156+
role: Optional[str] = None
157+
content: Optional[str] = None
158+
159+
160+
class ChatCompletionResponseStreamChoice(BaseModel):
161+
index: int
162+
delta: DeltaMessage
163+
finish_reason: Optional[Literal["stop", "length"]] = None
164+
165+
166+
class ChatCompletionStreamResponse(BaseModel):
167+
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
168+
object: str = "chat.completion.chunk"
169+
created: int = Field(default_factory=lambda: int(time.time()))
170+
model: str
171+
choices: List[ChatCompletionResponseStreamChoice]

0 commit comments

Comments
 (0)