2
2
3
3
import os
4
4
import re
5
+ from collections .abc import Sequence
5
6
from typing import Optional
6
7
8
+ import librosa
7
9
import pytest
8
10
from huggingface_hub import snapshot_download
9
11
from transformers import AutoTokenizer
10
12
13
+ from vllm .assets .image import ImageAsset
11
14
from vllm .lora .request import LoRARequest
12
15
from vllm .multimodal .image import rescale_image_size
13
16
from vllm .platforms import current_platform
14
17
from vllm .sequence import SampleLogprobs
15
18
16
- from ....conftest import IMAGE_ASSETS , HfRunner , PromptImageInput , VllmRunner
19
+ from ....conftest import (IMAGE_ASSETS , HfRunner , PromptAudioInput ,
20
+ PromptImageInput , VllmRunner )
17
21
from ....utils import large_gpu_test
18
22
from ...utils import check_logprobs_close
19
23
29
33
# Since the vision-lora and speech-lora co-exist with the base model,
30
34
# we have to manually specify the path of the lora weights.
31
35
vision_lora_path = os .path .join (model_path , "vision-lora" )
36
+ speech_question = os .path .join (model_path , "examples" ,
37
+ "what_is_shown_in_this_image.wav" )
32
38
models = [model_path ]
33
39
34
40
@@ -64,7 +70,8 @@ def vllm_to_hf_output(vllm_output: tuple[list[int], str,
64
70
def run_test (
65
71
hf_runner : type [HfRunner ],
66
72
vllm_runner : type [VllmRunner ],
67
- inputs : list [tuple [list [str ], PromptImageInput ]],
73
+ inputs : Sequence [tuple [list [str ], PromptImageInput ,
74
+ Optional [PromptAudioInput ]]],
68
75
model : str ,
69
76
* ,
70
77
max_model_len : int ,
@@ -104,28 +111,49 @@ def run_test(
104
111
enforce_eager = True ,
105
112
) as vllm_model :
106
113
lora_request = LoRARequest ("vision" , 1 , vision_lora_path )
107
- vllm_model .model .llm_engine .add_lora (lora_request = lora_request )
108
114
vllm_outputs_per_case = [
109
115
vllm_model .generate_greedy_logprobs (prompts ,
110
116
max_tokens ,
111
117
num_logprobs = num_logprobs ,
112
- images = images )
113
- for prompts , images in inputs
118
+ images = images ,
119
+ audios = audios ,
120
+ lora_request = lora_request )
121
+ for prompts , images , audios in inputs
114
122
]
115
123
116
- # use eager mode for hf runner, since phi3_v didn't work with flash_attn
117
- hf_model_kwargs = {"_attn_implementation" : "eager" }
124
+ hf_model_kwargs = {"_attn_implementation" : "sdpa" }
118
125
with hf_runner (model , dtype = dtype ,
119
126
model_kwargs = hf_model_kwargs ) as hf_model :
120
- eos_token_id = hf_model .processor .tokenizer .eos_token_id
127
+
128
+ hf_processor = hf_model .processor
129
+ eos_token_id = hf_processor .tokenizer .eos_token_id
130
+
131
+ def patch_hf_processor (* args ,
132
+ text = "" ,
133
+ images = None ,
134
+ audio = None ,
135
+ sampling_rate = None ,
136
+ ** kwargs ):
137
+ audios = None
138
+ if audio is not None and sampling_rate is not None :
139
+ audios = [(audio , sampling_rate )]
140
+ return hf_processor (* args ,
141
+ text = text ,
142
+ images = images ,
143
+ audios = audios ,
144
+ ** kwargs )
145
+
146
+ hf_model .processor = patch_hf_processor
147
+
121
148
hf_outputs_per_case = [
122
149
hf_model .generate_greedy_logprobs_limit (prompts ,
123
150
max_tokens ,
124
151
num_logprobs = num_logprobs ,
125
152
images = images ,
153
+ audios = audios ,
126
154
eos_token_id = eos_token_id ,
127
155
num_logits_to_keep = 0 )
128
- for prompts , images in inputs
156
+ for prompts , images , audios in inputs
129
157
]
130
158
131
159
for hf_outputs , vllm_outputs in zip (hf_outputs_per_case ,
@@ -138,8 +166,6 @@ def run_test(
138
166
)
139
167
140
168
141
- # Since we use _attn_implementation="eager" for hf_runner, there is more
142
- # significant numerical difference. The basic `logprobs=5` fails to pass.
143
169
@pytest .mark .parametrize ("model" , models )
144
170
@pytest .mark .parametrize (
145
171
"size_factors" ,
@@ -151,7 +177,7 @@ def run_test(
151
177
# Single-scale, batched
152
178
[1.0 , 1.0 , 1.0 ],
153
179
# Multi-scale
154
- [0.7 , 0.75 , 1.0 ],
180
+ [0.25 , 0.5 , 1.0 ],
155
181
],
156
182
)
157
183
@pytest .mark .parametrize ("dtype" , [target_dtype ])
@@ -166,6 +192,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
166
192
inputs_per_image = [(
167
193
[prompt for _ in size_factors ],
168
194
[rescale_image_size (image , factor ) for factor in size_factors ],
195
+ None ,
169
196
) for image , prompt in zip (images , HF_IMAGE_PROMPTS )]
170
197
171
198
run_test (
@@ -201,17 +228,18 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
201
228
@pytest .mark .parametrize ("max_model_len" , [10000 ])
202
229
@pytest .mark .parametrize ("max_tokens" , [128 ])
203
230
@pytest .mark .parametrize ("num_logprobs" , [10 ])
204
- @pytest .mark .xfail (
205
- reason = "Phi-4-MM multi-image inference is divergent with hf model." )
206
231
def test_multi_images_models (hf_runner , vllm_runner , image_assets , model ,
207
232
size_factors , dtype : str , max_model_len : int ,
208
233
max_tokens : int , num_logprobs : int ) -> None :
209
234
images = [asset .pil_image for asset in image_assets ]
210
235
211
236
inputs_per_case = [
212
- ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors ],
213
- [[rescale_image_size (image , factor ) for image in images ]
214
- for factor in size_factors ])
237
+ (
238
+ [HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors ],
239
+ [[rescale_image_size (image , factor ) for image in images ]
240
+ for factor in size_factors ],
241
+ None ,
242
+ ),
215
243
]
216
244
217
245
run_test (
@@ -226,3 +254,38 @@ def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
226
254
mm_limit = 2 ,
227
255
tensor_parallel_size = 1 ,
228
256
)
257
+
258
+
259
+ @pytest .mark .parametrize ("model" , models )
260
+ @pytest .mark .parametrize ("dtype" , [target_dtype ])
261
+ @pytest .mark .parametrize ("max_model_len" , [10000 ])
262
+ @pytest .mark .parametrize ("max_tokens" , [128 ])
263
+ @pytest .mark .parametrize ("num_logprobs" , [10 ])
264
+ def test_vision_speech_models (hf_runner , vllm_runner , model , dtype : str ,
265
+ max_model_len : int , max_tokens : int ,
266
+ num_logprobs : int ) -> None :
267
+
268
+ # use the example speech question so that the model outputs are reasonable
269
+ audio = librosa .load (speech_question , sr = None )
270
+ image = ImageAsset ("cherry_blossom" ).pil_image .convert ("RGB" )
271
+
272
+ inputs_vision_speech = [
273
+ (
274
+ ["<|user|><|image_1|><|audio_1|><|end|><|assistant|>" ],
275
+ [image ],
276
+ [audio ],
277
+ ),
278
+ ]
279
+
280
+ run_test (
281
+ hf_runner ,
282
+ vllm_runner ,
283
+ inputs_vision_speech ,
284
+ model ,
285
+ dtype = dtype ,
286
+ max_model_len = max_model_len ,
287
+ max_tokens = max_tokens ,
288
+ num_logprobs = num_logprobs ,
289
+ mm_limit = 1 ,
290
+ tensor_parallel_size = 1 ,
291
+ )
0 commit comments