Skip to content

Commit fe0411f

Browse files
947132885Isotr0py
andauthored
[Bugfix] should use stack instead of concat (#22972)
Signed-off-by: 947132885 <[email protected]> Signed-off-by: Isotr0py <[email protected]> Co-authored-by: Isotr0py <[email protected]>
1 parent 4d4061b commit fe0411f

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

vllm/model_executor/models/transformers.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,17 @@ def compute_logits(
694694
return logits
695695

696696

697+
def flatten_and_concat(x: list[torch.Tensor]) -> torch.Tensor:
698+
"""Flatten until a list of tensors can be concatenated then do concat"""
699+
700+
def _can_concat(x: list[torch.Tensor]):
701+
return len(set(map(lambda _x: _x.shape[1:], x))) == 1
702+
703+
if _can_concat(x):
704+
return torch.concat(x)
705+
return flatten_and_concat(flatten_bn(x))
706+
707+
697708
@MULTIMODAL_REGISTRY.register_processor(
698709
MultiModalProcessor,
699710
info=MultiModalProcessingInfo,
@@ -766,8 +777,7 @@ def get_multimodal_embeddings(self, **kwargs):
766777
if isinstance(pixel_values, torch.Tensor):
767778
pixel_values = flatten_bn(pixel_values).to(self.dtype)
768779
elif is_list_of(pixel_values, torch.Tensor):
769-
pixel_values = flatten_bn(flatten_bn(pixel_values),
770-
concat=True).to(self.dtype)
780+
pixel_values = flatten_and_concat(pixel_values).to(self.dtype)
771781
else:
772782
raise ValueError(
773783
f"Unsupported pixel_values type {type(pixel_values)}. "

0 commit comments

Comments
 (0)