Skip to content

Commit 4f1c350

Browse files
committed
do not flatten input on linear lora layer
Signed-off-by: NickLucche <[email protected]>
1 parent f81bf71 commit 4f1c350

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

vllm/lora/layers/base_linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def apply(self,
143143
# In transformers backend, x and output have extra batch dimension like
144144
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
145145
# therefore we need to flatten the batch dimensions.
146-
if x.ndim == 3 and output.ndim == 3:
146+
if (x.shape[0] ==1 and x.ndim == 3 and
147+
output.shape[0] == 1 and output.ndim == 3):
147148
output = output.flatten(0, 1)
148149
x = x.flatten(0, 1)
149150

vllm/model_executor/models/gemma3n_mm.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88
# yapf: disable
99
from torch import nn
10-
1110
from transformers import AutoModel, BatchFeature
1211
from transformers.models.gemma3n import (Gemma3nAudioConfig,
1312
Gemma3nAudioFeatureExtractor,
@@ -17,6 +16,7 @@
1716
from transformers.models.siglip import SiglipImageProcessorFast
1817

1918
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
19+
from vllm.config.lora import LoRAConfig
2020
from vllm.inputs.data import PromptType
2121
from vllm.logger import init_logger
2222
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -44,8 +44,8 @@
4444
from vllm.sequence import IntermediateTensors
4545
from vllm.utils.tensor_schema import TensorSchema, TensorShape
4646

47-
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal,
48-
SupportsTranscription)
47+
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
48+
SupportsMultiModal, SupportsTranscription)
4949
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
5050
init_vllm_registered_model, maybe_prefix)
5151

@@ -365,6 +365,7 @@ def __init__(
365365
self,
366366
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
367367
text_config: Gemma3nTextConfig,
368+
lora_config: Optional[LoRAConfig] = None,
368369
):
369370
super().__init__()
370371

@@ -374,9 +375,14 @@ def __init__(
374375
self.vocab_size = multimodal_config.vocab_size
375376
self.text_hidden_size = text_config.hidden_size
376377

378+
lora_vocab = (lora_config.lora_extra_vocab_size *
379+
(lora_config.max_loras or 1)) if lora_config else 0
380+
self.vocab_size = self.vocab_size + lora_vocab
381+
377382
self.embedding = VocabParallelEmbedding(
378383
self.vocab_size,
379384
self.multimodal_hidden_size,
385+
org_num_embeddings=multimodal_config.vocab_size,
380386
)
381387

382388
self.hard_embedding_norm = RMSNorm(
@@ -419,7 +425,6 @@ def forward(
419425
if (input_ids is None) ^ (inputs_embeds is not None):
420426
raise ValueError(
421427
"You must specify exactly one of input_ids or inputs_embeds")
422-
423428
if inputs_embeds is not None:
424429
emb_norm = self.soft_embedding_norm(inputs_embeds)
425430
else:
@@ -472,13 +477,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
472477
self.quant_config = quant_config
473478
self.multimodal_config = multimodal_config
474479
self.vocab_size = config.text_config.vocab_size
480+
self.lora_config = vllm_config.lora_config
475481

476482
self.vision_tower = AutoModel.from_config(config=config.vision_config)
477483
self.audio_tower = AutoModel.from_config(config=config.audio_config)
478484
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config,
479-
config.text_config)
485+
config.text_config,
486+
self.lora_config)
480487
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config,
481-
config.text_config)
488+
config.text_config,
489+
self.lora_config)
482490

483491
self.language_model: nn.Module = init_vllm_registered_model(
484492
vllm_config=vllm_config,
@@ -695,7 +703,7 @@ def get_mm_mapping(self) -> MultiModelKeys:
695703
return MultiModelKeys.from_string_field(
696704
language_model="language_model",
697705
connector="multi_modal_projector",
698-
tower_model="vision_tower")
706+
tower_model=["vision_tower", "audio_tower"])
699707

700708
@classmethod
701709
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:

0 commit comments

Comments
 (0)