Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ th {
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
| `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | ✅︎ | | ✅︎ |
| `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4MoeForCausalLM` | GLM-4.5, GLM-4.6 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
Expand Down Expand Up @@ -671,7 +671,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I<sup>+</sup>/ V<sup>+</sup> | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | ✅︎ |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | ✅︎ | | ✅︎ |
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
Expand Down
7 changes: 6 additions & 1 deletion vllm/lora/layers/base_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,12 @@ def apply(
# In transformers backend, x and output have extra batch dimension like
# (1, seq_len, hidden_dim), while punica expects (seq_len, hidden_dim),
# therefore we need to flatten the batch dimensions.
if x.ndim == 3 and output.ndim == 3:
if (
x.shape[0] == 1
and x.ndim == 3
and output.shape[0] == 1
and output.ndim == 3
):
output = output.flatten(0, 1)
x = x.flatten(0, 1)

Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/gemma3n.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from vllm.sequence import IntermediateTensors
from vllm.v1.attention.backends.utils import KVSharingFastPrefillMetadata

from .interfaces import SupportsQuant
from .interfaces import SupportsLoRA, SupportsQuant
from .utils import (
AutoWeightsLoader,
extract_layer_index,
Expand Down Expand Up @@ -1081,7 +1081,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
return loaded_params


class Gemma3nForCausalLM(nn.Module):
class Gemma3nForCausalLM(nn.Module, SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
Expand Down
30 changes: 24 additions & 6 deletions vllm/model_executor/models/gemma3n_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.config.lora import LoRAConfig
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
Expand Down Expand Up @@ -54,7 +55,12 @@
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
from .interfaces import (
MultiModalEmbeddings,
SupportsLoRA,
SupportsMultiModal,
SupportsTranscription,
)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
Expand Down Expand Up @@ -390,6 +396,7 @@ def __init__(
self,
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
text_config: Gemma3nTextConfig,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()

Expand All @@ -399,9 +406,17 @@ def __init__(
self.vocab_size = multimodal_config.vocab_size
self.text_hidden_size = text_config.hidden_size

lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = self.vocab_size + lora_vocab

self.embedding = VocabParallelEmbedding(
self.vocab_size,
self.multimodal_hidden_size,
org_num_embeddings=multimodal_config.vocab_size,
)

self.hard_embedding_norm = RMSNorm(
Expand Down Expand Up @@ -445,14 +460,16 @@ def forward(
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)

if inputs_embeds is not None:
emb_norm = self.soft_embedding_norm(inputs_embeds)
else:
hard_emb = self.embedding(input_ids - self.vocab_offset)
emb_norm = self.hard_embedding_norm(hard_emb)

emb_norm_proj, _ = self.embedding_projection(emb_norm)
if emb_norm_proj.ndim == 2:
# One-element batch squeezing when lora is enabled
emb_norm_proj = emb_norm_proj.unsqueeze(0)
return self.embedding_post_projection_norm(emb_norm_proj)


Expand All @@ -462,7 +479,7 @@ def forward(
dummy_inputs=Gemma3nDummyInputsBuilder,
)
class Gemma3nForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsTranscription
nn.Module, SupportsMultiModal, SupportsTranscription, SupportsLoRA
):
merge_by_field_config = True
supported_languages = ISO639_1_SUPPORTED_LANGS
Expand Down Expand Up @@ -502,14 +519,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.vocab_size = config.text_config.vocab_size
self.lora_config = vllm_config.lora_config

self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.audio_tower = AutoModel.from_config(config=config.audio_config)
self.embed_vision = Gemma3nMultimodalEmbedder(
config.vision_config, config.text_config
config.vision_config, config.text_config, self.lora_config
)
self.embed_audio = Gemma3nMultimodalEmbedder(
config.audio_config, config.text_config
config.audio_config, config.text_config, self.lora_config
)

self.language_model: nn.Module = init_vllm_registered_model(
Expand Down Expand Up @@ -745,7 +763,7 @@ def get_mm_mapping(self) -> MultiModelKeys:
return MultiModelKeys.from_string_field(
language_model="language_model",
connector="multi_modal_projector",
tower_model="vision_tower",
tower_model=["vision_tower", "audio_tower"],
)

@classmethod
Expand Down