|
4 | 4 | from collections.abc import Iterable, Mapping, Sequence
|
5 | 5 | from itertools import product
|
6 | 6 | from math import ceil, sqrt
|
7 |
| -from typing import Any, Literal, Optional, TypedDict, Union |
| 7 | +from typing import Annotated, Any, Literal, Optional, Union |
8 | 8 |
|
9 | 9 | import numpy as np
|
10 | 10 | import torch
|
|
44 | 44 | from vllm.sequence import IntermediateTensors
|
45 | 45 | from vllm.transformers_utils.configs import Step3VisionEncoderConfig
|
46 | 46 | from vllm.transformers_utils.tokenizer import AnyTokenizer
|
| 47 | +from vllm.utils.tensor_schema import TensorSchema, TensorShape |
47 | 48 |
|
48 | 49 | from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
49 | 50 | from .utils import (
|
50 | 51 | AutoWeightsLoader,
|
51 | 52 | WeightsMapper,
|
52 |
| - flatten_bn, |
53 | 53 | init_vllm_registered_model,
|
54 | 54 | maybe_prefix,
|
55 | 55 | )
|
56 | 56 | from .vision import run_dp_sharded_vision_model
|
57 | 57 |
|
58 | 58 |
|
59 |
| -class Step3VLImagePixelInputs(TypedDict): |
| 59 | +class Step3VLImagePixelInputs(TensorSchema): |
| 60 | + """ |
| 61 | + Dimensions: |
| 62 | + - bn: Batch size * number of images |
| 63 | + - c: Number of channels (3) |
| 64 | + - h: Height |
| 65 | + - w: Width |
| 66 | + - bnp: Batch size * number of images * number of patches |
| 67 | + - hp: Height of patch |
| 68 | + - wp: Width of patch |
| 69 | + """ |
| 70 | + |
60 | 71 | type: Literal["pixel_values"]
|
61 |
| - pixel_values: torch.Tensor |
62 |
| - patch_pixel_values: Optional[torch.Tensor] |
63 |
| - num_patches: list[int] |
| 72 | + pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")] |
| 73 | + patch_pixel_values: Annotated[ |
| 74 | + Optional[torch.Tensor], TensorShape("bnp", 3, "hp", "wp") |
| 75 | + ] |
| 76 | + num_patches: Annotated[torch.Tensor, TensorShape("bn")] |
| 77 | + |
64 | 78 |
|
| 79 | +class Step3VLImageEmbeddingInputs(TensorSchema): |
| 80 | + """ |
| 81 | + Dimensions: |
| 82 | + - bn: Batch size * number of images |
| 83 | + - f: Image feature size |
| 84 | + - h: Hidden size (must match the hidden size of language model backbone) |
| 85 | + """ |
65 | 86 |
|
66 |
| -class Step3VLImageEmbeddingInputs(TypedDict): |
67 |
| - type: Literal["image_embeds"] |
68 |
| - image_embeds: torch.Tensor |
| 87 | + type: Literal["image_embeds"] = "image_embeds" |
| 88 | + data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")] |
69 | 89 |
|
70 | 90 |
|
71 | 91 | Step3VLImageInputs = Union[Step3VLImagePixelInputs, Step3VLImageEmbeddingInputs]
|
@@ -895,6 +915,8 @@ def forward(
|
895 | 915 | dummy_inputs=Step3VLDummyInputsBuilder,
|
896 | 916 | )
|
897 | 917 | class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
| 918 | + merge_by_field_config = True |
| 919 | + |
898 | 920 | hf_to_vllm_mapper = WeightsMapper(
|
899 | 921 | orig_to_new_prefix={
|
900 | 922 | "model.": "language_model.model.",
|
@@ -982,41 +1004,22 @@ def _parse_and_validate_image_input(
|
982 | 1004 | return None
|
983 | 1005 |
|
984 | 1006 | if pixel_values is not None:
|
985 |
| - pixel_values = flatten_bn(pixel_values, concat=True) |
986 |
| - if pixel_values.dim() >= 3: |
987 |
| - pixel_values = pixel_values.view(-1, *pixel_values.shape[-3:]) |
988 |
| - if patch_pixel_values is not None: |
989 |
| - patch_pixel_values = flatten_bn(patch_pixel_values, concat=True) |
990 |
| - patch_pixel_values = patch_pixel_values.view( |
991 |
| - -1, *patch_pixel_values.shape[-3:] |
992 |
| - ) |
993 |
| - # Handle empty patch_pixel_values by setting to None |
994 |
| - if patch_pixel_values.shape[0] == 0: |
995 |
| - patch_pixel_values = None |
996 |
| - num_patches = flatten_bn(num_patches, concat=True).tolist() |
997 |
| - |
998 | 1007 | return Step3VLImagePixelInputs(
|
999 | 1008 | type="pixel_values",
|
1000 |
| - pixel_values=pixel_values.to(self.dtype).to(self.device), |
1001 |
| - patch_pixel_values=patch_pixel_values.to(self.dtype).to(self.device) |
| 1009 | + pixel_values=pixel_values.to(self.dtype), |
| 1010 | + patch_pixel_values=patch_pixel_values.to(self.dtype) |
1002 | 1011 | if patch_pixel_values is not None
|
1003 | 1012 | else None,
|
1004 | 1013 | num_patches=num_patches,
|
1005 | 1014 | )
|
1006 | 1015 |
|
1007 | 1016 | if image_embeds is not None:
|
1008 |
| - if image_embeds.dim() == 2 or image_embeds.dim() >= 3: |
1009 |
| - image_embeds = image_embeds.view(-1, image_embeds.shape[-1]) |
1010 |
| - else: |
1011 |
| - raise ValueError( |
1012 |
| - f"Unexpected shape for image_embeds: {image_embeds.shape}" |
1013 |
| - ) |
1014 |
| - |
1015 | 1017 | return Step3VLImageEmbeddingInputs(
|
1016 | 1018 | type="image_embeds",
|
1017 |
| - image_embeds=image_embeds.to(self.dtype).to(self.device), |
| 1019 | + image_embeds=image_embeds.to(self.dtype), |
1018 | 1020 | )
|
1019 |
| - return None |
| 1021 | + |
| 1022 | + raise AssertionError("This line should be unreachable.") |
1020 | 1023 |
|
1021 | 1024 | def _process_image_features(self, image_features: torch.Tensor) -> torch.Tensor:
|
1022 | 1025 | B, P = image_features.shape[:2]
|
|
0 commit comments