Skip to content

Commit 08b751b

Browse files
Roger WangandyxningtjtanaaqthequartermastermanEdwardf0t1
authored
Implicit language-model-only mode via limit-mm-per-prompt (#22299)
Signed-off-by: Roger Wang <[email protected]> Signed-off-by: Andy Xie <[email protected]> Signed-off-by: tjtanaa <[email protected]> Signed-off-by: Andrew Sansom <[email protected]> Signed-off-by: Zhiyu Cheng <[email protected]> Signed-off-by: Shu Wang <[email protected]> Signed-off-by: Po-Han Huang <[email protected]> Signed-off-by: Shu Wang. <[email protected]> Signed-off-by: XIn Li <[email protected]> Signed-off-by: Junhao Li <[email protected]> Signed-off-by: chaunceyjiang <[email protected]> Signed-off-by: zRzRzRzRzRzRzR <[email protected]> Signed-off-by: zitian.zhao <[email protected]> Signed-off-by: zitian zhao <[email protected]> Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: iAmir97 <[email protected]> Signed-off-by: iAmir97 <[email protected]> Signed-off-by: Linkun <[email protected]> Co-authored-by: Ning Xie <[email protected]> Co-authored-by: TJian <[email protected]> Co-authored-by: Andrew Sansom <[email protected]> Co-authored-by: Zhiyu <[email protected]> Co-authored-by: Shu Wang <[email protected]> Co-authored-by: XIn Li <[email protected]> Co-authored-by: Junhao Li <[email protected]> Co-authored-by: Chauncey <[email protected]> Co-authored-by: Yuxuan Zhang <[email protected]> Co-authored-by: ZiTian Zhao <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Po-Han Huang (NVIDIA) <[email protected]> Co-authored-by: iAmir97 <[email protected]> Co-authored-by: iAmir97 <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Hong Hanh <[email protected]> Co-authored-by: youkaichao <[email protected]> Co-authored-by: lkchen <[email protected]>
1 parent 429e4e2 commit 08b751b

File tree

16 files changed

+271
-116
lines changed

16 files changed

+271
-116
lines changed

tests/multimodal/test_registry.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Unit tests for MultiModalRegistry.supports_multimodal_inputs and
5+
Qwen2.5-VL visual component loading behavior.
6+
"""
7+
8+
import pytest
9+
10+
from vllm.multimodal import MULTIMODAL_REGISTRY
11+
12+
from ..models.utils import build_model_context
13+
14+
15+
@pytest.mark.parametrize(
16+
"model_id,limit_mm_per_prompt,expected",
17+
[
18+
("Qwen/Qwen2-0.5B-Instruct", {}, False),
19+
("Qwen/Qwen2.5-VL-3B-Instruct", {}, True),
20+
("Qwen/Qwen2.5-VL-3B-Instruct", {
21+
"image": 0,
22+
"video": 0
23+
}, False),
24+
("Qwen/Qwen2.5-VL-3B-Instruct", {
25+
"image": 0
26+
}, True),
27+
],
28+
)
29+
@pytest.mark.core_model
30+
def test_supports_multimodal_inputs(model_id, limit_mm_per_prompt, expected):
31+
"""Test supports_multimodal_inputs returns correct boolean for various
32+
configs."""
33+
ctx = build_model_context(
34+
model_id,
35+
limit_mm_per_prompt=limit_mm_per_prompt,
36+
)
37+
assert MULTIMODAL_REGISTRY.supports_multimodal_inputs(
38+
ctx.model_config) is expected

vllm/config/__init__.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,15 +1695,6 @@ def enable_mm_processor_cache(self) -> bool:
16951695

16961696
return mm_config.mm_processor_cache_gb > 0
16971697

1698-
@property
1699-
def enable_mm_input_cache(self) -> bool:
1700-
"""Whether the multi-modal input cache should be enabled."""
1701-
mm_config = self.multimodal_config
1702-
if mm_config is None:
1703-
return False
1704-
1705-
return mm_config.mm_processor_cache_gb > 0
1706-
17071698
def get_mm_input_cache_gb(self) -> int:
17081699
mm_config = self.multimodal_config
17091700
if mm_config is None:

vllm/model_executor/models/llava.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -521,18 +521,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
521521
config.projector_hidden_act = "gelu"
522522

523523
# TODO: Optionally initializes this for supporting embeddings.
524-
self.vision_tower = init_vision_tower_for_llava(
525-
config,
526-
quant_config,
527-
require_post_norm=False,
528-
prefix=maybe_prefix(prefix, "vision_tower"))
529-
self.multi_modal_projector = LlavaMultiModalProjector(
530-
vision_hidden_size=config.vision_config.hidden_size,
531-
text_hidden_size=config.text_config.hidden_size,
532-
projector_hidden_act=config.projector_hidden_act,
533-
multimodal_projector_bias=config.multimodal_projector_bias,
534-
quant_config=quant_config,
535-
prefix=maybe_prefix(prefix, "multi_modal_projector"))
524+
if multimodal_config.get_limit_per_prompt("image"):
525+
self.vision_tower = init_vision_tower_for_llava(
526+
config,
527+
quant_config,
528+
require_post_norm=False,
529+
prefix=maybe_prefix(prefix, "vision_tower"))
530+
self.multi_modal_projector = LlavaMultiModalProjector(
531+
vision_hidden_size=config.vision_config.hidden_size,
532+
text_hidden_size=config.text_config.hidden_size,
533+
projector_hidden_act=config.projector_hidden_act,
534+
multimodal_projector_bias=config.multimodal_projector_bias,
535+
quant_config=quant_config,
536+
prefix=maybe_prefix(prefix, "multi_modal_projector"))
537+
else:
538+
self.vision_tower = None
539+
self.multi_modal_projector = None
536540

537541
self.language_model = init_vllm_registered_model(
538542
vllm_config=vllm_config,
@@ -756,7 +760,11 @@ def compute_logits(
756760

757761
def load_weights(self, weights: Iterable[tuple[str,
758762
torch.Tensor]]) -> set[str]:
759-
loader = AutoWeightsLoader(self)
763+
skip_prefixes = []
764+
if self.vision_tower is None and self.multi_modal_projector is None:
765+
skip_prefixes.extend(["vision_tower.", "multi_modal_projector."])
766+
767+
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
760768
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
761769

762770

vllm/model_executor/models/mistral3.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -428,20 +428,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
428428
config.projector_hidden_act = "gelu"
429429

430430
# TODO: Optionally initializes this for supporting embeddings.
431-
self.vision_tower = init_vision_tower_for_llava(
432-
config,
433-
quant_config,
434-
require_post_norm=False,
435-
prefix=maybe_prefix(prefix, "vision_tower"))
436-
self.multi_modal_projector = Mistral3MultiModalProjector(
437-
vision_hidden_size=config.vision_config.hidden_size,
438-
text_hidden_size=config.text_config.hidden_size,
439-
projector_hidden_act=config.projector_hidden_act,
440-
spatial_merge_size=config.spatial_merge_size,
441-
patch_size=config.vision_config.patch_size,
442-
multimodal_projector_bias=config.multimodal_projector_bias,
443-
quant_config=quant_config,
444-
prefix=maybe_prefix(prefix, "multi_modal_projector"))
431+
if multimodal_config.get_limit_per_prompt("image"):
432+
self.vision_tower = init_vision_tower_for_llava(
433+
config,
434+
quant_config,
435+
require_post_norm=False,
436+
prefix=maybe_prefix(prefix, "vision_tower"))
437+
self.multi_modal_projector = Mistral3MultiModalProjector(
438+
vision_hidden_size=config.vision_config.hidden_size,
439+
text_hidden_size=config.text_config.hidden_size,
440+
projector_hidden_act=config.projector_hidden_act,
441+
spatial_merge_size=config.spatial_merge_size,
442+
patch_size=config.vision_config.patch_size,
443+
multimodal_projector_bias=config.multimodal_projector_bias,
444+
quant_config=quant_config,
445+
prefix=maybe_prefix(prefix, "multi_modal_projector"))
446+
else:
447+
self.vision_tower = None
448+
self.multi_modal_projector = None
445449

446450
self.language_model = init_vllm_registered_model(
447451
vllm_config=vllm_config,
@@ -611,7 +615,11 @@ def compute_logits(
611615

612616
def load_weights(self, weights: Iterable[tuple[str,
613617
torch.Tensor]]) -> set[str]:
614-
loader = AutoWeightsLoader(self)
618+
skip_prefixes = []
619+
if self.vision_tower is None and self.multi_modal_projector is None:
620+
skip_prefixes = ["vision_tower.", "multi_modal_projector."]
621+
622+
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
615623
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
616624

617625
def get_mm_mapping(self) -> MultiModelKeys:

vllm/model_executor/models/mllama4.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -737,16 +737,20 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
737737
self.config = config
738738
self.quant_config = quant_config
739739
self.multimodal_config = multimodal_config
740-
self.vision_model = Llama4VisionModel(
741-
config.vision_config,
742-
None,
743-
prefix=maybe_prefix(prefix, "vision_model"),
744-
use_data_parallel=self.use_data_parallel,
745-
)
746-
self.multi_modal_projector = Llama4MultiModalProjector(
747-
self.config,
748-
None,
749-
prefix=maybe_prefix(prefix, "multi_modal_projector"))
740+
if multimodal_config.get_limit_per_prompt("image"):
741+
self.vision_model = Llama4VisionModel(
742+
config.vision_config,
743+
None,
744+
prefix=maybe_prefix(prefix, "vision_model"),
745+
use_data_parallel=self.use_data_parallel,
746+
)
747+
self.multi_modal_projector = Llama4MultiModalProjector(
748+
self.config,
749+
None,
750+
prefix=maybe_prefix(prefix, "multi_modal_projector"))
751+
else:
752+
self.vision_model = None
753+
self.multi_modal_projector = None
750754
self.language_model = initialize_model(
751755
vllm_config=vllm_config.with_hf_config(config.text_config,
752756
["LlamaForCausalLM"]),
@@ -783,6 +787,8 @@ def _parse_and_validate_image_input(
783787

784788
def _process_image_input(
785789
self, image_input: Llama4ImagePatchInputs) -> MultiModalEmbeddings:
790+
791+
assert self.vision_model and self.multi_modal_projector
786792
flat_data = image_input["flat_data"]
787793
patches_per_image = image_input["patches_per_image"].tolist()
788794

@@ -1048,6 +1054,10 @@ def load_weights(self, weights: Iterable[tuple[str,
10481054
language_model_weights, other_weights = (
10491055
self._separate_and_rename_weights(weights))
10501056

1057+
# Skip loading vision model and projector if they're not initialized.
1058+
if self.vision_model is None and self.multi_modal_projector is None:
1059+
other_weights = []
1060+
10511061
# Handle expert scale parameters
10521062
regular_weights, expert_scale_weights, updated_params_from_experts = (
10531063
self._handle_expert_scale_broadcasting(language_model_weights,

vllm/model_executor/models/qwen2_5_omni_thinker.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -722,13 +722,24 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
722722
"exactly same result as the transformers implementation "
723723
"in the audio tower part.")
724724

725-
self.audio_tower = Qwen2_5OmniAudioEncoder(thinker_config.audio_config)
726-
self.visual = Qwen2_5_VisionTransformer(
727-
vision_config=thinker_config.vision_config,
728-
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps", 1e-6),
729-
quant_config=quant_config,
730-
prefix=maybe_prefix(prefix, "visual"),
731-
)
725+
if multimodal_config.get_limit_per_prompt("audio"):
726+
self.audio_tower = Qwen2_5OmniAudioEncoder(
727+
thinker_config.audio_config)
728+
else:
729+
self.audio_tower = None
730+
731+
if multimodal_config.get_limit_per_prompt(
732+
"image") or multimodal_config.get_limit_per_prompt("video"):
733+
self.visual = Qwen2_5_VisionTransformer(
734+
vision_config=thinker_config.vision_config,
735+
norm_eps=getattr(thinker_config.text_config, "rms_norm_eps",
736+
1e-6),
737+
quant_config=quant_config,
738+
prefix=maybe_prefix(prefix, "visual"),
739+
)
740+
else:
741+
self.visual = None
742+
732743
self.quant_config = quant_config
733744
self.language_model = init_vllm_registered_model(
734745
vllm_config=vllm_config,
@@ -886,9 +897,15 @@ def compute_logits(
886897

887898
def load_weights(self, weights: Iterable[tuple[str,
888899
torch.Tensor]]) -> set[str]:
900+
skip_prefixes = ["talker.", "token2wav."]
901+
if self.audio_tower is None:
902+
skip_prefixes.extend(["audio_tower."])
903+
if self.visual is None:
904+
skip_prefixes.extend(["visual."])
905+
889906
loader = AutoWeightsLoader(
890907
self,
891-
skip_prefixes=["talker.", "token2wav."],
908+
skip_prefixes=skip_prefixes,
892909
)
893910
loaded_weights = loader.load_weights(weights,
894911
mapper=self.hf_to_vllm_mapper)

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -843,12 +843,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
843843
self.config = config
844844
self.multimodal_config = multimodal_config
845845

846-
self.visual = Qwen2_5_VisionTransformer(
847-
config.vision_config,
848-
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
849-
quant_config=self._maybe_ignore_quant_config(self.quant_config),
850-
prefix=maybe_prefix(prefix, "visual"),
851-
)
846+
if multimodal_config.get_limit_per_prompt("image") or \
847+
multimodal_config.get_limit_per_prompt("video"):
848+
self.visual = Qwen2_5_VisionTransformer(
849+
config.vision_config,
850+
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
851+
quant_config=self._maybe_ignore_quant_config(
852+
self.quant_config),
853+
prefix=maybe_prefix(prefix, "visual"),
854+
)
855+
else:
856+
self.visual = None
852857

853858
self.language_model = init_vllm_registered_model(
854859
vllm_config=vllm_config,
@@ -1152,7 +1157,10 @@ def compute_logits(
11521157
def load_weights(self, weights: Iterable[tuple[str,
11531158
torch.Tensor]]) -> set[str]:
11541159

1155-
loader = AutoWeightsLoader(self)
1160+
skip_prefixes = []
1161+
if self.visual is None:
1162+
skip_prefixes.extend(["visual."])
1163+
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
11561164
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
11571165

11581166
def get_mm_mapping(self) -> MultiModelKeys:

vllm/model_executor/models/qwen2_vl.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,12 +1049,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
10491049
self.config = config
10501050
self.multimodal_config = multimodal_config
10511051

1052-
self.visual = Qwen2VisionTransformer(
1053-
config.vision_config,
1054-
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1055-
quant_config=self._maybe_ignore_quant_config(quant_config),
1056-
prefix=maybe_prefix(prefix, "visual"),
1057-
)
1052+
if multimodal_config.get_limit_per_prompt("image") or \
1053+
multimodal_config.get_limit_per_prompt("video"):
1054+
self.visual = Qwen2VisionTransformer(
1055+
config.vision_config,
1056+
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
1057+
quant_config=self._maybe_ignore_quant_config(quant_config),
1058+
prefix=maybe_prefix(prefix, "visual"),
1059+
)
1060+
else:
1061+
self.visual = None
10581062

10591063
self.language_model = init_vllm_registered_model(
10601064
vllm_config=vllm_config,
@@ -1350,7 +1354,10 @@ def compute_logits(
13501354
def load_weights(self, weights: Iterable[tuple[str,
13511355
torch.Tensor]]) -> set[str]:
13521356

1353-
loader = AutoWeightsLoader(self)
1357+
skip_prefixes = []
1358+
if self.visual is None:
1359+
skip_prefixes.extend(["visual."])
1360+
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
13541361
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
13551362

13561363
def get_mm_mapping(self) -> MultiModelKeys:
@@ -1445,5 +1452,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
14451452
def load_weights(self, weights: Iterable[tuple[str,
14461453
torch.Tensor]]) -> set[str]:
14471454

1448-
loader = AutoWeightsLoader(self)
1455+
skip_prefixes = []
1456+
if self.visual is None:
1457+
skip_prefixes.extend(["visual."])
1458+
loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
14491459
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

0 commit comments

Comments
 (0)