|
25 | 25 |
|
26 | 26 | from .blip import BlipVisionModel
|
27 | 27 | from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
28 |
| -from .utils import (AutoWeightsLoader, init_vllm_registered_model, |
| 28 | +from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model, |
29 | 29 | maybe_prefix, merge_multimodal_embeddings)
|
30 | 30 |
|
31 | 31 | # We use this internally as placeholders since there is no image token
|
@@ -565,25 +565,23 @@ def _parse_and_validate_image_input(
|
565 | 565 | return None
|
566 | 566 |
|
567 | 567 | if pixel_values is not None:
|
568 |
| - if not isinstance(pixel_values, torch.Tensor): |
| 568 | + if not isinstance(pixel_values, (torch.Tensor, list)): |
569 | 569 | raise ValueError("Incorrect type of pixel values. "
|
570 | 570 | f"Got type: {type(pixel_values)}")
|
571 | 571 |
|
572 |
| - # Remove the N dimension until multiple images are supported. |
573 |
| - pixel_values = pixel_values.squeeze(1) |
| 572 | + pixel_values = flatten_bn(pixel_values, concat=True) |
574 | 573 |
|
575 | 574 | return Blip2ImagePixelInputs(
|
576 | 575 | type="pixel_values",
|
577 | 576 | data=self._validate_pixel_values(pixel_values),
|
578 | 577 | )
|
579 | 578 |
|
580 | 579 | if image_embeds is not None:
|
581 |
| - if not isinstance(image_embeds, torch.Tensor): |
| 580 | + if not isinstance(image_embeds, (torch.Tensor, list)): |
582 | 581 | raise ValueError("Incorrect type of image embeddings. "
|
583 | 582 | f"Got type: {type(image_embeds)}")
|
584 | 583 |
|
585 |
| - # Remove the N dimension until multiple images are supported. |
586 |
| - image_embeds = image_embeds.squeeze(1) |
| 584 | + image_embeds = flatten_bn(image_embeds, concat=True) |
587 | 585 |
|
588 | 586 | return Blip2ImageEmbeddingInputs(
|
589 | 587 | type="image_embeds",
|
|
0 commit comments