68
68
from vllm .multimodal .profiling import BaseDummyInputsBuilder
69
69
from vllm .sequence import IntermediateTensors
70
70
71
- from .interfaces import (
72
- MultiModalEmbeddings ,
73
- SupportsLoRA ,
74
- SupportsMultiModal ,
75
- SupportsPP ,
76
- SupportsQuant ,
77
- )
71
+ from .interfaces import SupportsLoRA , SupportsMultiModal , SupportsPP , SupportsQuant
78
72
from .utils import (
79
73
AutoWeightsLoader ,
80
74
PPMissingLayer ,
@@ -534,7 +528,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
534
528
self .attention_instances = self .create_attention_instances ()
535
529
536
530
# Input embeddings
537
- if not isinstance (self .model .get_input_embeddings (), PPMissingLayer ):
531
+ input_embeddings = self .model .get_input_embeddings ()
532
+ if not isinstance (input_embeddings , PPMissingLayer ):
533
+ # Some models use embedding scales
534
+ self .embed_scale = getattr (input_embeddings , "embed_scale" , None )
538
535
names = ("embedding_size" , "hidden_size" )
539
536
embedding_dim = getattr_iter (self .text_config , names , None )
540
537
assert embedding_dim is not None
@@ -671,6 +668,7 @@ def create_attention_instances(
671
668
num_heads = self .model_config .get_num_attention_heads (self .parallel_config )
672
669
head_size = self .model_config .get_head_size ()
673
670
num_kv_heads = self .model_config .get_num_kv_heads (self .parallel_config )
671
+ logits_soft_cap = getattr (self .text_config , "attn_logit_softcapping" , None )
674
672
start , end = get_pp_indices (
675
673
self .text_config .num_hidden_layers ,
676
674
self .pp_group .rank_in_group ,
@@ -696,6 +694,7 @@ def create_attention_instances(
696
694
num_kv_heads = num_kv_heads ,
697
695
cache_config = self .cache_config ,
698
696
quant_config = self .quant_config ,
697
+ logits_soft_cap = logits_soft_cap ,
699
698
per_layer_sliding_window = per_layer_sliding_window ,
700
699
prefix = f"{ i } .attn" ,
701
700
attn_type = attn_type ,
@@ -735,6 +734,7 @@ def forward(
735
734
positions : torch .Tensor ,
736
735
intermediate_tensors : Optional [IntermediateTensors ] = None ,
737
736
inputs_embeds : Optional [torch .Tensor ] = None ,
737
+ ** kwargs ,
738
738
) -> Union [torch .Tensor , IntermediateTensors ]:
739
739
if not self .pp_group .is_first_rank :
740
740
assert intermediate_tensors is not None
@@ -758,6 +758,7 @@ def forward(
758
758
position_ids = position_ids ,
759
759
attention_instances = self .attention_instances ,
760
760
return_dict = False ,
761
+ ** kwargs ,
761
762
)[0 ][0 , ...] # we remove batch dimension for now
762
763
763
764
if not self .pp_group .is_last_rank :
@@ -819,7 +820,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
819
820
self .lm_head = PPMissingLayer ()
820
821
821
822
def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
822
- return self .model .get_input_embeddings ()(input_ids )
823
+ inputs_embeds = self .model .get_input_embeddings ()(input_ids )
824
+ if self .embed_scale is not None :
825
+ inputs_embeds *= self .embed_scale
826
+ return inputs_embeds
823
827
824
828
def compute_logits (
825
829
self ,
@@ -845,6 +849,7 @@ def compute_logits(
845
849
enable_if = can_enable_torch_compile ,
846
850
)
847
851
class TransformersForMultimodalLM (TransformersForCausalLM , SupportsMultiModal ):
852
+ supports_multimodal_raw_input_only = True
848
853
merge_by_field_config = True
849
854
# Backwards compatibility for prev released models. State dicts back then
850
855
# had different formats and cannot be loaded with `AutoModel` mapping as is
@@ -883,13 +888,27 @@ def forward(
883
888
inputs_embeds : Optional [torch .Tensor ] = None ,
884
889
** kwargs : object ,
885
890
) -> Union [torch .Tensor , IntermediateTensors ]:
891
+ # Gemma3 and PaliGemma needs `token_type_ids` to work correctly
892
+ # Other models will not have `token_type_ids` in kwargs
893
+ kwargs = {k : v for k , v in kwargs .items () if k == "token_type_ids" }
886
894
model_output = super ().forward (
887
- input_ids , positions , intermediate_tensors , inputs_embeds
895
+ input_ids , positions , intermediate_tensors , inputs_embeds , ** kwargs
888
896
)
889
897
return model_output
890
898
891
899
def get_language_model (self ) -> torch .nn .Module :
892
- return self .model
900
+ """`TransformersForMultimodalLM` does not contain a vLLM language model class.
901
+ Therefore, in order to return a language model vLLM class, we use a wrapper to
902
+ give `self` the same interface as `TransformersForCausalLM`."""
903
+
904
+ class LanguageModelWrapper (TransformersForCausalLM ):
905
+ def __init__ (self , multimodal_model ):
906
+ # Don't call super().__init__() to avoid re-initialization
907
+ self .__dict__ .update (multimodal_model .__dict__ )
908
+
909
+ model = getattr_iter (self .model , ("language_model" , "text_model" ), None )
910
+
911
+ return LanguageModelWrapper (self )
893
912
894
913
def get_multimodal_embeddings (self , ** kwargs ):
895
914
pixel_values : Optional [torch .Tensor ] = kwargs .pop ("pixel_values" , None )
@@ -905,6 +924,7 @@ def get_multimodal_embeddings(self, **kwargs):
905
924
return None
906
925
907
926
num_image_patches = kwargs .pop ("num_image_patches" )
927
+ kwargs .pop ("token_type_ids" , None ) # used only in `forward`
908
928
if pixel_values is not None :
909
929
vision_embeddings = self .model .get_image_features (pixel_values , ** kwargs )
910
930
@@ -925,46 +945,4 @@ def get_multimodal_embeddings(self, **kwargs):
925
945
926
946
return vision_embeddings
927
947
928
- def get_input_embeddings (
929
- self ,
930
- input_ids : torch .Tensor ,
931
- multimodal_embeddings : Optional [MultiModalEmbeddings ] = None ,
932
- * ,
933
- is_multimodal : Optional [torch .Tensor ] = None ,
934
- handle_oov_mm_token : bool = False ,
935
- ) -> torch .Tensor :
936
- """
937
- Apply token embeddings to `input_ids`.
938
-
939
- If `multimodal_embeddings` is passed, scatter them into
940
- `input_ids` according to the mask `is_multimodal`.
941
-
942
- In case the multi-modal token IDs exceed the vocabulary size of
943
- the language model, you can set `handle_oov_mm_token=False`
944
- to avoid calling the language model's `get_input_embeddings` method
945
- on those tokens.
946
- """
947
- from .utils import _merge_multimodal_embeddings
948
-
949
- inputs_embeds = self ._get_text_embeddings (
950
- input_ids ,
951
- self .model .get_input_embeddings (),
952
- is_multimodal = is_multimodal ,
953
- handle_oov_mm_token = handle_oov_mm_token ,
954
- )
955
-
956
- if multimodal_embeddings is None or len (multimodal_embeddings ) == 0 :
957
- return inputs_embeds
958
-
959
- if is_multimodal is None :
960
- raise ValueError (
961
- "`get_input_embeddings` now requires `is_multimodal` arg, "
962
- "please update your model runner according to "
963
- "https://github.com/vllm-project/vllm/pull/16229."
964
- )
965
-
966
- return _merge_multimodal_embeddings (
967
- inputs_embeds = inputs_embeds ,
968
- multimodal_embeddings = multimodal_embeddings ,
969
- is_multimodal = is_multimodal ,
970
- )
948
+ get_input_embeddings = SupportsMultiModal .get_input_embeddings
0 commit comments