Skip to content

Commit 22fa2e3

Browse files
authored
[VLM][Model] Support image input for Chameleon (#6633)
1 parent c520124 commit 22fa2e3

File tree

7 files changed

+696
-58
lines changed

7 files changed

+696
-58
lines changed

docs/source/models/supported_models.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ Vision Language Models
182182
- Models
183183
- Example HuggingFace Models
184184
- :ref:`LoRA <lora>`
185+
* - :code:`ChameleonForConditionalGeneration`
186+
- Chameleon
187+
- :code:`facebook/chameleon-7b` etc.
188+
-
185189
* - :code:`FuyuForCausalLM`
186190
- Fuyu
187191
- :code:`adept/fuyu-8b` etc.

tests/models/test_chameleon.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import re
2+
from typing import List, Optional, Type
3+
4+
import pytest
5+
6+
from vllm.multimodal.utils import rescale_image_size
7+
8+
from ..conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets
9+
10+
pytestmark = pytest.mark.vlm
11+
12+
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
13+
"stop_sign":
14+
"USER: <image>\nWhat's the content of the image?\nASSISTANT:",
15+
"cherry_blossom":
16+
"USER: <image>\nWhat is the season?\nASSISTANT:",
17+
})
18+
19+
models = ["facebook/chameleon-7b"]
20+
21+
22+
#TODO (ywang96): Add correctness test when chameleon is
23+
# available on transformers.
24+
def run_test(
25+
vllm_runner: Type[VllmRunner],
26+
image_assets: _ImageAssets,
27+
model: str,
28+
*,
29+
size_factors: List[float],
30+
dtype: str,
31+
max_tokens: int,
32+
tensor_parallel_size: int,
33+
distributed_executor_backend: Optional[str] = None,
34+
):
35+
"""Test if the model can generate text given
36+
a batch of images and prompts.
37+
38+
"""
39+
images = [asset.pil_image for asset in image_assets]
40+
41+
inputs_per_image = [(
42+
[prompt for _ in size_factors],
43+
[rescale_image_size(image, factor) for factor in size_factors],
44+
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
45+
46+
with vllm_runner(model,
47+
max_model_len=4096,
48+
dtype=dtype,
49+
tensor_parallel_size=tensor_parallel_size,
50+
distributed_executor_backend=distributed_executor_backend,
51+
enforce_eager=True) as vllm_model:
52+
53+
for prompts, images in inputs_per_image:
54+
vllm_outputs = vllm_model.generate_greedy(prompts,
55+
max_tokens,
56+
images=images)
57+
for i in range(len(vllm_outputs)):
58+
59+
# format prompt back to original
60+
replacements = {
61+
"<racm3:break>": "",
62+
"<eoss>": "",
63+
"<reserved08706>": ""
64+
}
65+
pattern = '|'.join(replacements.keys())
66+
vllm_result = re.sub(
67+
pattern,
68+
lambda match: replacements[match.group(0)], #noqa B023
69+
vllm_outputs[i][1])
70+
vllm_result = vllm_result.replace("<image>", "", 1023)
71+
assert vllm_result[:len(prompts[i])] == prompts[i]
72+
73+
# assert at least 10 new characters are generated
74+
# (to take stop token into account)
75+
assert len(vllm_outputs[i][1]) - len(prompts[i]) > 10
76+
77+
78+
@pytest.mark.parametrize("model", models)
79+
@pytest.mark.parametrize(
80+
"size_factors",
81+
[
82+
# Single-scale
83+
[1.0],
84+
# Single-scale, batched
85+
[1.0, 1.0, 1.0],
86+
# Multi-scale
87+
[0.25, 0.5, 1.0],
88+
],
89+
)
90+
@pytest.mark.parametrize("dtype", ["bfloat16"])
91+
@pytest.mark.parametrize("max_tokens", [128])
92+
def test_models(vllm_runner, image_assets, model, size_factors, dtype: str,
93+
max_tokens: int) -> None:
94+
run_test(
95+
vllm_runner,
96+
image_assets,
97+
model,
98+
size_factors=size_factors,
99+
dtype=dtype,
100+
max_tokens=max_tokens,
101+
tensor_parallel_size=1,
102+
)

vllm/entrypoints/chat_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ def _image_token_str(model_config: ModelConfig,
105105
return None
106106
if model_type.startswith("llava"):
107107
return tokenizer.decode(model_config.hf_config.image_token_index)
108-
108+
if model_type == "chameleon":
109+
return "<image>"
109110
raise TypeError("Unknown model type: {model_type}")
110111

111112

vllm/model_executor/models/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
1717
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
1818
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
19-
"ChameleonForCausalLM":
20-
("chameleon", "ChameleonForConditionalGeneration"
21-
), #TODO(ywang96): fix model name when huggingface fixes it
19+
#TODO(ywang96): remove this when huggingface fixes the model repo
20+
"ChameleonForCausalLM": ("chameleon", "ChameleonForConditionalGeneration"),
21+
"ChameleonForConditionalGeneration":
22+
("chameleon", "ChameleonForConditionalGeneration"),
2223
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
2324
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
2425
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),

0 commit comments

Comments
 (0)