Skip to content

Commit 19a00eb

Browse files
[Model] Use merge_by_field_config for MM models (Llava family) (#26280)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 391612e commit 19a00eb

File tree

9 files changed

+99
-173
lines changed

9 files changed

+99
-173
lines changed

examples/offline_inference/vision_language_multi_image.py

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -371,13 +371,14 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
371371
)
372372

373373

374-
def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
375-
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
374+
def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData:
375+
model_name = "Kwai-Keye/Keye-VL-8B-Preview"
376376

377377
engine_args = EngineArgs(
378378
model=model_name,
379-
max_model_len=131072,
380-
tensor_parallel_size=8,
379+
trust_remote_code=True,
380+
max_model_len=8192,
381+
max_num_seqs=5,
381382
limit_mm_per_prompt={"image": len(image_urls)},
382383
)
383384

@@ -389,29 +390,32 @@ def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
389390
*placeholders,
390391
{"type": "text", "text": question},
391392
],
392-
}
393+
},
393394
]
394395

395-
processor = AutoProcessor.from_pretrained(model_name)
396+
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
396397

397398
prompt = processor.apply_chat_template(
398399
messages, tokenize=False, add_generation_prompt=True
399400
)
400401

402+
image_data = [fetch_image(url) for url in image_urls]
403+
401404
return ModelRequestData(
402405
engine_args=engine_args,
403406
prompt=prompt,
404-
image_data=[fetch_image(url) for url in image_urls],
407+
image_data=image_data,
405408
)
406409

407410

408-
def load_llava(question: str, image_urls: list[str]) -> ModelRequestData:
409-
# NOTE: CAUTION! Original Llava models wasn't really trained on multi-image inputs,
410-
# it will generate poor response for multi-image inputs!
411-
model_name = "llava-hf/llava-1.5-7b-hf"
411+
def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
412+
model_name = "Kwai-Keye/Keye-VL-1_5-8B"
413+
412414
engine_args = EngineArgs(
413415
model=model_name,
414-
max_num_seqs=16,
416+
trust_remote_code=True,
417+
max_model_len=32768,
418+
max_num_seqs=5,
415419
limit_mm_per_prompt={"image": len(image_urls)},
416420
)
417421

@@ -423,28 +427,32 @@ def load_llava(question: str, image_urls: list[str]) -> ModelRequestData:
423427
*placeholders,
424428
{"type": "text", "text": question},
425429
],
426-
}
430+
},
427431
]
428432

429-
processor = AutoProcessor.from_pretrained(model_name)
433+
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
430434

431435
prompt = processor.apply_chat_template(
432436
messages, tokenize=False, add_generation_prompt=True
433437
)
434438

439+
image_data = [fetch_image(url) for url in image_urls]
440+
435441
return ModelRequestData(
436442
engine_args=engine_args,
437443
prompt=prompt,
438-
image_data=[fetch_image(url) for url in image_urls],
444+
image_data=image_data,
439445
)
440446

441447

442-
def load_llava_next(question: str, image_urls: list[str]) -> ModelRequestData:
443-
model_name = "llava-hf/llava-v1.6-mistral-7b-hf"
448+
def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
449+
model_name = "moonshotai/Kimi-VL-A3B-Instruct"
450+
444451
engine_args = EngineArgs(
445452
model=model_name,
446-
max_model_len=8192,
447-
max_num_seqs=16,
453+
trust_remote_code=True,
454+
max_model_len=4096,
455+
max_num_seqs=4,
448456
limit_mm_per_prompt={"image": len(image_urls)},
449457
)
450458

@@ -459,7 +467,7 @@ def load_llava_next(question: str, image_urls: list[str]) -> ModelRequestData:
459467
}
460468
]
461469

462-
processor = AutoProcessor.from_pretrained(model_name)
470+
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
463471

464472
prompt = processor.apply_chat_template(
465473
messages, tokenize=False, add_generation_prompt=True
@@ -472,12 +480,13 @@ def load_llava_next(question: str, image_urls: list[str]) -> ModelRequestData:
472480
)
473481

474482

475-
def load_llava_onevision(question: str, image_urls: list[str]) -> ModelRequestData:
476-
model_name = "llava-hf/llava-onevision-qwen2-7b-ov-hf"
483+
def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
484+
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
485+
477486
engine_args = EngineArgs(
478487
model=model_name,
479-
max_model_len=16384,
480-
max_num_seqs=16,
488+
max_model_len=131072,
489+
tensor_parallel_size=8,
481490
limit_mm_per_prompt={"image": len(image_urls)},
482491
)
483492

@@ -505,14 +514,13 @@ def load_llava_onevision(question: str, image_urls: list[str]) -> ModelRequestDa
505514
)
506515

507516

508-
def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData:
509-
model_name = "Kwai-Keye/Keye-VL-8B-Preview"
510-
517+
def load_llava(question: str, image_urls: list[str]) -> ModelRequestData:
518+
# NOTE: CAUTION! Original Llava models wasn't really trained on multi-image inputs,
519+
# it will generate poor response for multi-image inputs!
520+
model_name = "llava-hf/llava-1.5-7b-hf"
511521
engine_args = EngineArgs(
512522
model=model_name,
513-
trust_remote_code=True,
514-
max_model_len=8192,
515-
max_num_seqs=5,
523+
max_num_seqs=16,
516524
limit_mm_per_prompt={"image": len(image_urls)},
517525
)
518526

@@ -524,32 +532,28 @@ def load_keye_vl(question: str, image_urls: list[str]) -> ModelRequestData:
524532
*placeholders,
525533
{"type": "text", "text": question},
526534
],
527-
},
535+
}
528536
]
529537

530-
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
538+
processor = AutoProcessor.from_pretrained(model_name)
531539

532540
prompt = processor.apply_chat_template(
533541
messages, tokenize=False, add_generation_prompt=True
534542
)
535543

536-
image_data = [fetch_image(url) for url in image_urls]
537-
538544
return ModelRequestData(
539545
engine_args=engine_args,
540546
prompt=prompt,
541-
image_data=image_data,
547+
image_data=[fetch_image(url) for url in image_urls],
542548
)
543549

544550

545-
def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
546-
model_name = "Kwai-Keye/Keye-VL-1_5-8B"
547-
551+
def load_llava_next(question: str, image_urls: list[str]) -> ModelRequestData:
552+
model_name = "llava-hf/llava-v1.6-mistral-7b-hf"
548553
engine_args = EngineArgs(
549554
model=model_name,
550-
trust_remote_code=True,
551-
max_model_len=32768,
552-
max_num_seqs=5,
555+
max_model_len=8192,
556+
max_num_seqs=16,
553557
limit_mm_per_prompt={"image": len(image_urls)},
554558
)
555559

@@ -561,32 +565,28 @@ def load_keye_vl1_5(question: str, image_urls: list[str]) -> ModelRequestData:
561565
*placeholders,
562566
{"type": "text", "text": question},
563567
],
564-
},
568+
}
565569
]
566570

567-
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
571+
processor = AutoProcessor.from_pretrained(model_name)
568572

569573
prompt = processor.apply_chat_template(
570574
messages, tokenize=False, add_generation_prompt=True
571575
)
572576

573-
image_data = [fetch_image(url) for url in image_urls]
574-
575577
return ModelRequestData(
576578
engine_args=engine_args,
577579
prompt=prompt,
578-
image_data=image_data,
580+
image_data=[fetch_image(url) for url in image_urls],
579581
)
580582

581583

582-
def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
583-
model_name = "moonshotai/Kimi-VL-A3B-Instruct"
584-
584+
def load_llava_onevision(question: str, image_urls: list[str]) -> ModelRequestData:
585+
model_name = "llava-hf/llava-onevision-qwen2-7b-ov-hf"
585586
engine_args = EngineArgs(
586587
model=model_name,
587-
trust_remote_code=True,
588-
max_model_len=4096,
589-
max_num_seqs=4,
588+
max_model_len=16384,
589+
max_num_seqs=16,
590590
limit_mm_per_prompt={"image": len(image_urls)},
591591
)
592592

@@ -601,7 +601,7 @@ def load_kimi_vl(question: str, image_urls: list[str]) -> ModelRequestData:
601601
}
602602
]
603603

604-
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
604+
processor = AutoProcessor.from_pretrained(model_name)
605605

606606
prompt = processor.apply_chat_template(
607607
messages, tokenize=False, add_generation_prompt=True

vllm/model_executor/models/llava.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
from .utils import (
5858
AutoWeightsLoader,
5959
WeightsMapper,
60-
flatten_bn,
6160
init_vllm_registered_model,
6261
maybe_prefix,
6362
)
@@ -507,6 +506,8 @@ def init_vision_tower_for_llava(
507506
dummy_inputs=LlavaDummyInputsBuilder,
508507
)
509508
class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
509+
merge_by_field_config = True
510+
510511
packed_modules_mapping = {
511512
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
512513
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -592,37 +593,26 @@ def _parse_and_validate_image_input(
592593
return None
593594

594595
if pixel_values is not None:
595-
if not isinstance(pixel_values, (torch.Tensor, list)):
596-
raise ValueError(
597-
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
598-
)
599-
600596
if self.config.vision_config.model_type == "pixtral":
601597
return PixtralHFImagePixelInputs(
602598
type="pixel_values_pixtral",
603-
pixel_values=flatten_bn(pixel_values),
599+
pixel_values=pixel_values,
604600
)
605601

606602
expected_h = expected_w = self.config.vision_config.image_size
607603
return LlavaImagePixelInputs(
608604
type="pixel_values",
609-
pixel_values=flatten_bn(pixel_values, concat=True),
605+
pixel_values=pixel_values,
610606
resolve_bindings={"h": expected_h, "w": expected_w},
611607
)
612608

613609
if image_embeds is not None:
614-
if not isinstance(image_embeds, (torch.Tensor, list)):
615-
raise ValueError(
616-
"Incorrect type of image embeddings. "
617-
f"Got type: {type(image_embeds)}"
618-
)
619-
620610
if self.config.vision_config.model_type == "pixtral":
621611
raise ValueError("Pixtral-HF does not support image_embeds.")
622612

623613
return LlavaImageEmbeddingInputs(
624614
type="image_embeds",
625-
data=flatten_bn(image_embeds, concat=True),
615+
data=image_embeds,
626616
)
627617

628618
raise AssertionError("This line should be unreachable.")

vllm/model_executor/models/llava_next.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from .utils import (
3535
AutoWeightsLoader,
3636
WeightsMapper,
37-
flatten_bn,
3837
init_vllm_registered_model,
3938
maybe_prefix,
4039
)
@@ -222,6 +221,8 @@ def _get_mm_fields_config(
222221
dummy_inputs=LlavaDummyInputsBuilder,
223222
)
224223
class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
224+
merge_by_field_config = True
225+
225226
hf_to_vllm_mapper = WeightsMapper(
226227
orig_to_new_prefix={
227228
# mapping for new names in checkpoint saved after transformers v4.52
@@ -302,36 +303,21 @@ def _parse_and_validate_image_input(
302303
return None
303304

304305
if pixel_values is not None:
305-
if not isinstance(pixel_values, (torch.Tensor, list)):
306-
raise ValueError(
307-
f"Incorrect type of pixel values. Got type: {type(pixel_values)}"
308-
)
309-
310-
if not isinstance(image_sizes, (torch.Tensor, list)):
311-
raise ValueError(
312-
f"Incorrect type of image sizes. Got type: {type(image_sizes)}"
313-
)
314-
315306
expected_h = expected_w = self.config.vision_config.image_size
316307
return LlavaNextImagePixelInputs(
317308
type="pixel_values",
318-
pixel_values=flatten_bn(pixel_values),
319-
image_sizes=flatten_bn(image_sizes, concat=True),
309+
pixel_values=pixel_values,
310+
image_sizes=image_sizes,
320311
resolve_bindings={
321312
"h": expected_h,
322313
"w": expected_w,
323314
},
324315
)
325316

326317
if image_embeds is not None:
327-
if not isinstance(image_embeds, torch.Tensor):
328-
raise ValueError(
329-
f"Incorrect type of image embeds. Got type: {type(image_embeds)}"
330-
)
331-
332318
return LlavaNextImageEmbeddingInputs(
333319
type="image_embeds",
334-
data=flatten_bn(image_embeds),
320+
data=image_embeds,
335321
)
336322

337323
raise AssertionError("This line should be unreachable.")

0 commit comments

Comments
 (0)