Skip to content

Commit 7b86e7c

Browse files
HwwwwwwwHhezhihuiDarkLight1337
authored
[Model] Add multi-image support for minicpmv (#7122)
Co-authored-by: hezhihui <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent f80ab35 commit 7b86e7c

File tree

4 files changed

+172
-37
lines changed

4 files changed

+172
-37
lines changed

tests/conftest.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44
import sys
55
from collections import UserList
6-
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar
6+
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
77

88
import pytest
99
import torch
@@ -508,7 +508,8 @@ def generate_greedy_logprobs(
508508
prompts: List[str],
509509
max_tokens: int,
510510
num_logprobs: int,
511-
images: Optional[List[Image.Image]] = None,
511+
images: Optional[Union[List[Image.Image],
512+
List[List[Image.Image]]]] = None,
512513
stop_token_ids: Optional[List[int]] = None,
513514
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
514515
greedy_logprobs_params = SamplingParams(temperature=0.0,

tests/models/test_minicpmv.py

Lines changed: 133 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,18 @@
1414

1515
pytestmark = pytest.mark.vlm
1616

17+
18+
class NestedInputs(UserDict):
19+
20+
def __init__(self, model_inputs: BatchFeature):
21+
super().__init__({"model_inputs": model_inputs})
22+
23+
self.model_inputs = model_inputs
24+
25+
def to(self, device: torch.types.Device):
26+
return NestedInputs(self.model_inputs.to(device))
27+
28+
1729
# The image token is placed before "user" on purpose so that the test can pass
1830
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
1931
"stop_sign":
@@ -23,7 +35,7 @@
2335
"cherry_blossom":
2436
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
2537
"(<image>./</image>)\nWhat is the season?<|eot_id|>" \
26-
"<|start_header_id|>assistant<|end_header_id|>\n\n"
38+
"<|start_header_id|>assistant<|end_header_id|>\n\n",
2739
})
2840

2941
models = ["openbmb/MiniCPM-Llama3-V-2_5"]
@@ -94,22 +106,10 @@ def run_test(
94106
]
95107

96108
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
97-
98-
class NestedInputs(UserDict):
99-
100-
def __init__(self, model_inputs: BatchFeature):
101-
super().__init__({"model_inputs": model_inputs})
102-
103-
self.model_inputs = model_inputs
104-
105-
def to(self, device: torch.types.Device):
106-
return NestedInputs(self.model_inputs.to(device))
107-
108109
hf_processor = hf_model.processor
109110
hf_model.processor = lambda **kw: NestedInputs(
110111
hf_processor(**kw) # type: ignore
111112
)
112-
113113
hf_outputs_per_image = [
114114
hf_model.generate_greedy_logprobs_limit(prompts,
115115
max_tokens,
@@ -161,3 +161,123 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
161161
num_logprobs=num_logprobs,
162162
tensor_parallel_size=1,
163163
)
164+
165+
166+
HF_MULTIIMAGE_IMAGE_PROMPT = \
167+
"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \
168+
"(<image>./</image>)\n(<image>./</image>)\n" \
169+
"Describe these images.<|eot_id|>" \
170+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
171+
172+
173+
def run_multi_image_test(
174+
hf_runner: Type[HfRunner],
175+
vllm_runner: Type[VllmRunner],
176+
image_assets: _ImageAssets,
177+
model: str,
178+
*,
179+
size_factors: List[float],
180+
dtype: str,
181+
max_tokens: int,
182+
num_logprobs: int,
183+
tensor_parallel_size: int,
184+
distributed_executor_backend: Optional[str] = None,
185+
):
186+
"""Inference result should be the same between hf and vllm.
187+
188+
All the image fixtures for the test is under tests/images.
189+
For huggingface runner, we provide the PIL images as input.
190+
For vllm runner, we provide MultiModalDataDict objects
191+
and corresponding vision language config as input.
192+
Note, the text input is also adjusted to abide by vllm contract.
193+
The text output is sanitized to be able to compare with hf.
194+
"""
195+
images = [asset.pil_image for asset in image_assets]
196+
197+
inputs_per_case = [
198+
([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors],
199+
[[rescale_image_size(image, factor) for image in images]
200+
for factor in size_factors])
201+
]
202+
203+
# NOTE: take care of the order. run vLLM first, and then run HF.
204+
# vLLM needs a fresh new process without cuda initialization.
205+
# if we run HF first, the cuda initialization will be done and it
206+
# will hurt multiprocessing backend with fork method (the default method).
207+
208+
# max_model_len should be greater than image_feature_size
209+
with vllm_runner(model,
210+
max_model_len=4096,
211+
max_num_seqs=1,
212+
dtype=dtype,
213+
tensor_parallel_size=tensor_parallel_size,
214+
distributed_executor_backend=distributed_executor_backend,
215+
enforce_eager=True) as vllm_model:
216+
tokenizer = vllm_model.model.get_tokenizer()
217+
stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id]
218+
vllm_outputs_per_case = [
219+
vllm_model.generate_greedy_logprobs(prompts,
220+
max_tokens,
221+
num_logprobs=num_logprobs,
222+
images=images,
223+
stop_token_ids=stop_token_ids)
224+
for prompts, images in inputs_per_case
225+
]
226+
227+
with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad():
228+
hf_processor = hf_model.processor
229+
hf_model.processor = lambda **kw: NestedInputs(
230+
hf_processor(**kw) # type: ignore
231+
)
232+
hf_outputs_per_case = [
233+
hf_model.generate_greedy_logprobs_limit(prompts,
234+
max_tokens,
235+
num_logprobs=num_logprobs,
236+
images=images,
237+
tokenizer=tokenizer)
238+
for prompts, images in inputs_per_case
239+
]
240+
241+
for hf_outputs, vllm_outputs in zip(hf_outputs_per_case,
242+
vllm_outputs_per_case):
243+
check_logprobs_close(
244+
outputs_0_lst=[
245+
trunc_hf_output(hf_output) for hf_output in hf_outputs
246+
],
247+
outputs_1_lst=vllm_outputs,
248+
name_0="hf",
249+
name_1="vllm",
250+
)
251+
252+
253+
@pytest.mark.parametrize("model", models)
254+
@pytest.mark.parametrize(
255+
"size_factors",
256+
[
257+
# No image
258+
[],
259+
# Single-scale
260+
[1.0],
261+
# Single-scale, batched
262+
[1.0, 1.0, 1.0],
263+
# Multi-scale
264+
[0.25, 0.5, 1.0],
265+
],
266+
)
267+
@pytest.mark.parametrize("dtype", [target_dtype])
268+
@pytest.mark.parametrize("max_tokens", [128])
269+
@pytest.mark.parametrize("num_logprobs", [5])
270+
def test_multi_images_models(hf_runner, vllm_runner, image_assets, model,
271+
size_factors, dtype: str, max_tokens: int,
272+
num_logprobs: int) -> None:
273+
run_multi_image_test(
274+
hf_runner,
275+
vllm_runner,
276+
image_assets,
277+
model,
278+
size_factors=size_factors,
279+
dtype=dtype,
280+
max_tokens=max_tokens,
281+
num_logprobs=num_logprobs,
282+
tensor_parallel_size=1,
283+
)

vllm/model_executor/models/minicpmv.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,20 @@ def forward(self, x: torch.Tensor,
392392
return x
393393

394394

395+
def get_version_by_config(config: PretrainedConfig) -> Tuple[int, ...]:
396+
version_float = getattr(config, "version", None)
397+
398+
# The old configs do not include version number
399+
# TODO: Remove this after the HF repos are updated
400+
if version_float is None:
401+
if config.hidden_size == 2304 and config.query_num == 64:
402+
return (2, 0)
403+
return (2, 5)
404+
405+
version_str = str(version_float)
406+
return tuple(int(x) for x in version_str.split("."))
407+
408+
395409
def get_max_minicpmv_image_tokens(ctx: InputContext):
396410
hf_config = ctx.get_hf_config(PretrainedConfig)
397411
return getattr(hf_config, "query_num", 64)
@@ -421,36 +435,43 @@ def input_processor_for_minicpmv(ctx: InputContext, llm_inputs: LLMInputs):
421435
multi_modal_data = llm_inputs.get("multi_modal_data")
422436
if multi_modal_data is None or "image" not in multi_modal_data:
423437
return llm_inputs
424-
425438
model_config = ctx.model_config
426-
439+
version = get_version_by_config(model_config.hf_config)
427440
tokenizer = cached_get_tokenizer(model_config.tokenizer,
428441
trust_remote_code=True)
442+
image_processor = cached_get_image_processor(model_config.tokenizer)
443+
444+
def get_placeholder(image_size: Tuple[int, int], num_image: int):
445+
if version == (2, 0) or version == (2, 5):
446+
return image_processor. \
447+
get_slice_image_placeholder(image_size)
448+
return image_processor. \
449+
get_slice_image_placeholder(image_size, num_image)
429450

430451
prompt = llm_inputs.get("prompt")
431452
if prompt is None:
432453
token_ids = llm_inputs.get("prompt_token_ids")
433454
prompt = tokenizer.decode(token_ids)
434-
image_processor = cached_get_image_processor(model_config.tokenizer)
435455

436456
pattern = "(<image>./</image>)"
437-
image = multi_modal_data["image"]
457+
images = multi_modal_data["image"]
458+
if isinstance(images, Image.Image):
459+
images = [images]
438460
image_tags = re.findall(pattern, prompt)
439461

440462
if len(image_tags) == 0:
441463
new_token_ids = token_ids
442464
new_prompt = prompt
443465
else:
444-
if len(image_tags) > 1:
445-
logger.warning("Multiple image input is not supported yet, "
446-
"so any extra image tokens will be treated "
447-
"as plain text.")
448-
449466
text_chunks = prompt.split(pattern)
450-
new_prompt = (text_chunks[0] +
451-
image_processor.get_slice_image_placeholder(image.size) +
452-
"".join(text_chunks[1:]))
453-
467+
new_prompt_chunks: List[str] = []
468+
for i in range(len(images)):
469+
new_prompt_chunks += [
470+
text_chunks[i],
471+
get_placeholder(images[i].size, i)
472+
]
473+
new_prompt_chunks.append(text_chunks[-1])
474+
new_prompt = "".join(new_prompt_chunks)
454475
new_token_ids = tokenizer.encode(new_prompt)
455476

456477
llm_inputs = LLMInputs(
@@ -478,14 +499,7 @@ def __init__(
478499
self.config = config
479500
self.multimodal_config = multimodal_config
480501

481-
if not hasattr(self.config, "version"):
482-
if self.config.hidden_size == 2304 and self.config.query_num == 64:
483-
self.version = (2, 0)
484-
else:
485-
self.version = (2, 5)
486-
else:
487-
self.version = str(self.config.version).split(".")
488-
self.version = tuple([int(x) for x in self.version])
502+
self.version = get_version_by_config(self.config)
489503
self.llm = self.init_llm(config, cache_config, quant_config)
490504
self.vpm = self.init_vision_module()
491505
param_dtype = torch.get_default_dtype()

vllm/multimodal/image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def _get_hf_image_processor(self, model_config: ModelConfig):
113113
def _default_input_mapper(self, ctx: InputContext,
114114
data: object) -> MultiModalInputs:
115115
model_config = ctx.model_config
116-
if isinstance(data, Image.Image):
116+
if isinstance(data, (Image.Image, list)):
117117
image_processor = self._get_hf_image_processor(model_config)
118118
if image_processor is None:
119119
raise RuntimeError("No HuggingFace processor is available "

0 commit comments

Comments
 (0)