diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index c4033dd12558..5ec7845a122f 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -4,7 +4,7 @@ from collections.abc import Iterable, Mapping, Sequence from itertools import product from math import ceil, sqrt -from typing import Any, Literal, Optional, TypedDict, Union +from typing import Annotated, Any, Literal, Optional, Union import numpy as np import torch @@ -44,28 +44,48 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import Step3VisionEncoderConfig from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .utils import ( AutoWeightsLoader, WeightsMapper, - flatten_bn, init_vllm_registered_model, maybe_prefix, ) from .vision import run_dp_sharded_vision_model -class Step3VLImagePixelInputs(TypedDict): +class Step3VLImagePixelInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - c: Number of channels (3) + - h: Height + - w: Width + - bnp: Batch size * number of images * number of patches + - hp: Height of patch + - wp: Width of patch + """ + type: Literal["pixel_values"] - pixel_values: torch.Tensor - patch_pixel_values: Optional[torch.Tensor] - num_patches: list[int] + pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] + patch_pixel_values: Annotated[ + Optional[torch.Tensor], TensorShape("bnp", 3, "hp", "wp") + ] + num_patches: Annotated[torch.Tensor, TensorShape("bn")] + +class Step3VLImageEmbeddingInputs(TensorSchema): + """ + Dimensions: + - bn: Batch size * number of images + - f: Image feature size + - h: Hidden size (must match the hidden size of language model backbone) + """ -class Step3VLImageEmbeddingInputs(TypedDict): - type: Literal["image_embeds"] - image_embeds: torch.Tensor + type: Literal["image_embeds"] = "image_embeds" + data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")] Step3VLImageInputs = Union[Step3VLImagePixelInputs, Step3VLImageEmbeddingInputs] @@ -895,6 +915,8 @@ def forward( dummy_inputs=Step3VLDummyInputsBuilder, ) class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + merge_by_field_config = True + hf_to_vllm_mapper = WeightsMapper( orig_to_new_prefix={ "model.": "language_model.model.", @@ -982,41 +1004,22 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - pixel_values = flatten_bn(pixel_values, concat=True) - if pixel_values.dim() >= 3: - pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:]) - if patch_pixel_values is not None: - patch_pixel_values = flatten_bn(patch_pixel_values, concat=True) - patch_pixel_values = patch_pixel_values.view( - -1, *patch_pixel_values.shape[-3:] - ) - # Handle empty patch_pixel_values by setting to None - if patch_pixel_values.shape[0] == 0: - patch_pixel_values = None - num_patches = flatten_bn(num_patches, concat=True).tolist() - return Step3VLImagePixelInputs( type="pixel_values", - pixel_values=pixel_values.to(self.dtype).to(self.device), - patch_pixel_values=patch_pixel_values.to(self.dtype).to(self.device) + pixel_values=pixel_values.to(self.dtype), + patch_pixel_values=patch_pixel_values.to(self.dtype) if patch_pixel_values is not None else None, num_patches=num_patches, ) if image_embeds is not None: - if image_embeds.dim() == 2 or image_embeds.dim() >= 3: - image_embeds = image_embeds.view(-1, image_embeds.shape[-1]) - else: - raise ValueError( - f"Unexpected shape for image_embeds: {image_embeds.shape}" - ) - return Step3VLImageEmbeddingInputs( type="image_embeds", - image_embeds=image_embeds.to(self.dtype).to(self.device), + image_embeds=image_embeds.to(self.dtype), ) - return None + + raise AssertionError("This line should be unreachable.") def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor: B, P = image_features.shape[:2] diff --git a/vllm/model_executor/models/tarsier.py b/vllm/model_executor/models/tarsier.py index 482ad9cb7748..6a224fe9288b 100644 --- a/vllm/model_executor/models/tarsier.py +++ b/vllm/model_executor/models/tarsier.py @@ -47,11 +47,7 @@ from .clip import CLIPVisionModel from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP from .siglip import SiglipVisionModel -from .utils import ( - AutoWeightsLoader, - init_vllm_registered_model, - maybe_prefix, -) +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix from .vision import ( VisionEncoderInfo, get_num_selected_vision_tokens, diff --git a/vllm/model_executor/models/terratorch.py b/vllm/model_executor/models/terratorch.py index c7c82e9e10d1..13d2e8eacc01 100644 --- a/vllm/model_executor/models/terratorch.py +++ b/vllm/model_executor/models/terratorch.py @@ -87,12 +87,10 @@ def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]): if input.type == InputTypeEnum.tensor: fields[input_name] = "image" - mm_fields_config = {} - for field_name, field_modality in fields.items(): - mm_fields_config[field_name] = MultiModalFieldConfig.shared( - batch_size=1, modality=field_modality - ) - return mm_fields_config + return { + field_name: MultiModalFieldConfig.batched(modality=field_modality) + for field_name, field_modality in fields.items() + } return _terratorch_field_config @@ -192,9 +190,12 @@ def apply( ) -> MultiModalInputs: if "image" in mm_data: image_data = mm_data["image"] + image_data = {k: v.unsqueeze(0) for k, v in image_data.items()} else: image_data = mm_data - mm_data = {"image": mm_data} + image_data = {k: v.unsqueeze(0) for k, v in image_data.items()} + + mm_data = {"image": image_data} mm_items = self._to_mm_items(mm_data) tokenization_kwargs = tokenization_kwargs or {} @@ -226,6 +227,7 @@ def apply( dummy_inputs=TerratorchInputBuilder, ) class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal): + merge_by_field_config = True supports_multimodal_raw_input_only = True is_pooling_model = True