Skip to content

Commit c575232

Browse files
authored
[Model] Support Llama4 in vLLM (#16104)
1 parent 63375f0 commit c575232

File tree

35 files changed

+2369
-142
lines changed

35 files changed

+2369
-142
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,9 @@ def main(args: argparse.Namespace):
553553
intermediate_size = config.moe_intermediate_size
554554
shard_intermediate_size = 2 * intermediate_size // args.tp_size
555555
else:
556+
if not hasattr(config, "hidden_size"):
557+
# Support for llama4
558+
config = config.text_config
556559
# Default: Mixtral.
557560
E = config.num_local_experts
558561
topk = config.num_experts_per_tok

docs/source/models/supported_models.md

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ vLLM also supports model implementations that are available in Transformers. Thi
2424

2525
To check if the modeling backend is Transformers, you can simply do this:
2626

27-
```python
27+
```python
2828
from vllm import LLM
2929
llm = LLM(model=..., task="generate") # Name or path of your model
3030
llm.apply_model(lambda model: print(type(model)))
@@ -55,7 +55,7 @@ If your model is neither supported natively by vLLM or Transformers, you can sti
5555
Simply set `trust_remote_code=True` and vLLM will run any model on the Model Hub that is compatible with Transformers.
5656
Provided that the model writer implements their model in a compatible way, this means that you can run new models before they are officially supported in Transformers or vLLM!
5757

58-
```python
58+
```python
5959
from vllm import LLM
6060
llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model
6161
llm.apply_model(lambda model: print(model.__class__))
@@ -840,6 +840,13 @@ See [this page](#generative-models) for more information on how to use generativ
840840
*
841841
* ✅︎
842842
* ✅︎
843+
- * `Llama4ForConditionalGeneration`
844+
* Llama-4-17B-Omni-Instruct
845+
* T + I<sup>+</sup>
846+
* `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc.
847+
*
848+
*
849+
* ✅︎
843850
- * `LlavaForConditionalGeneration`
844851
* LLaVA-1.5
845852
* T + I<sup>E+</sup>
@@ -982,10 +989,10 @@ See [this page](#generative-models) for more information on how to use generativ
982989
* ✅︎
983990
:::
984991

985-
<sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM.
986-
&nbsp;&nbsp;&nbsp;&nbsp;• For example, to use DeepSeek-VL2 series models:
987-
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;`--hf-overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'`
988-
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
992+
<sup>^</sup> You need to set the architecture name via `--hf-overrides` to match the one in vLLM.
993+
&nbsp;&nbsp;&nbsp;&nbsp;• For example, to use DeepSeek-VL2 series models:
994+
&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;&nbsp;`--hf-overrides '{"architectures": ["DeepseekVLV2ForCausalLM"]}'`
995+
<sup>E</sup> Pre-computed embeddings can be inputted for this modality.
989996
<sup>+</sup> Multiple items can be inputted per text prompt for this modality.
990997

991998
:::{important}

examples/offline_inference/vision_language.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,42 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
582582
)
583583

584584

585+
def run_llama4(questions: list[str], modality: str):
586+
assert modality == "image"
587+
588+
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
589+
590+
engine_args = EngineArgs(
591+
model=model_name,
592+
max_model_len=8192,
593+
max_num_seqs=4,
594+
tensor_parallel_size=8,
595+
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
596+
gpu_memory_utilization=0.4,
597+
)
598+
599+
tokenizer = AutoTokenizer.from_pretrained(model_name)
600+
messages = [[{
601+
"role":
602+
"user",
603+
"content": [{
604+
"type": "image"
605+
}, {
606+
"type": "text",
607+
"text": f"{question}"
608+
}]
609+
}] for question in questions]
610+
prompts = tokenizer.apply_chat_template(messages,
611+
add_generation_prompt=True,
612+
tokenize=False)
613+
stop_token_ids = None
614+
return ModelRequestData(
615+
engine_args=engine_args,
616+
prompts=prompts,
617+
stop_token_ids=stop_token_ids,
618+
)
619+
620+
585621
# Molmo
586622
def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
587623
assert modality == "image"
@@ -907,6 +943,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
907943
"minicpmv": run_minicpmv,
908944
"mistral3": run_mistral3,
909945
"mllama": run_mllama,
946+
"llama4": run_llama4,
910947
"molmo": run_molmo,
911948
"NVLM_D": run_nvlm_d,
912949
"paligemma": run_paligemma,

examples/offline_inference/vision_language_multi_image.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
253253
)
254254

255255

256+
def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
257+
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
258+
259+
engine_args = EngineArgs(
260+
model=model_name,
261+
max_model_len=8192,
262+
max_num_seqs=4,
263+
tensor_parallel_size=8,
264+
limit_mm_per_prompt={"image": len(image_urls)},
265+
)
266+
267+
placeholders = [{"type": "image", "image": url} for url in image_urls]
268+
messages = [{
269+
"role":
270+
"user",
271+
"content": [
272+
*placeholders,
273+
{
274+
"type": "text",
275+
"text": question
276+
},
277+
],
278+
}]
279+
280+
processor = AutoProcessor.from_pretrained(model_name)
281+
282+
prompt = processor.apply_chat_template(messages,
283+
tokenize=False,
284+
add_generation_prompt=True)
285+
286+
return ModelRequestData(
287+
engine_args=engine_args,
288+
prompt=prompt,
289+
image_data=[fetch_image(url) for url in image_urls],
290+
)
291+
292+
256293
def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
257294
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
258295

@@ -567,6 +604,7 @@ def load_qwen2_5_vl(question: str, image_urls: list[str]) -> ModelRequestData:
567604
"h2ovl_chat": load_h2ovl,
568605
"idefics3": load_idefics3,
569606
"internvl_chat": load_internvl,
607+
"llama4": load_llama4,
570608
"mistral3": load_mistral3,
571609
"mllama": load_mllama,
572610
"NVLM_D": load_nvlm_d,

requirements/common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ requests >= 2.26.0
66
tqdm
77
blake3
88
py-cpuinfo
9-
transformers >= 4.50.3
9+
transformers >= 4.51.0
1010
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
1111
tokenizers >= 0.19.1 # Required for Llama 3.
1212
protobuf # Required by LlamaTokenizer.

requirements/test.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ mistral_common[opencv] >= 1.5.4 # required for pixtral test
3030
opencv-python-headless >= 4.11.0 # required for video test
3131
datamodel_code_generator # required for minicpm3 test
3232
lm-eval[api]==0.4.8 # required for model evaluation test
33-
transformers==4.50.3
33+
transformers==4.51.0
3434
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
3535
# quantization
3636
bitsandbytes>=0.45.3

requirements/test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ tqdm==4.66.6
645645
# transformers
646646
tqdm-multiprocess==0.0.11
647647
# via lm-eval
648-
transformers==4.50.3
648+
transformers==4.51.0
649649
# via
650650
# -r requirements/test.in
651651
# genai-perf

tests/models/decoder_only/vision_language/test_models.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,22 @@
536536
limit_mm_per_prompt={"image": 1},
537537
)],
538538
),
539+
"llama4": VLMTestInfo(
540+
models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"],
541+
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501
542+
img_idx_to_prompt=lambda _: "<|image|>",
543+
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
544+
distributed_executor_backend="mp",
545+
image_size_factors=[(.25, 0.5, 1.0)],
546+
hf_model_kwargs={"device_map": "auto"},
547+
max_model_len=8192,
548+
max_num_seqs=4,
549+
dtype="bfloat16",
550+
auto_cls=AutoModelForImageTextToText,
551+
tensor_parallel_size=8,
552+
vllm_runner_kwargs={"gpu_memory_utilization": 0.8},
553+
marks=[large_gpu_mark(min_gb=80), multi_gpu_marks(num_gpus=8)],
554+
),
539555
}
540556
# yapf: enable
541557

tests/models/multimodal/processing/test_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def _test_processing_correctness_mistral(
280280
"Skywork/Skywork-R1V-38B",
281281
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
282282
"openai/whisper-large-v3",
283+
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
283284
])
284285
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
285286
@pytest.mark.parametrize("num_batches", [32])
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Tests for Llama4's multimodal preprocessing kwargs."""
3+
4+
import pytest
5+
6+
from vllm.multimodal import MULTIMODAL_REGISTRY
7+
from vllm.transformers_utils.tokenizer import encode_tokens
8+
9+
from ....conftest import _ImageAssets
10+
from ...utils import build_model_context
11+
12+
13+
@pytest.mark.parametrize("model_id",
14+
["meta-llama/Llama-4-Scout-17B-16E-Instruct"])
15+
@pytest.mark.parametrize("mm_processor_kwargs", [{}])
16+
@pytest.mark.parametrize("num_imgs", [1, 5])
17+
@pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False])
18+
@pytest.mark.parametrize("tokenized_prompt", [True, False])
19+
def test_processor_override(
20+
image_assets: _ImageAssets,
21+
model_id: str,
22+
mm_processor_kwargs: dict,
23+
num_imgs: int,
24+
disable_mm_preprocessor_cache: bool,
25+
tokenized_prompt: bool,
26+
):
27+
"""Ensure llama4 processor works properly."""
28+
ctx = build_model_context(
29+
model_id,
30+
mm_processor_kwargs=mm_processor_kwargs,
31+
limit_mm_per_prompt={"image": num_imgs},
32+
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
33+
)
34+
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
35+
config = processor.info.get_hf_config()
36+
tokenizer = processor.info.get_tokenizer()
37+
hf_processor = processor.info.get_hf_processor()
38+
vocab = tokenizer.get_vocab()
39+
40+
prompt = "<|begin_of_text|><|header_start|>user<|header_end|>" \
41+
+ "<|image|>" * num_imgs \
42+
+ "<|eot|><|header_start|>assistant<|header_end|>"
43+
mm_data = {
44+
"image": [
45+
image_assets[(i % len(image_assets))].pil_image
46+
for i in range(num_imgs)
47+
]
48+
}
49+
if tokenized_prompt:
50+
prompt = encode_tokens(tokenizer, prompt)
51+
52+
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
53+
mm_kwargs = processed_inputs["mm_kwargs"]
54+
55+
# place holder replacements
56+
prompt_token_ids = processed_inputs["prompt_token_ids"]
57+
assert prompt_token_ids.count(config.boi_token_index) == num_imgs
58+
assert prompt_token_ids.count(config.eoi_token_index) == num_imgs
59+
assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs
60+
aspect_ratios = mm_kwargs["aspect_ratios"]
61+
num_x_separators = num_y_separators = 0
62+
for tiles_y, tiles_x in aspect_ratios:
63+
if tiles_x * tiles_y > 1:
64+
num_x_separators += (tiles_x - 1) * tiles_y
65+
num_y_separators += tiles_y
66+
assert prompt_token_ids.count(vocab[hf_processor.tile_token]) \
67+
== num_x_separators
68+
assert prompt_token_ids.count(vocab[hf_processor.tile_global_token]) \
69+
== num_y_separators
70+
71+
# image token offsets
72+
img_locs = processed_inputs["mm_placeholders"].get("image", [])
73+
assert len(img_locs) == num_imgs
74+
assert [img_loc["offset"] for img_loc in img_locs] == \
75+
[i for i, v in enumerate(prompt_token_ids) \
76+
if v == config.boi_token_index]
77+
78+
# patch sizes and masks
79+
assert prompt_token_ids.count(config.image_token_index) \
80+
== sum(img_patch.sum() for img_patch in mm_kwargs["embed_is_patch"])
81+
patch_token_id = vocab[hf_processor.img_patch_token]
82+
num_patches = processed_inputs["prompt_token_ids"].count(patch_token_id)
83+
mm_counts = {"image": num_imgs}
84+
assert num_patches / num_imgs <= \
85+
processor.info.get_mm_max_tokens_per_item(32768, mm_counts)["image"]
86+
num_patches_per_chunk = processor.info.get_patch_per_chunk(
87+
config.vision_config)
88+
assert prompt_token_ids.count(config.image_token_index) \
89+
== mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk
90+
assert mm_kwargs["pixel_values"].shape[0] \
91+
== mm_kwargs["patches_per_image"].sum()
92+
93+
for embed_is_patch, aspect_ratio in zip(mm_kwargs["embed_is_patch"],
94+
mm_kwargs["aspect_ratios"]):
95+
assert embed_is_patch.shape[0] == \
96+
len(tokenizer.encode(
97+
hf_processor._prompt_split_image(
98+
aspect_ratio, num_patches_per_chunk),
99+
add_special_tokens=False))

0 commit comments

Comments
 (0)