@@ -976,10 +976,12 @@ def _process_image_input(
976
976
image_embeds = self .visual (pixel_values , grid_thw = grid_thw_list )
977
977
978
978
# Split concatenated embeddings for each image item.
979
+ # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
979
980
merge_size = self .visual .spatial_merge_size
980
- sizes = grid_thw .prod (- 1 ) // merge_size // merge_size
981
+ sizes = (torch .tensor (grid_thw_list , dtype = torch .long ).prod (- 1 ) //
982
+ (merge_size * merge_size )).tolist ()
981
983
982
- return image_embeds .split (sizes . tolist () )
984
+ return image_embeds .split (sizes )
983
985
984
986
def _process_video_input (
985
987
self ,
@@ -998,9 +1000,11 @@ def _process_video_input(
998
1000
999
1001
# Split concatenated embeddings for each video item.
1000
1002
merge_size = self .visual .spatial_merge_size
1001
- sizes = grid_thw .prod (- 1 ) // merge_size // merge_size
1003
+ # Using prod on grid_thw_list instead of grid_thw.prod avoids CUDA sync
1004
+ sizes = (torch .tensor (grid_thw_list , dtype = torch .long ).prod (- 1 ) //
1005
+ (merge_size * merge_size )).tolist ()
1002
1006
1003
- return video_embeds .split (sizes . tolist () )
1007
+ return video_embeds .split (sizes )
1004
1008
1005
1009
def _parse_and_validate_multimodal_inputs (self , ** kwargs : object ) -> dict :
1006
1010
mm_input_by_modality = {}
0 commit comments