File tree Expand file tree Collapse file tree 1 file changed +12
-2
lines changed
vllm/model_executor/models Expand file tree Collapse file tree 1 file changed +12
-2
lines changed Original file line number Diff line number Diff line change @@ -694,6 +694,17 @@ def compute_logits(
694
694
return logits
695
695
696
696
697
+ def flatten_and_concat (x : list [torch .Tensor ]) -> torch .Tensor :
698
+ """Flatten until a list of tensors can be concatenated then do concat"""
699
+
700
+ def _can_concat (x : list [torch .Tensor ]):
701
+ return len (set (map (lambda _x : _x .shape [1 :], x ))) == 1
702
+
703
+ if _can_concat (x ):
704
+ return torch .concat (x )
705
+ return flatten_and_concat (flatten_bn (x ))
706
+
707
+
697
708
@MULTIMODAL_REGISTRY .register_processor (
698
709
MultiModalProcessor ,
699
710
info = MultiModalProcessingInfo ,
@@ -766,8 +777,7 @@ def get_multimodal_embeddings(self, **kwargs):
766
777
if isinstance (pixel_values , torch .Tensor ):
767
778
pixel_values = flatten_bn (pixel_values ).to (self .dtype )
768
779
elif is_list_of (pixel_values , torch .Tensor ):
769
- pixel_values = flatten_bn (flatten_bn (pixel_values ),
770
- concat = True ).to (self .dtype )
780
+ pixel_values = flatten_and_concat (pixel_values ).to (self .dtype )
771
781
else :
772
782
raise ValueError (
773
783
f"Unsupported pixel_values type { type (pixel_values )} . "
You can’t perform that action at this time.
0 commit comments