@@ -1047,7 +1047,12 @@ def get_specializations(
10471047 max_pixels = mm_processor_kwargs .get ("max_pixels" , max_pixels )
10481048
10491049 vision = []
1050- min_vision_size = ctx_len
1050+ min_vision_size = None
1051+ user_vision_size = compiler_options .pop ("vision_size" , None )
1052+ if user_vision_size :
1053+ assert user_vision_size < ctx_len , "vision_size must be less than ctx_len"
1054+ else :
1055+ min_vision_size = ctx_len
10511056 for h , w in zip (height , width ):
10521057 resized_height , resized_width = smart_resize (
10531058 height = h , width = w , factor = IMAGE_FACTOR , min_pixels = min_pixels , max_pixels = max_pixels
@@ -1057,7 +1062,8 @@ def get_specializations(
10571062 grid_width = patch_size * patch_size * temporal_patch_size * channel
10581063 vision_size = grid_height // 4
10591064 grid_height = grid_height * batch_size
1060- min_vision_size = min (min_vision_size , vision_size * num_frames )
1065+ if not user_vision_size :
1066+ min_vision_size = min (min_vision_size , vision_size * num_frames )
10611067
10621068 vision .append (
10631069 {
@@ -1078,7 +1084,7 @@ def get_specializations(
10781084 "batch_size" : 1 if continuous_batching else batch_size ,
10791085 "seq_len" : prefill_seq_len ,
10801086 "ctx_len" : ctx_len ,
1081- "vision_size" : min_vision_size ,
1087+ "vision_size" : min_vision_size if not user_vision_size else user_vision_size ,
10821088 "comp_ctx_lengths" : comp_ctx_lengths_prefill [i ],
10831089 "vision_batch_size" : batch_size ,
10841090 }
@@ -1097,7 +1103,7 @@ def get_specializations(
10971103 "batch_size" : full_batch_size if continuous_batching else batch_size ,
10981104 "seq_len" : "1" ,
10991105 "ctx_len" : ctx_len ,
1100- "vision_size" : min_vision_size ,
1106+ "vision_size" : min_vision_size if not user_vision_size else user_vision_size ,
11011107 "comp_ctx_lengths" : comp_ctx_lengths_decode [i ],
11021108 "vision_batch_size" : batch_size ,
11031109 }
@@ -1113,7 +1119,7 @@ def get_specializations(
11131119 "batch_size" : 1 if continuous_batching else batch_size ,
11141120 "seq_len" : prefill_seq_len ,
11151121 "ctx_len" : ctx_len ,
1116- "vision_size" : min_vision_size ,
1122+ "vision_size" : min_vision_size if not user_vision_size else user_vision_size ,
11171123 "vision_batch_size" : batch_size ,
11181124 }
11191125
@@ -1128,7 +1134,7 @@ def get_specializations(
11281134 "batch_size" : full_batch_size if continuous_batching else batch_size ,
11291135 "seq_len" : 1 ,
11301136 "ctx_len" : ctx_len ,
1131- "vision_size" : min_vision_size ,
1137+ "vision_size" : min_vision_size if not user_vision_size else user_vision_size ,
11321138 "vision_batch_size" : batch_size ,
11331139 }
11341140
0 commit comments