Skip to content

Commit 3d64cf0

Browse files
authored
[Server] use fastchat.model.model_adapter.get_conversation_template method to get model template (#357)
1 parent 98fe8cb commit 3d64cf0

File tree

1 file changed

+5
-17
lines changed

1 file changed

+5
-17
lines changed

vllm/entrypoints/openai/api_server.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
from fastapi.exceptions import RequestValidationError
1414
from fastapi.middleware.cors import CORSMiddleware
1515
from fastapi.responses import JSONResponse, StreamingResponse
16-
from fastchat.conversation import (Conversation, SeparatorStyle,
17-
get_conv_template)
16+
from fastchat.conversation import Conversation, SeparatorStyle
17+
from fastchat.model.model_adapter import get_conversation_template
18+
1819
import uvicorn
1920

2021
from vllm.engine.arg_utils import AsyncEngineArgs
@@ -36,7 +37,6 @@
3637

3738
logger = init_logger(__name__)
3839
served_model = None
39-
chat_template = None
4040
app = fastapi.FastAPI()
4141

4242

@@ -63,7 +63,7 @@ async def check_model(request) -> Optional[JSONResponse]:
6363

6464

6565
async def get_gen_prompt(request) -> str:
66-
conv = get_conv_template(chat_template)
66+
conv = get_conversation_template(request.model)
6767
conv = Conversation(
6868
name=conv.name,
6969
system=conv.system,
@@ -560,14 +560,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
560560
help="The model name used in the API. If not "
561561
"specified, the model name will be the same as "
562562
"the huggingface name.")
563-
parser.add_argument(
564-
"--chat-template",
565-
type=str,
566-
default=None,
567-
help="The chat template name used in the ChatCompletion endpoint. If "
568-
"not specified, we use the API model name as the template name. See "
569-
"https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py "
570-
"for the list of available templates.")
563+
571564
parser = AsyncEngineArgs.add_cli_args(parser)
572565
args = parser.parse_args()
573566

@@ -586,11 +579,6 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]:
586579
else:
587580
served_model = args.model
588581

589-
if args.chat_template is not None:
590-
chat_template = args.chat_template
591-
else:
592-
chat_template = served_model
593-
594582
engine_args = AsyncEngineArgs.from_cli_args(args)
595583
engine = AsyncLLMEngine.from_engine_args(engine_args)
596584
engine_model_config = asyncio.run(engine.get_model_config())

0 commit comments

Comments
 (0)