@@ -94,11 +94,11 @@ def _get_inputs_gemma3(
9494 width : int ,
9595 height : int ,
9696 num_channels : int ,
97- batch_size : int = 1 ,
98- sequence_length : int = 281 ,
99- n_images : int = 1 ,
100- max_sequence_length : int = 580 ,
101- total_sequence_length : int = 860 ,
97+ batch_size : Optional [ int ] = 1 ,
98+ sequence_length : Optional [ int ] = 281 ,
99+ n_images : Optional [ int ] = 1 ,
100+ max_sequence_length : Optional [ int ] = 580 ,
101+ total_sequence_length : Optional [ int ] = 860 ,
102102 ** kwargs , # unused
103103):
104104 """
@@ -129,6 +129,12 @@ def _get_inputs_gemma3(
129129 attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580),
130130 position_ids:None,
131131 """
132+ batch_size = 1 if batch_size is None else batch_size
133+ sequence_length = 281 if sequence_length is None else sequence_length
134+ n_images = 1 if n_images is None else n_images
135+ max_sequence_length = 580 if max_sequence_length is None else max_sequence_length
136+ total_sequence_length = 860 if total_sequence_length is None else total_sequence_length
137+
132138 assert (
133139 "cls_cache" not in kwargs
134140 ), f"Not yet implemented for cls_cache={ kwargs ['cls_cache' ]!r} ."
@@ -224,6 +230,111 @@ def _check_():
224230 return dict (inputs = inputs , dynamic_shapes = shapes )
225231
226232
233+ def get_inputs_default (
234+ model : torch .nn .Module ,
235+ config : Optional [Any ],
236+ dummy_max_token_id : int ,
237+ num_key_value_heads : int ,
238+ num_hidden_layers : int ,
239+ pad_token_id : int ,
240+ image_token_index : int ,
241+ head_dim : int ,
242+ width : int ,
243+ height : int ,
244+ num_channels : int ,
245+ batch_size : Optional [int ] = 2 ,
246+ sequence_length : Optional [int ] = 43 ,
247+ n_images : Optional [int ] = 2 ,
248+ max_sequence_length : Optional [int ] = 43 ,
249+ total_sequence_length : Optional [int ] = 43 ,
250+ add_second_input : int = 0 ,
251+ ** kwargs , # unused
252+ ):
253+ batch_size = 2 if batch_size is None else batch_size
254+ sequence_length = 43 if sequence_length is None else sequence_length
255+ n_images = 2 if n_images is None else n_images
256+ max_sequence_length = 43 if max_sequence_length is None else max_sequence_length
257+ total_sequence_length = 43 if total_sequence_length is None else total_sequence_length
258+
259+ assert (
260+ "cls_cache" not in kwargs
261+ ), f"Not yet implemented for cls_cache={ kwargs ['cls_cache' ]!r} ."
262+ batch = "batch"
263+ batch_img = torch .export .Dim ("batch_img" , min = 1 , max = 1024 )
264+ seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
265+ cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
266+ images = "images" # torch.export.Dim("images", min=1, max=4096)
267+
268+ shapes = {
269+ "input_ids" : {0 : batch , 1 : seq_length },
270+ "token_type_ids" : {0 : batch , 1 : seq_length },
271+ "attention_mask" : {0 : batch , 1 : "cache+seq" },
272+ "position_ids" : {0 : batch , 1 : "cache+seq" },
273+ "past_key_values" : [
274+ [{0 : batch } for _ in range (num_hidden_layers )],
275+ [{0 : batch , 2 : cache_length } for _ in range (num_hidden_layers )],
276+ ],
277+ "pixel_values" : (
278+ {0 : batch , 1 : images }
279+ if model .__class__ .__name__ == "IdeficsForVisionText2Text"
280+ else {0 : batch_img }
281+ ),
282+ "image_attention_mask" : {0 : batch , 1 : seq_length , 2 : images },
283+ "image_grid_thw" : {0 : batch },
284+ "use_cache" : None ,
285+ }
286+
287+ input_ids = torch .randint (0 , dummy_max_token_id , (batch_size , total_sequence_length )).to (
288+ torch .int64
289+ )
290+ input_ids [0 , 0 ] = image_token_index
291+ input_ids [1 , 1 ] = image_token_index
292+ # input_ids[input_ids == image_token_index] = pad_token_id
293+ token_type_ids = torch .zeros_like (input_ids )
294+ token_type_ids [input_ids == image_token_index ] = 1
295+ image_grid_thw = torch .zeros ((n_images , 3 ), dtype = torch .int64 )
296+ image_grid_thw [:, 1 ] = height
297+ image_grid_thw [:, 2 ] = width
298+ image_grid_thw [0 , :] //= 2
299+ image_grid_thw [:, 0 ] = torch .arange (n_images , dtype = image_grid_thw .dtype )
300+
301+ inputs = dict (
302+ input_ids = input_ids ,
303+ attention_mask = torch .cat (
304+ [
305+ torch .ones ((batch_size , sequence_length ), dtype = torch .int64 ),
306+ input_ids .ne (pad_token_id ).to (torch .int64 ),
307+ ],
308+ axis = - 1 ,
309+ ),
310+ position_ids = torch .arange (0 , total_sequence_length )
311+ .to (torch .int64 )
312+ .expand ((batch_size , - 1 )),
313+ past_key_values = make_dynamic_cache (
314+ [
315+ (
316+ torch .randn (batch_size , num_key_value_heads , sequence_length , head_dim ),
317+ torch .randn (batch_size , num_key_value_heads , sequence_length , head_dim ),
318+ )
319+ for i in range (num_hidden_layers )
320+ ]
321+ ),
322+ pixel_values = (
323+ torch .randn ((batch_size , n_images , num_channels , width , height )).clamp (- 1 , 1 )
324+ if model .__class__ .__name__ == "IdeficsForVisionText2Text"
325+ else torch .randn (n_images , num_channels , width , height ).clamp (- 1 , 1 )
326+ ),
327+ # image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
328+ # torch.int64
329+ # ),
330+ token_type_ids = token_type_ids ,
331+ image_grid_thw = image_grid_thw ,
332+ use_cache = True , # Gemma3 does not set this value to true when a cache is provided
333+ )
334+ res = dict (inputs = inputs , dynamic_shapes = shapes )
335+ return res
336+
337+
227338def get_inputs (
228339 model : torch .nn .Module ,
229340 config : Optional [Any ],
@@ -236,11 +347,11 @@ def get_inputs(
236347 width : int ,
237348 height : int ,
238349 num_channels : int ,
239- batch_size : int = 1 ,
240- sequence_length : int = 281 ,
241- n_images : int = 1 ,
242- max_sequence_length : int = 580 ,
243- total_sequence_length : int = 860 ,
350+ batch_size : Optional [ int ] = None ,
351+ sequence_length : Optional [ int ] = None ,
352+ n_images : Optional [ int ] = None ,
353+ max_sequence_length : Optional [ int ] = None ,
354+ total_sequence_length : Optional [ int ] = None ,
244355 add_second_input : int = 0 ,
245356 ** kwargs , # unused
246357):
@@ -290,86 +401,26 @@ def get_inputs(
290401 ** kwargs ,
291402 )
292403 else :
293- assert (
294- "cls_cache" not in kwargs
295- ), f"Not yet implemented for cls_cache={ kwargs ['cls_cache' ]!r} ."
296- batch = "batch"
297- batch_img = torch .export .Dim ("batch_img" , min = 1 , max = 1024 )
298- seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
299- cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
300- images = "images" # torch.export.Dim("images", min=1, max=4096)
301-
302- shapes = {
303- "input_ids" : {0 : batch , 1 : seq_length },
304- "token_type_ids" : {0 : batch , 1 : seq_length },
305- "attention_mask" : {0 : batch , 1 : "cache+seq" },
306- "position_ids" : {0 : batch , 1 : "cache+seq" },
307- "past_key_values" : [
308- [{0 : batch } for _ in range (num_hidden_layers )],
309- [{0 : batch , 2 : cache_length } for _ in range (num_hidden_layers )],
310- ],
311- "pixel_values" : (
312- {0 : batch , 1 : images }
313- if model .__class__ .__name__ == "IdeficsForVisionText2Text"
314- else {0 : batch_img }
315- ),
316- "image_attention_mask" : {0 : batch , 1 : seq_length , 2 : images },
317- "image_grid_thw" : {0 : batch },
318- "use_cache" : None ,
319- }
320-
321- input_ids = torch .randint (
322- 0 , dummy_max_token_id , (batch_size , total_sequence_length )
323- ).to (torch .int64 )
324- input_ids [0 , 0 ] = image_token_index
325- input_ids [1 , 1 ] = image_token_index
326- # input_ids[input_ids == image_token_index] = pad_token_id
327- token_type_ids = torch .zeros_like (input_ids )
328- token_type_ids [input_ids == image_token_index ] = 1
329- image_grid_thw = torch .zeros ((n_images , 3 ), dtype = torch .int64 )
330- image_grid_thw [:, 1 ] = height
331- image_grid_thw [:, 2 ] = width
332- image_grid_thw [0 , :] //= 2
333- image_grid_thw [:, 0 ] = torch .arange (n_images , dtype = image_grid_thw .dtype )
334-
335- inputs = dict (
336- input_ids = input_ids ,
337- attention_mask = torch .cat (
338- [
339- torch .ones ((batch_size , sequence_length ), dtype = torch .int64 ),
340- input_ids .ne (pad_token_id ).to (torch .int64 ),
341- ],
342- axis = - 1 ,
343- ),
344- position_ids = torch .arange (0 , total_sequence_length )
345- .to (torch .int64 )
346- .expand ((batch_size , - 1 )),
347- past_key_values = make_dynamic_cache (
348- [
349- (
350- torch .randn (
351- batch_size , num_key_value_heads , sequence_length , head_dim
352- ),
353- torch .randn (
354- batch_size , num_key_value_heads , sequence_length , head_dim
355- ),
356- )
357- for i in range (num_hidden_layers )
358- ]
359- ),
360- pixel_values = (
361- torch .randn ((batch_size , n_images , num_channels , width , height )).clamp (- 1 , 1 )
362- if model .__class__ .__name__ == "IdeficsForVisionText2Text"
363- else torch .randn (n_images , num_channels , width , height ).clamp (- 1 , 1 )
364- ),
365- # image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
366- # torch.int64
367- # ),
368- token_type_ids = token_type_ids ,
369- image_grid_thw = image_grid_thw ,
370- use_cache = True , # Gemma3 does not set this value to true when a cache is provided
404+ res = get_inputs_default (
405+ model ,
406+ config ,
407+ dummy_max_token_id = dummy_max_token_id ,
408+ num_key_value_heads = num_key_value_heads ,
409+ num_hidden_layers = num_hidden_layers ,
410+ pad_token_id = pad_token_id ,
411+ image_token_index = image_token_index ,
412+ head_dim = head_dim ,
413+ width = width ,
414+ height = height ,
415+ num_channels = num_channels ,
416+ batch_size = batch_size ,
417+ sequence_length = sequence_length ,
418+ max_sequence_length = max_sequence_length ,
419+ total_sequence_length = total_sequence_length ,
420+ n_images = n_images ,
421+ ** kwargs ,
371422 )
372- res = dict ( inputs = inputs , dynamic_shapes = shapes )
423+
373424 if add_second_input :
374425 assert (
375426 add_second_input > 0
@@ -384,7 +435,7 @@ def get_inputs(
384435 width = width ,
385436 height = height ,
386437 num_channels = num_channels ,
387- batch_size = batch_size + 1 ,
438+ batch_size = 3 ,
388439 sequence_length = 0 ,
389440 max_sequence_length = 0 ,
390441 total_sequence_length = 0 ,
@@ -431,9 +482,6 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
431482 text_config = False
432483 check_hasattr (config .vision_config , ("num_channels" , "in_chans" , "in_channels" ))
433484 kwargs = dict (
434- sequence_length = 281 ,
435- max_sequence_length = 580 ,
436- total_sequence_length = 860 ,
437485 head_dim = (
438486 16
439487 if config is None
0 commit comments