36
36
37
37
logger = init_logger (__name__ )
38
38
served_model = None
39
+ chat_template = None
39
40
app = fastapi .FastAPI ()
40
41
41
42
@@ -62,7 +63,7 @@ async def check_model(request) -> Optional[JSONResponse]:
62
63
63
64
64
65
async def get_gen_prompt (request ) -> str :
65
- conv = get_conv_template (request . model )
66
+ conv = get_conv_template (chat_template )
66
67
conv = Conversation (
67
68
name = conv .name ,
68
69
system = conv .system ,
@@ -553,13 +554,20 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
553
554
type = json .loads ,
554
555
default = ["*" ],
555
556
help = "allowed headers" )
557
+ parser .add_argument ("--served-model-name" ,
558
+ type = str ,
559
+ default = None ,
560
+ help = "The model name used in the API. If not "
561
+ "specified, the model name will be the same as "
562
+ "the huggingface name." )
556
563
parser .add_argument (
557
- "--served-model-name " ,
564
+ "--chat-template " ,
558
565
type = str ,
559
566
default = None ,
560
- help = "The model name used in the API. If not specified, "
561
- "the model name will be the same as the "
562
- "huggingface name." )
567
+ help = "The chat template name used in the ChatCompletion endpoint. If "
568
+ "not specified, we use the API model name as the template name. See "
569
+ "https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py "
570
+ "for the list of available templates." )
563
571
parser = AsyncEngineArgs .add_cli_args (parser )
564
572
args = parser .parse_args ()
565
573
@@ -573,7 +581,15 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
573
581
574
582
logger .info (f"args: { args } " )
575
583
576
- served_model = args .served_model_name or args .model
584
+ if args .served_model_name is not None :
585
+ served_model = args .served_model_name
586
+ else :
587
+ served_model = args .model
588
+
589
+ if args .chat_template is not None :
590
+ chat_template = args .chat_template
591
+ else :
592
+ chat_template = served_model
577
593
578
594
engine_args = AsyncEngineArgs .from_cli_args (args )
579
595
engine = AsyncLLMEngine .from_engine_args (engine_args )
0 commit comments