Skip to content

Commit a8d07ac

Browse files
committed
do not flatten input on linear lora layer
Signed-off-by: NickLucche <[email protected]>
1 parent bc27552 commit a8d07ac

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

vllm/lora/layers/base_linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,8 @@ def apply(self,
143143
# In transformers backend, x and output have extra batch dimension like
144144
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
145145
# therefore we need to flatten the batch dimensions.
146-
if x.ndim == 3 and output.ndim == 3:
146+
if (x.shape[0] ==1 and x.ndim == 3 and
147+
output.shape[0] == 1 and output.ndim == 3):
147148
output = output.flatten(0, 1)
148149
x = x.flatten(0, 1)
149150

vllm/model_executor/models/gemma3n_mm.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88
# yapf: disable
99
from torch import nn
10-
1110
from transformers import AutoModel, BatchFeature
1211
from transformers.models.gemma3n import (Gemma3nAudioConfig,
1312
Gemma3nAudioFeatureExtractor,
@@ -18,6 +17,7 @@
1817

1918
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
2019
from vllm.config.multimodal import BaseDummyOptions
20+
from vllm.config.lora import LoRAConfig
2121
from vllm.inputs.data import PromptType
2222
from vllm.logger import init_logger
2323
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -45,8 +45,8 @@
4545
from vllm.sequence import IntermediateTensors
4646
from vllm.utils.tensor_schema import TensorSchema, TensorShape
4747

48-
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal,
49-
SupportsTranscription)
48+
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
49+
SupportsMultiModal, SupportsTranscription)
5050
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
5151
init_vllm_registered_model, maybe_prefix)
5252

@@ -373,6 +373,7 @@ def __init__(
373373
self,
374374
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
375375
text_config: Gemma3nTextConfig,
376+
lora_config: Optional[LoRAConfig] = None,
376377
):
377378
super().__init__()
378379

@@ -382,9 +383,14 @@ def __init__(
382383
self.vocab_size = multimodal_config.vocab_size
383384
self.text_hidden_size = text_config.hidden_size
384385

386+
lora_vocab = (lora_config.lora_extra_vocab_size *
387+
(lora_config.max_loras or 1)) if lora_config else 0
388+
self.vocab_size = self.vocab_size + lora_vocab
389+
385390
self.embedding = VocabParallelEmbedding(
386391
self.vocab_size,
387392
self.multimodal_hidden_size,
393+
org_num_embeddings=multimodal_config.vocab_size,
388394
)
389395

390396
self.hard_embedding_norm = RMSNorm(
@@ -427,7 +433,6 @@ def forward(
427433
if (input_ids is None) ^ (inputs_embeds is not None):
428434
raise ValueError(
429435
"You must specify exactly one of input_ids or inputs_embeds")
430-
431436
if inputs_embeds is not None:
432437
emb_norm = self.soft_embedding_norm(inputs_embeds)
433438
else:
@@ -480,13 +485,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
480485
self.quant_config = quant_config
481486
self.multimodal_config = multimodal_config
482487
self.vocab_size = config.text_config.vocab_size
488+
self.lora_config = vllm_config.lora_config
483489

484490
self.vision_tower = AutoModel.from_config(config=config.vision_config)
485491
self.audio_tower = AutoModel.from_config(config=config.audio_config)
486492
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config,
487-
config.text_config)
493+
config.text_config,
494+
self.lora_config)
488495
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config,
489-
config.text_config)
496+
config.text_config,
497+
self.lora_config)
490498

491499
self.language_model: nn.Module = init_vllm_registered_model(
492500
vllm_config=vllm_config,
@@ -703,7 +711,7 @@ def get_mm_mapping(self) -> MultiModelKeys:
703711
return MultiModelKeys.from_string_field(
704712
language_model="language_model",
705713
connector="multi_modal_projector",
706-
tower_model="vision_tower")
714+
tower_model=["vision_tower", "audio_tower"])
707715

708716
@classmethod
709717
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:

0 commit comments

Comments
 (0)