Skip to content

Commit 08d26a1

Browse files
authored
[Model] Use merge_by_field_config for MM models (Ovis family) (#26308)
Signed-off-by: Isotr0py <[email protected]>
1 parent 63773a6 commit 08d26a1

File tree

5 files changed

+80
-75
lines changed

5 files changed

+80
-75
lines changed

examples/offline_inference/vision_language.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,14 +1140,10 @@ def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData:
11401140
elif modality == "video":
11411141
placeholder = "<video>"
11421142

1143-
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
1144-
messages = [
1145-
[{"role": "user", "content": f"{placeholder}\n{question}"}]
1143+
prompts = [
1144+
f"<|im_start|>user\n\n{placeholder}\n{question}<|im_end|>\n<|im_start|>assistant\n"
11461145
for question in questions
11471146
]
1148-
prompts = tokenizer.apply_chat_template(
1149-
messages, tokenize=False, add_generation_prompt=True
1150-
)
11511147

11521148
return ModelRequestData(
11531149
engine_args=engine_args,

examples/offline_inference/vision_language_multi_image.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -713,11 +713,9 @@ def load_ovis2_5(question: str, image_urls: list[str]) -> ModelRequestData:
713713
placeholders = "\n".join(
714714
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
715715
)
716-
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]
717-
718-
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
719-
prompt = tokenizer.apply_chat_template(
720-
messages, tokenize=False, add_generation_prompt=True
716+
prompt = (
717+
f"<|im_start|>user\n\n{placeholders}\n{question}<|im_end|>\n"
718+
"<|im_start|>assistant\n"
721719
)
722720

723721
return ModelRequestData(

vllm/model_executor/models/ovis.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -217,17 +217,17 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
217217
class OvisImagePatchInputs(TensorSchema):
218218
"""
219219
Dimensions:
220-
- batch_patches: Batch size * number of patches
221-
- patch_size: patch_size_x * patch_size_y * num_channels
220+
- bnp: Batch size * number of images * number of patches
221+
- h: Height of each patch
222+
- w: Width of each patch
222223
- patch_indicators: Batch size * (number of patches + 1)
223-
- patches_per_image: List of number of total patches for each image
224-
in the batch.
224+
- bn: Batch size * number of images
225225
"""
226226

227227
type: Literal["image_patches"]
228-
flat_data: Annotated[torch.Tensor, TensorShape("batch_patches", "patch_size")]
228+
flat_data: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
229229
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
230-
patches_per_image: Annotated[list[int], TensorShape("num_patches_per_image")]
230+
patches_per_image: Annotated[list[int], TensorShape("bn")]
231231
# This is used to restore the first two dimensions of `flat_data`.
232232

233233

@@ -366,7 +366,7 @@ def _call_hf_processor(
366366
self.image_indicators_to_visual_tokens(indicator)
367367
for indicator in image_indicators
368368
]
369-
processed_outputs["indicator_tokens"] = indicator_tokens
369+
processed_outputs["indicator_tokens"] = torch.tensor(indicator_tokens)
370370
return processed_outputs
371371

372372
def _apply_hf_processor_tokens_only(
@@ -414,6 +414,8 @@ def get_replacement_ovis(item_idx: int):
414414
dummy_inputs=OvisDummyInputsBuilder,
415415
)
416416
class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
417+
merge_by_field_config = True
418+
417419
@classmethod
418420
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
419421
if modality.startswith("image"):
@@ -470,14 +472,11 @@ def _parse_and_validate_image_input(
470472
f"Got type: {type(pixel_values)}"
471473
)
472474

473-
flat_data = flatten_bn(pixel_values, concat=True)
474-
if flat_data.ndim >= 3:
475-
flat_data = flat_data.flatten(start_dim=1)
476475
return OvisImagePatchInputs(
477476
type="image_patches",
478-
flat_data=flat_data,
479-
patches_per_image=[x.shape[0] for x in flatten_bn(pixel_values)],
480-
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True),
477+
flat_data=flatten_bn(pixel_values, concat=True),
478+
patches_per_image=[x.shape[0] for x in pixel_values],
479+
indicator_tokens=flatten_bn(indicator_tokens, concat=True),
481480
)
482481

483482
raise AssertionError("This line should be unreachable.")

vllm/model_executor/models/ovis2_5.py

Lines changed: 62 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from collections.abc import Iterable, Mapping
66
from functools import partial
7-
from typing import Literal, Optional, TypedDict, Union
7+
from typing import Annotated, Literal, Optional, Union
88

99
import torch
1010
import torch.nn as nn
@@ -14,7 +14,7 @@
1414
from vllm.config.multimodal import BaseDummyOptions
1515
from vllm.model_executor.layers.linear import ReplicatedLinear
1616
from vllm.model_executor.layers.quantization import QuantizationConfig
17-
from vllm.model_executor.models.ovis import OvisImagePatchInputs, VisualEmbedding
17+
from vllm.model_executor.models.ovis import VisualEmbedding
1818
from vllm.model_executor.models.siglip2navit import Siglip2NavitModel
1919
from vllm.model_executor.models.utils import (
2020
AutoWeightsLoader,
@@ -37,6 +37,7 @@
3737
from vllm.multimodal.profiling import BaseDummyInputsBuilder
3838
from vllm.sequence import IntermediateTensors
3939
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
40+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
4041

4142
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
4243

@@ -58,36 +59,38 @@
5859
}
5960

6061

61-
class OvisVideoPatchInputs(TypedDict):
62-
type: Literal["video_patches"]
63-
flat_data: torch.Tensor
62+
class Ovis2_5ImagePatchInputs(TensorSchema):
6463
"""
65-
Shape:
66-
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
64+
Dimensions:
65+
- bnp: Batch size * number of images * number of patches
66+
- patch_size: patch_size_x * patch_size_y * num_channels
67+
- patch_indicators: Batch size * (number of patches + 1)
68+
- bn: Batch size * number of images
6769
"""
6870

69-
indicator_tokens: torch.Tensor
70-
"""
71-
Shape:
72-
`(batch_size * (num_patches + 1))`
73-
"""
71+
type: Literal["image_patches"]
72+
flat_data: Annotated[torch.Tensor, TensorShape("bnp", "patch_size")]
73+
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
74+
patches_per_item: Annotated[list[int], TensorShape("bn")]
75+
grids: Annotated[torch.Tensor, TensorShape("bn", 3)]
76+
# This is used to restore the first two dimensions of `flat_data`.
77+
7478

75-
patches_per_image: list[int]
79+
class Ovis2_5VideoPatchInputs(TensorSchema):
7680
"""
77-
List of number of total patches for each frame in the video.
78-
This is used to restore the first two dimensions of `flat_data`.
81+
Dimensions:
82+
- bnp: Batch size * number of videos * number of patches
83+
- patch_size: patch_size_x * patch_size_y * num_channels
84+
- patch_indicators: Batch size * (number of patches + 1)
85+
- bn: Batch size * number of videos
7986
"""
8087

81-
82-
def _ovis2_5_field_config():
83-
return dict(
84-
pixel_values=MultiModalFieldConfig.batched("image"),
85-
grids=MultiModalFieldConfig.batched("image"),
86-
indicator_tokens=MultiModalFieldConfig.batched("image"),
87-
video_pixel_values=MultiModalFieldConfig.batched("video"),
88-
video_indicator_tokens=MultiModalFieldConfig.batched("video"),
89-
video_grids=MultiModalFieldConfig.batched("video"),
90-
)
88+
type: Literal["video_patches"]
89+
flat_data: Annotated[torch.Tensor, TensorShape("bnp", "patch_size")]
90+
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
91+
patches_per_item: Annotated[list[int], TensorShape("bn")]
92+
grids: Annotated[torch.Tensor, TensorShape("bn", 3)]
93+
# This is used to restore the first two dimensions of `flat_data`.
9194

9295

9396
class VisualTokenizer(torch.nn.Module):
@@ -380,7 +383,7 @@ def _call_hf_processor(
380383
self.visual_indicators_to_visual_tokens(indicator)
381384
for indicator in visual_indicators
382385
]
383-
processed_outputs["video_indicator_tokens"] = indicator_tokens
386+
processed_outputs["video_indicator_tokens"] = torch.tensor(indicator_tokens)
384387
if "images" in mm_data:
385388
visual_indicators = [
386389
hf_processor.construct_visual_indicators((1, 1, 1), False)
@@ -391,7 +394,7 @@ def _call_hf_processor(
391394
for indicator in visual_indicators
392395
]
393396

394-
processed_outputs["indicator_tokens"] = indicator_tokens
397+
processed_outputs["indicator_tokens"] = torch.tensor(indicator_tokens)
395398
return processed_outputs
396399

397400
def _apply_hf_processor_tokens_only(
@@ -405,7 +408,14 @@ def _get_mm_fields_config(
405408
hf_inputs: BatchFeature,
406409
hf_processor_mm_kwargs: Mapping[str, object],
407410
) -> Mapping[str, MultiModalFieldConfig]:
408-
return _ovis2_5_field_config()
411+
return dict(
412+
pixel_values=MultiModalFieldConfig.batched("image"),
413+
grids=MultiModalFieldConfig.batched("image"),
414+
indicator_tokens=MultiModalFieldConfig.batched("image"),
415+
video_pixel_values=MultiModalFieldConfig.batched("video"),
416+
video_indicator_tokens=MultiModalFieldConfig.batched("video"),
417+
video_grids=MultiModalFieldConfig.batched("video"),
418+
)
409419

410420
def _get_prompt_updates(
411421
self,
@@ -441,6 +451,8 @@ def get_replacement_ovis(item_idx, modality: str):
441451
dummy_inputs=Ovis2_5DummyInputsBuilder,
442452
)
443453
class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
454+
merge_by_field_config = True
455+
444456
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
445457
super().__init__()
446458
config = vllm_config.model_config.hf_config
@@ -470,7 +482,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
470482

471483
def _parse_and_validate_image_input(
472484
self, **kwargs: object
473-
) -> Optional[OvisImagePatchInputs]:
485+
) -> Optional[Ovis2_5ImagePatchInputs]:
474486
pixel_values = kwargs.pop("pixel_values", None)
475487
indicator_tokens = kwargs.pop("indicator_tokens", None)
476488
grids = kwargs.pop("grids", None)
@@ -489,22 +501,22 @@ def _parse_and_validate_image_input(
489501
f"Got type: {type(indicator_tokens)}"
490502
)
491503

492-
return OvisImagePatchInputs(
504+
return Ovis2_5ImagePatchInputs(
493505
type="image_patches",
494-
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
495-
patches_per_image=[
506+
flat_data=flatten_bn(pixel_values, concat=True),
507+
patches_per_item=[
496508
x.shape[0] // (self.config.vit_config.hidden_stride**2)
497-
for x in flatten_bn(pixel_values)
509+
for x in pixel_values
498510
],
499-
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True),
500-
grids=flatten_bn(flatten_bn(grids), concat=True),
511+
indicator_tokens=flatten_bn(indicator_tokens, concat=True),
512+
grids=flatten_bn(grids, concat=True),
501513
)
502514

503515
raise AssertionError("This line should be unreachable.")
504516

505517
def _parse_and_validate_video_input(
506518
self, **kwargs: object
507-
) -> Optional[OvisImagePatchInputs]:
519+
) -> Optional[Ovis2_5VideoPatchInputs]:
508520
pixel_values = kwargs.pop("video_pixel_values", None)
509521
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
510522
grids = kwargs.pop("video_grids", None)
@@ -523,26 +535,26 @@ def _parse_and_validate_video_input(
523535
f"Got type: {type(indicator_tokens)}"
524536
)
525537

526-
return OvisVideoPatchInputs(
538+
return Ovis2_5VideoPatchInputs(
527539
type="video_patches",
528-
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
529-
patches_per_image=[
540+
flat_data=flatten_bn(pixel_values, concat=True),
541+
patches_per_item=[
530542
x.shape[0] // (self.config.vit_config.hidden_stride**2)
531-
for x in flatten_bn(pixel_values)
543+
for x in pixel_values
532544
],
533-
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True),
534-
grids=flatten_bn(flatten_bn(grids), concat=True),
545+
indicator_tokens=flatten_bn(indicator_tokens, concat=True),
546+
grids=flatten_bn(grids, concat=True),
535547
)
536548

537549
raise AssertionError("This line should be unreachable.")
538550

539-
def _process_image_input(
540-
self, image_input: Union[OvisImagePatchInputs, OvisVideoPatchInputs]
551+
def _process_visual_input(
552+
self, visual_input: Union[Ovis2_5ImagePatchInputs, Ovis2_5VideoPatchInputs]
541553
) -> MultiModalEmbeddings:
542-
image_patches_flat = image_input["flat_data"]
543-
patches_per_image = image_input["patches_per_image"]
544-
indicator_tokens = image_input["indicator_tokens"]
545-
grid_thws = image_input["grids"]
554+
image_patches_flat = visual_input["flat_data"]
555+
patches_per_image = visual_input["patches_per_item"]
556+
indicator_tokens = visual_input["indicator_tokens"]
557+
grid_thws = visual_input["grids"]
546558

547559
indicator_per_image = list(
548560
map(lambda x: 2 if x > 1 else x + 2, patches_per_image)
@@ -604,11 +616,11 @@ def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
604616
for modality in modalities:
605617
if modality == "images":
606618
image_input = modalities["images"]
607-
vision_embeddings = self._process_image_input(image_input)
619+
vision_embeddings = self._process_visual_input(image_input)
608620
multimodal_embeddings += vision_embeddings
609621
if modality == "videos":
610622
video_input = modalities["videos"]
611-
video_embeddings = self._process_image_input(video_input)
623+
video_embeddings = self._process_visual_input(video_input)
612624
multimodal_embeddings += video_embeddings
613625

614626
return multimodal_embeddings

vllm/transformers_utils/processors/ovis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ def _get_best_grid(img, side):
408408
crops.insert(0, image)
409409
pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0)
410410
image_placeholders = self.construct_image_placeholders(grid)
411-
return pixel_values, image_placeholders, grid
411+
return torch.tensor(pixel_values), image_placeholders, torch.tensor(grid)
412412

413413
def batch_decode(self, *args, **kwargs):
414414
"""

0 commit comments

Comments
 (0)