Skip to content

Commit bbf097f

Browse files
committed
do not flatten input on linear lora layer
Signed-off-by: NickLucche <[email protected]>
1 parent a607d38 commit bbf097f

File tree

2 files changed

+20
-15
lines changed

2 files changed

+20
-15
lines changed

vllm/lora/layers/base_linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,8 @@ def apply(
157157
# In transformers backend, x and output have extra batch dimension like
158158
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
159159
# therefore we need to flatten the batch dimensions.
160-
if x.ndim == 3 and output.ndim == 3:
160+
if (x.shape[0] ==1 and x.ndim == 3 and
161+
output.shape[0] == 1 and output.ndim == 3):
161162
output = output.flatten(0, 1)
162163
x = x.flatten(0, 1)
163164

vllm/model_executor/models/gemma3n_mm.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import torch
88

99
from torch import nn
10-
1110
from transformers import AutoModel, BatchFeature
1211
from transformers.models.gemma3n import (
1312
Gemma3nAudioConfig,
@@ -21,6 +20,7 @@
2120

2221
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
2322
from vllm.config.multimodal import BaseDummyOptions
23+
from vllm.config.lora import LoRAConfig
2424
from vllm.inputs.data import PromptType
2525
from vllm.logger import init_logger
2626
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -55,8 +55,8 @@
5555
from vllm.sequence import IntermediateTensors
5656
from vllm.utils.tensor_schema import TensorSchema, TensorShape
5757

58-
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, SupportsMultiModal,
59-
SupportsTranscription)
58+
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
59+
SupportsMultiModal, SupportsTranscription)
6060
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
6161
init_vllm_registered_model, maybe_prefix)
6262

@@ -387,6 +387,7 @@ def __init__(
387387
self,
388388
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
389389
text_config: Gemma3nTextConfig,
390+
lora_config: Optional[LoRAConfig] = None,
390391
):
391392
super().__init__()
392393

@@ -396,9 +397,14 @@ def __init__(
396397
self.vocab_size = multimodal_config.vocab_size
397398
self.text_hidden_size = text_config.hidden_size
398399

400+
lora_vocab = (lora_config.lora_extra_vocab_size *
401+
(lora_config.max_loras or 1)) if lora_config else 0
402+
self.vocab_size = self.vocab_size + lora_vocab
403+
399404
self.embedding = VocabParallelEmbedding(
400405
self.vocab_size,
401406
self.multimodal_hidden_size,
407+
org_num_embeddings=multimodal_config.vocab_size,
402408
)
403409

404410
self.hard_embedding_norm = RMSNorm(
@@ -440,9 +446,7 @@ def forward(
440446
""" # noqa: E501
441447
if (input_ids is None) ^ (inputs_embeds is not None):
442448
raise ValueError(
443-
"You must specify exactly one of input_ids or inputs_embeds"
444-
)
445-
449+
"You must specify exactly one of input_ids or inputs_embeds")
446450
if inputs_embeds is not None:
447451
emb_norm = self.soft_embedding_norm(inputs_embeds)
448452
else:
@@ -496,15 +500,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
496500
self.quant_config = quant_config
497501
self.multimodal_config = multimodal_config
498502
self.vocab_size = config.text_config.vocab_size
503+
self.lora_config = vllm_config.lora_config
499504

500505
self.vision_tower = AutoModel.from_config(config=config.vision_config)
501506
self.audio_tower = AutoModel.from_config(config=config.audio_config)
502-
self.embed_vision = Gemma3nMultimodalEmbedder(
503-
config.vision_config, config.text_config
504-
)
505-
self.embed_audio = Gemma3nMultimodalEmbedder(
506-
config.audio_config, config.text_config
507-
)
507+
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config,
508+
config.text_config,
509+
self.lora_config)
510+
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config,
511+
config.text_config,
512+
self.lora_config)
508513

509514
self.language_model: nn.Module = init_vllm_registered_model(
510515
vllm_config=vllm_config,
@@ -739,8 +744,7 @@ def get_mm_mapping(self) -> MultiModelKeys:
739744
return MultiModelKeys.from_string_field(
740745
language_model="language_model",
741746
connector="multi_modal_projector",
742-
tower_model="vision_tower",
743-
)
747+
tower_model=["vision_tower", "audio_tower"])
744748

745749
@classmethod
746750
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:

0 commit comments

Comments
 (0)