Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 39 additions & 33 deletions vllm/model_executor/models/step3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Iterable, Mapping, Sequence
from itertools import product
from math import ceil, sqrt
from typing import Any, Literal, Optional, TypedDict, Union
from typing import Annotated, Any, Literal, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -44,28 +44,51 @@
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import Step3VisionEncoderConfig
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (
AutoWeightsLoader,
WeightsMapper,
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
from .vision import run_dp_sharded_vision_model


class Step3VLImagePixelInputs(TypedDict):
class Step3VLImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height
- w: Width
- bnp: Batch size * number of images * number of patches
- hp: Height of patch
- wp: Width of patch

Note that `num_patches` may be different per batch and image,
in which case the data is passed as a list instead of a batched tensor.
"""

type: Literal["pixel_values"]
pixel_values: torch.Tensor
patch_pixel_values: Optional[torch.Tensor]
num_patches: list[int]
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
patch_pixel_values: Annotated[
Optional[torch.Tensor], TensorShape("bnp", 3, "hp", "wp")
]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]


class Step3VLImageEmbeddingInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- f: Image feature size
- h: Hidden size (must match the hidden size of language model backbone)
"""

class Step3VLImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
image_embeds: torch.Tensor
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]


Step3VLImageInputs = Union[Step3VLImagePixelInputs, Step3VLImageEmbeddingInputs]
Expand Down Expand Up @@ -895,6 +918,8 @@ def forward(
dummy_inputs=Step3VLDummyInputsBuilder,
)
class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True

hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.": "language_model.model.",
Expand Down Expand Up @@ -982,41 +1007,22 @@ def _parse_and_validate_image_input(
return None

if pixel_values is not None:
pixel_values = flatten_bn(pixel_values, concat=True)
if pixel_values.dim() >= 3:
pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:])
if patch_pixel_values is not None:
patch_pixel_values = flatten_bn(patch_pixel_values, concat=True)
patch_pixel_values = patch_pixel_values.view(
-1, *patch_pixel_values.shape[-3:]
)
# Handle empty patch_pixel_values by setting to None
if patch_pixel_values.shape[0] == 0:
patch_pixel_values = None
num_patches = flatten_bn(num_patches, concat=True).tolist()

return Step3VLImagePixelInputs(
type="pixel_values",
pixel_values=pixel_values.to(self.dtype).to(self.device),
patch_pixel_values=patch_pixel_values.to(self.dtype).to(self.device)
pixel_values=pixel_values.to(self.dtype),
patch_pixel_values=patch_pixel_values.to(self.dtype)
if patch_pixel_values is not None
else None,
num_patches=num_patches,
)

if image_embeds is not None:
if image_embeds.dim() == 2 or image_embeds.dim() >= 3:
image_embeds = image_embeds.view(-1, image_embeds.shape[-1])
else:
raise ValueError(
f"Unexpected shape for image_embeds: {image_embeds.shape}"
)

return Step3VLImageEmbeddingInputs(
type="image_embeds",
image_embeds=image_embeds.to(self.dtype).to(self.device),
image_embeds=image_embeds.to(self.dtype),
)
return None

raise AssertionError("This line should be unreachable.")

def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
B, P = image_features.shape[:2]
Expand Down
6 changes: 1 addition & 5 deletions vllm/model_executor/models/tarsier.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,7 @@
from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .siglip import SiglipVisionModel
from .utils import (
AutoWeightsLoader,
init_vllm_registered_model,
maybe_prefix,
)
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
from .vision import (
VisionEncoderInfo,
get_num_selected_vision_tokens,
Expand Down
16 changes: 9 additions & 7 deletions vllm/model_executor/models/terratorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,10 @@ def _terratorch_field_config(hf_inputs: Mapping[str, torch.Tensor]):
if input.type == InputTypeEnum.tensor:
fields[input_name] = "image"

mm_fields_config = {}
for field_name, field_modality in fields.items():
mm_fields_config[field_name] = MultiModalFieldConfig.shared(
batch_size=1, modality=field_modality
)
return mm_fields_config
return {
field_name: MultiModalFieldConfig.batched(modality=field_modality)
for field_name, field_modality in fields.items()
}

return _terratorch_field_config

Expand Down Expand Up @@ -192,9 +190,12 @@ def apply(
) -> MultiModalInputs:
if "image" in mm_data:
image_data = mm_data["image"]
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}
else:
image_data = mm_data
mm_data = {"image": mm_data}
image_data = {k: v.unsqueeze(0) for k, v in image_data.items()}

mm_data = {"image": image_data}

mm_items = self._to_mm_items(mm_data)
tokenization_kwargs = tokenization_kwargs or {}
Expand Down Expand Up @@ -226,6 +227,7 @@ def apply(
dummy_inputs=TerratorchInputBuilder,
)
class Terratorch(nn.Module, IsAttentionFree, SupportsMultiModal):
merge_by_field_config = True
supports_multimodal_raw_input_only = True
is_pooling_model = True

Expand Down
Loading