Skip to content

Commit 6eaf1e5

Browse files
[Misc] Add --seed option to offline multi-modal examples (#14934)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 868a8c5 commit 6eaf1e5

File tree

6 files changed

+537
-315
lines changed

6 files changed

+537
-315
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,13 @@ steps:
226226
- python3 offline_inference/basic/chat.py
227227
- python3 offline_inference/prefix_caching.py
228228
- python3 offline_inference/llm_engine_example.py
229-
- python3 offline_inference/vision_language.py
230-
- python3 offline_inference/vision_language_multi_image.py
229+
- python3 offline_inference/audio_language.py --seed 0
230+
- python3 offline_inference/vision_language.py --seed 0
231+
- python3 offline_inference/vision_language_embedding.py --seed 0
232+
- python3 offline_inference/vision_language_multi_image.py --seed 0
231233
- VLLM_USE_V1=0 python3 other/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 other/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
232234
- python3 offline_inference/encoder_decoder.py
235+
- python3 offline_inference/encoder_decoder_multimodal.py --model-type whisper --seed 0
233236
- python3 offline_inference/basic/classify.py
234237
- python3 offline_inference/basic/embed.py
235238
- python3 offline_inference/basic/score.py

examples/offline_inference/audio_language.py

Lines changed: 88 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
on HuggingFace model repository.
88
"""
99
import os
10+
from dataclasses import asdict
11+
from typing import NamedTuple, Optional
1012

1113
from huggingface_hub import snapshot_download
1214
from transformers import AutoTokenizer
1315

14-
from vllm import LLM, SamplingParams
16+
from vllm import LLM, EngineArgs, SamplingParams
1517
from vllm.assets.audio import AudioAsset
1618
from vllm.lora.request import LoRARequest
1719
from vllm.utils import FlexibleArgumentParser
@@ -23,21 +25,31 @@
2325
2: "What sport and what nursery rhyme are referenced?"
2426
}
2527

28+
29+
class ModelRequestData(NamedTuple):
30+
engine_args: EngineArgs
31+
prompt: str
32+
stop_token_ids: Optional[list[int]] = None
33+
lora_requests: Optional[list[LoRARequest]] = None
34+
35+
2636
# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
2737
# lower-end GPUs.
2838
# Unless specified, these settings have been tested to work on a single L4.
2939

3040

3141
# MiniCPM-O
32-
def run_minicpmo(question: str, audio_count: int):
42+
def run_minicpmo(question: str, audio_count: int) -> ModelRequestData:
3343
model_name = "openbmb/MiniCPM-o-2_6"
3444
tokenizer = AutoTokenizer.from_pretrained(model_name,
3545
trust_remote_code=True)
36-
llm = LLM(model=model_name,
37-
trust_remote_code=True,
38-
max_model_len=4096,
39-
max_num_seqs=5,
40-
limit_mm_per_prompt={"audio": audio_count})
46+
engine_args = EngineArgs(
47+
model=model_name,
48+
trust_remote_code=True,
49+
max_model_len=4096,
50+
max_num_seqs=5,
51+
limit_mm_per_prompt={"audio": audio_count},
52+
)
4153

4254
stop_tokens = ['<|im_end|>', '<|endoftext|>']
4355
stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens]
@@ -52,11 +64,16 @@ def run_minicpmo(question: str, audio_count: int):
5264
tokenize=False,
5365
add_generation_prompt=True,
5466
chat_template=audio_chat_template)
55-
return llm, prompt, stop_token_ids
67+
68+
return ModelRequestData(
69+
engine_args=engine_args,
70+
prompt=prompt,
71+
stop_token_ids=stop_token_ids,
72+
)
5673

5774

5875
# Phi-4-multimodal-instruct
59-
def run_phi4mm(questions: str, audio_count: int):
76+
def run_phi4mm(question: str, audio_count: int) -> ModelRequestData:
6077
"""
6178
Phi-4-multimodal-instruct supports both image and audio inputs. Here, we
6279
show how to process audio inputs.
@@ -67,9 +84,9 @@ def run_phi4mm(questions: str, audio_count: int):
6784
speech_lora_path = os.path.join(model_path, "speech-lora")
6885
placeholders = "".join([f"<|audio_{i+1}|>" for i in range(audio_count)])
6986

70-
prompts = f"<|user|>{placeholders}{questions}<|end|><|assistant|>"
87+
prompts = f"<|user|>{placeholders}{question}<|end|><|assistant|>"
7188

72-
llm = LLM(
89+
engine_args = EngineArgs(
7390
model=model_path,
7491
trust_remote_code=True,
7592
max_model_len=4096,
@@ -79,24 +96,24 @@ def run_phi4mm(questions: str, audio_count: int):
7996
lora_extra_vocab_size=0,
8097
limit_mm_per_prompt={"audio": audio_count},
8198
)
82-
lora_request = LoRARequest("speech", 1, speech_lora_path)
83-
# To maintain code compatibility in this script, we add LoRA here.
84-
llm.llm_engine.add_lora(lora_request=lora_request)
85-
# You can also add LoRA using:
86-
# llm.generate(prompts, lora_request=lora_request,...)
8799

88-
stop_token_ids = None
89-
return llm, prompts, stop_token_ids
100+
return ModelRequestData(
101+
engine_args=engine_args,
102+
prompt=prompts,
103+
lora_requests=[LoRARequest("speech", 1, speech_lora_path)],
104+
)
90105

91106

92107
# Qwen2-Audio
93-
def run_qwen2_audio(question: str, audio_count: int):
108+
def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
94109
model_name = "Qwen/Qwen2-Audio-7B-Instruct"
95110

96-
llm = LLM(model=model_name,
97-
max_model_len=4096,
98-
max_num_seqs=5,
99-
limit_mm_per_prompt={"audio": audio_count})
111+
engine_args = EngineArgs(
112+
model=model_name,
113+
max_model_len=4096,
114+
max_num_seqs=5,
115+
limit_mm_per_prompt={"audio": audio_count},
116+
)
100117

101118
audio_in_prompt = "".join([
102119
f"Audio {idx+1}: "
@@ -107,12 +124,15 @@ def run_qwen2_audio(question: str, audio_count: int):
107124
"<|im_start|>user\n"
108125
f"{audio_in_prompt}{question}<|im_end|>\n"
109126
"<|im_start|>assistant\n")
110-
stop_token_ids = None
111-
return llm, prompt, stop_token_ids
127+
128+
return ModelRequestData(
129+
engine_args=engine_args,
130+
prompt=prompt,
131+
)
112132

113133

114134
# Ultravox 0.5-1B
115-
def run_ultravox(question: str, audio_count: int):
135+
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
116136
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
117137

118138
tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -124,29 +144,39 @@ def run_ultravox(question: str, audio_count: int):
124144
tokenize=False,
125145
add_generation_prompt=True)
126146

127-
llm = LLM(model=model_name,
128-
max_model_len=4096,
129-
max_num_seqs=5,
130-
trust_remote_code=True,
131-
limit_mm_per_prompt={"audio": audio_count})
132-
stop_token_ids = None
133-
return llm, prompt, stop_token_ids
147+
engine_args = EngineArgs(
148+
model=model_name,
149+
max_model_len=4096,
150+
max_num_seqs=5,
151+
trust_remote_code=True,
152+
limit_mm_per_prompt={"audio": audio_count},
153+
)
154+
155+
return ModelRequestData(
156+
engine_args=engine_args,
157+
prompt=prompt,
158+
)
134159

135160

136161
# Whisper
137-
def run_whisper(question: str, audio_count: int):
162+
def run_whisper(question: str, audio_count: int) -> ModelRequestData:
138163
assert audio_count == 1, (
139164
"Whisper only support single audio input per prompt")
140165
model_name = "openai/whisper-large-v3-turbo"
141166

142167
prompt = "<|startoftranscript|>"
143168

144-
llm = LLM(model=model_name,
145-
max_model_len=448,
146-
max_num_seqs=5,
147-
limit_mm_per_prompt={"audio": audio_count})
148-
stop_token_ids = None
149-
return llm, prompt, stop_token_ids
169+
engine_args = EngineArgs(
170+
model=model_name,
171+
max_model_len=448,
172+
max_num_seqs=5,
173+
limit_mm_per_prompt={"audio": audio_count},
174+
)
175+
176+
return ModelRequestData(
177+
engine_args=engine_args,
178+
prompt=prompt,
179+
)
150180

151181

152182
model_example_map = {
@@ -164,14 +194,24 @@ def main(args):
164194
raise ValueError(f"Model type {model} is not supported.")
165195

166196
audio_count = args.num_audios
167-
llm, prompt, stop_token_ids = model_example_map[model](
168-
question_per_audio_count[audio_count], audio_count)
197+
req_data = model_example_map[model](question_per_audio_count[audio_count],
198+
audio_count)
199+
200+
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
201+
llm = LLM(**engine_args)
202+
203+
# To maintain code compatibility in this script, we add LoRA here.
204+
# You can also add LoRA using:
205+
# llm.generate(prompts, lora_request=lora_request,...)
206+
if req_data.lora_requests:
207+
for lora_request in req_data.lora_requests:
208+
llm.llm_engine.add_lora(lora_request=lora_request)
169209

170210
# We set temperature to 0.2 so that outputs can be different
171211
# even when all prompts are identical when running batch inference.
172212
sampling_params = SamplingParams(temperature=0.2,
173213
max_tokens=64,
174-
stop_token_ids=stop_token_ids)
214+
stop_token_ids=req_data.stop_token_ids)
175215

176216
mm_data = {}
177217
if audio_count > 0:
@@ -183,7 +223,7 @@ def main(args):
183223
}
184224

185225
assert args.num_prompts > 0
186-
inputs = {"prompt": prompt, "multi_modal_data": mm_data}
226+
inputs = {"prompt": req_data.prompt, "multi_modal_data": mm_data}
187227
if args.num_prompts > 1:
188228
# Batch inference
189229
inputs = [inputs] * args.num_prompts
@@ -214,6 +254,10 @@ def main(args):
214254
default=1,
215255
choices=[0, 1, 2],
216256
help="Number of audio items per prompt.")
257+
parser.add_argument("--seed",
258+
type=int,
259+
default=None,
260+
help="Set the seed when initializing `vllm.LLM`.")
217261

218262
args = parser.parse_args()
219263
main(args)

examples/offline_inference/encoder_decoder_multimodal.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,23 @@
44
the explicit/implicit prompt format on enc-dec LMMs for text generation.
55
"""
66
import time
7+
from collections.abc import Sequence
8+
from dataclasses import asdict
9+
from typing import NamedTuple
710

8-
from vllm import LLM, SamplingParams
11+
from vllm import LLM, EngineArgs, PromptType, SamplingParams
912
from vllm.assets.audio import AudioAsset
1013
from vllm.assets.image import ImageAsset
1114
from vllm.utils import FlexibleArgumentParser
1215

1316

17+
class ModelRequestData(NamedTuple):
18+
engine_args: EngineArgs
19+
prompts: Sequence[PromptType]
20+
21+
1422
def run_florence2():
15-
# Create a Florence-2 encoder/decoder model instance
16-
llm = LLM(
23+
engine_args = EngineArgs(
1724
model="microsoft/Florence-2-large",
1825
tokenizer="facebook/bart-large",
1926
max_num_seqs=8,
@@ -39,12 +46,15 @@ def run_florence2():
3946
"decoder_prompt": "",
4047
},
4148
]
42-
return llm, prompts
49+
50+
return ModelRequestData(
51+
engine_args=engine_args,
52+
prompts=prompts,
53+
)
4354

4455

4556
def run_mllama():
46-
# Create a Mllama encoder/decoder model instance
47-
llm = LLM(
57+
engine_args = EngineArgs(
4858
model="meta-llama/Llama-3.2-11B-Vision-Instruct",
4959
max_model_len=4096,
5060
max_num_seqs=2,
@@ -69,12 +79,15 @@ def run_mllama():
6979
"decoder_prompt": "<|image|><|begin_of_text|>Please describe the image.", # noqa: E501
7080
},
7181
]
72-
return llm, prompts
82+
83+
return ModelRequestData(
84+
engine_args=engine_args,
85+
prompts=prompts,
86+
)
7387

7488

7589
def run_whisper():
76-
# Create a Whisper encoder/decoder model instance
77-
llm = LLM(
90+
engine_args = EngineArgs(
7891
model="openai/whisper-large-v3-turbo",
7992
max_model_len=448,
8093
max_num_seqs=16,
@@ -99,7 +112,11 @@ def run_whisper():
99112
"decoder_prompt": "<|startoftranscript|>",
100113
}
101114
]
102-
return llm, prompts
115+
116+
return ModelRequestData(
117+
engine_args=engine_args,
118+
prompts=prompts,
119+
)
103120

104121

105122
model_example_map = {
@@ -114,7 +131,12 @@ def main(args):
114131
if model not in model_example_map:
115132
raise ValueError(f"Model type {model} is not supported.")
116133

117-
llm, prompts = model_example_map[model]()
134+
req_data = model_example_map[model]()
135+
136+
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
137+
llm = LLM(**engine_args)
138+
139+
prompts = req_data.prompts
118140

119141
# Create a sampling params object.
120142
sampling_params = SamplingParams(
@@ -153,6 +175,10 @@ def main(args):
153175
default="mllama",
154176
choices=model_example_map.keys(),
155177
help='Huggingface "model_type".')
178+
parser.add_argument("--seed",
179+
type=int,
180+
default=None,
181+
help="Set the seed when initializing `vllm.LLM`.")
156182

157183
args = parser.parse_args()
158184
main(args)

0 commit comments

Comments
 (0)