Skip to content

Commit bbb7003

Browse files
maxdebaysersimon-mo
authored andcommitted
Enable conversion of multimodal models to pooling tasks (#24451)
Signed-off-by: Max de Bayser <[email protected]>
1 parent 89da8d9 commit bbb7003

File tree

5 files changed

+266
-59
lines changed

5 files changed

+266
-59
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from vllm.platforms import current_platform
5+
6+
7+
def test_idefics_multimodal(
8+
vllm_runner,
9+
monkeypatch,
10+
) -> None:
11+
if current_platform.is_rocm():
12+
# ROCm Triton FA does not currently support sliding window attention
13+
# switch to use ROCm CK FA backend
14+
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
15+
16+
prompts = [
17+
"Hello, my name is",
18+
"The president of the United States is",
19+
"The capital of France is",
20+
"The future of AI is",
21+
]
22+
23+
with vllm_runner(model_name="HuggingFaceM4/Idefics3-8B-Llama3",
24+
runner="pooling",
25+
task="classify",
26+
convert="classify",
27+
load_format="dummy",
28+
max_model_len=512,
29+
enforce_eager=True,
30+
tensor_parallel_size=1,
31+
disable_log_stats=True,
32+
dtype="bfloat16") as vllm_model:
33+
llm = vllm_model.get_llm()
34+
outputs = llm.classify(prompts)
35+
for output in outputs:
36+
assert len(output.outputs.probs) == 2
37+
38+
39+
def update_config(config):
40+
config.text_config.update({
41+
"architectures": ["Gemma3ForSequenceClassification"],
42+
"classifier_from_token": ["A", "B", "C", "D", "E"],
43+
"method":
44+
"no_post_processing",
45+
"id2label": {
46+
"A": "Chair",
47+
"B": "Couch",
48+
"C": "Table",
49+
"D": "Bed",
50+
"E": "Cupboard"
51+
},
52+
})
53+
return config
54+
55+
56+
def test_gemma_multimodal(
57+
vllm_runner,
58+
monkeypatch,
59+
) -> None:
60+
if current_platform.is_rocm():
61+
# ROCm Triton FA does not currently support sliding window attention
62+
# switch to use ROCm CK FA backend
63+
monkeypatch.setenv("VLLM_USE_TRITON_FLASH_ATTN", "False")
64+
65+
messages = [{
66+
"role":
67+
"system",
68+
"content":
69+
"""
70+
You are a helpful assistant. You will be given a product description
71+
which may also include an image. Classify the following product into
72+
one of the categories:
73+
74+
A = chair
75+
B = couch
76+
C = table
77+
D = bed
78+
E = cupboard
79+
80+
You'll answer with exactly one letter (A, B, C, D, or E)."""
81+
}, {
82+
"role":
83+
"user",
84+
"content": [{
85+
"type": "image_url",
86+
"image_url": {
87+
"url":
88+
"https://upload.wikimedia.org/wikipedia/commons/c/c6/Set_of_fourteen_side_chairs_MET_DP110780.jpg"
89+
}
90+
}, {
91+
"type": "text",
92+
"text": "A fine 19th century piece of furniture."
93+
}]
94+
}]
95+
96+
with vllm_runner(model_name="google/gemma-3-4b-it",
97+
runner="pooling",
98+
task="classify",
99+
convert="classify",
100+
load_format="auto",
101+
hf_overrides=update_config,
102+
override_pooler_config={"pooling_type": "LAST"},
103+
max_model_len=512,
104+
enforce_eager=True,
105+
tensor_parallel_size=1,
106+
disable_log_stats=True,
107+
dtype="bfloat16") as vllm_model:
108+
109+
llm = vllm_model.get_llm()
110+
prompts = llm.preprocess_chat(messages)
111+
112+
result = llm.classify(prompts)
113+
assert result[0].outputs.probs[0] > 0.95
114+
assert all(c < 0.05 for c in result[0].outputs.probs[1:])

vllm/entrypoints/llm.py

Lines changed: 90 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -703,13 +703,10 @@ def create_tokens_prompt_from_beam(
703703

704704
return outputs
705705

706-
def chat(
706+
def preprocess_chat(
707707
self,
708708
messages: Union[list[ChatCompletionMessageParam],
709709
list[list[ChatCompletionMessageParam]]],
710-
sampling_params: Optional[Union[SamplingParams,
711-
list[SamplingParams]]] = None,
712-
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
713710
lora_request: Optional[LoRARequest] = None,
714711
chat_template: Optional[str] = None,
715712
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
@@ -718,56 +715,16 @@ def chat(
718715
tools: Optional[list[dict[str, Any]]] = None,
719716
chat_template_kwargs: Optional[dict[str, Any]] = None,
720717
mm_processor_kwargs: Optional[dict[str, Any]] = None,
721-
) -> list[RequestOutput]:
718+
) -> list[TokensPrompt]:
722719
"""
723-
Generate responses for a chat conversation.
724-
725-
The chat conversation is converted into a text prompt using the
726-
tokenizer and calls the [generate][vllm.LLM.generate] method to generate
727-
the responses.
728-
729-
Multi-modal inputs can be passed in the same way you would pass them
730-
to the OpenAI API.
731-
732-
Args:
733-
messages: A list of conversations or a single conversation.
734-
735-
- Each conversation is represented as a list of messages.
736-
- Each message is a dictionary with 'role' and 'content' keys.
737-
738-
sampling_params: The sampling parameters for text generation.
739-
If None, we use the default sampling parameters. When it
740-
is a single value, it is applied to every prompt. When it
741-
is a list, the list must have the same length as the
742-
prompts and it is paired one by one with the prompt.
743-
use_tqdm: If `True`, shows a tqdm progress bar.
744-
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
745-
it is used to create the progress bar.
746-
If `False`, no progress bar is created.
747-
lora_request: LoRA request to use for generation, if any.
748-
chat_template: The template to use for structuring the chat.
749-
If not provided, the model's default chat template will be used.
750-
chat_template_content_format: The format to render message content.
751-
752-
- "string" will render the content as a string.
753-
Example: `"Who are you?"`
754-
- "openai" will render the content as a list of dictionaries,
755-
similar to OpenAI schema.
756-
Example: `[{"type": "text", "text": "Who are you?"}]`
757-
758-
add_generation_prompt: If True, adds a generation template
759-
to each message.
760-
continue_final_message: If True, continues the final message in
761-
the conversation instead of starting a new one. Cannot be
762-
`True` if `add_generation_prompt` is also `True`.
763-
chat_template_kwargs: Additional kwargs to pass to the chat
764-
template.
765-
mm_processor_kwargs: Multimodal processor kwarg overrides for this
766-
chat request. Only used for offline requests.
720+
Generate prompt for a chat conversation. The pre-processed
721+
prompt can then be used as input for the other LLM methods.
767722
723+
Refer to `chat` for a complete description of the arguments.
768724
Returns:
769-
A list of `RequestOutput` objects containing the generated
770-
responses in the same order as the input messages.
725+
A list of `TokensPrompts` objects containing the tokenized
726+
prompt after chat template interpolation, and the
727+
pre-processed multi-modal inputs.
771728
"""
772729
list_of_messages: list[list[ChatCompletionMessageParam]]
773730

@@ -800,7 +757,7 @@ def chat(
800757
)
801758
_chat_template_kwargs.update(chat_template_kwargs or {})
802759

803-
prompts: list[Union[TokensPrompt, TextPrompt]] = []
760+
prompts: list[TokensPrompt] = []
804761

805762
for msgs in list_of_messages:
806763
# NOTE: _parse_chat_message_content_parts() currently doesn't
@@ -844,6 +801,87 @@ def chat(
844801

845802
prompts.append(prompt)
846803

804+
return prompts
805+
806+
def chat(
807+
self,
808+
messages: Union[list[ChatCompletionMessageParam],
809+
list[list[ChatCompletionMessageParam]]],
810+
sampling_params: Optional[Union[SamplingParams,
811+
list[SamplingParams]]] = None,
812+
use_tqdm: Union[bool, Callable[..., tqdm]] = True,
813+
lora_request: Optional[LoRARequest] = None,
814+
chat_template: Optional[str] = None,
815+
chat_template_content_format: ChatTemplateContentFormatOption = "auto",
816+
add_generation_prompt: bool = True,
817+
continue_final_message: bool = False,
818+
tools: Optional[list[dict[str, Any]]] = None,
819+
chat_template_kwargs: Optional[dict[str, Any]] = None,
820+
mm_processor_kwargs: Optional[dict[str, Any]] = None,
821+
) -> list[RequestOutput]:
822+
"""
823+
Generate responses for a chat conversation.
824+
825+
The chat conversation is converted into a text prompt using the
826+
tokenizer and calls the [generate][vllm.LLM.generate] method to generate
827+
the responses.
828+
829+
Multi-modal inputs can be passed in the same way you would pass them
830+
to the OpenAI API.
831+
832+
Args:
833+
messages: A list of conversations or a single conversation.
834+
835+
- Each conversation is represented as a list of messages.
836+
- Each message is a dictionary with 'role' and 'content' keys.
837+
838+
sampling_params: The sampling parameters for text generation.
839+
If None, we use the default sampling parameters. When it
840+
is a single value, it is applied to every prompt. When it
841+
is a list, the list must have the same length as the
842+
prompts and it is paired one by one with the prompt.
843+
use_tqdm: If `True`, shows a tqdm progress bar.
844+
If a callable (e.g., `functools.partial(tqdm, leave=False)`),
845+
it is used to create the progress bar.
846+
If `False`, no progress bar is created.
847+
lora_request: LoRA request to use for generation, if any.
848+
chat_template: The template to use for structuring the chat.
849+
If not provided, the model's default chat template will be used.
850+
chat_template_content_format: The format to render message content.
851+
852+
- "string" will render the content as a string.
853+
Example: `"Who are you?"`
854+
- "openai" will render the content as a list of dictionaries,
855+
similar to OpenAI schema.
856+
Example: `[{"type": "text", "text": "Who are you?"}]`
857+
858+
add_generation_prompt: If True, adds a generation template
859+
to each message.
860+
continue_final_message: If True, continues the final message in
861+
the conversation instead of starting a new one. Cannot be
862+
`True` if `add_generation_prompt` is also `True`.
863+
chat_template_kwargs: Additional kwargs to pass to the chat
864+
template.
865+
mm_processor_kwargs: Multimodal processor kwarg overrides for this
866+
chat request. Only used for offline requests.
867+
868+
Returns:
869+
A list of `RequestOutput` objects containing the generated
870+
responses in the same order as the input messages.
871+
"""
872+
873+
prompts = self.preprocess_chat(
874+
messages=messages,
875+
lora_request=lora_request,
876+
chat_template=chat_template,
877+
chat_template_content_format=chat_template_content_format,
878+
add_generation_prompt=add_generation_prompt,
879+
continue_final_message=continue_final_message,
880+
tools=tools,
881+
chat_template_kwargs=chat_template_kwargs,
882+
mm_processor_kwargs=mm_processor_kwargs,
883+
)
884+
847885
return self.generate(
848886
prompts,
849887
sampling_params=sampling_params,

vllm/model_executor/model_loader/utils.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
2020
from vllm.model_executor.layers.quantization.base_config import (
2121
QuantizationConfig, QuantizeMethodBase)
22-
from vllm.model_executor.models.adapters import (as_embedding_model,
23-
as_reward_model,
24-
as_seq_cls_model)
25-
from vllm.model_executor.models.interfaces import SupportsQuant
22+
from vllm.model_executor.models.adapters import (
23+
as_embedding_model, as_reward_model, as_seq_cls_model,
24+
try_create_mm_pooling_model_cls)
25+
from vllm.model_executor.models.interfaces import (SupportsQuant,
26+
supports_multimodal)
2627
from vllm.utils import is_pin_memory_available
2728

2829
logger = init_logger(__name__)
@@ -183,6 +184,15 @@ def get_model_architecture(
183184
"performance may not be optimal.", arch)
184185

185186
convert_type = model_config.convert_type
187+
if convert_type != "none" and supports_multimodal(model_cls):
188+
logger.debug_once("Detected conversion of Multi Modal model.")
189+
converted = try_create_mm_pooling_model_cls(model_cls)
190+
if converted is not None:
191+
logger.debug_once("Creating wrapper class to forward pooler.")
192+
return converted, arch
193+
else:
194+
logger.debug_once("Attempting direct conversion.")
195+
186196
if convert_type == "none":
187197
pass
188198
elif convert_type == "embed":

0 commit comments

Comments
 (0)