|
27 | 27 | from collections import defaultdict
|
28 | 28 | from collections.abc import Iterable, Mapping, Sequence
|
29 | 29 | from functools import partial
|
30 |
| -from typing import Any, Callable, Literal, Optional, TypedDict, Union |
| 30 | +from typing import Annotated, Any, Callable, Literal, Optional, Union |
31 | 31 |
|
32 | 32 | import numpy as np
|
33 | 33 | import torch
|
|
63 | 63 | from vllm.platforms import current_platform
|
64 | 64 | from vllm.sequence import IntermediateTensors
|
65 | 65 | from vllm.utils import flatten_2d_lists
|
| 66 | +from vllm.utils.tensor_schema import TensorSchema, TensorShape |
66 | 67 |
|
67 | 68 | from .idefics2_vision_model import Idefics2VisionTransformer
|
68 | 69 | from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
|
74 | 75 | _MAX_FRAMES_PER_VIDEO = 16
|
75 | 76 |
|
76 | 77 |
|
77 |
| -class MiniCPMVImagePixelInputs(TypedDict): |
78 |
| - type: Literal["pixel_values"] |
79 |
| - pixel_values: list[torch.Tensor] |
| 78 | +class MiniCPMVImagePixelInputs(TensorSchema): |
80 | 79 | """
|
81 |
| - Shape: `(batch_size * num_images * num_slices, num_channels, height, width)` |
82 |
| -
|
83 |
| - Note that the image size may vary, so we pass it as a list |
84 |
| - instead of a batched tensor. |
| 80 | + Dimensions: |
| 81 | + - bns: Batch size * number of images * number of slices |
| 82 | + - bn: Batch size * number of images |
| 83 | + - c: Number of channels |
| 84 | + - h: Height |
| 85 | + - w: Width |
85 | 86 | """
|
86 | 87 |
|
87 |
| - tgt_sizes: torch.Tensor |
| 88 | + type: Literal["pixel_values"] = "pixel_values" |
| 89 | + |
| 90 | + # Note that the image size may vary, so we pass it as a list instead of a |
| 91 | + # batched tensor. |
| 92 | + pixel_values: Annotated[ |
| 93 | + list[torch.Tensor], |
| 94 | + TensorShape("bns", "c", "h", "w"), |
| 95 | + ] |
| 96 | + tgt_sizes: Annotated[ |
| 97 | + torch.Tensor, |
| 98 | + TensorShape("bns", 2), # This should be in `(height, width)` format. |
| 99 | + ] |
| 100 | + num_slices: Annotated[ |
| 101 | + torch.Tensor, |
| 102 | + TensorShape("bn"), |
| 103 | + ] |
| 104 | + |
| 105 | + |
| 106 | +class MiniCPMVImageEmbeddingInputs(TensorSchema): |
88 | 107 | """
|
89 |
| - Shape: `(batch_size * num_images * num_slices, 2)` |
90 |
| -
|
91 |
| - This should be in `(height, width)` format. |
| 108 | + Dimensions: |
| 109 | + - bn: Batch size * number of images |
| 110 | + - ns: Number of slices |
| 111 | + - hs: Hidden size (must match language model backbone) |
92 | 112 | """
|
93 | 113 |
|
94 |
| - num_slices: torch.Tensor |
95 |
| - """Shape: `(batch_size * num_images)`""" |
96 |
| - |
97 |
| - |
98 |
| -class MiniCPMVImageEmbeddingInputs(TypedDict): |
99 | 114 | type: Literal["image_embeds"]
|
100 |
| - image_embeds: Union[torch.Tensor, list[torch.Tensor]] |
101 |
| - """ |
102 |
| - Shape: `(batch_size * num_images, num_slices, hidden_size)` |
103 |
| -
|
104 |
| - `hidden_size` must match the hidden size of language model backbone. |
105 |
| - instead of a batched tensor. |
106 |
| - """ |
| 115 | + image_embeds: Annotated[ |
| 116 | + Union[torch.Tensor, list[torch.Tensor]], |
| 117 | + TensorShape("bn", "ns", "hs"), |
| 118 | + ] |
107 | 119 |
|
108 | 120 |
|
109 | 121 | MiniCPMVImageInputs = Union[MiniCPMVImagePixelInputs,
|
@@ -832,11 +844,6 @@ def _parse_and_validate_vision_input(
|
832 | 844 | pixel_values_flat = flatten_bn(flatten_2d_lists(pixel_values))
|
833 | 845 | tgt_sizes_flat = flatten_bn(flatten_2d_lists(tgt_sizes), concat=True)
|
834 | 846 |
|
835 |
| - if len(pixel_values_flat) != len(tgt_sizes_flat): |
836 |
| - raise ValueError("Inconsistent flattened lengths, found: " |
837 |
| - f"{len(pixel_values_flat)} vs. " |
838 |
| - f"{len(tgt_sizes_flat)}") |
839 |
| - |
840 | 847 | return MiniCPMVImagePixelInputs(
|
841 | 848 | type="pixel_values",
|
842 | 849 | pixel_values=pixel_values_flat,
|
|
0 commit comments