4
4
from typing import Optional
5
5
6
6
import pytest
7
+ from packaging .version import Version
8
+ from transformers import __version__ as TRANSFORMERS_VERSION
7
9
8
10
from vllm .assets .image import ImageAsset
9
11
from vllm .config import ModelConfig
10
- from vllm .entrypoints .chat_utils import (_try_extract_ast , load_chat_template ,
12
+ from vllm .entrypoints .chat_utils import (_resolve_hf_chat_template ,
13
+ _try_extract_ast , load_chat_template ,
11
14
parse_chat_messages ,
12
15
parse_chat_messages_futures ,
13
16
resolve_chat_template_content_format )
23
26
PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
24
27
ULTRAVOX_MODEL_ID = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
25
28
QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct"
29
+ QWEN25VL_MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
26
30
MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct"
27
31
LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B"
32
+ HERMES_MODEL_ID = "NousResearch/Hermes-3-Llama-3.1-8B"
28
33
29
34
30
35
@pytest .fixture (scope = "function" )
@@ -703,25 +708,70 @@ def get_conversation(is_hf: bool):
703
708
704
709
vllm_result = apply_hf_chat_template (
705
710
tokenizer ,
711
+ trust_remote_code = model_config .trust_remote_code ,
706
712
conversation = conversation ,
707
713
chat_template = None ,
714
+ tools = None ,
708
715
add_generation_prompt = True ,
709
716
)
710
717
711
718
assert hf_result == vllm_result
712
719
713
720
721
+ @pytest .mark .parametrize (
722
+ "model" ,
723
+ [
724
+ QWEN2VL_MODEL_ID , # tokenizer.chat_template is of type str
725
+ HERMES_MODEL_ID , # tokenizer.chat_template is of type dict
726
+ ])
727
+ @pytest .mark .parametrize ("use_tools" , [True , False ])
728
+ def test_resolve_hf_chat_template (sample_json_schema , model , use_tools ):
729
+ """checks that chat_template is a dict type for HF models."""
730
+
731
+ # Build the tokenizer group and grab the underlying tokenizer
732
+ tokenizer_group = TokenizerGroup (
733
+ model ,
734
+ enable_lora = False ,
735
+ max_num_seqs = 5 ,
736
+ max_input_length = None ,
737
+ )
738
+ tokenizer = tokenizer_group .tokenizer
739
+
740
+ tools = [{
741
+ "type" : "function" ,
742
+ "function" : {
743
+ "name" : "dummy_function_name" ,
744
+ "description" : "This is a dummy function" ,
745
+ "parameters" : sample_json_schema
746
+ }
747
+ }] if use_tools else None
748
+
749
+ # Test detecting the tokenizer's chat_template
750
+ chat_template = _resolve_hf_chat_template (
751
+ tokenizer ,
752
+ chat_template = None ,
753
+ tools = tools ,
754
+ trust_remote_code = True ,
755
+ )
756
+ assert isinstance (chat_template , str )
757
+
758
+
714
759
# yapf: disable
715
760
@pytest .mark .parametrize (
716
761
("model" , "expected_format" ),
717
762
[(PHI3V_MODEL_ID , "string" ),
718
763
(QWEN2VL_MODEL_ID , "openai" ),
764
+ (QWEN25VL_MODEL_ID , "openai" ),
719
765
(ULTRAVOX_MODEL_ID , "string" ),
720
766
(MLLAMA_MODEL_ID , "openai" ),
721
767
(LLAMA_GUARD_MODEL_ID , "openai" )],
722
768
)
723
769
# yapf: enable
724
770
def test_resolve_content_format_hf_defined (model , expected_format ):
771
+ if model == QWEN25VL_MODEL_ID and Version (TRANSFORMERS_VERSION ) < Version (
772
+ "4.49.0" ):
773
+ pytest .skip ("Qwen2.5-VL requires transformers>=4.49.0" )
774
+
725
775
tokenizer_group = TokenizerGroup (
726
776
model ,
727
777
enable_lora = False ,
@@ -730,7 +780,13 @@ def test_resolve_content_format_hf_defined(model, expected_format):
730
780
)
731
781
tokenizer = tokenizer_group .tokenizer
732
782
733
- chat_template = tokenizer .chat_template
783
+ # Test detecting the tokenizer's chat_template
784
+ chat_template = _resolve_hf_chat_template (
785
+ tokenizer ,
786
+ chat_template = None ,
787
+ tools = None ,
788
+ trust_remote_code = True ,
789
+ )
734
790
assert isinstance (chat_template , str )
735
791
736
792
print ("[TEXT]" )
@@ -740,8 +796,10 @@ def test_resolve_content_format_hf_defined(model, expected_format):
740
796
741
797
resolved_format = resolve_chat_template_content_format (
742
798
None , # Test detecting the tokenizer's chat_template
799
+ None ,
743
800
"auto" ,
744
801
tokenizer ,
802
+ trust_remote_code = True ,
745
803
)
746
804
747
805
assert resolved_format == expected_format
@@ -791,8 +849,10 @@ def test_resolve_content_format_examples(template_path, expected_format):
791
849
792
850
resolved_format = resolve_chat_template_content_format (
793
851
chat_template ,
852
+ None ,
794
853
"auto" ,
795
854
dummy_tokenizer ,
855
+ trust_remote_code = True ,
796
856
)
797
857
798
858
assert resolved_format == expected_format
0 commit comments