From 3c187d4395ec30ce7c8ada45d73a413083f60379 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Sun, 31 Aug 2025 09:42:30 +0000 Subject: [PATCH 1/5] gemma3n lora Signed-off-by: NickLucche --- docs/models/supported_models.md | 4 ++-- vllm/model_executor/models/gemma3n_mm.py | 26 +++++++++--------------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 60fe5b887952..3eb179385637 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -366,7 +366,7 @@ th { | `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ | -| `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | +| `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | ✅︎ | | ✅︎ | | `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4MoeForCausalLM` | GLM-4.5, GLM-4.6 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ | @@ -671,7 +671,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ | | `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ | -| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ | +| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | ✅︎ | | ✅︎ | | `GLM4VForCausalLM`^ | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ | | `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + IE+ + VE+ | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ | diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 0e69fcfd8feb..cebab4895839 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -7,6 +7,7 @@ import torch from torch import nn + from transformers import AutoModel, BatchFeature from transformers.models.gemma3n import ( Gemma3nAudioConfig, @@ -54,14 +55,10 @@ from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription -from .utils import ( - AutoWeightsLoader, - WeightsMapper, - flatten_bn, - init_vllm_registered_model, - maybe_prefix, -) +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, + SupportsTranscription) +from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, + init_vllm_registered_model, maybe_prefix) logger = init_logger(__name__) @@ -456,14 +453,11 @@ def forward( return self.embedding_post_projection_norm(emb_norm_proj) -@MULTIMODAL_REGISTRY.register_processor( - Gemma3nMultiModalProcessor, - info=Gemma3nProcessingInfo, - dummy_inputs=Gemma3nDummyInputsBuilder, -) -class Gemma3nForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsTranscription -): +@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor, + info=Gemma3nProcessingInfo, + dummy_inputs=Gemma3nDummyInputsBuilder) +class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsTranscription, SupportsLoRA): merge_by_field_config = True supported_languages = ISO639_1_SUPPORTED_LANGS From a607d3818857ac99ad776334e3f304268b1ee622 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Sun, 31 Aug 2025 09:51:25 +0000 Subject: [PATCH 2/5] gemma3n lora Signed-off-by: NickLucche --- vllm/model_executor/models/gemma3n.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index e4ea4256ebc2..96c9fb161ad9 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -20,8 +20,8 @@ import torch from torch import nn -from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig +from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig @@ -52,14 +52,9 @@ from vllm.sequence import IntermediateTensors from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata -from .interfaces import SupportsQuant -from .utils import ( - AutoWeightsLoader, - extract_layer_index, - is_pp_missing_parameter, - make_layers, - maybe_prefix, -) +from .interfaces import SupportsLoRA, SupportsQuant +from .utils import (AutoWeightsLoader, extract_layer_index, + is_pp_missing_parameter, make_layers, maybe_prefix) logger = init_logger(__name__) @@ -1081,7 +1076,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: return loaded_params -class Gemma3nForCausalLM(nn.Module): +class Gemma3nForCausalLM(nn.Module, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", From bbf097fe0d0e91015dcddd242e5be760b4962a08 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Fri, 3 Oct 2025 10:57:35 +0000 Subject: [PATCH 3/5] do not flatten input on linear lora layer Signed-off-by: NickLucche --- vllm/lora/layers/base_linear.py | 3 ++- vllm/model_executor/models/gemma3n_mm.py | 32 +++++++++++++----------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index d2f017c19ccd..410b9f5605cb 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -157,7 +157,8 @@ def apply( # In transformers backend, x and output have extra batch dimension like # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), # therefore we need to flatten the batch dimensions. - if x.ndim == 3 and output.ndim == 3: + if (x.shape[0] ==1 and x.ndim == 3 and + output.shape[0] == 1 and output.ndim == 3): output = output.flatten(0, 1) x = x.flatten(0, 1) diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index cebab4895839..4a826165c5ce 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -7,7 +7,6 @@ import torch from torch import nn - from transformers import AutoModel, BatchFeature from transformers.models.gemma3n import ( Gemma3nAudioConfig, @@ -21,6 +20,7 @@ from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions +from vllm.config.lora import LoRAConfig from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm @@ -55,8 +55,8 @@ from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal, - SupportsTranscription) +from .interfaces import (MultiModalEmbeddings, SupportsLoRA, + SupportsMultiModal, SupportsTranscription) from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, init_vllm_registered_model, maybe_prefix) @@ -387,6 +387,7 @@ def __init__( self, multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig], text_config: Gemma3nTextConfig, + lora_config: Optional[LoRAConfig] = None, ): super().__init__() @@ -396,9 +397,14 @@ def __init__( self.vocab_size = multimodal_config.vocab_size self.text_hidden_size = text_config.hidden_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = self.vocab_size + lora_vocab + self.embedding = VocabParallelEmbedding( self.vocab_size, self.multimodal_hidden_size, + org_num_embeddings=multimodal_config.vocab_size, ) self.hard_embedding_norm = RMSNorm( @@ -440,9 +446,7 @@ def forward( """ # noqa: E501 if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds" - ) - + "You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is not None: emb_norm = self.soft_embedding_norm(inputs_embeds) else: @@ -496,15 +500,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config = quant_config self.multimodal_config = multimodal_config self.vocab_size = config.text_config.vocab_size + self.lora_config = vllm_config.lora_config self.vision_tower = AutoModel.from_config(config=config.vision_config) self.audio_tower = AutoModel.from_config(config=config.audio_config) - self.embed_vision = Gemma3nMultimodalEmbedder( - config.vision_config, config.text_config - ) - self.embed_audio = Gemma3nMultimodalEmbedder( - config.audio_config, config.text_config - ) + self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, + config.text_config, + self.lora_config) + self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, + config.text_config, + self.lora_config) self.language_model: nn.Module = init_vllm_registered_model( vllm_config=vllm_config, @@ -739,8 +744,7 @@ def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="language_model", connector="multi_modal_projector", - tower_model="vision_tower", - ) + tower_model=["vision_tower", "audio_tower"]) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: From 1fed5dd54ab38242dc58ebf9e4900f9aaa9fd223 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 6 Oct 2025 08:01:05 +0000 Subject: [PATCH 4/5] precommit Signed-off-by: NickLucche --- vllm/lora/layers/base_linear.py | 8 +++- vllm/model_executor/models/gemma3n.py | 11 +++-- vllm/model_executor/models/gemma3n_mm.py | 55 ++++++++++++++++-------- 3 files changed, 50 insertions(+), 24 deletions(-) diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index 410b9f5605cb..20ff51b0da61 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -157,8 +157,12 @@ def apply( # In transformers backend, x and output have extra batch dimension like # (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim), # therefore we need to flatten the batch dimensions. - if (x.shape[0] ==1 and x.ndim == 3 and - output.shape[0] == 1 and output.ndim == 3): + if ( + x.shape[0] == 1 + and x.ndim == 3 + and output.shape[0] == 1 + and output.ndim == 3 + ): output = output.flatten(0, 1) x = x.flatten(0, 1) diff --git a/vllm/model_executor/models/gemma3n.py b/vllm/model_executor/models/gemma3n.py index 96c9fb161ad9..a274e7a21f59 100644 --- a/vllm/model_executor/models/gemma3n.py +++ b/vllm/model_executor/models/gemma3n.py @@ -20,8 +20,8 @@ import torch from torch import nn - from transformers.models.gemma3n.configuration_gemma3n import Gemma3nTextConfig + from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig @@ -53,8 +53,13 @@ from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata from .interfaces import SupportsLoRA, SupportsQuant -from .utils import (AutoWeightsLoader, extract_layer_index, - is_pp_missing_parameter, make_layers, maybe_prefix) +from .utils import ( + AutoWeightsLoader, + extract_layer_index, + is_pp_missing_parameter, + make_layers, + maybe_prefix, +) logger = init_logger(__name__) diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 4a826165c5ce..e80b0c0de971 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -55,10 +55,19 @@ from vllm.sequence import IntermediateTensors from vllm.utils.tensor_schema import TensorSchema, TensorShape -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsTranscription) -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, - init_vllm_registered_model, maybe_prefix) +from .interfaces import ( + MultiModalEmbeddings, + SupportsLoRA, + SupportsMultiModal, + SupportsTranscription, +) +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + flatten_bn, + init_vllm_registered_model, + maybe_prefix, +) logger = init_logger(__name__) @@ -397,8 +406,11 @@ def __init__( self.vocab_size = multimodal_config.vocab_size self.text_hidden_size = text_config.hidden_size - lora_vocab = (lora_config.lora_extra_vocab_size * - (lora_config.max_loras or 1)) if lora_config else 0 + lora_vocab = ( + (lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1)) + if lora_config + else 0 + ) self.vocab_size = self.vocab_size + lora_vocab self.embedding = VocabParallelEmbedding( @@ -446,7 +458,8 @@ def forward( """ # noqa: E501 if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds") + "You must specify exactly one of input_ids or inputs_embeds" + ) if inputs_embeds is not None: emb_norm = self.soft_embedding_norm(inputs_embeds) else: @@ -457,11 +470,14 @@ def forward( return self.embedding_post_projection_norm(emb_norm_proj) -@MULTIMODAL_REGISTRY.register_processor(Gemma3nMultiModalProcessor, - info=Gemma3nProcessingInfo, - dummy_inputs=Gemma3nDummyInputsBuilder) -class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal, - SupportsTranscription, SupportsLoRA): +@MULTIMODAL_REGISTRY.register_processor( + Gemma3nMultiModalProcessor, + info=Gemma3nProcessingInfo, + dummy_inputs=Gemma3nDummyInputsBuilder, +) +class Gemma3nForConditionalGeneration( + nn.Module, SupportsMultiModal, SupportsTranscription, SupportsLoRA +): merge_by_field_config = True supported_languages = ISO639_1_SUPPORTED_LANGS @@ -504,12 +520,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.vision_tower = AutoModel.from_config(config=config.vision_config) self.audio_tower = AutoModel.from_config(config=config.audio_config) - self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, - config.text_config, - self.lora_config) - self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, - config.text_config, - self.lora_config) + self.embed_vision = Gemma3nMultimodalEmbedder( + config.vision_config, config.text_config, self.lora_config + ) + self.embed_audio = Gemma3nMultimodalEmbedder( + config.audio_config, config.text_config, self.lora_config + ) self.language_model: nn.Module = init_vllm_registered_model( vllm_config=vllm_config, @@ -744,7 +760,8 @@ def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( language_model="language_model", connector="multi_modal_projector", - tower_model=["vision_tower", "audio_tower"]) + tower_model=["vision_tower", "audio_tower"], + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: From 7663ed42431b2232b9fc30f2181a185cfbe20620 Mon Sep 17 00:00:00 2001 From: NickLucche Date: Mon, 13 Oct 2025 17:25:04 +0000 Subject: [PATCH 5/5] single-elem case squeeze Signed-off-by: NickLucche --- vllm/model_executor/models/gemma3n_mm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index e80b0c0de971..e0dbbf9dfd5b 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -467,6 +467,9 @@ def forward( emb_norm = self.hard_embedding_norm(hard_emb) emb_norm_proj, _ = self.embedding_projection(emb_norm) + if emb_norm_proj.ndim == 2: + # One-element batch squeezing when lora is enabled + emb_norm_proj = emb_norm_proj.unsqueeze(0) return self.embedding_post_projection_norm(emb_norm_proj)