77 _pick ,
88 default_num_hidden_layers as nhl ,
99)
10- from ..helpers .mini_onnx_builder import create_input_tensors_from_onnx_model
1110from .data import get_data
1211
1312__TASK__ = "image-text-to-text"
@@ -95,37 +94,15 @@ def _get_inputs_gemma3(
9594 width : int ,
9695 height : int ,
9796 num_channels : int ,
98- batch_size : int = 2 ,
99- sequence_length : int = 43 ,
100- sequence_length2 : int = 43 ,
101- n_images : int = 2 ,
102- dynamic_rope : bool = False ,
103- max_sequence_length : int = 380 ,
97+ batch_size : int = 1 ,
98+ sequence_length : int = 281 ,
99+ n_images : int = 1 ,
100+ max_sequence_length : int = 580 ,
101+ total_sequence_length : int = 860 ,
104102 ** kwargs , # unused
105103):
106104 """
107- ::
108-
109- dict(input_ids:T7s1x281,
110- pixel_values:T16s1x3x896x896,
111- attention_mask:dict(full_attention:T9s1x1x281x380,sliding_attention:T9s1x1x281x380),
112- position_ids:T7s1x281,
113- past_key_values:HybridCache(
114- key_cache=#34[T1s1x4x380x256,...],
115- value_cache=#34[T1s1x4x380x256,...]),
116- token_type_ids:T7s1x281,
117- cache_position:T7s281,
118- logits_to_keep:1)
119- dict(input_ids:T7s1x1,
120- pixel_values:None,
121- attention_mask:dict(full_attention:T9s1x1x1x380,sliding_attention:T9s1x1x1x380),
122- position_ids:T7s1x1,
123- past_key_values:HybridCache(
124- key_cache=#34[T1s1x4x380x256,...],
125- value_cache=#34[T1s1x4x380x256,...]),
126- token_type_ids:T7s1x1,
127- cache_position:T7s1,
128- logits_to_keep:1)
105+ The functions uses predefined values for input_ids and token_type_ids.
129106
130107 **google/gemma-3-4b-it**
131108
@@ -151,21 +128,20 @@ def _get_inputs_gemma3(
151128 token_type_ids:T7s1x1,
152129 attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580),
153130 position_ids:None,
154- use_cache:bool,logits_to_keep:None,return_dict:bool)
155-
156131 """
157132 assert (
158133 "cls_cache" not in kwargs
159134 ), f"Not yet implemented for cls_cache={ kwargs ['cls_cache' ]!r} ."
160135 batch = "batch"
161136 seq_length = "seq_length"
137+ tot_length = "total_length"
162138
163139 shapes = {
164140 "input_ids" : {0 : batch , 1 : seq_length },
165141 "token_type_ids" : {0 : batch , 1 : seq_length },
166142 "attention_mask" : {
167- "full_attention" : {0 : batch , 2 : seq_length },
168- "sliding_attention" : {0 : batch , 2 : seq_length },
143+ "full_attention" : {0 : batch , 2 : seq_length , 3 : tot_length },
144+ "sliding_attention" : {0 : batch , 2 : seq_length , 3 : tot_length },
169145 },
170146 "position_ids" : {0 : batch , 1 : seq_length },
171147 "cache_position" : {1 : seq_length },
@@ -177,22 +153,46 @@ def _get_inputs_gemma3(
177153 "use_cache" : None ,
178154 }
179155
180- # first iteration
181- dummies = create_input_tensors_from_onnx_model (
182- get_data ("dummies_imagetext2text_generation_gemma3.onnx" )
183- )
156+ # retrieve specific inputs to keep the consistency between
157+ # ids and images
158+ dummies = get_data ("dummies_imagetext2text_generation_gemma3.onnx" )
159+ dummies = dummies [( "" , 0 , "I" )][ 1 ]
184160 dummies = {k : v for k , v in dummies .items () if k in shapes }
185161 expected = {"input_ids" , "token_type_ids" , "position_ids" , "cache_position" }
186162 assert expected & set (
187163 dummies
188164 ), f"Unable to find expected inputs { expected } in loaded inputs { set (dummies )} "
165+ assert sequence_length == dummies ["input_ids" ].shape [- 1 ], (
166+ f"sequence_length={ sequence_length } != { dummies ['input_ids' ].shape [- 1 ]} for "
167+ f"model class { model .__class__ .__name__ } "
168+ )
169+ assert batch_size == dummies ["input_ids" ].shape [0 ], (
170+ f"batch_size={ batch_size } != { dummies ['input_ids' ].shape [0 ]} for "
171+ f"model class { model .__class__ .__name__ } "
172+ )
173+ assert max_sequence_length == 580 , (
174+ f"max_sequence_length={ max_sequence_length } != 580 "
175+ f"for model { model .__class__ .__name__ } "
176+ )
177+ assert total_sequence_length == 860 , (
178+ f"total_sequence_length={ total_sequence_length } != 860 "
179+ f"for model { model .__class__ .__name__ } "
180+ )
181+ assert head_dim == 256 , f"head_dim={ head_dim } != 256 for model { model .__class__ .__name__ } "
182+ assert n_images == 1 , f"n_images={ n_images } != 1 for model { model .__class__ .__name__ } "
183+ assert num_key_value_heads == 4 , (
184+ f"num_key_value_heads={ num_key_value_heads } != 256 "
185+ f"for this model { model .__class__ .__name__ } "
186+ )
189187
190188 inputs = dict (
191- input_ids = input_ids ,
192- token_type_ids = token_type_ids ,
189+ input_ids = dummies [ " input_ids" ] ,
190+ token_type_ids = dummies [ " token_type_ids" ] ,
193191 attention_mask = dict (
194- full_attention = torch .randn (batch_size , 1 , sequence_length , max_sequence_length ),
195- sliding_attention = torch .randn (batch_size , 1 , sequence_length , max_sequence_length ),
192+ full_attention = torch .randn (batch_size , 1 , sequence_length , total_sequence_length ),
193+ sliding_attention = torch .randn (
194+ batch_size , 1 , sequence_length , total_sequence_length
195+ ),
196196 ),
197197 cache_position = torch .arange (0 , sequence_length ).to (torch .int64 ),
198198 position_ids = torch .arange (0 , sequence_length ).to (torch .int64 ).expand ((batch_size , - 1 )),
@@ -210,9 +210,9 @@ def _get_inputs_gemma3(
210210 ]
211211 ),
212212 pixel_values = torch .randn (n_images , num_channels , width , height ).clamp (- 1 , 1 ),
213- image_attention_mask = torch .ones ((batch_size , sequence_length2 , n_images )).to (
214- torch .int64
215- ),
213+ # image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
214+ # torch.int64
215+ # ),
216216 use_cache = True , # Gemma3 does not set this value to true when a cache is provided
217217 )
218218 return dict (inputs = inputs , dynamic_shapes = shapes )
@@ -230,12 +230,12 @@ def get_inputs(
230230 width : int ,
231231 height : int ,
232232 num_channels : int ,
233- batch_size : int = 2 ,
234- sequence_length : int = 43 ,
235- sequence_length2 : int = 43 ,
236- n_images : int = 2 ,
237- dynamic_rope : bool = False ,
238- add_second_input : int = 1 ,
233+ batch_size : int = 1 ,
234+ sequence_length : int = 281 ,
235+ n_images : int = 1 ,
236+ max_sequence_length : int = 580 ,
237+ total_sequence_length : int = 860 ,
238+ add_second_input : int = 0 ,
239239 ** kwargs , # unused
240240):
241241 """
@@ -249,13 +249,19 @@ def get_inputs(
249249 :param image_token_index: image_token_index
250250 :param batch_size: batch size
251251 :param sequence_length: sequence length
252- :param sequence_length2: new sequence length
252+ :param max_sequence_length: for the cache
253+ :param total_sequence_length: for the mask
253254 :param n_images: number of images
254255 :param width: width of the image
255256 :param height: height of the image
256257 :param num_channels: number of channels
257- :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
258258 :return: dictionary
259+
260+ .. note::
261+
262+ The content of the input_ids and its shape is correlated to the images.
263+ The function uses a predefined values. The function raises an exception
264+ if dimension are not the expected ones.
259265 """
260266 if model .__class__ .__name__ .startswith ("Gemma3" ):
261267 res = _get_inputs_gemma3 (
@@ -272,9 +278,9 @@ def get_inputs(
272278 num_channels = num_channels ,
273279 batch_size = batch_size ,
274280 sequence_length = sequence_length ,
275- sequence_length2 = sequence_length2 ,
281+ max_sequence_length = max_sequence_length ,
282+ total_sequence_length = total_sequence_length ,
276283 n_images = n_images ,
277- dynamic_rope = dynamic_rope ,
278284 ** kwargs ,
279285 )
280286 else :
@@ -306,9 +312,9 @@ def get_inputs(
306312 "use_cache" : None ,
307313 }
308314
309- input_ids = torch .randint (0 , dummy_max_token_id , ( batch_size , sequence_length2 )). to (
310- torch . int64
311- )
315+ input_ids = torch .randint (
316+ 0 , dummy_max_token_id , ( batch_size , total_sequence_length )
317+ ). to ( torch . int64 )
312318 input_ids [0 , 0 ] = image_token_index
313319 input_ids [1 , 1 ] = image_token_index
314320 # input_ids[input_ids == image_token_index] = pad_token_id
@@ -329,7 +335,7 @@ def get_inputs(
329335 ],
330336 axis = - 1 ,
331337 ),
332- position_ids = torch .arange (0 , sequence_length2 )
338+ position_ids = torch .arange (0 , total_sequence_length )
333339 .to (torch .int64 )
334340 .expand ((batch_size , - 1 )),
335341 past_key_values = make_dynamic_cache (
@@ -350,9 +356,9 @@ def get_inputs(
350356 if model .__class__ .__name__ == "IdeficsForVisionText2Text"
351357 else torch .randn (n_images , num_channels , width , height ).clamp (- 1 , 1 )
352358 ),
353- image_attention_mask = torch .ones ((batch_size , sequence_length2 , n_images )).to (
354- torch .int64
355- ),
359+ # image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
360+ # torch.int64
361+ # ),
356362 token_type_ids = token_type_ids ,
357363 image_grid_thw = image_grid_thw ,
358364 use_cache = True , # Gemma3 does not set this value to true when a cache is provided
@@ -373,10 +379,10 @@ def get_inputs(
373379 height = height ,
374380 num_channels = num_channels ,
375381 batch_size = batch_size + 1 ,
376- sequence_length = sequence_length + add_second_input ,
377- sequence_length2 = sequence_length2 + 1 ,
378- n_images = n_images + 1 ,
379- dynamic_rope = dynamic_rope ,
382+ sequence_length = 0 ,
383+ max_sequence_length = 0 ,
384+ total_sequence_length = 0 ,
385+ n_images = 0 ,
380386 pad_token_id = pad_token_id ,
381387 image_token_index = image_token_index ,
382388 add_second_input = 0 ,
@@ -419,9 +425,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
419425 text_config = False
420426 check_hasattr (config .vision_config , ("num_channels" , "in_chans" , "in_channels" ))
421427 kwargs = dict (
422- batch_size = 2 ,
423- sequence_length = 43 ,
424- sequence_length2 = 43 ,
428+ sequence_length = 281 ,
429+ max_sequence_length = 580 ,
430+ total_sequence_length = 860 ,
425431 head_dim = (
426432 16
427433 if config is None
0 commit comments