Skip to content

Commit be13281

Browse files
DarkLight1337simon-mo
authored andcommitted
[Bugfix] Loosen type check to avoid errors in V1 (#15021)
Signed-off-by: DarkLight1337 <[email protected]>
1 parent 54e084f commit be13281

File tree

9 files changed

+28
-37
lines changed

9 files changed

+28
-37
lines changed

vllm/model_executor/models/blip2.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from .blip import BlipVisionModel
2727
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
28-
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
28+
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
2929
maybe_prefix, merge_multimodal_embeddings)
3030

3131
# We use this internally as placeholders since there is no image token
@@ -565,25 +565,23 @@ def _parse_and_validate_image_input(
565565
return None
566566

567567
if pixel_values is not None:
568-
if not isinstance(pixel_values, torch.Tensor):
568+
if not isinstance(pixel_values, (torch.Tensor, list)):
569569
raise ValueError("Incorrect type of pixel values. "
570570
f"Got type: {type(pixel_values)}")
571571

572-
# Remove the N dimension until multiple images are supported.
573-
pixel_values = pixel_values.squeeze(1)
572+
pixel_values = flatten_bn(pixel_values, concat=True)
574573

575574
return Blip2ImagePixelInputs(
576575
type="pixel_values",
577576
data=self._validate_pixel_values(pixel_values),
578577
)
579578

580579
if image_embeds is not None:
581-
if not isinstance(image_embeds, torch.Tensor):
580+
if not isinstance(image_embeds, (torch.Tensor, list)):
582581
raise ValueError("Incorrect type of image embeddings. "
583582
f"Got type: {type(image_embeds)}")
584583

585-
# Remove the N dimension until multiple images are supported.
586-
image_embeds = image_embeds.squeeze(1)
584+
image_embeds = flatten_bn(image_embeds, concat=True)
587585

588586
return Blip2ImageEmbeddingInputs(
589587
type="image_embeds",

vllm/model_executor/models/chameleon.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from vllm.sequence import IntermediateTensors
4040

4141
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
42-
from .utils import (is_pp_missing_parameter,
42+
from .utils import (flatten_bn, is_pp_missing_parameter,
4343
make_empty_intermediate_tensors_factory, make_layers,
4444
maybe_prefix, merge_multimodal_embeddings)
4545

@@ -972,12 +972,11 @@ def _parse_and_validate_image_input(
972972
if pixel_values is None:
973973
return None
974974

975-
if not isinstance(pixel_values, torch.Tensor):
975+
if not isinstance(pixel_values, (torch.Tensor, list)):
976976
raise ValueError("Incorrect type of pixel values. "
977977
f"Got type: {type(pixel_values)}")
978978

979-
# Remove the N dimension until multiple images are supported.
980-
pixel_values = pixel_values.squeeze(1)
979+
pixel_values = flatten_bn(pixel_values, concat=True)
981980

982981
return ChameleonImagePixelInputs(
983982
type="pixel_values",

vllm/model_executor/models/deepseek_vl2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def _parse_and_validate_image_input(
478478
flatten_bn(images_spatial_crop, concat=True)))
479479

480480
if image_embeds is not None:
481-
if not isinstance(image_embeds, torch.Tensor):
481+
if not isinstance(image_embeds, (torch.Tensor, list)):
482482
raise ValueError("Incorrect type of image embeddings. "
483483
f"Got type: {type(image_embeds)}")
484484

vllm/model_executor/models/glm4v.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def _parse_and_validate_image_input(
578578
pixel_values = kwargs.pop("pixel_values", None)
579579

580580
if pixel_values is not None:
581-
if not isinstance(pixel_values, torch.Tensor):
581+
if not isinstance(pixel_values, (torch.Tensor, list)):
582582
raise ValueError("Incorrect type of pixel values. "
583583
f"Got type: {type(pixel_values)}")
584584

vllm/model_executor/models/internvl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,7 @@ def _parse_and_validate_image_input(
838838
return None
839839

840840
if image_embeds is not None:
841-
if not isinstance(image_embeds, torch.Tensor):
841+
if not isinstance(image_embeds, (torch.Tensor, list)):
842842
raise ValueError("Incorrect type of image embeddings. "
843843
f"Got type: {type(image_embeds)}")
844844

@@ -856,7 +856,9 @@ def _parse_and_validate_image_input(
856856
raise ValueError("Incorrect type of pixel values. "
857857
f"Got type: {type(pixel_values_flat)}")
858858

859-
assert isinstance(image_num_patches, (torch.Tensor, list))
859+
if not isinstance(image_num_patches, (torch.Tensor, list)):
860+
raise ValueError("Incorrect type of image_num_patches. "
861+
f"Got type: {type(pixel_values_flat)}")
860862

861863
return InternVLImagePixelInputs(
862864
type="pixel_values",

vllm/model_executor/models/llava_next_video.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -349,21 +349,18 @@ def _parse_and_validate_video_input(
349349
List[b, Tensor(nb_frames, nb_channels, height, width)]
350350
}
351351
"""
352-
pixel_values = kwargs.pop("pixel_values_videos", None)
352+
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
353353

354-
if pixel_values is None:
354+
if pixel_values_videos is None:
355355
return None
356356

357-
if not (is_list_of(pixel_values,
358-
(torch.Tensor)) # different shape videos
359-
or isinstance(pixel_values,
360-
torch.Tensor)): # same shape videos
361-
raise ValueError("Incorrect type of pixel values. "
362-
f"Got type: {type(pixel_values)}")
357+
if not isinstance(pixel_values_videos, (torch.Tensor, list)):
358+
raise ValueError("Incorrect type of pixel_values_videos. "
359+
f"Got type: {type(pixel_values_videos)}")
363360

364361
return LlavaNextVideoPixelInputs(
365362
type="pixel_values_videos",
366-
data=pixel_values,
363+
data=pixel_values_videos,
367364
)
368365

369366
def _select_image_features(self, image_features: torch.Tensor, *,

vllm/model_executor/models/llava_onevision.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -574,10 +574,7 @@ def _parse_and_validate_video_input(
574574
if pixel_values_videos is None:
575575
return None
576576

577-
if not (is_list_of(pixel_values_videos,
578-
torch.Tensor) # different shape videos
579-
or isinstance(pixel_values_videos,
580-
torch.Tensor)): # same shape videos
577+
if not isinstance(pixel_values_videos, (torch.Tensor, list)):
581578
raise ValueError("Incorrect type of pixel_values_videos. "
582579
f"Got type: {type(pixel_values_videos)}")
583580

vllm/model_executor/models/paligemma.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
2525
from .siglip import SiglipVisionModel
26-
from .utils import (AutoWeightsLoader, init_vllm_registered_model,
26+
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
2727
maybe_prefix, merge_multimodal_embeddings)
2828
from .vision import get_vision_encoder_info
2929

@@ -270,12 +270,11 @@ def _parse_and_validate_image_input(
270270
return None
271271

272272
if pixel_values is not None:
273-
if not isinstance(pixel_values, torch.Tensor):
273+
if not isinstance(pixel_values, (torch.Tensor, list)):
274274
raise ValueError("Incorrect type of pixel values. "
275275
f"Got type: {type(pixel_values)}")
276276

277-
# Remove the N dimension until multiple images are supported.
278-
pixel_values = pixel_values.squeeze(1)
277+
pixel_values = flatten_bn(pixel_values, concat=True)
279278

280279
return PaliGemmaImagePixelInputs(
281280
type="pixel_values",
@@ -287,8 +286,7 @@ def _parse_and_validate_image_input(
287286
raise ValueError("Incorrect type of image embeddings. "
288287
f"Got type: {type(image_embeds)}")
289288

290-
# Remove the N dimension until multiple images are supported.
291-
image_embeds = image_embeds.squeeze(1)
289+
image_embeds = flatten_bn(image_embeds, concat=True)
292290

293291
return PaliGemmaImageEmbeddingInputs(
294292
type="image_embeds",

vllm/model_executor/models/qwen_vl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ def _parse_and_validate_image_input(
711711
image_embeds = kwargs.pop("image_embeds", None)
712712

713713
if pixel_values is not None:
714-
if not isinstance(pixel_values, torch.Tensor):
714+
if not isinstance(pixel_values, (torch.Tensor, list)):
715715
raise ValueError("Incorrect type of pixel values. "
716716
f"Got type: {type(pixel_values)}")
717717

@@ -722,13 +722,13 @@ def _parse_and_validate_image_input(
722722
)
723723

724724
if image_embeds is not None:
725-
if not isinstance(image_embeds, torch.Tensor):
725+
if not isinstance(image_embeds, (torch.Tensor, list)):
726726
raise ValueError("Incorrect type of image embeddings. "
727727
f"Got type: {type(image_embeds)}")
728728

729729
return QwenImageEmbeddingInputs(
730730
type="image_embeds",
731-
data=flatten_bn(image_embeds),
731+
data=flatten_bn(image_embeds, concat=True),
732732
)
733733

734734
return None

0 commit comments

Comments
 (0)