Skip to content

Commit 98fe8cb

Browse files
authored
[Server] Add option to specify chat template for chat endpoint (#345)
1 parent ffa6d2f commit 98fe8cb

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ xformers >= 0.0.19
99
fastapi
1010
uvicorn
1111
pydantic # Required for OpenAI server.
12+
fschat # Required for OpenAI ChatCompletion Endpoint.

vllm/entrypoints/openai/api_server.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
logger = init_logger(__name__)
3838
served_model = None
39+
chat_template = None
3940
app = fastapi.FastAPI()
4041

4142

@@ -62,7 +63,7 @@ async def check_model(request) -> Optional[JSONResponse]:
6263

6364

6465
async def get_gen_prompt(request) -> str:
65-
conv = get_conv_template(request.model)
66+
conv = get_conv_template(chat_template)
6667
conv = Conversation(
6768
name=conv.name,
6869
system=conv.system,
@@ -553,13 +554,20 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
553554
type=json.loads,
554555
default=["*"],
555556
help="allowed headers")
557+
parser.add_argument("--served-model-name",
558+
type=str,
559+
default=None,
560+
help="The model name used in the API. If not "
561+
"specified, the model name will be the same as "
562+
"the huggingface name.")
556563
parser.add_argument(
557-
"--served-model-name",
564+
"--chat-template",
558565
type=str,
559566
default=None,
560-
help="The model name used in the API. If not specified, "
561-
"the model name will be the same as the "
562-
"huggingface name.")
567+
help="The chat template name used in the ChatCompletion endpoint. If "
568+
"not specified, we use the API model name as the template name. See "
569+
"https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py "
570+
"for the list of available templates.")
563571
parser = AsyncEngineArgs.add_cli_args(parser)
564572
args = parser.parse_args()
565573

@@ -573,7 +581,15 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
573581

574582
logger.info(f"args: {args}")
575583

576-
served_model = args.served_model_name or args.model
584+
if args.served_model_name is not None:
585+
served_model = args.served_model_name
586+
else:
587+
served_model = args.model
588+
589+
if args.chat_template is not None:
590+
chat_template = args.chat_template
591+
else:
592+
chat_template = served_model
577593

578594
engine_args = AsyncEngineArgs.from_cli_args(args)
579595
engine = AsyncLLMEngine.from_engine_args(engine_args)

0 commit comments

Comments
 (0)