|
4 | 4 | from http import HTTPStatus
|
5 | 5 | import json
|
6 | 6 | import time
|
7 |
| -from typing import AsyncGenerator, Dict, List, Optional |
| 7 | +from typing import AsyncGenerator, Dict, List, Optional, Union, Any |
8 | 8 |
|
9 | 9 | import fastapi
|
10 | 10 | from fastapi import BackgroundTasks, Request
|
|
17 | 17 | from vllm.engine.async_llm_engine import AsyncLLMEngine
|
18 | 18 | from vllm.entrypoints.openai.protocol import (
|
19 | 19 | CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
20 |
| - CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, |
21 |
| - LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo) |
| 20 | + CompletionResponseStreamChoice, CompletionStreamResponse, |
| 21 | + ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, |
| 22 | + ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, |
| 23 | + ChatMessage, DeltaMessage, ErrorResponse, LogProbs, |
| 24 | + ModelCard, ModelList, ModelPermission, UsageInfo) |
| 25 | +from fastchat.conversation import Conversation, SeparatorStyle, get_conv_template |
22 | 26 | from vllm.logger import init_logger
|
23 | 27 | from vllm.outputs import RequestOutput
|
24 | 28 | from vllm.sampling_params import SamplingParams
|
@@ -55,6 +59,70 @@ async def check_model(request) -> Optional[JSONResponse]:
|
55 | 59 | return ret
|
56 | 60 |
|
57 | 61 |
|
| 62 | +async def get_gen_prompt(request) -> str: |
| 63 | + conv = get_conv_template(request.model) |
| 64 | + conv = Conversation( |
| 65 | + name=conv.name, |
| 66 | + system=conv.system, |
| 67 | + roles=conv.roles, |
| 68 | + messages=list(conv.messages), # prevent in-place modification |
| 69 | + offset=conv.offset, |
| 70 | + sep_style=SeparatorStyle(conv.sep_style), |
| 71 | + sep=conv.sep, |
| 72 | + sep2=conv.sep2, |
| 73 | + stop_str=conv.stop_str, |
| 74 | + stop_token_ids=conv.stop_token_ids, |
| 75 | + ) |
| 76 | + |
| 77 | + if isinstance(request.messages, str): |
| 78 | + prompt = request.messages |
| 79 | + else: |
| 80 | + for message in request.messages: |
| 81 | + msg_role = message["role"] |
| 82 | + if msg_role == "system": |
| 83 | + conv.system = message["content"] |
| 84 | + elif msg_role == "user": |
| 85 | + conv.append_message(conv.roles[0], message["content"]) |
| 86 | + elif msg_role == "assistant": |
| 87 | + conv.append_message(conv.roles[1], message["content"]) |
| 88 | + else: |
| 89 | + raise ValueError(f"Unknown role: {msg_role}") |
| 90 | + |
| 91 | + # Add a blank message for the assistant. |
| 92 | + conv.append_message(conv.roles[1], None) |
| 93 | + prompt = conv.get_prompt() |
| 94 | + |
| 95 | + return prompt |
| 96 | + |
| 97 | + |
| 98 | +async def check_length(request, prompt, engine): |
| 99 | + if hasattr(engine.engine.model_config.hf_config, "max_sequence_length"): |
| 100 | + context_len = engine.engine.model_config.hf_config.max_sequence_length |
| 101 | + elif hasattr(engine.engine.model_config.hf_config, "seq_length"): |
| 102 | + context_len = engine.engine.model_config.hf_config.seq_length |
| 103 | + elif hasattr(engine.engine.model_config.hf_config, "max_position_embeddings"): |
| 104 | + context_len = engine.engine.model_config.hf_config.max_position_embeddings |
| 105 | + elif hasattr(engine.engine.model_config.hf_config, "seq_length"): |
| 106 | + context_len = engine.engine.model_config.hf_config.seq_length |
| 107 | + else: |
| 108 | + context_len = 2048 |
| 109 | + |
| 110 | + input_ids = tokenizer(prompt).input_ids |
| 111 | + token_num = len(input_ids) |
| 112 | + |
| 113 | + if token_num + request.max_tokens > context_len: |
| 114 | + return create_error_response( |
| 115 | + HTTPStatus.BAD_REQUEST, |
| 116 | + f"This model's maximum context length is {context_len} tokens. " |
| 117 | + f"However, you requested {request.max_tokens + token_num} tokens " |
| 118 | + f"({token_num} in the messages, " |
| 119 | + f"{request.max_tokens} in the completion). " |
| 120 | + f"Please reduce the length of the messages or completion.", |
| 121 | + ) |
| 122 | + else: |
| 123 | + return None |
| 124 | + |
| 125 | + |
58 | 126 | @app.get("/v1/models")
|
59 | 127 | async def show_available_models():
|
60 | 128 | """Show available models. Right now we only have one model."""
|
@@ -85,6 +153,171 @@ def create_logprobs(token_ids: List[int],
|
85 | 153 | return logprobs
|
86 | 154 |
|
87 | 155 |
|
| 156 | +@app.post("/v1/chat/completions") |
| 157 | +async def create_chat_completion(raw_request: Request): |
| 158 | + """Completion API similar to OpenAI's API. |
| 159 | +
|
| 160 | + See https://platform.openai.com/docs/api-reference/chat/create |
| 161 | + for the API specification. This API mimics the OpenAI ChatCompletion API. |
| 162 | +
|
| 163 | + NOTE: Currently we do not support the following features: |
| 164 | + - function_call (Users should implement this by themselves) |
| 165 | + - logit_bias (to be supported by vLLM engine) |
| 166 | + """ |
| 167 | + request = ChatCompletionRequest(**await raw_request.json()) |
| 168 | + logger.info(f"Received chat completion request: {request}") |
| 169 | + |
| 170 | + error_check_ret = await check_model(request) |
| 171 | + if error_check_ret is not None: |
| 172 | + return error_check_ret |
| 173 | + |
| 174 | + if request.logit_bias is not None: |
| 175 | + # TODO: support logit_bias in vLLM engine. |
| 176 | + return create_error_response(HTTPStatus.BAD_REQUEST, |
| 177 | + "logit_bias is not currently supported") |
| 178 | + |
| 179 | + prompt = await get_gen_prompt(request) |
| 180 | + error_check_ret = await check_length(request, prompt, engine) |
| 181 | + if error_check_ret is not None: |
| 182 | + return error_check_ret |
| 183 | + |
| 184 | + model_name = request.model |
| 185 | + request_id = f"cmpl-{random_uuid()}" |
| 186 | + created_time = int(time.time()) |
| 187 | + try: |
| 188 | + sampling_params = SamplingParams( |
| 189 | + n=request.n, |
| 190 | + presence_penalty=request.presence_penalty, |
| 191 | + frequency_penalty=request.frequency_penalty, |
| 192 | + temperature=request.temperature, |
| 193 | + top_p=request.top_p, |
| 194 | + stop=request.stop, |
| 195 | + max_tokens=request.max_tokens, |
| 196 | + best_of=request.best_of, |
| 197 | + top_k=request.top_k, |
| 198 | + ignore_eos=request.ignore_eos, |
| 199 | + use_beam_search=request.use_beam_search, |
| 200 | + ) |
| 201 | + except ValueError as e: |
| 202 | + return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) |
| 203 | + |
| 204 | + result_generator = engine.generate(prompt, sampling_params, |
| 205 | + request_id) |
| 206 | + |
| 207 | + async def abort_request() -> None: |
| 208 | + await engine.abort(request_id) |
| 209 | + |
| 210 | + def create_stream_response_json(index: int, |
| 211 | + text: str, |
| 212 | + finish_reason: Optional[str] = None) -> str: |
| 213 | + choice_data = ChatCompletionResponseStreamChoice( |
| 214 | + index=index, |
| 215 | + delta=DeltaMessage(content=text), |
| 216 | + finish_reason=finish_reason, |
| 217 | + ) |
| 218 | + response = ChatCompletionStreamResponse( |
| 219 | + id=request_id, |
| 220 | + created=created_time, |
| 221 | + model=model_name, |
| 222 | + choices=[choice_data], |
| 223 | + ) |
| 224 | + response_json = response.json(ensure_ascii=False) |
| 225 | + |
| 226 | + return response_json |
| 227 | + |
| 228 | + async def completion_stream_generator() -> AsyncGenerator[str, None]: |
| 229 | + # First chunk with role |
| 230 | + for i in range(request.n): |
| 231 | + choice_data = ChatCompletionResponseStreamChoice( |
| 232 | + index=i, |
| 233 | + delta=DeltaMessage(role="assistant"), |
| 234 | + finish_reason=None, |
| 235 | + ) |
| 236 | + chunk = ChatCompletionStreamResponse( |
| 237 | + id=request_id, choices=[choice_data], model=model_name |
| 238 | + ) |
| 239 | + yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" |
| 240 | + |
| 241 | + previous_texts = [""] * request.n |
| 242 | + previous_num_tokens = [0] * request.n |
| 243 | + async for res in result_generator: |
| 244 | + res: RequestOutput |
| 245 | + for output in res.outputs: |
| 246 | + i = output.index |
| 247 | + delta_text = output.text[len(previous_texts[i]):] |
| 248 | + previous_texts[i] = output.text |
| 249 | + previous_num_tokens[i] = len(output.token_ids) |
| 250 | + response_json = create_stream_response_json( |
| 251 | + index=i, |
| 252 | + text=delta_text, |
| 253 | + ) |
| 254 | + yield f"data: {response_json}\n\n" |
| 255 | + if output.finish_reason is not None: |
| 256 | + response_json = create_stream_response_json( |
| 257 | + index=i, |
| 258 | + text="", |
| 259 | + finish_reason=output.finish_reason, |
| 260 | + ) |
| 261 | + yield f"data: {response_json}\n\n" |
| 262 | + yield "data: [DONE]\n\n" |
| 263 | + |
| 264 | + # Streaming response |
| 265 | + if request.stream: |
| 266 | + background_tasks = BackgroundTasks() |
| 267 | + # Abort the request if the client disconnects. |
| 268 | + background_tasks.add_task(abort_request) |
| 269 | + return StreamingResponse(completion_stream_generator(), |
| 270 | + media_type="text/event-stream", |
| 271 | + background=background_tasks) |
| 272 | + |
| 273 | + # Non-streaming response |
| 274 | + final_res: RequestOutput = None |
| 275 | + async for res in result_generator: |
| 276 | + if await raw_request.is_disconnected(): |
| 277 | + # Abort the request if the client disconnects. |
| 278 | + await abort_request() |
| 279 | + return create_error_response(HTTPStatus.BAD_REQUEST, |
| 280 | + "Client disconnected") |
| 281 | + final_res = res |
| 282 | + assert final_res is not None |
| 283 | + choices = [] |
| 284 | + for output in final_res.outputs: |
| 285 | + choice_data = ChatCompletionResponseChoice( |
| 286 | + index=output.index, |
| 287 | + message=ChatMessage(role="assistant", content=output.text), |
| 288 | + finish_reason=output.finish_reason, |
| 289 | + ) |
| 290 | + choices.append(choice_data) |
| 291 | + |
| 292 | + num_prompt_tokens = len(final_res.prompt_token_ids) |
| 293 | + num_generated_tokens = sum(len(output.token_ids) |
| 294 | + for output in final_res.outputs) |
| 295 | + usage = UsageInfo( |
| 296 | + prompt_tokens=num_prompt_tokens, |
| 297 | + completion_tokens=num_generated_tokens, |
| 298 | + total_tokens=num_prompt_tokens + num_generated_tokens, |
| 299 | + ) |
| 300 | + response = ChatCompletionResponse( |
| 301 | + id=request_id, |
| 302 | + created=created_time, |
| 303 | + model=model_name, |
| 304 | + choices=choices, |
| 305 | + usage=usage, |
| 306 | + ) |
| 307 | + |
| 308 | + if request.stream: |
| 309 | + # When user requests streaming but we don't stream, we still need to |
| 310 | + # return a streaming response with a single event. |
| 311 | + response_json = response.json(ensure_ascii=False) |
| 312 | + async def fake_stream_generator() -> AsyncGenerator[str, None]: |
| 313 | + yield f"data: {response_json}\n\n" |
| 314 | + yield "data: [DONE]\n\n" |
| 315 | + return StreamingResponse(fake_stream_generator(), |
| 316 | + media_type="text/event-stream") |
| 317 | + |
| 318 | + return response |
| 319 | + |
| 320 | + |
88 | 321 | @app.post("/v1/completions")
|
89 | 322 | async def create_completion(raw_request: Request):
|
90 | 323 | """Completion API similar to OpenAI's API.
|
|
0 commit comments