Skip to content

Commit 7cd95dc

Browse files
[Bugfix] Fix gemma3 with transformers backend (#23178)
Signed-off-by: raushan <[email protected]> Signed-off-by: Raushan Turganbay <[email protected]> Signed-off-by: Harry Mellor <[email protected]> Co-authored-by: Harry Mellor <[email protected]>
1 parent c02058c commit 7cd95dc

File tree

4 files changed

+72
-59
lines changed

4 files changed

+72
-59
lines changed

tests/models/multimodal/generation/test_common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,20 @@
193193
# when processing the 3rd prompt in vLLM
194194
marks=[pytest.mark.core_model, pytest.mark.skip(reason="Test hangs")],
195195
),
196+
# Gemma3 has bidirectional mask on images
197+
"gemma3-transformers": VLMTestInfo(
198+
models=["google/gemma-3-4b-it"],
199+
test_type=VLMTestType.IMAGE,
200+
prompt_formatter=lambda vid_prompt: f"<'<bos><start_of_turn>user\n{vid_prompt}<start_of_image><end_of_turn>\n<start_of_turn>model\n", # noqa: E501
201+
max_model_len=4096,
202+
auto_cls=AutoModelForImageTextToText,
203+
vllm_output_post_proc=model_utils.gemma3_vllm_to_hf_output,
204+
image_size_factors=[(0.25, 0.5, 1.0)],
205+
vllm_runner_kwargs={
206+
"model_impl": "transformers",
207+
},
208+
marks=[pytest.mark.core_model],
209+
),
196210
"idefics3-transformers": VLMTestInfo(
197211
models=["HuggingFaceTB/SmolVLM-256M-Instruct"],
198212
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),

tests/models/multimodal/generation/vlm_utils/model_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,29 @@ def _generate(self, *args, **kwargs):
342342
return hf_model
343343

344344

345+
def gemma3_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput:
346+
"""Sanitize vllm output [gemma-3] to compare with hf output."""
347+
output_ids, output_str, out_logprobs = vllm_output
348+
349+
config = AutoConfig.from_pretrained(model)
350+
image_token_id = config.image_token_id
351+
352+
tokenizer = AutoTokenizer.from_pretrained(model)
353+
eos_token_id = tokenizer.eos_token_id
354+
355+
hf_output_ids = [
356+
token_id
357+
for idx, token_id in enumerate(output_ids)
358+
if token_id != image_token_id
359+
]
360+
361+
hf_output_str = output_str
362+
if hf_output_ids[-1] == eos_token_id:
363+
hf_output_str = hf_output_str + tokenizer.decode(eos_token_id)
364+
365+
return hf_output_ids, hf_output_str, out_logprobs
366+
367+
345368
def glm4v_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
346369
"""Patches and returns an instance of the HfRunner to use for GLM4V."""
347370
hf_processor = hf_model.processor

vllm/model_executor/models/transformers.py

Lines changed: 32 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,7 @@
6868
from vllm.multimodal.profiling import BaseDummyInputsBuilder
6969
from vllm.sequence import IntermediateTensors
7070

71-
from .interfaces import (
72-
MultiModalEmbeddings,
73-
SupportsLoRA,
74-
SupportsMultiModal,
75-
SupportsPP,
76-
SupportsQuant,
77-
)
71+
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP, SupportsQuant
7872
from .utils import (
7973
AutoWeightsLoader,
8074
PPMissingLayer,
@@ -534,7 +528,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
534528
self.attention_instances = self.create_attention_instances()
535529

536530
# Input embeddings
537-
if not isinstance(self.model.get_input_embeddings(), PPMissingLayer):
531+
input_embeddings = self.model.get_input_embeddings()
532+
if not isinstance(input_embeddings, PPMissingLayer):
533+
# Some models use embedding scales
534+
self.embed_scale = getattr(input_embeddings, "embed_scale", None)
538535
names = ("embedding_size", "hidden_size")
539536
embedding_dim = getattr_iter(self.text_config, names, None)
540537
assert embedding_dim is not None
@@ -671,6 +668,7 @@ def create_attention_instances(
671668
num_heads = self.model_config.get_num_attention_heads(self.parallel_config)
672669
head_size = self.model_config.get_head_size()
673670
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
671+
logits_soft_cap = getattr(self.text_config, "attn_logit_softcapping", None)
674672
start, end = get_pp_indices(
675673
self.text_config.num_hidden_layers,
676674
self.pp_group.rank_in_group,
@@ -696,6 +694,7 @@ def create_attention_instances(
696694
num_kv_heads=num_kv_heads,
697695
cache_config=self.cache_config,
698696
quant_config=self.quant_config,
697+
logits_soft_cap=logits_soft_cap,
699698
per_layer_sliding_window=per_layer_sliding_window,
700699
prefix=f"{i}.attn",
701700
attn_type=attn_type,
@@ -735,6 +734,7 @@ def forward(
735734
positions: torch.Tensor,
736735
intermediate_tensors: Optional[IntermediateTensors] = None,
737736
inputs_embeds: Optional[torch.Tensor] = None,
737+
**kwargs,
738738
) -> Union[torch.Tensor, IntermediateTensors]:
739739
if not self.pp_group.is_first_rank:
740740
assert intermediate_tensors is not None
@@ -758,6 +758,7 @@ def forward(
758758
position_ids=position_ids,
759759
attention_instances=self.attention_instances,
760760
return_dict=False,
761+
**kwargs,
761762
)[0][0, ...] # we remove batch dimension for now
762763

763764
if not self.pp_group.is_last_rank:
@@ -819,7 +820,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
819820
self.lm_head = PPMissingLayer()
820821

821822
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
822-
return self.model.get_input_embeddings()(input_ids)
823+
inputs_embeds = self.model.get_input_embeddings()(input_ids)
824+
if self.embed_scale is not None:
825+
inputs_embeds *= self.embed_scale
826+
return inputs_embeds
823827

824828
def compute_logits(
825829
self,
@@ -845,6 +849,7 @@ def compute_logits(
845849
enable_if=can_enable_torch_compile,
846850
)
847851
class TransformersForMultimodalLM(TransformersForCausalLM, SupportsMultiModal):
852+
supports_multimodal_raw_input_only = True
848853
merge_by_field_config = True
849854
# Backwards compatibility for prev released models. State dicts back then
850855
# had different formats and cannot be loaded with `AutoModel` mapping as is
@@ -883,13 +888,27 @@ def forward(
883888
inputs_embeds: Optional[torch.Tensor] = None,
884889
**kwargs: object,
885890
) -> Union[torch.Tensor, IntermediateTensors]:
891+
# Gemma3 and PaliGemma needs `token_type_ids` to work correctly
892+
# Other models will not have `token_type_ids` in kwargs
893+
kwargs = {k: v for k, v in kwargs.items() if k == "token_type_ids"}
886894
model_output = super().forward(
887-
input_ids, positions, intermediate_tensors, inputs_embeds
895+
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
888896
)
889897
return model_output
890898

891899
def get_language_model(self) -> torch.nn.Module:
892-
return self.model
900+
"""`TransformersForMultimodalLM` does not contain a vLLM language model class.
901+
Therefore, in order to return a language model vLLM class, we use a wrapper to
902+
give `self` the same interface as `TransformersForCausalLM`."""
903+
904+
class LanguageModelWrapper(TransformersForCausalLM):
905+
def __init__(self, multimodal_model):
906+
# Don't call super().__init__() to avoid re-initialization
907+
self.__dict__.update(multimodal_model.__dict__)
908+
909+
model = getattr_iter(self.model, ("language_model", "text_model"), None)
910+
911+
return LanguageModelWrapper(self)
893912

894913
def get_multimodal_embeddings(self, **kwargs):
895914
pixel_values: Optional[torch.Tensor] = kwargs.pop("pixel_values", None)
@@ -905,6 +924,7 @@ def get_multimodal_embeddings(self, **kwargs):
905924
return None
906925

907926
num_image_patches = kwargs.pop("num_image_patches")
927+
kwargs.pop("token_type_ids", None) # used only in `forward`
908928
if pixel_values is not None:
909929
vision_embeddings = self.model.get_image_features(pixel_values, **kwargs)
910930

@@ -925,46 +945,4 @@ def get_multimodal_embeddings(self, **kwargs):
925945

926946
return vision_embeddings
927947

928-
def get_input_embeddings(
929-
self,
930-
input_ids: torch.Tensor,
931-
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
932-
*,
933-
is_multimodal: Optional[torch.Tensor] = None,
934-
handle_oov_mm_token: bool = False,
935-
) -> torch.Tensor:
936-
"""
937-
Apply token embeddings to `input_ids`.
938-
939-
If `multimodal_embeddings` is passed, scatter them into
940-
`input_ids` according to the mask `is_multimodal`.
941-
942-
In case the multi-modal token IDs exceed the vocabulary size of
943-
the language model, you can set `handle_oov_mm_token=False`
944-
to avoid calling the language model's `get_input_embeddings` method
945-
on those tokens.
946-
"""
947-
from .utils import _merge_multimodal_embeddings
948-
949-
inputs_embeds = self._get_text_embeddings(
950-
input_ids,
951-
self.model.get_input_embeddings(),
952-
is_multimodal=is_multimodal,
953-
handle_oov_mm_token=handle_oov_mm_token,
954-
)
955-
956-
if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
957-
return inputs_embeds
958-
959-
if is_multimodal is None:
960-
raise ValueError(
961-
"`get_input_embeddings` now requires `is_multimodal` arg, "
962-
"please update your model runner according to "
963-
"https://github.com/vllm-project/vllm/pull/16229."
964-
)
965-
966-
return _merge_multimodal_embeddings(
967-
inputs_embeds=inputs_embeds,
968-
multimodal_embeddings=multimodal_embeddings,
969-
is_multimodal=is_multimodal,
970-
)
948+
get_input_embeddings = SupportsMultiModal.get_input_embeddings

vllm/model_executor/models/transformers_moe.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from vllm.platforms import current_platform
3131
from vllm.utils import direct_register_custom_op
3232

33-
from .interfaces import MixtureOfExperts
33+
from .interfaces import MixtureOfExperts, SupportsMultiModal
3434
from .transformers import (
3535
TransformersBase,
3636
TransformersForCausalLM,
@@ -335,7 +335,5 @@ class TransformersMoEForCausalLM(TransformersMoEBase, TransformersForCausalLM):
335335
},
336336
enable_if=can_enable_torch_compile,
337337
)
338-
class TransformersMoEForMultimodalLM(
339-
TransformersMoEForCausalLM, TransformersForMultimodalLM
340-
):
341-
pass
338+
class TransformersMoEForMultimodalLM(TransformersMoEBase, TransformersForMultimodalLM):
339+
get_input_embeddings = SupportsMultiModal.get_input_embeddings

0 commit comments

Comments
 (0)