Skip to content

Commit cbcdf2c

Browse files
DarkLight1337chaunceyjiangywang96
authored
[Bugfix] Fix chat template loading (#15143)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Roger Wang <[email protected]> Co-authored-by: chaunceyjiang <[email protected]> Co-authored-by: Roger Wang <[email protected]>
1 parent 038de04 commit cbcdf2c

File tree

7 files changed

+187
-47
lines changed

7 files changed

+187
-47
lines changed

tests/entrypoints/openai/test_chat_template.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,10 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
107107
# Call the function and get the result
108108
result = apply_hf_chat_template(
109109
tokenizer,
110+
trust_remote_code=True,
110111
conversation=mock_request.messages,
111112
chat_template=mock_request.chat_template or template_content,
113+
tools=None,
112114
add_generation_prompt=mock_request.add_generation_prompt,
113115
continue_final_message=mock_request.continue_final_message,
114116
)

tests/entrypoints/openai/test_video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ async def test_single_chat_session_video(client: openai.AsyncOpenAI,
8787
choice = chat_completion.choices[0]
8888
assert choice.finish_reason == "length"
8989
assert chat_completion.usage == openai.types.CompletionUsage(
90-
completion_tokens=10, prompt_tokens=6299, total_tokens=6309)
90+
completion_tokens=10, prompt_tokens=6287, total_tokens=6297)
9191

9292
message = choice.message
9393
message = chat_completion.choices[0].message
@@ -180,7 +180,7 @@ async def test_single_chat_session_video_base64encoded(
180180
choice = chat_completion.choices[0]
181181
assert choice.finish_reason == "length"
182182
assert chat_completion.usage == openai.types.CompletionUsage(
183-
completion_tokens=10, prompt_tokens=6299, total_tokens=6309)
183+
completion_tokens=10, prompt_tokens=6287, total_tokens=6297)
184184

185185
message = choice.message
186186
message = chat_completion.choices[0].message

tests/entrypoints/test_chat_utils.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
from typing import Optional
55

66
import pytest
7+
from packaging.version import Version
8+
from transformers import __version__ as TRANSFORMERS_VERSION
79

810
from vllm.assets.image import ImageAsset
911
from vllm.config import ModelConfig
10-
from vllm.entrypoints.chat_utils import (_try_extract_ast, load_chat_template,
12+
from vllm.entrypoints.chat_utils import (_resolve_hf_chat_template,
13+
_try_extract_ast, load_chat_template,
1114
parse_chat_messages,
1215
parse_chat_messages_futures,
1316
resolve_chat_template_content_format)
@@ -23,8 +26,10 @@
2326
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
2427
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
2528
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
29+
QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
2630
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
2731
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
32+
HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
2833

2934

3035
@pytest.fixture(scope="function")
@@ -703,25 +708,70 @@ def get_conversation(is_hf: bool):
703708

704709
vllm_result = apply_hf_chat_template(
705710
tokenizer,
711+
trust_remote_code=model_config.trust_remote_code,
706712
conversation=conversation,
707713
chat_template=None,
714+
tools=None,
708715
add_generation_prompt=True,
709716
)
710717

711718
assert hf_result == vllm_result
712719

713720

721+
@pytest.mark.parametrize(
722+
"model",
723+
[
724+
QWEN2VL_MODEL_ID, # tokenizer.chat_template is of type str
725+
HERMES_MODEL_ID, # tokenizer.chat_template is of type dict
726+
])
727+
@pytest.mark.parametrize("use_tools", [True, False])
728+
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
729+
"""checks that chat_template is a dict type for HF models."""
730+
731+
# Build the tokenizer group and grab the underlying tokenizer
732+
tokenizer_group = TokenizerGroup(
733+
model,
734+
enable_lora=False,
735+
max_num_seqs=5,
736+
max_input_length=None,
737+
)
738+
tokenizer = tokenizer_group.tokenizer
739+
740+
tools = [{
741+
"type": "function",
742+
"function": {
743+
"name": "dummy_function_name",
744+
"description": "This is a dummy function",
745+
"parameters": sample_json_schema
746+
}
747+
}] if use_tools else None
748+
749+
# Test detecting the tokenizer's chat_template
750+
chat_template = _resolve_hf_chat_template(
751+
tokenizer,
752+
chat_template=None,
753+
tools=tools,
754+
trust_remote_code=True,
755+
)
756+
assert isinstance(chat_template, str)
757+
758+
714759
# yapf: disable
715760
@pytest.mark.parametrize(
716761
("model", "expected_format"),
717762
[(PHI3V_MODEL_ID, "string"),
718763
(QWEN2VL_MODEL_ID, "openai"),
764+
(QWEN25VL_MODEL_ID, "openai"),
719765
(ULTRAVOX_MODEL_ID, "string"),
720766
(MLLAMA_MODEL_ID, "openai"),
721767
(LLAMA_GUARD_MODEL_ID, "openai")],
722768
)
723769
# yapf: enable
724770
def test_resolve_content_format_hf_defined(model, expected_format):
771+
if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version(
772+
"4.49.0"):
773+
pytest.skip("Qwen2.5-VL requires transformers>=4.49.0")
774+
725775
tokenizer_group = TokenizerGroup(
726776
model,
727777
enable_lora=False,
@@ -730,7 +780,13 @@ def test_resolve_content_format_hf_defined(model, expected_format):
730780
)
731781
tokenizer = tokenizer_group.tokenizer
732782

733-
chat_template = tokenizer.chat_template
783+
# Test detecting the tokenizer's chat_template
784+
chat_template = _resolve_hf_chat_template(
785+
tokenizer,
786+
chat_template=None,
787+
tools=None,
788+
trust_remote_code=True,
789+
)
734790
assert isinstance(chat_template, str)
735791

736792
print("[TEXT]")
@@ -740,8 +796,10 @@ def test_resolve_content_format_hf_defined(model, expected_format):
740796

741797
resolved_format = resolve_chat_template_content_format(
742798
None, # Test detecting the tokenizer's chat_template
799+
None,
743800
"auto",
744801
tokenizer,
802+
trust_remote_code=True,
745803
)
746804

747805
assert resolved_format == expected_format
@@ -791,8 +849,10 @@ def test_resolve_content_format_examples(template_path, expected_format):
791849

792850
resolved_format = resolve_chat_template_content_format(
793851
chat_template,
852+
None,
794853
"auto",
795854
dummy_tokenizer,
855+
trust_remote_code=True,
796856
)
797857

798858
assert resolved_format == expected_format

tests/tool_use/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ def ensure_system_prompt(messages: list[dict[str, Any]],
3939

4040
# universal args for all models go here. also good if you need to test locally
4141
# and change type or KV cache quantization or something.
42-
ARGS: list[str] = ["--enable-auto-tool-choice", "--max-model-len", "1024"]
42+
ARGS: list[str] = [
43+
"--enable-auto-tool-choice", "--max-model-len", "1024", "--max-num-seqs",
44+
"256"
45+
]
4346

4447
CONFIGS: dict[str, ServerConfig] = {
4548
"hermes": {

0 commit comments

Comments
 (0)