Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,14 +1140,10 @@ def run_ovis2_5(questions: list[str], modality: str) -> ModelRequestData:
elif modality == "video":
placeholder = "<video>"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
messages = [
[{"role": "user", "content": f"{placeholder}\n{question}"}]
prompts = [
f"<|im_start|>user\n\n{placeholder}\n{question}<|im_end|>\n<|im_start|>assistant\n"
for question in questions
]
prompts = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

return ModelRequestData(
engine_args=engine_args,
Expand Down
8 changes: 3 additions & 5 deletions examples/offline_inference/vision_language_multi_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,11 +713,9 @@ def load_ovis2_5(question: str, image_urls: list[str]) -> ModelRequestData:
placeholders = "\n".join(
f"Image-{i}: <image>\n" for i, _ in enumerate(image_urls, start=1)
)
messages = [{"role": "user", "content": f"{placeholders}\n{question}"}]

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
prompt = (
f"<|im_start|>user\n\n{placeholders}\n{question}<|im_end|>\n"
"<|im_start|>assistant\n"
)

return ModelRequestData(
Expand Down
18 changes: 9 additions & 9 deletions vllm/model_executor/models/ovis.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,15 @@ class OvisImagePatchInputs(TensorSchema):
"""
Dimensions:
- batch_patches: Batch size * number of patches
- patch_size: patch_size_x * patch_size_y * num_channels
- h: Height of each patch
- w: Width of each patch
- patch_indicators: Batch size * (number of patches + 1)
- patches_per_image: List of number of total patches for each image
in the batch.
"""

type: Literal["image_patches"]
flat_data: Annotated[torch.Tensor, TensorShape("batch_patches", "patch_size")]
flat_data: Annotated[torch.Tensor, TensorShape("batch_patches", 3, "h", "w")]
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
patches_per_image: Annotated[list[int], TensorShape("num_patches_per_image")]
# This is used to restore the first two dimensions of `flat_data`.
Expand Down Expand Up @@ -366,7 +367,7 @@ def _call_hf_processor(
self.image_indicators_to_visual_tokens(indicator)
for indicator in image_indicators
]
processed_outputs["indicator_tokens"] = indicator_tokens
processed_outputs["indicator_tokens"] = torch.tensor(indicator_tokens)
return processed_outputs

def _apply_hf_processor_tokens_only(
Expand Down Expand Up @@ -414,6 +415,8 @@ def get_replacement_ovis(item_idx: int):
dummy_inputs=OvisDummyInputsBuilder,
)
class Ovis(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True

@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
if modality.startswith("image"):
Expand Down Expand Up @@ -470,14 +473,11 @@ def _parse_and_validate_image_input(
f"Got type: {type(pixel_values)}"
)

flat_data = flatten_bn(pixel_values, concat=True)
if flat_data.ndim >= 3:
flat_data = flat_data.flatten(start_dim=1)
return OvisImagePatchInputs(
type="image_patches",
flat_data=flat_data,
patches_per_image=[x.shape[0] for x in flatten_bn(pixel_values)],
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True),
flat_data=flatten_bn(pixel_values, concat=True),
patches_per_image=[x.shape[0] for x in pixel_values],
indicator_tokens=flatten_bn(indicator_tokens, concat=True),
)

raise AssertionError("This line should be unreachable.")
Expand Down
94 changes: 53 additions & 41 deletions vllm/model_executor/models/ovis2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from collections.abc import Iterable, Mapping
from functools import partial
from typing import Literal, Optional, TypedDict, Union
from typing import Annotated, Literal, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -14,7 +14,7 @@
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.ovis import OvisImagePatchInputs, VisualEmbedding
from vllm.model_executor.models.ovis import VisualEmbedding
from vllm.model_executor.models.siglip2navit import Siglip2NavitModel
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
Expand All @@ -37,6 +37,7 @@
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP

Expand All @@ -58,36 +59,38 @@
}


class OvisVideoPatchInputs(TypedDict):
type: Literal["video_patches"]
flat_data: torch.Tensor
class Ovis2_5ImagePatchInputs(TensorSchema):
"""
Shape:
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
Dimensions:
- batch_patches: Batch size * number of patches
- patch_size: patch_size_x * patch_size_y * num_channels
- patch_indicators: Batch size * (number of patches + 1)
- patches_per_image: List of number of total patches for each image
in the batch.
"""

indicator_tokens: torch.Tensor
"""
Shape:
`(batch_size * (num_patches + 1))`
"""
type: Literal["image_patches"]
flat_data: Annotated[torch.Tensor, TensorShape("batch_patches", "patch_size")]
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
patches_per_image: Annotated[list[int], TensorShape("num_patches_per_image")]
# This is used to restore the first two dimensions of `flat_data`.


patches_per_image: list[int]
class Ovis2_5VideoPatchInputs(TensorSchema):
"""
List of number of total patches for each frame in the video.
This is used to restore the first two dimensions of `flat_data`.
Dimensions:
- batch_patches: Batch size * number of patches
- patch_size: patch_size_x * patch_size_y * num_channels
- patch_indicators: Batch size * (number of patches + 1)
- patches_per_image: List of number of total patches for each image
in the batch.
"""


def _ovis2_5_field_config():
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
grids=MultiModalFieldConfig.batched("image"),
indicator_tokens=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"),
video_indicator_tokens=MultiModalFieldConfig.batched("video"),
video_grids=MultiModalFieldConfig.batched("video"),
)
type: Literal["image_patches"]
flat_data: Annotated[torch.Tensor, TensorShape("batch_patches", "patch_size")]
indicator_tokens: Annotated[torch.Tensor, TensorShape("patch_indicators")]
patches_per_image: Annotated[list[int], TensorShape("num_patches_per_image")]
# This is used to restore the first two dimensions of `flat_data`.


class VisualTokenizer(torch.nn.Module):
Expand Down Expand Up @@ -380,7 +383,7 @@ def _call_hf_processor(
self.visual_indicators_to_visual_tokens(indicator)
for indicator in visual_indicators
]
processed_outputs["video_indicator_tokens"] = indicator_tokens
processed_outputs["video_indicator_tokens"] = torch.tensor(indicator_tokens)
if "images" in mm_data:
visual_indicators = [
hf_processor.construct_visual_indicators((1, 1, 1), False)
Expand All @@ -391,7 +394,7 @@ def _call_hf_processor(
for indicator in visual_indicators
]

processed_outputs["indicator_tokens"] = indicator_tokens
processed_outputs["indicator_tokens"] = torch.tensor(indicator_tokens)
return processed_outputs

def _apply_hf_processor_tokens_only(
Expand All @@ -405,7 +408,14 @@ def _get_mm_fields_config(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return _ovis2_5_field_config()
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
grids=MultiModalFieldConfig.batched("image"),
indicator_tokens=MultiModalFieldConfig.batched("image"),
video_pixel_values=MultiModalFieldConfig.batched("video"),
video_indicator_tokens=MultiModalFieldConfig.batched("video"),
video_grids=MultiModalFieldConfig.batched("video"),
)

def _get_prompt_updates(
self,
Expand Down Expand Up @@ -441,6 +451,8 @@ def get_replacement_ovis(item_idx, modality: str):
dummy_inputs=Ovis2_5DummyInputsBuilder,
)
class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
Expand Down Expand Up @@ -470,7 +482,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

def _parse_and_validate_image_input(
self, **kwargs: object
) -> Optional[OvisImagePatchInputs]:
) -> Optional[Ovis2_5ImagePatchInputs]:
pixel_values = kwargs.pop("pixel_values", None)
indicator_tokens = kwargs.pop("indicator_tokens", None)
grids = kwargs.pop("grids", None)
Expand All @@ -489,22 +501,22 @@ def _parse_and_validate_image_input(
f"Got type: {type(indicator_tokens)}"
)

return OvisImagePatchInputs(
return Ovis2_5ImagePatchInputs(
type="image_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
flat_data=flatten_bn(pixel_values, concat=True),
patches_per_image=[
x.shape[0] // (self.config.vit_config.hidden_stride**2)
for x in flatten_bn(pixel_values)
for x in pixel_values
],
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True),
grids=flatten_bn(flatten_bn(grids), concat=True),
indicator_tokens=flatten_bn(indicator_tokens, concat=True),
grids=flatten_bn(grids, concat=True),
)

raise AssertionError("This line should be unreachable.")

def _parse_and_validate_video_input(
self, **kwargs: object
) -> Optional[OvisImagePatchInputs]:
) -> Optional[Ovis2_5VideoPatchInputs]:
pixel_values = kwargs.pop("video_pixel_values", None)
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
grids = kwargs.pop("video_grids", None)
Expand All @@ -523,21 +535,21 @@ def _parse_and_validate_video_input(
f"Got type: {type(indicator_tokens)}"
)

return OvisVideoPatchInputs(
return Ovis2_5VideoPatchInputs(
type="video_patches",
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
flat_data=flatten_bn(pixel_values, concat=True),
patches_per_image=[
x.shape[0] // (self.config.vit_config.hidden_stride**2)
for x in flatten_bn(pixel_values)
for x in pixel_values
],
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens), concat=True),
grids=flatten_bn(flatten_bn(grids), concat=True),
indicator_tokens=flatten_bn(indicator_tokens, concat=True),
grids=flatten_bn(grids, concat=True),
)

raise AssertionError("This line should be unreachable.")

def _process_image_input(
self, image_input: Union[OvisImagePatchInputs, OvisVideoPatchInputs]
self, image_input: Union[Ovis2_5ImagePatchInputs, Ovis2_5VideoPatchInputs]
) -> MultiModalEmbeddings:
image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"]
Expand Down
2 changes: 1 addition & 1 deletion vllm/transformers_utils/processors/ovis.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def _get_best_grid(img, side):
crops.insert(0, image)
pixel_values = torch.cat([_preprocess(crop, side) for crop in crops], dim=0)
image_placeholders = self.construct_image_placeholders(grid)
return pixel_values, image_placeholders, grid
return torch.tensor(pixel_values), image_placeholders, torch.tensor(grid)

def batch_decode(self, *args, **kwargs):
"""
Expand Down