@@ -144,7 +144,7 @@ def _get_inputs_gemma3(
144144 "sliding_attention" : {0 : batch , 2 : seq_length , 3 : tot_length },
145145 },
146146 "position_ids" : {0 : batch , 1 : seq_length },
147- "cache_position" : {1 : seq_length },
147+ "cache_position" : {0 : seq_length },
148148 "past_key_values" : [
149149 [{0 : batch } for _ in range (num_hidden_layers )],
150150 [{0 : batch } for _ in range (num_hidden_layers )],
@@ -159,31 +159,37 @@ def _get_inputs_gemma3(
159159 dummies = dummies [("" , 0 , "I" )][1 ]
160160 dummies = {k : v for k , v in dummies .items () if k in shapes }
161161 expected = {"input_ids" , "token_type_ids" , "position_ids" , "cache_position" }
162- assert expected & set (
163- dummies
164- ), f"Unable to find expected inputs { expected } in loaded inputs { set (dummies )} "
165- assert sequence_length == dummies ["input_ids" ].shape [- 1 ], (
166- f"sequence_length={ sequence_length } != { dummies ['input_ids' ].shape [- 1 ]} for "
167- f"model class { model .__class__ .__name__ } "
168- )
169- assert batch_size == dummies ["input_ids" ].shape [0 ], (
170- f"batch_size={ batch_size } != { dummies ['input_ids' ].shape [0 ]} for "
171- f"model class { model .__class__ .__name__ } "
172- )
173- assert max_sequence_length == 580 , (
174- f"max_sequence_length={ max_sequence_length } != 580 "
175- f"for model { model .__class__ .__name__ } "
176- )
177- assert total_sequence_length == 860 , (
178- f"total_sequence_length={ total_sequence_length } != 860 "
179- f"for model { model .__class__ .__name__ } "
180- )
181- assert head_dim == 256 , f"head_dim={ head_dim } != 256 for model { model .__class__ .__name__ } "
182- assert n_images == 1 , f"n_images={ n_images } != 1 for model { model .__class__ .__name__ } "
183- assert num_key_value_heads == 4 , (
184- f"num_key_value_heads={ num_key_value_heads } != 256 "
185- f"for this model { model .__class__ .__name__ } "
186- )
162+
163+ def _check_ ():
164+ assert expected & set (
165+ dummies
166+ ), f"Unable to find expected inputs { expected } in loaded inputs { set (dummies )} "
167+ assert sequence_length == dummies ["input_ids" ].shape [- 1 ], (
168+ f"sequence_length={ sequence_length } != { dummies ['input_ids' ].shape [- 1 ]} for "
169+ f"model class { model .__class__ .__name__ } "
170+ )
171+ assert batch_size == dummies ["input_ids" ].shape [0 ], (
172+ f"batch_size={ batch_size } != { dummies ['input_ids' ].shape [0 ]} for "
173+ f"model class { model .__class__ .__name__ } "
174+ )
175+ assert max_sequence_length == 580 , (
176+ f"max_sequence_length={ max_sequence_length } != 580 "
177+ f"for model { model .__class__ .__name__ } "
178+ )
179+ assert total_sequence_length == 860 , (
180+ f"total_sequence_length={ total_sequence_length } != 860 "
181+ f"for model { model .__class__ .__name__ } "
182+ )
183+ assert (
184+ head_dim == 256
185+ ), f"head_dim={ head_dim } != 256 for model { model .__class__ .__name__ } "
186+ assert n_images == 1 , f"n_images={ n_images } != 1 for model { model .__class__ .__name__ } "
187+ assert num_key_value_heads == 4 , (
188+ f"num_key_value_heads={ num_key_value_heads } != 256 "
189+ f"for this model { model .__class__ .__name__ } "
190+ )
191+
192+ _check_ ()
187193
188194 inputs = dict (
189195 input_ids = dummies ["input_ids" ],
0 commit comments