@@ -256,6 +256,7 @@ def get_inputs_default(
256256 max_sequence_length = 43 if max_sequence_length is None else max_sequence_length
257257 total_sequence_length = 43 if total_sequence_length is None else total_sequence_length
258258
259+ assert batch_size > 0 , "batch_size cannot be null"
259260 assert (
260261 "cls_cache" not in kwargs
261262 ), f"Not yet implemented for cls_cache={ kwargs ['cls_cache' ]!r} ."
@@ -287,19 +288,22 @@ def get_inputs_default(
287288 input_ids = torch .randint (0 , dummy_max_token_id , (batch_size , total_sequence_length )).to (
288289 torch .int64
289290 )
290- input_ids [0 , 0 ] = image_token_index
291- input_ids [1 , 1 ] = image_token_index
291+ if total_sequence_length > 0 :
292+ input_ids [0 , 0 ] = image_token_index
293+ input_ids [1 , 1 ] = image_token_index
292294 # input_ids[input_ids == image_token_index] = pad_token_id
293295 token_type_ids = torch .zeros_like (input_ids )
294296 token_type_ids [input_ids == image_token_index ] = 1
295297 image_grid_thw = torch .zeros ((n_images , 3 ), dtype = torch .int64 )
296- image_grid_thw [:, 1 ] = height
297- image_grid_thw [:, 2 ] = width
298- image_grid_thw [0 , :] //= 2
299- image_grid_thw [:, 0 ] = torch .arange (n_images , dtype = image_grid_thw .dtype )
298+ if n_images > 0 :
299+ image_grid_thw [:, 1 ] = height
300+ image_grid_thw [:, 2 ] = width
301+ image_grid_thw [0 , :] //= 2
302+ image_grid_thw [:, 0 ] = torch .arange (n_images , dtype = image_grid_thw .dtype )
300303
301304 inputs = dict (
302305 input_ids = input_ids ,
306+ token_type_ids = token_type_ids ,
303307 attention_mask = torch .cat (
304308 [
305309 torch .ones ((batch_size , sequence_length ), dtype = torch .int64 ),
@@ -324,10 +328,9 @@ def get_inputs_default(
324328 if model .__class__ .__name__ == "IdeficsForVisionText2Text"
325329 else torch .randn (n_images , num_channels , width , height ).clamp (- 1 , 1 )
326330 ),
327- # image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
328- # torch.int64
329- # ),
330- token_type_ids = token_type_ids ,
331+ image_attention_mask = torch .ones ((batch_size , total_sequence_length , n_images )).to (
332+ torch .int64
333+ ),
331334 image_grid_thw = image_grid_thw ,
332335 use_cache = True , # Gemma3 does not set this value to true when a cache is provided
333336 )
0 commit comments