13
13
from fastapi .exceptions import RequestValidationError
14
14
from fastapi .middleware .cors import CORSMiddleware
15
15
from fastapi .responses import JSONResponse , StreamingResponse
16
- from fastchat .conversation import (Conversation , SeparatorStyle ,
17
- get_conv_template )
16
+ from fastchat .conversation import Conversation , SeparatorStyle
17
+ from fastchat .model .model_adapter import get_conversation_template
18
+
18
19
import uvicorn
19
20
20
21
from vllm .engine .arg_utils import AsyncEngineArgs
36
37
37
38
logger = init_logger (__name__ )
38
39
served_model = None
39
- chat_template = None
40
40
app = fastapi .FastAPI ()
41
41
42
42
@@ -63,7 +63,7 @@ async def check_model(request) -> Optional[JSONResponse]:
63
63
64
64
65
65
async def get_gen_prompt (request ) -> str :
66
- conv = get_conv_template ( chat_template )
66
+ conv = get_conversation_template ( request . model )
67
67
conv = Conversation (
68
68
name = conv .name ,
69
69
system = conv .system ,
@@ -560,14 +560,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
560
560
help = "The model name used in the API. If not "
561
561
"specified, the model name will be the same as "
562
562
"the huggingface name." )
563
- parser .add_argument (
564
- "--chat-template" ,
565
- type = str ,
566
- default = None ,
567
- help = "The chat template name used in the ChatCompletion endpoint. If "
568
- "not specified, we use the API model name as the template name. See "
569
- "https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py "
570
- "for the list of available templates." )
563
+
571
564
parser = AsyncEngineArgs .add_cli_args (parser )
572
565
args = parser .parse_args ()
573
566
@@ -586,11 +579,6 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
586
579
else :
587
580
served_model = args .model
588
581
589
- if args .chat_template is not None :
590
- chat_template = args .chat_template
591
- else :
592
- chat_template = served_model
593
-
594
582
engine_args = AsyncEngineArgs .from_cli_args (args )
595
583
engine = AsyncLLMEngine .from_engine_args (engine_args )
596
584
engine_model_config = asyncio .run (engine .get_model_config ())
0 commit comments