7
7
on HuggingFace model repository.
8
8
"""
9
9
import os
10
+ from dataclasses import asdict
11
+ from typing import NamedTuple , Optional
10
12
11
13
from huggingface_hub import snapshot_download
12
14
from transformers import AutoTokenizer
13
15
14
- from vllm import LLM , SamplingParams
16
+ from vllm import LLM , EngineArgs , SamplingParams
15
17
from vllm .assets .audio import AudioAsset
16
18
from vllm .lora .request import LoRARequest
17
19
from vllm .utils import FlexibleArgumentParser
23
25
2 : "What sport and what nursery rhyme are referenced?"
24
26
}
25
27
28
+
29
+ class ModelRequestData (NamedTuple ):
30
+ engine_args : EngineArgs
31
+ prompt : str
32
+ stop_token_ids : Optional [list [int ]] = None
33
+ lora_requests : Optional [list [LoRARequest ]] = None
34
+
35
+
26
36
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
27
37
# lower-end GPUs.
28
38
# Unless specified, these settings have been tested to work on a single L4.
29
39
30
40
31
41
# MiniCPM-O
32
- def run_minicpmo (question : str , audio_count : int ):
42
+ def run_minicpmo (question : str , audio_count : int ) -> ModelRequestData :
33
43
model_name = "openbmb/MiniCPM-o-2_6"
34
44
tokenizer = AutoTokenizer .from_pretrained (model_name ,
35
45
trust_remote_code = True )
36
- llm = LLM (model = model_name ,
37
- trust_remote_code = True ,
38
- max_model_len = 4096 ,
39
- max_num_seqs = 5 ,
40
- limit_mm_per_prompt = {"audio" : audio_count })
46
+ engine_args = EngineArgs (
47
+ model = model_name ,
48
+ trust_remote_code = True ,
49
+ max_model_len = 4096 ,
50
+ max_num_seqs = 5 ,
51
+ limit_mm_per_prompt = {"audio" : audio_count },
52
+ )
41
53
42
54
stop_tokens = ['<|im_end|>' , '<|endoftext|>' ]
43
55
stop_token_ids = [tokenizer .convert_tokens_to_ids (i ) for i in stop_tokens ]
@@ -52,11 +64,16 @@ def run_minicpmo(question: str, audio_count: int):
52
64
tokenize = False ,
53
65
add_generation_prompt = True ,
54
66
chat_template = audio_chat_template )
55
- return llm , prompt , stop_token_ids
67
+
68
+ return ModelRequestData (
69
+ engine_args = engine_args ,
70
+ prompt = prompt ,
71
+ stop_token_ids = stop_token_ids ,
72
+ )
56
73
57
74
58
75
# Phi-4-multimodal-instruct
59
- def run_phi4mm (questions : str , audio_count : int ):
76
+ def run_phi4mm (question : str , audio_count : int ) -> ModelRequestData :
60
77
"""
61
78
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
62
79
show how to process audio inputs.
@@ -67,9 +84,9 @@ def run_phi4mm(questions: str, audio_count: int):
67
84
speech_lora_path = os .path .join (model_path , "speech-lora" )
68
85
placeholders = "" .join ([f"<|audio_{ i + 1 } |>" for i in range (audio_count )])
69
86
70
- prompts = f"<|user|>{ placeholders } { questions } <|end|><|assistant|>"
87
+ prompts = f"<|user|>{ placeholders } { question } <|end|><|assistant|>"
71
88
72
- llm = LLM (
89
+ engine_args = EngineArgs (
73
90
model = model_path ,
74
91
trust_remote_code = True ,
75
92
max_model_len = 4096 ,
@@ -79,24 +96,24 @@ def run_phi4mm(questions: str, audio_count: int):
79
96
lora_extra_vocab_size = 0 ,
80
97
limit_mm_per_prompt = {"audio" : audio_count },
81
98
)
82
- lora_request = LoRARequest ("speech" , 1 , speech_lora_path )
83
- # To maintain code compatibility in this script, we add LoRA here.
84
- llm .llm_engine .add_lora (lora_request = lora_request )
85
- # You can also add LoRA using:
86
- # llm.generate(prompts, lora_request=lora_request,...)
87
99
88
- stop_token_ids = None
89
- return llm , prompts , stop_token_ids
100
+ return ModelRequestData (
101
+ engine_args = engine_args ,
102
+ prompt = prompts ,
103
+ lora_requests = [LoRARequest ("speech" , 1 , speech_lora_path )],
104
+ )
90
105
91
106
92
107
# Qwen2-Audio
93
- def run_qwen2_audio (question : str , audio_count : int ):
108
+ def run_qwen2_audio (question : str , audio_count : int ) -> ModelRequestData :
94
109
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
95
110
96
- llm = LLM (model = model_name ,
97
- max_model_len = 4096 ,
98
- max_num_seqs = 5 ,
99
- limit_mm_per_prompt = {"audio" : audio_count })
111
+ engine_args = EngineArgs (
112
+ model = model_name ,
113
+ max_model_len = 4096 ,
114
+ max_num_seqs = 5 ,
115
+ limit_mm_per_prompt = {"audio" : audio_count },
116
+ )
100
117
101
118
audio_in_prompt = "" .join ([
102
119
f"Audio { idx + 1 } : "
@@ -107,12 +124,15 @@ def run_qwen2_audio(question: str, audio_count: int):
107
124
"<|im_start|>user\n "
108
125
f"{ audio_in_prompt } { question } <|im_end|>\n "
109
126
"<|im_start|>assistant\n " )
110
- stop_token_ids = None
111
- return llm , prompt , stop_token_ids
127
+
128
+ return ModelRequestData (
129
+ engine_args = engine_args ,
130
+ prompt = prompt ,
131
+ )
112
132
113
133
114
134
# Ultravox 0.5-1B
115
- def run_ultravox (question : str , audio_count : int ):
135
+ def run_ultravox (question : str , audio_count : int ) -> ModelRequestData :
116
136
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
117
137
118
138
tokenizer = AutoTokenizer .from_pretrained (model_name )
@@ -124,29 +144,39 @@ def run_ultravox(question: str, audio_count: int):
124
144
tokenize = False ,
125
145
add_generation_prompt = True )
126
146
127
- llm = LLM (model = model_name ,
128
- max_model_len = 4096 ,
129
- max_num_seqs = 5 ,
130
- trust_remote_code = True ,
131
- limit_mm_per_prompt = {"audio" : audio_count })
132
- stop_token_ids = None
133
- return llm , prompt , stop_token_ids
147
+ engine_args = EngineArgs (
148
+ model = model_name ,
149
+ max_model_len = 4096 ,
150
+ max_num_seqs = 5 ,
151
+ trust_remote_code = True ,
152
+ limit_mm_per_prompt = {"audio" : audio_count },
153
+ )
154
+
155
+ return ModelRequestData (
156
+ engine_args = engine_args ,
157
+ prompt = prompt ,
158
+ )
134
159
135
160
136
161
# Whisper
137
- def run_whisper (question : str , audio_count : int ):
162
+ def run_whisper (question : str , audio_count : int ) -> ModelRequestData :
138
163
assert audio_count == 1 , (
139
164
"Whisper only support single audio input per prompt" )
140
165
model_name = "openai/whisper-large-v3-turbo"
141
166
142
167
prompt = "<|startoftranscript|>"
143
168
144
- llm = LLM (model = model_name ,
145
- max_model_len = 448 ,
146
- max_num_seqs = 5 ,
147
- limit_mm_per_prompt = {"audio" : audio_count })
148
- stop_token_ids = None
149
- return llm , prompt , stop_token_ids
169
+ engine_args = EngineArgs (
170
+ model = model_name ,
171
+ max_model_len = 448 ,
172
+ max_num_seqs = 5 ,
173
+ limit_mm_per_prompt = {"audio" : audio_count },
174
+ )
175
+
176
+ return ModelRequestData (
177
+ engine_args = engine_args ,
178
+ prompt = prompt ,
179
+ )
150
180
151
181
152
182
model_example_map = {
@@ -164,14 +194,24 @@ def main(args):
164
194
raise ValueError (f"Model type { model } is not supported." )
165
195
166
196
audio_count = args .num_audios
167
- llm , prompt , stop_token_ids = model_example_map [model ](
168
- question_per_audio_count [audio_count ], audio_count )
197
+ req_data = model_example_map [model ](question_per_audio_count [audio_count ],
198
+ audio_count )
199
+
200
+ engine_args = asdict (req_data .engine_args ) | {"seed" : args .seed }
201
+ llm = LLM (** engine_args )
202
+
203
+ # To maintain code compatibility in this script, we add LoRA here.
204
+ # You can also add LoRA using:
205
+ # llm.generate(prompts, lora_request=lora_request,...)
206
+ if req_data .lora_requests :
207
+ for lora_request in req_data .lora_requests :
208
+ llm .llm_engine .add_lora (lora_request = lora_request )
169
209
170
210
# We set temperature to 0.2 so that outputs can be different
171
211
# even when all prompts are identical when running batch inference.
172
212
sampling_params = SamplingParams (temperature = 0.2 ,
173
213
max_tokens = 64 ,
174
- stop_token_ids = stop_token_ids )
214
+ stop_token_ids = req_data . stop_token_ids )
175
215
176
216
mm_data = {}
177
217
if audio_count > 0 :
@@ -183,7 +223,7 @@ def main(args):
183
223
}
184
224
185
225
assert args .num_prompts > 0
186
- inputs = {"prompt" : prompt , "multi_modal_data" : mm_data }
226
+ inputs = {"prompt" : req_data . prompt , "multi_modal_data" : mm_data }
187
227
if args .num_prompts > 1 :
188
228
# Batch inference
189
229
inputs = [inputs ] * args .num_prompts
@@ -214,6 +254,10 @@ def main(args):
214
254
default = 1 ,
215
255
choices = [0 , 1 , 2 ],
216
256
help = "Number of audio items per prompt." )
257
+ parser .add_argument ("--seed" ,
258
+ type = int ,
259
+ default = None ,
260
+ help = "Set the seed when initializing `vllm.LLM`." )
217
261
218
262
args = parser .parse_args ()
219
263
main (args )
0 commit comments