77 _pick ,
88 default_num_hidden_layers as nhl ,
99)
10+ from ..helpers .mini_onnx_builder import create_input_tensors_from_onnx_model
11+ from .data import get_data
1012
1113__TASK__ = "image-text-to-text"
1214
@@ -124,13 +126,39 @@ def _get_inputs_gemma3(
124126 token_type_ids:T7s1x1,
125127 cache_position:T7s1,
126128 logits_to_keep:1)
129+
130+ **google/gemma-3-4b-it**
131+
132+ iteration 1
133+
134+ ::
135+ cache_position:T7s281,
136+ input_ids:T7s1x281,
137+ token_type_ids:T7s1x281,
138+ attention_mask:dict(sliding_attention:T9s1x1x281x580,
139+ full_attention:T9s1x1x281x580),
140+ pixel_values:T16s1x3x896x896,
141+
142+ iteration 2
143+
144+ ::
145+
146+ cache_position:T7s1,
147+ past_key_values:StaticCache(key_cache=#34[T1s1x4x580x256,...],
148+ value_cache=#34[T1s1x4x580x256,...]),
149+ input_ids:T7s1x1,
150+ inputs_embeds:None,
151+ token_type_ids:T7s1x1,
152+ attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580),
153+ position_ids:None,
154+ use_cache:bool,logits_to_keep:None,return_dict:bool)
155+
127156 """
128157 assert (
129158 "cls_cache" not in kwargs
130159 ), f"Not yet implemented for cls_cache={ kwargs ['cls_cache' ]!r} ."
131160 batch = "batch"
132- seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
133- # cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
161+ seq_length = "seq_length"
134162
135163 shapes = {
136164 "input_ids" : {0 : batch , 1 : seq_length },
@@ -149,13 +177,15 @@ def _get_inputs_gemma3(
149177 "use_cache" : None ,
150178 }
151179
152- input_ids = torch .randint (0 , dummy_max_token_id , (batch_size , sequence_length2 )).to (
153- torch .int64
180+ # first iteration
181+ dummies = create_input_tensors_from_onnx_model (
182+ get_data ("dummies_imagetext2text_generation_gemma3.onnx" )
154183 )
155- input_ids [:, 1 ] = image_token_index
156- # input_ids[input_ids == image_token_index] = pad_token_id
157- token_type_ids = torch .zeros_like (input_ids )
158- token_type_ids [input_ids == image_token_index ] = 1
184+ dummies = {k : v for k , v in dummies .items () if k in shapes }
185+ expected = {"input_ids" , "token_type_ids" , "position_ids" , "cache_position" }
186+ assert expected & set (
187+ dummies
188+ ), f"Unable to find expected inputs { expected } in loaded inputs { set (dummines )} "
159189
160190 inputs = dict (
161191 input_ids = input_ids ,
0 commit comments