|
46 | 46 | import math
|
47 | 47 | from collections.abc import Iterable, Mapping, Sequence
|
48 | 48 | from dataclasses import dataclass
|
49 |
| -from typing import Any, Literal, Optional, TypedDict, Union |
| 49 | +from typing import Annotated, Any, Literal, Optional, Union |
50 | 50 |
|
51 | 51 | import torch
|
52 | 52 | from torch import nn
|
|
79 | 79 | from vllm.sequence import IntermediateTensors
|
80 | 80 | from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
|
81 | 81 | from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
|
| 82 | +from vllm.utils.tensor_schema import TensorSchema, TensorShape |
82 | 83 |
|
83 | 84 | from .utils import is_pp_missing_parameter, maybe_prefix
|
84 | 85 |
|
@@ -118,15 +119,22 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
118 | 119 | return hidden_states
|
119 | 120 |
|
120 | 121 |
|
121 |
| -class KimiVLImagePixelInputs(TypedDict): |
122 |
| - type: Literal["pixel_values"] |
123 |
| - pixel_values: Union[torch.Tensor, list[torch.Tensor]] |
| 122 | +class KimiVLImagePixelInputs(TensorSchema): |
124 | 123 | """
|
125 |
| - Shape:`(num_patches, num_channels, patch_size, patch_size)` |
| 124 | + Dimensions: |
| 125 | + - nc: Number of channels |
| 126 | + - np: Number of patches |
| 127 | + - ps: Patch size |
| 128 | + - ni: Number of images |
126 | 129 | """
|
| 130 | + type: Literal["pixel_values"] = "pixel_values" |
127 | 131 |
|
128 |
| - image_grid_hws: torch.Tensor |
129 |
| - """Shape:`(num_images, 2)`""" |
| 132 | + pixel_values: Annotated[ |
| 133 | + Union[torch.Tensor, list[torch.Tensor]], |
| 134 | + TensorShape("np", 3, "ps", "ps"), |
| 135 | + ] |
| 136 | + |
| 137 | + image_grid_hws: Annotated[torch.Tensor, TensorShape("ni", 2)] |
130 | 138 |
|
131 | 139 |
|
132 | 140 | # TODO: support embeds too
|
@@ -348,8 +356,6 @@ def _parse_and_validate_image_input(
|
348 | 356 | pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
|
349 | 357 | patch_size)
|
350 | 358 | pixel_values = pixel_values.to(self.vision_tower.dtype)
|
351 |
| - # image_grid_hws.shape = (N, 2) |
352 |
| - assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}" |
353 | 359 |
|
354 | 360 | return KimiVLImagePixelInputs(
|
355 | 361 | type="pixel_values",
|
|
0 commit comments