@@ -723,42 +723,64 @@ def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
723
723
724
724
return input_lengths_after_cnn , input_lengths_after_pooling
725
725
726
- def get_image_features (self , tgt_sizes , all_pixel_values , dtype , device ):
727
- tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance (tgt_size , torch .Tensor )]
728
- tgt_sizes = torch .vstack (tgt_sizes ).type (torch .int32 )
726
+ def get_image_features (self , pixel_values_list , tgt_sizes , dtype , device ):
727
+ vision_hidden_states = []
728
+ all_pixel_values = []
729
+ img_cnt = []
730
+ for pixel_values in pixel_values_list :
731
+ img_cnt .append (len (pixel_values ))
732
+ all_pixel_values .extend ([i .flatten (end_dim = 1 ).permute (1 , 0 ) for i in pixel_values ])
733
+
734
+ # exist image
735
+ if all_pixel_values :
736
+ tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance (tgt_size , torch .Tensor )]
737
+ tgt_sizes = torch .vstack (tgt_sizes ).type (torch .int32 )
738
+
739
+ max_patches = torch .max (tgt_sizes [:, 0 ] * tgt_sizes [:, 1 ])
740
+
741
+ all_pixel_values = torch .nn .utils .rnn .pad_sequence (
742
+ all_pixel_values , batch_first = True , padding_value = 0.0
743
+ )
744
+ B , L , _ = all_pixel_values .shape
745
+ all_pixel_values = all_pixel_values .permute (0 , 2 , 1 ).reshape (B , 3 , - 1 , L )
746
+
747
+ patch_attn_mask = torch .zeros ((B , 1 , max_patches ), dtype = torch .bool , device = device )
748
+ for i in range (B ):
749
+ patch_attn_mask [i , 0 , : tgt_sizes [i ][0 ] * tgt_sizes [i ][1 ]] = True
750
+
751
+ vision_batch_size = self .config .vision_batch_size
752
+ all_pixel_values = all_pixel_values .type (dtype )
753
+ if B > vision_batch_size :
754
+ hs = []
755
+ for i in range (0 , B , vision_batch_size ):
756
+ start_idx = i
757
+ end_idx = i + vision_batch_size
758
+ tmp_hs = self .vpm (
759
+ all_pixel_values [start_idx :end_idx ],
760
+ patch_attention_mask = patch_attn_mask [start_idx :end_idx ],
761
+ tgt_sizes = tgt_sizes [start_idx :end_idx ],
762
+ ).last_hidden_state
763
+ hs .append (tmp_hs )
764
+ vision_embedding = torch .cat (hs , dim = 0 )
765
+ else :
766
+ vision_embedding = self .vpm (
767
+ all_pixel_values , patch_attention_mask = patch_attn_mask , tgt_sizes = tgt_sizes
768
+ ).last_hidden_state
729
769
730
- max_patches = torch . max ( tgt_sizes [:, 0 ] * tgt_sizes [:, 1 ] )
770
+ vision_embedding = self . resampler ( vision_embedding , tgt_sizes )
731
771
732
- all_pixel_values = torch .nn .utils .rnn .pad_sequence (
733
- all_pixel_values , batch_first = True , padding_value = 0.0
734
- )
735
- B , L , _ = all_pixel_values .shape
736
- all_pixel_values = all_pixel_values .permute (0 , 2 , 1 ).reshape (B , 3 , - 1 , L )
737
-
738
- patch_attn_mask = torch .zeros ((B , 1 , max_patches ), dtype = torch .bool , device = device )
739
- for i in range (B ):
740
- patch_attn_mask [i , 0 , : tgt_sizes [i ][0 ] * tgt_sizes [i ][1 ]] = True
741
-
742
- vision_batch_size = self .config .vision_batch_size
743
- all_pixel_values = all_pixel_values .type (dtype )
744
- if B > vision_batch_size :
745
- hs = []
746
- for i in range (0 , B , vision_batch_size ):
747
- start_idx = i
748
- end_idx = i + vision_batch_size
749
- tmp_hs = self .vpm (
750
- all_pixel_values [start_idx :end_idx ],
751
- patch_attention_mask = patch_attn_mask [start_idx :end_idx ],
752
- tgt_sizes = tgt_sizes [start_idx :end_idx ],
753
- ).last_hidden_state
754
- hs .append (tmp_hs )
755
- vision_embedding = torch .cat (hs , dim = 0 )
756
- else :
757
- vision_embedding = self .vpm (
758
- all_pixel_values , patch_attention_mask = patch_attn_mask , tgt_sizes = tgt_sizes
759
- ).last_hidden_state
760
- vision_embedding = self .resampler (vision_embedding , tgt_sizes )
761
- return vision_embedding
772
+ start = 0
773
+ for pixel_values in pixel_values_list :
774
+ img_cnt = len (pixel_values )
775
+ if img_cnt > 0 :
776
+ vision_hidden_states .append (vision_embedding [start : start + img_cnt ])
777
+ start += img_cnt
778
+ else :
779
+ vision_hidden_states .append ([])
780
+ else : # no image
781
+ vision_hidden_states .extend ([[]] * len (pixel_values_list ))
782
+
783
+ return vision_hidden_states
762
784
763
785
def get_vllm_embedding (self , data ):
764
786
"""
@@ -773,34 +795,12 @@ def get_vllm_embedding(self, data):
773
795
Returns:
774
796
embedding with vision, vision_hidden_states
775
797
"""
798
+ dtype = self .language_model .embed_tokens .weight .dtype
799
+ device = self .language_model .embed_tokens .weight .device
776
800
if "vision_hidden_states" not in data :
777
- dtype = self .language_model .embed_tokens .weight .dtype
778
- device = self .language_model .embed_tokens .weight .device
779
- tgt_sizes = data ["tgt_sizes" ]
780
- pixel_values_list = data ["pixel_values" ]
781
- vision_hidden_states = []
782
- all_pixel_values = []
783
- img_cnt = []
784
- for pixel_values in pixel_values_list :
785
- img_cnt .append (len (pixel_values ))
786
- all_pixel_values .extend ([i .flatten (end_dim = 1 ).permute (1 , 0 ) for i in pixel_values ])
787
-
788
- # exist image
789
- if all_pixel_values :
790
-
791
- vision_embedding = self .get_image_features (tgt_sizes = tgt_sizes , all_pixel_values = all_pixel_values , dtype = dtype , device = device )
792
-
793
- start = 0
794
- for pixel_values in pixel_values_list :
795
- img_cnt = len (pixel_values )
796
- if img_cnt > 0 :
797
- vision_hidden_states .append (vision_embedding [start : start + img_cnt ])
798
- start += img_cnt
799
- else :
800
- vision_hidden_states .append ([])
801
- else : # no image
802
- vision_hidden_states .extend ([[]] * len (pixel_values_list ))
803
-
801
+ vision_hidden_states = self .get_image_features (
802
+ data ["pixel_values" ], data ["tgt_sizes" ], dtype , device
803
+ )
804
804
else :
805
805
vision_hidden_states = data ["vision_hidden_states" ]
806
806
0 commit comments