Skip to content

Commit 5a16fa6

Browse files
NickLuccheShriKodeywang96
authored
[Model] Gemma3n MM (#20495)
Signed-off-by: ShriKode <[email protected]> Signed-off-by: NickLucche <[email protected]> Signed-off-by: Roger Wang <[email protected]> Co-authored-by: ShriKode <[email protected]> Co-authored-by: Roger Wang <[email protected]>
1 parent 2d18256 commit 5a16fa6

File tree

11 files changed

+864
-55
lines changed

11 files changed

+864
-55
lines changed

docs/models/supported_models.md

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ th {
349349
| `GemmaForCausalLM` | Gemma | `google/gemma-2b`, `google/gemma-1.1-2b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
350350
| `Gemma2ForCausalLM` | Gemma 2 | `google/gemma-2-9b`, `google/gemma-2-27b`, etc. | ✅︎ | ✅︎ | ✅︎ |
351351
| `Gemma3ForCausalLM` | Gemma 3 | `google/gemma-3-1b-it`, etc. | ✅︎ | ✅︎ | ✅︎ |
352-
| `Gemma3nForConditionalGeneration` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
352+
| `Gemma3nForCausalLM` | Gemma 3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
353353
| `GlmForCausalLM` | GLM-4 | `zai-org/glm-4-9b-chat-hf`, etc. | ✅︎ | ✅︎ | ✅︎ |
354354
| `Glm4ForCausalLM` | GLM-4-0414 | `zai-org/GLM-4-32B-0414`, etc. | ✅︎ | ✅︎ | ✅︎ |
355355
| `Glm4MoeForCausalLM` | GLM-4.5 | `zai-org/GLM-4.5`, etc. | ✅︎ | ✅︎ | ✅︎ |
@@ -412,9 +412,6 @@ th {
412412
!!! note
413413
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
414414

415-
!!! note
416-
Only text inputs are currently supported for `Gemma3nForConditionalGeneration`. To use this model, please upgrade Hugging Face Transformers to version 4.53.0.
417-
418415
### Pooling Models
419416

420417
See [this page](./pooling_models.md) for more information on how to use pooling models.
@@ -608,6 +605,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
608605
| `Florence2ForConditionalGeneration` | Florence-2 | T + I | `microsoft/Florence-2-base`, `microsoft/Florence-2-large`, etc. | | | |
609606
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | ✅︎ |
610607
| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I<sup>+</sup> | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | ⚠️ |
608+
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | ✅︎ |
611609
| `GLM4VForCausalLM`<sup>^</sup> | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | ✅︎ |
612610
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | ✅︎ |
613611
| `Glm4vMoeForConditionalGeneration` | GLM-4.5V | T + I<sup>E+</sup> + V<sup>E+</sup> | `zai-org/GLM-4.5V`, etc. | ✅︎ | ✅︎ | ✅︎ |
@@ -677,6 +675,15 @@ Some models are supported only via the [Transformers backend](#transformers). Th
677675

678676
This limitation exists because the model's mixed attention pattern (bidirectional for images, causal otherwise) is not yet supported by vLLM's attention backends.
679677

678+
!!! note
679+
`Gemma3nForConditionalGeneration` is only supported on V1 due to shared KV caching and it depends on `timm>=1.0.17` to make use of its
680+
MobileNet-v5 vision backbone.
681+
682+
Performance is not yet fully optimized mainly due to:
683+
684+
- Both audio and vision MM encoders use `transformers.AutoModel` implementation.
685+
- There's no PLE caching or out-of-memory swapping support, as described in [Google's blog](https://developers.googleblog.com/en/introducing-gemma-3n/). These features might be too model-specific for vLLM, and swapping in particular may be better suited for constrained setups.
686+
680687
!!! note
681688
Only `InternVLChatModel` with Qwen2.5 text backbone (`OpenGVLab/InternVL3-2B`, `OpenGVLab/InternVL2.5-1B` etc) has video inputs support currently.
682689

examples/offline_inference/audio_language.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,25 @@ def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
9696
)
9797

9898

99+
# Gemma3N
100+
def run_gemma3n(question: str, audio_count: int) -> ModelRequestData:
101+
model_name = "google/gemma-3n-E2B-it"
102+
engine_args = EngineArgs(
103+
model=model_name,
104+
max_model_len=2048,
105+
max_num_batched_tokens=2048,
106+
max_num_seqs=2,
107+
limit_mm_per_prompt={"audio": audio_count},
108+
enforce_eager=True,
109+
)
110+
prompt = f"<start_of_turn>user\n<audio_soft_token>{question}"
111+
"<end_of_turn>\n<start_of_turn>model\n"
112+
return ModelRequestData(
113+
engine_args=engine_args,
114+
prompt=prompt,
115+
)
116+
117+
99118
# Granite Speech
100119
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
101120
# NOTE - the setting in this example are somehat different than what is
@@ -331,6 +350,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
331350

332351
model_example_map = {
333352
"voxtral": run_voxtral,
353+
"gemma3n": run_gemma3n,
334354
"granite_speech": run_granite_speech,
335355
"minicpmo": run_minicpmo,
336356
"phi4_mm": run_phi4mm,

examples/offline_inference/vision_language.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,33 @@ def run_gemma3(questions: list[str], modality: str) -> ModelRequestData:
211211
)
212212
for question in questions
213213
]
214+
return ModelRequestData(
215+
engine_args=engine_args,
216+
prompts=prompts,
217+
)
218+
214219

220+
# Gemma3N
221+
def run_gemma3n(questions: list[str], modality: str) -> ModelRequestData:
222+
assert modality == "image"
223+
model_name = "google/gemma-3n-E2B-it"
224+
225+
engine_args = EngineArgs(
226+
model=model_name,
227+
max_model_len=2048,
228+
max_num_seqs=2,
229+
limit_mm_per_prompt={modality: 1},
230+
enforce_eager=True,
231+
)
232+
233+
prompts = [
234+
(
235+
"<start_of_turn>user\n"
236+
f"<image_soft_token>{question}<end_of_turn>\n"
237+
"<start_of_turn>model\n"
238+
)
239+
for question in questions
240+
]
215241
return ModelRequestData(
216242
engine_args=engine_args,
217243
prompts=prompts,
@@ -1395,6 +1421,7 @@ def run_tarsier2(questions: list[str], modality: str) -> ModelRequestData:
13951421
"florence2": run_florence2,
13961422
"fuyu": run_fuyu,
13971423
"gemma3": run_gemma3,
1424+
"gemma3n": run_gemma3n,
13981425
"glm4v": run_glm4v,
13991426
"glm4_1v": run_glm4_1v,
14001427
"h2ovl_chat": run_h2ovl,

requirements/test.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ ray[cgraph,default]>=2.48.0 # Ray Compiled Graph, required by pipeline paralleli
2121
sentence-transformers # required for embedding tests
2222
soundfile # required for audio tests
2323
jiwer # required for audio tests
24-
timm # required for internvl test
24+
timm >=1.0.17 # required for internvl and gemma3n-mm test
2525
torch==2.7.1
2626
torchaudio==2.7.1
2727
torchvision==0.22.1

requirements/test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1051,7 +1051,7 @@ tiktoken==0.7.0
10511051
# via
10521052
# lm-eval
10531053
# mistral-common
1054-
timm==1.0.15
1054+
timm==1.0.17
10551055
# via
10561056
# -r requirements/test.in
10571057
# open-clip-torch

tests/models/multimodal/processing/test_common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def _test_processing_correctness_one(
271271
"microsoft/Florence-2-base",
272272
"adept/fuyu-8b",
273273
"google/gemma-3-4b-it",
274+
"google/gemma-3n-E2B-it",
274275
"zai-org/glm-4v-9b",
275276
"zai-org/GLM-4.1V-9B-Thinking",
276277
"ibm-granite/granite-speech-3.3-2b",
@@ -315,7 +316,7 @@ def _test_processing_correctness_one(
315316
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
316317
"openai/whisper-large-v3",
317318
"omni-research/Tarsier-7b",
318-
"omni-research/Tarsier2-Recap-7b"
319+
"omni-research/Tarsier2-Recap-7b",
319320
])
320321
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
321322
@pytest.mark.parametrize("num_batches", [32])
@@ -327,6 +328,8 @@ def test_processing_correctness(
327328
num_batches: int,
328329
simplify_rate: float,
329330
):
331+
if model_id == "google/gemma-3n-E2B-it":
332+
pytest.skip("Skipping gemma-3n-E2B-it due to transformers #39911 bug.")
330333
_test_processing_correctness(
331334
model_id,
332335
hit_rate=hit_rate,

tests/models/registry.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def check_available_online(
186186
"GemmaForCausalLM": _HfExamplesInfo("google/gemma-1.1-2b-it"),
187187
"Gemma2ForCausalLM": _HfExamplesInfo("google/gemma-2-9b"),
188188
"Gemma3ForCausalLM": _HfExamplesInfo("google/gemma-3-1b-it"),
189-
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
189+
"Gemma3nForCausalLM": _HfExamplesInfo("google/gemma-3n-E2B-it",
190190
min_transformers_version="4.53"),
191191
"GlmForCausalLM": _HfExamplesInfo("zai-org/glm-4-9b-chat-hf"),
192192
"Glm4ForCausalLM": _HfExamplesInfo("zai-org/GLM-4-9B-0414"),
@@ -391,6 +391,8 @@ def check_available_online(
391391
"Emu3ForConditionalGeneration": _HfExamplesInfo("BAAI/Emu3-Chat-hf"),
392392
"FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"),
393393
"Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"),
394+
"Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it", # noqa: E501
395+
min_transformers_version="4.53"),
394396
"GraniteSpeechForConditionalGeneration": _HfExamplesInfo("ibm-granite/granite-speech-3.3-2b"), # noqa: E501
395397
"GLM4VForCausalLM": _HfExamplesInfo("zai-org/glm-4v-9b",
396398
trust_remote_code=True,

tests/test_test.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import pytest
5+
6+
from vllm import LLM, envs
7+
from vllm.sampling_params import SamplingParams
8+
9+
if not envs.VLLM_USE_V1:
10+
pytest.skip(
11+
"Skipping V1 tests. Rerun with `VLLM_USE_V1=1` to test.",
12+
allow_module_level=True,
13+
)
14+
15+
16+
@pytest.mark.parametrize("model_name", ["Qwen/Qwen2.5-1.5B-Instruct"])
17+
# TODO TPU will appear busy if we fan-out test params here
18+
@pytest.mark.parametrize("n_prompts", [1])
19+
def test_logprobs(model_name: str, n_prompts: int):
20+
"""
21+
Request top logprobs with different sampling settings and check
22+
that results contains the requested number, ordered ascendingly.
23+
"""
24+
25+
def check_num_logprobs(logprobs, expected_num: int):
26+
for step in logprobs:
27+
prev_logp = 1.0
28+
# order by rank
29+
sorted_step = dict(
30+
sorted(step.items(), key=lambda item: item[1].rank))
31+
32+
if len(step) != expected_num:
33+
print("watch out", sorted_step)
34+
35+
# check results are ordered by prob value
36+
# assert len(step) == expected_num
37+
for rankno, (tid, logp) in enumerate(sorted_step.items()):
38+
assert logp.logprob <= prev_logp
39+
prev_logp = logp.logprob
40+
assert logp.rank == rankno + 1
41+
42+
llm = LLM(model_name,
43+
enforce_eager=False,
44+
max_num_seqs=1,
45+
max_model_len=128,
46+
max_num_batched_tokens=128)
47+
prompts = [
48+
"Write a short story about a robot that dreams for the first time."
49+
] * n_prompts
50+
greedy_sampling_params = SamplingParams(temperature=0.0, max_tokens=64,\
51+
logprobs=4)
52+
regular_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
53+
logprobs=4)
54+
topkp_sampling_params = SamplingParams(temperature=0.4, max_tokens=64,\
55+
logprobs=4, top_k=12, top_p=0.5)
56+
57+
for sp in [greedy_sampling_params, regular_sampling_params, \
58+
topkp_sampling_params]:
59+
output = llm.generate(prompts, sp)
60+
for o in output:
61+
check_num_logprobs(o.outputs[0].logprobs, 4)

0 commit comments

Comments
 (0)