77import torch
88
99from torch import nn
10-
1110from transformers import AutoModel , BatchFeature
1211from transformers .models .gemma3n import (
1312 Gemma3nAudioConfig ,
2120
2221from vllm .config import ModelConfig , SpeechToTextConfig , VllmConfig
2322from vllm .config .multimodal import BaseDummyOptions
23+ from vllm .config .lora import LoRAConfig
2424from vllm .inputs .data import PromptType
2525from vllm .logger import init_logger
2626from vllm .model_executor .layers .layernorm import RMSNorm
5555from vllm .sequence import IntermediateTensors
5656from vllm .utils .tensor_schema import TensorSchema , TensorShape
5757
58- from .interfaces import (MultiModalEmbeddings , SupportsLoRA , SupportsMultiModal ,
59- SupportsTranscription )
58+ from .interfaces import (MultiModalEmbeddings , SupportsLoRA ,
59+ SupportsMultiModal , SupportsTranscription )
6060from .utils import (AutoWeightsLoader , WeightsMapper , flatten_bn ,
6161 init_vllm_registered_model , maybe_prefix )
6262
@@ -387,6 +387,7 @@ def __init__(
387387 self ,
388388 multimodal_config : Union [Gemma3nAudioConfig , Gemma3nVisionConfig ],
389389 text_config : Gemma3nTextConfig ,
390+ lora_config : Optional [LoRAConfig ] = None ,
390391 ):
391392 super ().__init__ ()
392393
@@ -396,9 +397,14 @@ def __init__(
396397 self .vocab_size = multimodal_config .vocab_size
397398 self .text_hidden_size = text_config .hidden_size
398399
400+ lora_vocab = (lora_config .lora_extra_vocab_size *
401+ (lora_config .max_loras or 1 )) if lora_config else 0
402+ self .vocab_size = self .vocab_size + lora_vocab
403+
399404 self .embedding = VocabParallelEmbedding (
400405 self .vocab_size ,
401406 self .multimodal_hidden_size ,
407+ org_num_embeddings = multimodal_config .vocab_size ,
402408 )
403409
404410 self .hard_embedding_norm = RMSNorm (
@@ -440,9 +446,7 @@ def forward(
440446 """ # noqa: E501
441447 if (input_ids is None ) ^ (inputs_embeds is not None ):
442448 raise ValueError (
443- "You must specify exactly one of input_ids or inputs_embeds"
444- )
445-
449+ "You must specify exactly one of input_ids or inputs_embeds" )
446450 if inputs_embeds is not None :
447451 emb_norm = self .soft_embedding_norm (inputs_embeds )
448452 else :
@@ -496,15 +500,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
496500 self .quant_config = quant_config
497501 self .multimodal_config = multimodal_config
498502 self .vocab_size = config .text_config .vocab_size
503+ self .lora_config = vllm_config .lora_config
499504
500505 self .vision_tower = AutoModel .from_config (config = config .vision_config )
501506 self .audio_tower = AutoModel .from_config (config = config .audio_config )
502- self .embed_vision = Gemma3nMultimodalEmbedder (
503- config . vision_config , config .text_config
504- )
505- self .embed_audio = Gemma3nMultimodalEmbedder (
506- config . audio_config , config .text_config
507- )
507+ self .embed_vision = Gemma3nMultimodalEmbedder (config . vision_config ,
508+ config .text_config ,
509+ self . lora_config )
510+ self .embed_audio = Gemma3nMultimodalEmbedder (config . audio_config ,
511+ config .text_config ,
512+ self . lora_config )
508513
509514 self .language_model : nn .Module = init_vllm_registered_model (
510515 vllm_config = vllm_config ,
@@ -739,8 +744,7 @@ def get_mm_mapping(self) -> MultiModelKeys:
739744 return MultiModelKeys .from_string_field (
740745 language_model = "language_model" ,
741746 connector = "multi_modal_projector" ,
742- tower_model = "vision_tower" ,
743- )
747+ tower_model = ["vision_tower" , "audio_tower" ])
744748
745749 @classmethod
746750 def get_placeholder_str (cls , modality : str , i : int ) -> Optional [str ]:
0 commit comments