Skip to content

Commit c28b9e5

Browse files
committed
one fix
1 parent 4f52b4f commit c28b9e5

File tree

2 files changed

+177
-93
lines changed

2 files changed

+177
-93
lines changed

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 141 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -94,11 +94,11 @@ def _get_inputs_gemma3(
9494
width: int,
9595
height: int,
9696
num_channels: int,
97-
batch_size: int = 1,
98-
sequence_length: int = 281,
99-
n_images: int = 1,
100-
max_sequence_length: int = 580,
101-
total_sequence_length: int = 860,
97+
batch_size: Optional[int] = 1,
98+
sequence_length: Optional[int] = 281,
99+
n_images: Optional[int] = 1,
100+
max_sequence_length: Optional[int] = 580,
101+
total_sequence_length: Optional[int] = 860,
102102
**kwargs, # unused
103103
):
104104
"""
@@ -129,6 +129,12 @@ def _get_inputs_gemma3(
129129
attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580),
130130
position_ids:None,
131131
"""
132+
batch_size = 1 if batch_size is None else batch_size
133+
sequence_length = 281 if sequence_length is None else sequence_length
134+
n_images = 1 if n_images is None else n_images
135+
max_sequence_length = 580 if max_sequence_length is None else max_sequence_length
136+
total_sequence_length = 860 if total_sequence_length is None else total_sequence_length
137+
132138
assert (
133139
"cls_cache" not in kwargs
134140
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
@@ -224,6 +230,111 @@ def _check_():
224230
return dict(inputs=inputs, dynamic_shapes=shapes)
225231

226232

233+
def get_inputs_default(
234+
model: torch.nn.Module,
235+
config: Optional[Any],
236+
dummy_max_token_id: int,
237+
num_key_value_heads: int,
238+
num_hidden_layers: int,
239+
pad_token_id: int,
240+
image_token_index: int,
241+
head_dim: int,
242+
width: int,
243+
height: int,
244+
num_channels: int,
245+
batch_size: Optional[int] = 2,
246+
sequence_length: Optional[int] = 43,
247+
n_images: Optional[int] = 2,
248+
max_sequence_length: Optional[int] = 43,
249+
total_sequence_length: Optional[int] = 43,
250+
add_second_input: int = 0,
251+
**kwargs, # unused
252+
):
253+
batch_size = 2 if batch_size is None else batch_size
254+
sequence_length = 43 if sequence_length is None else sequence_length
255+
n_images = 2 if n_images is None else n_images
256+
max_sequence_length = 43 if max_sequence_length is None else max_sequence_length
257+
total_sequence_length = 43 if total_sequence_length is None else total_sequence_length
258+
259+
assert (
260+
"cls_cache" not in kwargs
261+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
262+
batch = "batch"
263+
batch_img = torch.export.Dim("batch_img", min=1, max=1024)
264+
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
265+
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
266+
images = "images" # torch.export.Dim("images", min=1, max=4096)
267+
268+
shapes = {
269+
"input_ids": {0: batch, 1: seq_length},
270+
"token_type_ids": {0: batch, 1: seq_length},
271+
"attention_mask": {0: batch, 1: "cache+seq"},
272+
"position_ids": {0: batch, 1: "cache+seq"},
273+
"past_key_values": [
274+
[{0: batch} for _ in range(num_hidden_layers)],
275+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
276+
],
277+
"pixel_values": (
278+
{0: batch, 1: images}
279+
if model.__class__.__name__ == "IdeficsForVisionText2Text"
280+
else {0: batch_img}
281+
),
282+
"image_attention_mask": {0: batch, 1: seq_length, 2: images},
283+
"image_grid_thw": {0: batch},
284+
"use_cache": None,
285+
}
286+
287+
input_ids = torch.randint(0, dummy_max_token_id, (batch_size, total_sequence_length)).to(
288+
torch.int64
289+
)
290+
input_ids[0, 0] = image_token_index
291+
input_ids[1, 1] = image_token_index
292+
# input_ids[input_ids == image_token_index] = pad_token_id
293+
token_type_ids = torch.zeros_like(input_ids)
294+
token_type_ids[input_ids == image_token_index] = 1
295+
image_grid_thw = torch.zeros((n_images, 3), dtype=torch.int64)
296+
image_grid_thw[:, 1] = height
297+
image_grid_thw[:, 2] = width
298+
image_grid_thw[0, :] //= 2
299+
image_grid_thw[:, 0] = torch.arange(n_images, dtype=image_grid_thw.dtype)
300+
301+
inputs = dict(
302+
input_ids=input_ids,
303+
attention_mask=torch.cat(
304+
[
305+
torch.ones((batch_size, sequence_length), dtype=torch.int64),
306+
input_ids.ne(pad_token_id).to(torch.int64),
307+
],
308+
axis=-1,
309+
),
310+
position_ids=torch.arange(0, total_sequence_length)
311+
.to(torch.int64)
312+
.expand((batch_size, -1)),
313+
past_key_values=make_dynamic_cache(
314+
[
315+
(
316+
torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
317+
torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
318+
)
319+
for i in range(num_hidden_layers)
320+
]
321+
),
322+
pixel_values=(
323+
torch.randn((batch_size, n_images, num_channels, width, height)).clamp(-1, 1)
324+
if model.__class__.__name__ == "IdeficsForVisionText2Text"
325+
else torch.randn(n_images, num_channels, width, height).clamp(-1, 1)
326+
),
327+
# image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
328+
# torch.int64
329+
# ),
330+
token_type_ids=token_type_ids,
331+
image_grid_thw=image_grid_thw,
332+
use_cache=True, # Gemma3 does not set this value to true when a cache is provided
333+
)
334+
res = dict(inputs=inputs, dynamic_shapes=shapes)
335+
return res
336+
337+
227338
def get_inputs(
228339
model: torch.nn.Module,
229340
config: Optional[Any],
@@ -236,11 +347,11 @@ def get_inputs(
236347
width: int,
237348
height: int,
238349
num_channels: int,
239-
batch_size: int = 1,
240-
sequence_length: int = 281,
241-
n_images: int = 1,
242-
max_sequence_length: int = 580,
243-
total_sequence_length: int = 860,
350+
batch_size: Optional[int] = None,
351+
sequence_length: Optional[int] = None,
352+
n_images: Optional[int] = None,
353+
max_sequence_length: Optional[int] = None,
354+
total_sequence_length: Optional[int] = None,
244355
add_second_input: int = 0,
245356
**kwargs, # unused
246357
):
@@ -290,86 +401,26 @@ def get_inputs(
290401
**kwargs,
291402
)
292403
else:
293-
assert (
294-
"cls_cache" not in kwargs
295-
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
296-
batch = "batch"
297-
batch_img = torch.export.Dim("batch_img", min=1, max=1024)
298-
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
299-
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
300-
images = "images" # torch.export.Dim("images", min=1, max=4096)
301-
302-
shapes = {
303-
"input_ids": {0: batch, 1: seq_length},
304-
"token_type_ids": {0: batch, 1: seq_length},
305-
"attention_mask": {0: batch, 1: "cache+seq"},
306-
"position_ids": {0: batch, 1: "cache+seq"},
307-
"past_key_values": [
308-
[{0: batch} for _ in range(num_hidden_layers)],
309-
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
310-
],
311-
"pixel_values": (
312-
{0: batch, 1: images}
313-
if model.__class__.__name__ == "IdeficsForVisionText2Text"
314-
else {0: batch_img}
315-
),
316-
"image_attention_mask": {0: batch, 1: seq_length, 2: images},
317-
"image_grid_thw": {0: batch},
318-
"use_cache": None,
319-
}
320-
321-
input_ids = torch.randint(
322-
0, dummy_max_token_id, (batch_size, total_sequence_length)
323-
).to(torch.int64)
324-
input_ids[0, 0] = image_token_index
325-
input_ids[1, 1] = image_token_index
326-
# input_ids[input_ids == image_token_index] = pad_token_id
327-
token_type_ids = torch.zeros_like(input_ids)
328-
token_type_ids[input_ids == image_token_index] = 1
329-
image_grid_thw = torch.zeros((n_images, 3), dtype=torch.int64)
330-
image_grid_thw[:, 1] = height
331-
image_grid_thw[:, 2] = width
332-
image_grid_thw[0, :] //= 2
333-
image_grid_thw[:, 0] = torch.arange(n_images, dtype=image_grid_thw.dtype)
334-
335-
inputs = dict(
336-
input_ids=input_ids,
337-
attention_mask=torch.cat(
338-
[
339-
torch.ones((batch_size, sequence_length), dtype=torch.int64),
340-
input_ids.ne(pad_token_id).to(torch.int64),
341-
],
342-
axis=-1,
343-
),
344-
position_ids=torch.arange(0, total_sequence_length)
345-
.to(torch.int64)
346-
.expand((batch_size, -1)),
347-
past_key_values=make_dynamic_cache(
348-
[
349-
(
350-
torch.randn(
351-
batch_size, num_key_value_heads, sequence_length, head_dim
352-
),
353-
torch.randn(
354-
batch_size, num_key_value_heads, sequence_length, head_dim
355-
),
356-
)
357-
for i in range(num_hidden_layers)
358-
]
359-
),
360-
pixel_values=(
361-
torch.randn((batch_size, n_images, num_channels, width, height)).clamp(-1, 1)
362-
if model.__class__.__name__ == "IdeficsForVisionText2Text"
363-
else torch.randn(n_images, num_channels, width, height).clamp(-1, 1)
364-
),
365-
# image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
366-
# torch.int64
367-
# ),
368-
token_type_ids=token_type_ids,
369-
image_grid_thw=image_grid_thw,
370-
use_cache=True, # Gemma3 does not set this value to true when a cache is provided
404+
res = get_inputs_default(
405+
model,
406+
config,
407+
dummy_max_token_id=dummy_max_token_id,
408+
num_key_value_heads=num_key_value_heads,
409+
num_hidden_layers=num_hidden_layers,
410+
pad_token_id=pad_token_id,
411+
image_token_index=image_token_index,
412+
head_dim=head_dim,
413+
width=width,
414+
height=height,
415+
num_channels=num_channels,
416+
batch_size=batch_size,
417+
sequence_length=sequence_length,
418+
max_sequence_length=max_sequence_length,
419+
total_sequence_length=total_sequence_length,
420+
n_images=n_images,
421+
**kwargs,
371422
)
372-
res = dict(inputs=inputs, dynamic_shapes=shapes)
423+
373424
if add_second_input:
374425
assert (
375426
add_second_input > 0
@@ -384,7 +435,7 @@ def get_inputs(
384435
width=width,
385436
height=height,
386437
num_channels=num_channels,
387-
batch_size=batch_size + 1,
438+
batch_size=3,
388439
sequence_length=0,
389440
max_sequence_length=0,
390441
total_sequence_length=0,
@@ -431,9 +482,6 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
431482
text_config = False
432483
check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
433484
kwargs = dict(
434-
sequence_length=281,
435-
max_sequence_length=580,
436-
total_sequence_length=860,
437485
head_dim=(
438486
16
439487
if config is None

onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4829,3 +4829,39 @@ def _ccached_microsoft_phi3_mini_128k_instruct():
48294829
"vocab_size": 32064,
48304830
}
48314831
)
4832+
4833+
4834+
def _ccached_google_gemma_3_4b_it_like():
4835+
"google/gemma-3-4b-it"
4836+
return transformers.Gemma3Config(
4837+
**{
4838+
"architectures": ["Gemma3ForConditionalGeneration"],
4839+
"boi_token_index": 255999,
4840+
"eoi_token_index": 256000,
4841+
"eos_token_id": [1, 106],
4842+
"image_token_index": 262144,
4843+
"initializer_range": 0.02,
4844+
"mm_tokens_per_image": 256,
4845+
"model_type": "gemma3",
4846+
"text_config": {
4847+
"hidden_size": 2560,
4848+
"intermediate_size": 10240,
4849+
"model_type": "gemma3_text",
4850+
"num_hidden_layers": 34,
4851+
"rope_scaling": {"factor": 8.0, "rope_type": "linear"},
4852+
"sliding_window": 1024,
4853+
},
4854+
"torch_dtype": "bfloat16",
4855+
"transformers_version": "4.50.0.dev0",
4856+
"vision_config": {
4857+
"hidden_size": 1152,
4858+
"image_size": 896,
4859+
"intermediate_size": 4304,
4860+
"model_type": "siglip_vision_model",
4861+
"num_attention_heads": 16,
4862+
"num_hidden_layers": 27,
4863+
"patch_size": 14,
4864+
"vision_use_head": false,
4865+
},
4866+
}
4867+
)

0 commit comments

Comments
 (0)