5555from vllm .sequence import IntermediateTensors
5656from vllm .utils .tensor_schema import TensorSchema , TensorShape
5757
58- from .interfaces import (MultiModalEmbeddings , SupportsLoRA ,
59- SupportsMultiModal , SupportsTranscription )
60- from .utils import (AutoWeightsLoader , WeightsMapper , flatten_bn ,
61- init_vllm_registered_model , maybe_prefix )
58+ from .interfaces import (
59+ MultiModalEmbeddings ,
60+ SupportsLoRA ,
61+ SupportsMultiModal ,
62+ SupportsTranscription ,
63+ )
64+ from .utils import (
65+ AutoWeightsLoader ,
66+ WeightsMapper ,
67+ flatten_bn ,
68+ init_vllm_registered_model ,
69+ maybe_prefix ,
70+ )
6271
6372logger = init_logger (__name__ )
6473
@@ -397,8 +406,11 @@ def __init__(
397406 self .vocab_size = multimodal_config .vocab_size
398407 self .text_hidden_size = text_config .hidden_size
399408
400- lora_vocab = (lora_config .lora_extra_vocab_size *
401- (lora_config .max_loras or 1 )) if lora_config else 0
409+ lora_vocab = (
410+ (lora_config .lora_extra_vocab_size * (lora_config .max_loras or 1 ))
411+ if lora_config
412+ else 0
413+ )
402414 self .vocab_size = self .vocab_size + lora_vocab
403415
404416 self .embedding = VocabParallelEmbedding (
@@ -446,7 +458,8 @@ def forward(
446458 """ # noqa: E501
447459 if (input_ids is None ) ^ (inputs_embeds is not None ):
448460 raise ValueError (
449- "You must specify exactly one of input_ids or inputs_embeds" )
461+ "You must specify exactly one of input_ids or inputs_embeds"
462+ )
450463 if inputs_embeds is not None :
451464 emb_norm = self .soft_embedding_norm (inputs_embeds )
452465 else :
@@ -457,11 +470,14 @@ def forward(
457470 return self .embedding_post_projection_norm (emb_norm_proj )
458471
459472
460- @MULTIMODAL_REGISTRY .register_processor (Gemma3nMultiModalProcessor ,
461- info = Gemma3nProcessingInfo ,
462- dummy_inputs = Gemma3nDummyInputsBuilder )
463- class Gemma3nForConditionalGeneration (nn .Module , SupportsMultiModal ,
464- SupportsTranscription , SupportsLoRA ):
473+ @MULTIMODAL_REGISTRY .register_processor (
474+ Gemma3nMultiModalProcessor ,
475+ info = Gemma3nProcessingInfo ,
476+ dummy_inputs = Gemma3nDummyInputsBuilder ,
477+ )
478+ class Gemma3nForConditionalGeneration (
479+ nn .Module , SupportsMultiModal , SupportsTranscription , SupportsLoRA
480+ ):
465481 merge_by_field_config = True
466482 supported_languages = ISO639_1_SUPPORTED_LANGS
467483
@@ -504,12 +520,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
504520
505521 self .vision_tower = AutoModel .from_config (config = config .vision_config )
506522 self .audio_tower = AutoModel .from_config (config = config .audio_config )
507- self .embed_vision = Gemma3nMultimodalEmbedder (config . vision_config ,
508- config .text_config ,
509- self . lora_config )
510- self .embed_audio = Gemma3nMultimodalEmbedder (config . audio_config ,
511- config .text_config ,
512- self . lora_config )
523+ self .embed_vision = Gemma3nMultimodalEmbedder (
524+ config . vision_config , config .text_config , self . lora_config
525+ )
526+ self .embed_audio = Gemma3nMultimodalEmbedder (
527+ config . audio_config , config .text_config , self . lora_config
528+ )
513529
514530 self .language_model : nn .Module = init_vllm_registered_model (
515531 vllm_config = vllm_config ,
@@ -744,7 +760,8 @@ def get_mm_mapping(self) -> MultiModelKeys:
744760 return MultiModelKeys .from_string_field (
745761 language_model = "language_model" ,
746762 connector = "multi_modal_projector" ,
747- tower_model = ["vision_tower" , "audio_tower" ])
763+ tower_model = ["vision_tower" , "audio_tower" ],
764+ )
748765
749766 @classmethod
750767 def get_placeholder_str (cls , modality : str , i : int ) -> Optional [str ]:
0 commit comments