Skip to content

Commit 05fae02

Browse files
bbeckcaIsotr0py
andauthored
Migrate KimiVLImagePixelInputs to TensorSchema (#21769)
Signed-off-by: Benji Beck <[email protected]> Co-authored-by: Isotr0py <[email protected]>
1 parent d1bf1b9 commit 05fae02

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

vllm/model_executor/models/kimi_vl.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
import math
4747
from collections.abc import Iterable, Mapping, Sequence
4848
from dataclasses import dataclass
49-
from typing import Any, Literal, Optional, TypedDict, Union
49+
from typing import Annotated, Any, Literal, Optional, Union
5050

5151
import torch
5252
from torch import nn
@@ -79,6 +79,7 @@
7979
from vllm.sequence import IntermediateTensors
8080
from vllm.transformers_utils.configs import KimiVLConfig, MoonViTConfig
8181
from vllm.transformers_utils.configs.deepseek_vl2 import DeepseekV2Config
82+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
8283

8384
from .utils import is_pp_missing_parameter, maybe_prefix
8485

@@ -118,15 +119,22 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor:
118119
return hidden_states
119120

120121

121-
class KimiVLImagePixelInputs(TypedDict):
122-
type: Literal["pixel_values"]
123-
pixel_values: Union[torch.Tensor, list[torch.Tensor]]
122+
class KimiVLImagePixelInputs(TensorSchema):
124123
"""
125-
Shape:`(num_patches, num_channels, patch_size, patch_size)`
124+
Dimensions:
125+
- nc: Number of channels
126+
- np: Number of patches
127+
- ps: Patch size
128+
- ni: Number of images
126129
"""
130+
type: Literal["pixel_values"] = "pixel_values"
127131

128-
image_grid_hws: torch.Tensor
129-
"""Shape:`(num_images, 2)`"""
132+
pixel_values: Annotated[
133+
Union[torch.Tensor, list[torch.Tensor]],
134+
TensorShape("np", 3, "ps", "ps"),
135+
]
136+
137+
image_grid_hws: Annotated[torch.Tensor, TensorShape("ni", 2)]
130138

131139

132140
# TODO: support embeds too
@@ -348,8 +356,6 @@ def _parse_and_validate_image_input(
348356
pixel_values = pixel_values.reshape(-1, num_channels, patch_size,
349357
patch_size)
350358
pixel_values = pixel_values.to(self.vision_tower.dtype)
351-
# image_grid_hws.shape = (N, 2)
352-
assert image_grid_hws.ndim == 2, f"unexpected shape for image_grid_hws: {image_grid_hws.shape}"
353359

354360
return KimiVLImagePixelInputs(
355361
type="pixel_values",

0 commit comments

Comments
 (0)