2
2
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py
3
3
4
4
import argparse
5
+ import asyncio
5
6
from http import HTTPStatus
6
7
import json
7
8
import time
8
- from typing import AsyncGenerator , Dict , List , Optional , Union , Any
9
+ from typing import AsyncGenerator , Dict , List , Optional
9
10
10
11
import fastapi
11
12
from fastapi import BackgroundTasks , Request
12
13
from fastapi .exceptions import RequestValidationError
13
14
from fastapi .middleware .cors import CORSMiddleware
14
15
from fastapi .responses import JSONResponse , StreamingResponse
16
+ from fastchat .conversation import (Conversation , SeparatorStyle ,
17
+ get_conv_template )
15
18
import uvicorn
16
19
17
20
from vllm .engine .arg_utils import AsyncEngineArgs
18
21
from vllm .engine .async_llm_engine import AsyncLLMEngine
19
22
from vllm .entrypoints .openai .protocol import (
20
23
CompletionRequest , CompletionResponse , CompletionResponseChoice ,
21
24
CompletionResponseStreamChoice , CompletionStreamResponse ,
22
- ChatCompletionRequest , ChatCompletionResponse , ChatCompletionResponseChoice ,
23
- ChatCompletionResponseStreamChoice , ChatCompletionStreamResponse ,
24
- ChatMessage , DeltaMessage , ErrorResponse , LogProbs ,
25
- ModelCard , ModelList , ModelPermission , UsageInfo )
26
- from fastchat .conversation import Conversation , SeparatorStyle , get_conv_template
25
+ ChatCompletionRequest , ChatCompletionResponse ,
26
+ ChatCompletionResponseChoice , ChatCompletionResponseStreamChoice ,
27
+ ChatCompletionStreamResponse , ChatMessage , DeltaMessage , ErrorResponse ,
28
+ LogProbs , ModelCard , ModelList , ModelPermission , UsageInfo )
27
29
from vllm .logger import init_logger
28
30
from vllm .outputs import RequestOutput
29
31
from vllm .sampling_params import SamplingParams
@@ -95,15 +97,15 @@ async def get_gen_prompt(request) -> str:
95
97
return prompt
96
98
97
99
98
- async def check_length (request , prompt , engine ):
99
- if hasattr (engine . engine . model_config .hf_config , "max_sequence_length" ):
100
- context_len = engine . engine . model_config .hf_config .max_sequence_length
101
- elif hasattr (engine . engine . model_config .hf_config , "seq_length" ):
102
- context_len = engine . engine . model_config .hf_config .seq_length
103
- elif hasattr (engine . engine . model_config .hf_config , "max_position_embeddings" ):
104
- context_len = engine . engine . model_config .hf_config .max_position_embeddings
105
- elif hasattr (engine . engine . model_config .hf_config , "seq_length" ):
106
- context_len = engine . engine . model_config .hf_config .seq_length
100
+ async def check_length (request , prompt , model_config ):
101
+ if hasattr (model_config .hf_config , "max_sequence_length" ):
102
+ context_len = model_config .hf_config .max_sequence_length
103
+ elif hasattr (model_config .hf_config , "seq_length" ):
104
+ context_len = model_config .hf_config .seq_length
105
+ elif hasattr (model_config .hf_config , "max_position_embeddings" ):
106
+ context_len = model_config .hf_config .max_position_embeddings
107
+ elif hasattr (model_config .hf_config , "seq_length" ):
108
+ context_len = model_config .hf_config .seq_length
107
109
else :
108
110
context_len = 2048
109
111
@@ -182,7 +184,7 @@ async def create_chat_completion(raw_request: Request):
182
184
"logit_bias is not currently supported" )
183
185
184
186
prompt = await get_gen_prompt (request )
185
- error_check_ret = await check_length (request , prompt , engine )
187
+ error_check_ret = await check_length (request , prompt , engine_model_config )
186
188
if error_check_ret is not None :
187
189
return error_check_ret
188
190
@@ -206,15 +208,16 @@ async def create_chat_completion(raw_request: Request):
206
208
except ValueError as e :
207
209
return create_error_response (HTTPStatus .BAD_REQUEST , str (e ))
208
210
209
- result_generator = engine .generate (prompt , sampling_params ,
210
- request_id )
211
+ result_generator = engine .generate (prompt , sampling_params , request_id )
211
212
212
213
async def abort_request () -> None :
213
214
await engine .abort (request_id )
214
215
215
- def create_stream_response_json (index : int ,
216
- text : str ,
217
- finish_reason : Optional [str ] = None ) -> str :
216
+ def create_stream_response_json (
217
+ index : int ,
218
+ text : str ,
219
+ finish_reason : Optional [str ] = None ,
220
+ ) -> str :
218
221
choice_data = ChatCompletionResponseStreamChoice (
219
222
index = index ,
220
223
delta = DeltaMessage (content = text ),
@@ -238,10 +241,11 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
238
241
delta = DeltaMessage (role = "assistant" ),
239
242
finish_reason = None ,
240
243
)
241
- chunk = ChatCompletionStreamResponse (
242
- id = request_id , choices = [choice_data ], model = model_name
243
- )
244
- yield f"data: { chunk .json (exclude_unset = True , ensure_ascii = False )} \n \n "
244
+ chunk = ChatCompletionStreamResponse (id = request_id ,
245
+ choices = [choice_data ],
246
+ model = model_name )
247
+ data = chunk .json (exclude_unset = True , ensure_ascii = False )
248
+ yield f"data: { data } \n \n "
245
249
246
250
previous_texts = ["" ] * request .n
247
251
previous_num_tokens = [0 ] * request .n
@@ -295,8 +299,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
295
299
choices .append (choice_data )
296
300
297
301
num_prompt_tokens = len (final_res .prompt_token_ids )
298
- num_generated_tokens = sum (len ( output . token_ids )
299
- for output in final_res .outputs )
302
+ num_generated_tokens = sum (
303
+ len ( output . token_ids ) for output in final_res .outputs )
300
304
usage = UsageInfo (
301
305
prompt_tokens = num_prompt_tokens ,
302
306
completion_tokens = num_generated_tokens ,
@@ -314,9 +318,11 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
314
318
# When user requests streaming but we don't stream, we still need to
315
319
# return a streaming response with a single event.
316
320
response_json = response .json (ensure_ascii = False )
321
+
317
322
async def fake_stream_generator () -> AsyncGenerator [str , None ]:
318
323
yield f"data: { response_json } \n \n "
319
324
yield "data: [DONE]\n \n "
325
+
320
326
return StreamingResponse (fake_stream_generator (),
321
327
media_type = "text/event-stream" )
322
328
@@ -367,9 +373,9 @@ async def create_completion(raw_request: Request):
367
373
return create_error_response (HTTPStatus .BAD_REQUEST ,
368
374
"please provide at least one prompt" )
369
375
if len (request .prompt ) > 1 :
370
- return create_error_response (HTTPStatus . BAD_REQUEST ,
371
- "multiple prompts in a batch is not "
372
- " currently supported" )
376
+ return create_error_response (
377
+ HTTPStatus . BAD_REQUEST ,
378
+ "multiple prompts in a batch is not currently supported" )
373
379
prompt = request .prompt [0 ]
374
380
else :
375
381
prompt = request .prompt
@@ -571,6 +577,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
571
577
572
578
engine_args = AsyncEngineArgs .from_cli_args (args )
573
579
engine = AsyncLLMEngine .from_engine_args (engine_args )
580
+ engine_model_config = asyncio .run (engine .get_model_config ())
574
581
575
582
# A separate tokenizer to map token IDs to strings.
576
583
tokenizer = get_tokenizer (engine_args .tokenizer ,
0 commit comments