Skip to content

Commit 9319514

Browse files
authored
[Bugfix][VLM] Fix failing Phi-4-MM multi-images tests and add vision-speech test (#16424)
Signed-off-by: Isotr0py <[email protected]>
1 parent ed37599 commit 9319514

File tree

5 files changed

+119
-46
lines changed

5 files changed

+119
-46
lines changed

examples/offline_inference/audio_language.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,6 @@ def main(args):
199199
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
200200
llm = LLM(**engine_args)
201201

202-
# To maintain code compatibility in this script, we add LoRA here.
203-
# You can also add LoRA using:
204-
# llm.generate(prompts, lora_request=lora_request,...)
205-
if req_data.lora_requests:
206-
for lora_request in req_data.lora_requests:
207-
llm.llm_engine.add_lora(lora_request=lora_request)
208-
209202
# We set temperature to 0.2 so that outputs can be different
210203
# even when all prompts are identical when running batch inference.
211204
sampling_params = SamplingParams(temperature=0.2,
@@ -226,8 +219,15 @@ def main(args):
226219
if args.num_prompts > 1:
227220
# Batch inference
228221
inputs = [inputs] * args.num_prompts
229-
230-
outputs = llm.generate(inputs, sampling_params=sampling_params)
222+
# Add LoRA request if applicable
223+
lora_request = (req_data.lora_requests *
224+
args.num_prompts if req_data.lora_requests else None)
225+
226+
outputs = llm.generate(
227+
inputs,
228+
sampling_params=sampling_params,
229+
lora_request=lora_request,
230+
)
231231

232232
for o in outputs:
233233
generated_text = o.outputs[0].text

examples/offline_inference/vision_language.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"""
99
import os
1010
import random
11+
from contextlib import contextmanager
1112
from dataclasses import asdict
1213
from typing import NamedTuple, Optional
1314

@@ -1055,6 +1056,20 @@ def apply_image_repeat(image_repeat_prob, num_prompts, data,
10551056
return inputs
10561057

10571058

1059+
@contextmanager
1060+
def time_counter(enable: bool):
1061+
if enable:
1062+
import time
1063+
start_time = time.time()
1064+
yield
1065+
elapsed_time = time.time() - start_time
1066+
print("-" * 50)
1067+
print("-- generate time = {}".format(elapsed_time))
1068+
print("-" * 50)
1069+
else:
1070+
yield
1071+
1072+
10581073
def main(args):
10591074
model = args.model_type
10601075
if model not in model_example_map:
@@ -1113,17 +1128,16 @@ def main(args):
11131128
},
11141129
} for i in range(args.num_prompts)]
11151130

1116-
if args.time_generate:
1117-
import time
1118-
start_time = time.time()
1119-
outputs = llm.generate(inputs, sampling_params=sampling_params)
1120-
elapsed_time = time.time() - start_time
1121-
print("-" * 50)
1122-
print("-- generate time = {}".format(elapsed_time))
1123-
print("-" * 50)
1131+
# Add LoRA request if applicable
1132+
lora_request = (req_data.lora_requests *
1133+
args.num_prompts if req_data.lora_requests else None)
11241134

1125-
else:
1126-
outputs = llm.generate(inputs, sampling_params=sampling_params)
1135+
with time_counter(args.time_generate):
1136+
outputs = llm.generate(
1137+
inputs,
1138+
sampling_params=sampling_params,
1139+
lora_request=lora_request,
1140+
)
11271141

11281142
print("-" * 50)
11291143
for o in outputs:

examples/offline_inference/vision_language_multi_image.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -661,13 +661,6 @@ def run_generate(model, question: str, image_urls: list[str],
661661
engine_args = asdict(req_data.engine_args) | {"seed": args.seed}
662662
llm = LLM(**engine_args)
663663

664-
# To maintain code compatibility in this script, we add LoRA here.
665-
# You can also add LoRA using:
666-
# llm.generate(prompts, lora_request=lora_request,...)
667-
if req_data.lora_requests:
668-
for lora_request in req_data.lora_requests:
669-
llm.llm_engine.add_lora(lora_request=lora_request)
670-
671664
sampling_params = SamplingParams(temperature=0.0,
672665
max_tokens=256,
673666
stop_token_ids=req_data.stop_token_ids)
@@ -679,7 +672,9 @@ def run_generate(model, question: str, image_urls: list[str],
679672
"image": req_data.image_data
680673
},
681674
},
682-
sampling_params=sampling_params)
675+
sampling_params=sampling_params,
676+
lora_request=req_data.lora_requests,
677+
)
683678

684679
print("-" * 50)
685680
for o in outputs:
@@ -724,6 +719,7 @@ def run_chat(model: str, question: str, image_urls: list[str],
724719
}],
725720
sampling_params=sampling_params,
726721
chat_template=req_data.chat_template,
722+
lora_request=req_data.lora_requests,
727723
)
728724

729725
print("-" * 50)

tests/models/decoder_only/vision_language/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -433,8 +433,8 @@
433433
max_model_len=4096,
434434
max_num_seqs=2,
435435
task="generate",
436-
# use eager mode for hf runner since phi3v didn't work with flash_attn
437-
hf_model_kwargs={"_attn_implementation": "eager"},
436+
# use sdpa mode for hf runner since phi3v didn't work with flash_attn
437+
hf_model_kwargs={"_attn_implementation": "sdpa"},
438438
use_tokenizer_eos=True,
439439
vllm_output_post_proc=model_utils.phi3v_vllm_to_hf_output,
440440
num_logprobs=10,

tests/models/decoder_only/vision_language/test_phi4mm.py

Lines changed: 80 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,22 @@
22

33
import os
44
import re
5+
from collections.abc import Sequence
56
from typing import Optional
67

8+
import librosa
79
import pytest
810
from huggingface_hub import snapshot_download
911
from transformers import AutoTokenizer
1012

13+
from vllm.assets.image import ImageAsset
1114
from vllm.lora.request import LoRARequest
1215
from vllm.multimodal.image import rescale_image_size
1316
from vllm.platforms import current_platform
1417
from vllm.sequence import SampleLogprobs
1518

16-
from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner
19+
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptAudioInput,
20+
PromptImageInput, VllmRunner)
1721
from ....utils import large_gpu_test
1822
from ...utils import check_logprobs_close
1923

@@ -29,6 +33,8 @@
2933
# Since the vision-lora and speech-lora co-exist with the base model,
3034
# we have to manually specify the path of the lora weights.
3135
vision_lora_path = os.path.join(model_path, "vision-lora")
36+
speech_question = os.path.join(model_path, "examples",
37+
"what_is_shown_in_this_image.wav")
3238
models = [model_path]
3339

3440

@@ -64,7 +70,8 @@ def vllm_to_hf_output(vllm_output: tuple[list[int], str,
6470
def run_test(
6571
hf_runner: type[HfRunner],
6672
vllm_runner: type[VllmRunner],
67-
inputs: list[tuple[list[str], PromptImageInput]],
73+
inputs: Sequence[tuple[list[str], PromptImageInput,
74+
Optional[PromptAudioInput]]],
6875
model: str,
6976
*,
7077
max_model_len: int,
@@ -104,28 +111,49 @@ def run_test(
104111
enforce_eager=True,
105112
) as vllm_model:
106113
lora_request = LoRARequest("vision", 1, vision_lora_path)
107-
vllm_model.model.llm_engine.add_lora(lora_request=lora_request)
108114
vllm_outputs_per_case = [
109115
vllm_model.generate_greedy_logprobs(prompts,
110116
max_tokens,
111117
num_logprobs=num_logprobs,
112-
images=images)
113-
for prompts, images in inputs
118+
images=images,
119+
audios=audios,
120+
lora_request=lora_request)
121+
for prompts, images, audios in inputs
114122
]
115123

116-
# use eager mode for hf runner, since phi3_v didn't work with flash_attn
117-
hf_model_kwargs = {"_attn_implementation": "eager"}
124+
hf_model_kwargs = {"_attn_implementation": "sdpa"}
118125
with hf_runner(model, dtype=dtype,
119126
model_kwargs=hf_model_kwargs) as hf_model:
120-
eos_token_id = hf_model.processor.tokenizer.eos_token_id
127+
128+
hf_processor = hf_model.processor
129+
eos_token_id = hf_processor.tokenizer.eos_token_id
130+
131+
def patch_hf_processor(*args,
132+
text="",
133+
images=None,
134+
audio=None,
135+
sampling_rate=None,
136+
**kwargs):
137+
audios = None
138+
if audio is not None and sampling_rate is not None:
139+
audios = [(audio, sampling_rate)]
140+
return hf_processor(*args,
141+
text=text,
142+
images=images,
143+
audios=audios,
144+
**kwargs)
145+
146+
hf_model.processor = patch_hf_processor
147+
121148
hf_outputs_per_case = [
122149
hf_model.generate_greedy_logprobs_limit(prompts,
123150
max_tokens,
124151
num_logprobs=num_logprobs,
125152
images=images,
153+
audios=audios,
126154
eos_token_id=eos_token_id,
127155
num_logits_to_keep=0)
128-
for prompts, images in inputs
156+
for prompts, images, audios in inputs
129157
]
130158

131159
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
@@ -138,8 +166,6 @@ def run_test(
138166
)
139167

140168

141-
# Since we use _attn_implementation="eager" for hf_runner, there is more
142-
# significant numerical difference. The basic `logprobs=5` fails to pass.
143169
@pytest.mark.parametrize("model", models)
144170
@pytest.mark.parametrize(
145171
"size_factors",
@@ -151,7 +177,7 @@ def run_test(
151177
# Single-scale, batched
152178
[1.0, 1.0, 1.0],
153179
# Multi-scale
154-
[0.7, 0.75, 1.0],
180+
[0.25, 0.5, 1.0],
155181
],
156182
)
157183
@pytest.mark.parametrize("dtype", [target_dtype])
@@ -166,6 +192,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
166192
inputs_per_image = [(
167193
[prompt for _ in size_factors],
168194
[rescale_image_size(image, factor) for factor in size_factors],
195+
None,
169196
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
170197

171198
run_test(
@@ -201,17 +228,18 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
201228
@pytest.mark.parametrize("max_model_len", [10000])
202229
@pytest.mark.parametrize("max_tokens", [128])
203230
@pytest.mark.parametrize("num_logprobs", [10])
204-
@pytest.mark.xfail(
205-
reason="Phi-4-MM multi-image inference is divergent with hf model.")
206231
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
207232
size_factors, dtype: str, max_model_len: int,
208233
max_tokens: int, num_logprobs: int) -> None:
209234
images = [asset.pil_image for asset in image_assets]
210235

211236
inputs_per_case = [
212-
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
213-
[[rescale_image_size(image, factor) for image in images]
214-
for factor in size_factors])
237+
(
238+
[HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
239+
[[rescale_image_size(image, factor) for image in images]
240+
for factor in size_factors],
241+
None,
242+
),
215243
]
216244

217245
run_test(
@@ -226,3 +254,38 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
226254
mm_limit=2,
227255
tensor_parallel_size=1,
228256
)
257+
258+
259+
@pytest.mark.parametrize("model", models)
260+
@pytest.mark.parametrize("dtype", [target_dtype])
261+
@pytest.mark.parametrize("max_model_len", [10000])
262+
@pytest.mark.parametrize("max_tokens", [128])
263+
@pytest.mark.parametrize("num_logprobs", [10])
264+
def test_vision_speech_models(hf_runner, vllm_runner, model, dtype: str,
265+
max_model_len: int, max_tokens: int,
266+
num_logprobs: int) -> None:
267+
268+
# use the example speech question so that the model outputs are reasonable
269+
audio = librosa.load(speech_question, sr=None)
270+
image = ImageAsset("cherry_blossom").pil_image.convert("RGB")
271+
272+
inputs_vision_speech = [
273+
(
274+
["<|user|><|image_1|><|audio_1|><|end|><|assistant|>"],
275+
[image],
276+
[audio],
277+
),
278+
]
279+
280+
run_test(
281+
hf_runner,
282+
vllm_runner,
283+
inputs_vision_speech,
284+
model,
285+
dtype=dtype,
286+
max_model_len=max_model_len,
287+
max_tokens=max_tokens,
288+
num_logprobs=num_logprobs,
289+
mm_limit=1,
290+
tensor_parallel_size=1,
291+
)

0 commit comments

Comments
 (0)