Skip to content

Commit ea4b570

Browse files
[VLM] Cleanup validation and update docs (#6149)
1 parent a41357e commit ea4b570

File tree

3 files changed

+87
-82
lines changed

3 files changed

+87
-82
lines changed

vllm/model_executor/models/llava.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -149,14 +149,16 @@ def __init__(self,
149149
config.vocab_size, logit_scale)
150150
self.sampler = Sampler()
151151

152-
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
153-
if list(data.shape)[1:] != [
154-
3, self.config.vision_config.image_size,
155-
self.config.vision_config.image_size
156-
]:
152+
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
153+
h = w = self.config.vision_config.image_size
154+
expected_dims = (3, h, w)
155+
actual_dims = tuple(data.shape[1:])
156+
157+
if actual_dims != expected_dims:
158+
expected_expr = ("batch_size", *map(str, expected_dims))
157159
raise ValueError(
158-
"The expected image tensor shape is batch dimension plus "
159-
"channel, height and width.")
160+
f"The expected shape of pixel values is {expected_expr}. "
161+
f"You supplied {tuple(data.shape)}.")
160162

161163
return data
162164

@@ -173,7 +175,7 @@ def _parse_and_validate_image_input(
173175

174176
return LlavaImagePixelInputs(
175177
type="pixel_values",
176-
data=self._validate_image_data(pixel_values),
178+
data=self._validate_pixel_values(pixel_values),
177179
)
178180

179181
def _select_image_features(self, image_features: torch.Tensor, *,
@@ -226,18 +228,25 @@ def forward(
226228
227229
One key thing to understand is the `input_ids` already accounts for the
228230
positions of the to-be-inserted image embeddings.
231+
229232
Concretely, consider a text prompt:
230-
"<image>\nUSER: What's the content of the image?\nASSISTANT:".
233+
`"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
234+
231235
Tokenizer outputs:
232-
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
233-
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
234-
The to-be-inserted image has a size of 576 (24 * 24) along the context
235-
length dimension.
236-
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
237-
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
238-
9047, 13566, 29901].
239-
There will be 576 `32000` in the `input_ids`.
240-
(32000 is the token id for `<image>`.)
236+
`[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
237+
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
238+
239+
To reserve space in KV cache, we have to insert placeholder tokens
240+
before they are inputted to the model, so the input processor prepends
241+
additional image tokens (denoted as `32000`), resulting in:
242+
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
243+
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
244+
29901]`.
245+
246+
We insert 575 tokens so that including the original image token in the
247+
input, there are a total of 576 (24 * 24) image tokens, which
248+
corresponds to the number of image tokens inputted to the language
249+
model, i.e. the number of image tokens outputted by the visual encoder.
241250
242251
This way, the `positions` and `attn_metadata` are consistent
243252
with the `input_ids`.
@@ -246,6 +255,9 @@ def forward(
246255
input_ids: Flattened (concatenated) input_ids corresponding to a
247256
batch.
248257
pixel_values: The pixels in each input image.
258+
259+
See also:
260+
:class:`LlavaImageInputs`
249261
"""
250262
image_input = self._parse_and_validate_image_input(**kwargs)
251263

vllm/model_executor/models/llava_next.py

Lines changed: 41 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ class LlavaNextImagePixelInputs(TypedDict):
4747
"""
4848
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
4949
50-
Note that `num_patches` may be different for each batch.
50+
Note that `num_patches` may be different for each batch, in which case
51+
the data is passed as a list instead of a batched tensor.
5152
"""
5253

5354
image_sizes: NotRequired[torch.Tensor]
@@ -255,40 +256,20 @@ def _validate_pixel_values(
255256
self, data: Union[torch.Tensor, List[torch.Tensor]]
256257
) -> Union[torch.Tensor, List[torch.Tensor]]:
257258

258-
def _validate_shape(data: torch.Tensor):
259-
260-
dim = data.dim()
261-
height = width = self.config.vision_config.image_size
262-
# All 4d image tensors have the same number of patches,
263-
# so data is a 5d batch of these tensors
264-
if dim == 5:
265-
if list(data.shape)[2:] != [
266-
3, self.config.vision_config.image_size,
267-
self.config.vision_config.image_size
268-
]:
269-
raise ValueError(
270-
"Expected pixel value tensor in shape of: (batch size, "
271-
f"patch number, 3, {height}, {width}), got {data.shape}"
272-
)
273-
274-
# 4d image tensors have different number of patches,
275-
# so data is each individual tensor.
276-
elif dim == 4:
277-
if list(data.shape)[1:] != [
278-
3, self.config.vision_config.image_size,
279-
self.config.vision_config.image_size
280-
]:
281-
raise ValueError(
282-
"Expected pixel value tensor in shape of: (patch "
283-
f"number, 3, {height}, {width}), got {data.shape}")
284-
else:
259+
h = w = self.config.vision_config.image_size
260+
expected_dims = (3, h, w)
261+
262+
def _validate_shape(d: torch.Tensor):
263+
actual_dims = tuple(d.shape[1:])
264+
265+
if actual_dims != expected_dims:
266+
expected_expr = ("num_patches", *map(str, expected_dims))
285267
raise ValueError(
286-
f"Invalid pixel value tensor of shape {data.shape}")
268+
"The expected shape of pixel values in each batch element "
269+
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
287270

288-
if isinstance(data, torch.Tensor):
289-
_validate_shape(data)
290-
else:
291-
[_validate_shape(d) for d in data]
271+
for d in data:
272+
_validate_shape(d)
292273

293274
return data
294275

@@ -464,18 +445,33 @@ def forward(
464445
465446
One key thing to understand is the `input_ids` already accounts for the
466447
positions of the to-be-inserted image embeddings.
448+
467449
Concretely, consider a text prompt:
468-
"<image>\nUSER: What's the content of the image?\nASSISTANT:".
450+
`"A chat between a curious human and an artificial intelligence
451+
assistant. The assistant gives helpful, detailed, and polite answers to
452+
the human's questions.
453+
USER: <image>\\nWhat is shown in this image? ASSISTANT:"`.
454+
469455
Tokenizer outputs:
470-
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
471-
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
472-
The to-be-inserted image has a size of 576 (24 * 24) along the context
473-
length dimension.
474-
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
475-
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
476-
9047, 13566, 29901].
477-
There will be 576 `32000` in the `input_ids`.
478-
(32000 is the token id for `<image>`.)
456+
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
457+
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
458+
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
459+
29871, 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973, 319, 1799,
460+
9047, 13566, 29901]`.
461+
462+
To reserve space in KV cache, we have to insert placeholder tokens
463+
before they are inputted to the model, so the input processor prepends
464+
additional image tokens (denoted as `32000`), resulting in:
465+
`[1, 319, 13563, 1546, 263, 12758, 5199, 322, 385, 23116, 21082, 20255,
466+
29889, 450, 20255, 4076, 8444, 29892, 13173, 29892, 322, 1248, 568,
467+
6089, 304, 278, 5199, 29915, 29879, 5155, 29889, 3148, 1001, 29901,
468+
29871, 32000, ..., 32000, 13, 5618, 338, 4318, 297, 445, 1967, 29973,
469+
319, 1799, 9047, 13566, 29901]`.
470+
471+
Unlike in LLaVA-1.5, the number of image tokens inputted to the language
472+
model depends on the original size of the input image. Including the
473+
original image token in the input, the required number of image tokens
474+
is given by :func:`get_llava_next_image_feature_size`.
479475
480476
This way, the `positions` and `attn_metadata` are consistent
481477
with the `input_ids`.
@@ -484,15 +480,10 @@ def forward(
484480
input_ids: Flattened (concatenated) input_ids corresponding to a
485481
batch.
486482
pixel_values: The pixels in each grid patch for each input image.
487-
Expects a batch with shape `[1, num_patches, 3, h, w]`.
488483
image_sizes: The original `(height, width)` for each input image.
489-
Expects a batch with shape `[1, 2]`.
490-
484+
491485
See also:
492-
Each input maps to huggingface implementation, as follows:
493-
494-
- `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L690
495-
- `image_sizes`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava_next/modeling_llava_next.py#L691
486+
:class:`LlavaNextImageInputs`
496487
"""
497488
image_input = self._parse_and_validate_image_input(**kwargs)
498489

vllm/model_executor/models/phi3v.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,8 @@ class Phi3VImagePixelInputs(TypedDict):
263263
"""
264264
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
265265
266-
Note that `num_patches` may be different for each batch.
266+
Note that `num_patches` may be different for each batch, in which case
267+
the data is passed as a list instead of a batched tensor.
267268
"""
268269

269270
image_sizes: torch.Tensor
@@ -466,28 +467,29 @@ def __init__(self,
466467
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
467468
if list(data.shape[1:]) != [2]:
468469
raise ValueError(
469-
f"The expected image sizes shape is batch dimension plus "
470-
f"{[2]}. You supplied {data.shape}.")
470+
f"The expected shape of image sizes is batch dimension plus "
471+
f"{[2]}. You supplied {tuple(data.shape)}.")
471472

472473
return data
473474

474475
def _validate_pixel_values(
475476
self, data: Union[torch.Tensor, List[torch.Tensor]]
476477
) -> Union[torch.Tensor, List[torch.Tensor]]:
477478

478-
def _validate_shape(data: torch.Tensor):
479-
if list(data.shape)[2:] != [
480-
3, CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
481-
CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
482-
]:
479+
h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
480+
expected_dims = (3, h, w)
481+
482+
def _validate_shape(d: torch.Tensor):
483+
actual_dims = tuple(d.shape[1:])
484+
485+
if actual_dims != expected_dims:
486+
expected_expr = ("num_patches", *map(str, expected_dims))
483487
raise ValueError(
484-
"The expected pixel value tensor shape is batch dimension "
485-
"plus patch number, channel, height and width.")
488+
"The expected shape of pixel values in each batch element "
489+
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
486490

487-
if isinstance(data, torch.Tensor):
488-
_validate_shape(data)
489-
else:
490-
[_validate_shape(d) for d in data]
491+
for d in data:
492+
_validate_shape(d)
491493

492494
return data
493495

0 commit comments

Comments
 (0)