77import torch
88# yapf: disable
99from torch import nn
10-
1110from transformers import AutoModel , BatchFeature
1211from transformers .models .gemma3n import (Gemma3nAudioConfig ,
1312 Gemma3nAudioFeatureExtractor ,
1716from transformers .models .siglip import SiglipImageProcessorFast
1817
1918from vllm .config import ModelConfig , SpeechToTextConfig , VllmConfig
19+ from vllm .config .lora import LoRAConfig
2020from vllm .inputs .data import PromptType
2121from vllm .logger import init_logger
2222from vllm .model_executor .layers .layernorm import RMSNorm
4444from vllm .sequence import IntermediateTensors
4545from vllm .utils .tensor_schema import TensorSchema , TensorShape
4646
47- from .interfaces import (MultiModalEmbeddings , SupportsLoRA , SupportsMultiModal ,
48- SupportsTranscription )
47+ from .interfaces import (MultiModalEmbeddings , SupportsLoRA ,
48+ SupportsMultiModal , SupportsTranscription )
4949from .utils import (AutoWeightsLoader , WeightsMapper , flatten_bn ,
5050 init_vllm_registered_model , maybe_prefix )
5151
@@ -365,6 +365,7 @@ def __init__(
365365 self ,
366366 multimodal_config : Union [Gemma3nAudioConfig , Gemma3nVisionConfig ],
367367 text_config : Gemma3nTextConfig ,
368+ lora_config : Optional [LoRAConfig ] = None ,
368369 ):
369370 super ().__init__ ()
370371
@@ -374,9 +375,14 @@ def __init__(
374375 self .vocab_size = multimodal_config .vocab_size
375376 self .text_hidden_size = text_config .hidden_size
376377
378+ lora_vocab = (lora_config .lora_extra_vocab_size *
379+ (lora_config .max_loras or 1 )) if lora_config else 0
380+ self .vocab_size = self .vocab_size + lora_vocab
381+
377382 self .embedding = VocabParallelEmbedding (
378383 self .vocab_size ,
379384 self .multimodal_hidden_size ,
385+ org_num_embeddings = multimodal_config .vocab_size ,
380386 )
381387
382388 self .hard_embedding_norm = RMSNorm (
@@ -419,7 +425,6 @@ def forward(
419425 if (input_ids is None ) ^ (inputs_embeds is not None ):
420426 raise ValueError (
421427 "You must specify exactly one of input_ids or inputs_embeds" )
422-
423428 if inputs_embeds is not None :
424429 emb_norm = self .soft_embedding_norm (inputs_embeds )
425430 else :
@@ -472,13 +477,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
472477 self .quant_config = quant_config
473478 self .multimodal_config = multimodal_config
474479 self .vocab_size = config .text_config .vocab_size
480+ self .lora_config = vllm_config .lora_config
475481
476482 self .vision_tower = AutoModel .from_config (config = config .vision_config )
477483 self .audio_tower = AutoModel .from_config (config = config .audio_config )
478484 self .embed_vision = Gemma3nMultimodalEmbedder (config .vision_config ,
479- config .text_config )
485+ config .text_config ,
486+ self .lora_config )
480487 self .embed_audio = Gemma3nMultimodalEmbedder (config .audio_config ,
481- config .text_config )
488+ config .text_config ,
489+ self .lora_config )
482490
483491 self .language_model : nn .Module = init_vllm_registered_model (
484492 vllm_config = vllm_config ,
@@ -695,7 +703,7 @@ def get_mm_mapping(self) -> MultiModelKeys:
695703 return MultiModelKeys .from_string_field (
696704 language_model = "language_model" ,
697705 connector = "multi_modal_projector" ,
698- tower_model = "vision_tower" )
706+ tower_model = [ "vision_tower" , "audio_tower" ] )
699707
700708 @classmethod
701709 def get_placeholder_str (cls , modality : str , i : int ) -> Optional [str ]:
0 commit comments