77import torch
88# yapf: disable
99from torch import nn
10-
1110from transformers import AutoModel , BatchFeature
1211from transformers .models .gemma3n import (Gemma3nAudioConfig ,
1312 Gemma3nAudioFeatureExtractor ,
1817
1918from vllm .config import ModelConfig , SpeechToTextConfig , VllmConfig
2019from vllm .config .multimodal import BaseDummyOptions
20+ from vllm .config .lora import LoRAConfig
2121from vllm .inputs .data import PromptType
2222from vllm .logger import init_logger
2323from vllm .model_executor .layers .layernorm import RMSNorm
4545from vllm .sequence import IntermediateTensors
4646from vllm .utils .tensor_schema import TensorSchema , TensorShape
4747
48- from .interfaces import (MultiModalEmbeddings , SupportsLoRA , SupportsMultiModal ,
49- SupportsTranscription )
48+ from .interfaces import (MultiModalEmbeddings , SupportsLoRA ,
49+ SupportsMultiModal , SupportsTranscription )
5050from .utils import (AutoWeightsLoader , WeightsMapper , flatten_bn ,
5151 init_vllm_registered_model , maybe_prefix )
5252
@@ -373,6 +373,7 @@ def __init__(
373373 self ,
374374 multimodal_config : Union [Gemma3nAudioConfig , Gemma3nVisionConfig ],
375375 text_config : Gemma3nTextConfig ,
376+ lora_config : Optional [LoRAConfig ] = None ,
376377 ):
377378 super ().__init__ ()
378379
@@ -382,9 +383,14 @@ def __init__(
382383 self .vocab_size = multimodal_config .vocab_size
383384 self .text_hidden_size = text_config .hidden_size
384385
386+ lora_vocab = (lora_config .lora_extra_vocab_size *
387+ (lora_config .max_loras or 1 )) if lora_config else 0
388+ self .vocab_size = self .vocab_size + lora_vocab
389+
385390 self .embedding = VocabParallelEmbedding (
386391 self .vocab_size ,
387392 self .multimodal_hidden_size ,
393+ org_num_embeddings = multimodal_config .vocab_size ,
388394 )
389395
390396 self .hard_embedding_norm = RMSNorm (
@@ -427,7 +433,6 @@ def forward(
427433 if (input_ids is None ) ^ (inputs_embeds is not None ):
428434 raise ValueError (
429435 "You must specify exactly one of input_ids or inputs_embeds" )
430-
431436 if inputs_embeds is not None :
432437 emb_norm = self .soft_embedding_norm (inputs_embeds )
433438 else :
@@ -480,13 +485,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
480485 self .quant_config = quant_config
481486 self .multimodal_config = multimodal_config
482487 self .vocab_size = config .text_config .vocab_size
488+ self .lora_config = vllm_config .lora_config
483489
484490 self .vision_tower = AutoModel .from_config (config = config .vision_config )
485491 self .audio_tower = AutoModel .from_config (config = config .audio_config )
486492 self .embed_vision = Gemma3nMultimodalEmbedder (config .vision_config ,
487- config .text_config )
493+ config .text_config ,
494+ self .lora_config )
488495 self .embed_audio = Gemma3nMultimodalEmbedder (config .audio_config ,
489- config .text_config )
496+ config .text_config ,
497+ self .lora_config )
490498
491499 self .language_model : nn .Module = init_vllm_registered_model (
492500 vllm_config = vllm_config ,
@@ -703,7 +711,7 @@ def get_mm_mapping(self) -> MultiModelKeys:
703711 return MultiModelKeys .from_string_field (
704712 language_model = "language_model" ,
705713 connector = "multi_modal_projector" ,
706- tower_model = "vision_tower" )
714+ tower_model = [ "vision_tower" , "audio_tower" ] )
707715
708716 @classmethod
709717 def get_placeholder_str (cls , modality : str , i : int ) -> Optional [str ]:
0 commit comments