Skip to content

Commit de59acd

Browse files
committed
used inputs
1 parent e090445 commit de59acd

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

_unittests/ut_tasks/try_tasks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,8 @@ def test_imagetext2text_generation_gemma3_4b_it(self):
849849
# inputs.pop("token_type_ids", None)
850850
print(f"-- inputs={self.string_type(inputs)}")
851851

852-
# iteration 1
852+
# iteration merge = sequence > 1, cache not empty
853+
# iteration 1 = sequence > 1, no cache
853854
# cache_position:T7s281,
854855
# past_key_values:StaticCache(key_cache=#0[], value_cache=#0[]),
855856
# input_ids:T7s1x281,
@@ -862,7 +863,7 @@ def test_imagetext2text_generation_gemma3_4b_it(self):
862863
# logits_to_keep:None,
863864
# pixel_values:T16s1x3x896x896,
864865
# return_dict:bool)
865-
# iteration 3
866+
# iteration 2 = sequence = 1, cache not empty
866867
# cache_position:T7s1,
867868
# past_key_values:StaticCache(key_cache=#34[T1s1x4x580x256,...],
868869
# value_cache=#34[T1s1x4x580x256,...]),

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
_pick,
88
default_num_hidden_layers as nhl,
99
)
10+
from ..helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
11+
from .data import get_data
1012

1113
__TASK__ = "image-text-to-text"
1214

@@ -124,13 +126,39 @@ def _get_inputs_gemma3(
124126
token_type_ids:T7s1x1,
125127
cache_position:T7s1,
126128
logits_to_keep:1)
129+
130+
**google/gemma-3-4b-it**
131+
132+
iteration 1
133+
134+
::
135+
cache_position:T7s281,
136+
input_ids:T7s1x281,
137+
token_type_ids:T7s1x281,
138+
attention_mask:dict(sliding_attention:T9s1x1x281x580,
139+
full_attention:T9s1x1x281x580),
140+
pixel_values:T16s1x3x896x896,
141+
142+
iteration 2
143+
144+
::
145+
146+
cache_position:T7s1,
147+
past_key_values:StaticCache(key_cache=#34[T1s1x4x580x256,...],
148+
value_cache=#34[T1s1x4x580x256,...]),
149+
input_ids:T7s1x1,
150+
inputs_embeds:None,
151+
token_type_ids:T7s1x1,
152+
attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580),
153+
position_ids:None,
154+
use_cache:bool,logits_to_keep:None,return_dict:bool)
155+
127156
"""
128157
assert (
129158
"cls_cache" not in kwargs
130159
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
131160
batch = "batch"
132-
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
133-
# cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
161+
seq_length = "seq_length"
134162

135163
shapes = {
136164
"input_ids": {0: batch, 1: seq_length},
@@ -149,13 +177,15 @@ def _get_inputs_gemma3(
149177
"use_cache": None,
150178
}
151179

152-
input_ids = torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
153-
torch.int64
180+
# first iteration
181+
dummies = create_input_tensors_from_onnx_model(
182+
get_data("dummies_imagetext2text_generation_gemma3.onnx")
154183
)
155-
input_ids[:, 1] = image_token_index
156-
# input_ids[input_ids == image_token_index] = pad_token_id
157-
token_type_ids = torch.zeros_like(input_ids)
158-
token_type_ids[input_ids == image_token_index] = 1
184+
dummies = {k: v for k, v in dummies.items() if k in shapes}
185+
expected = {"input_ids", "token_type_ids", "position_ids", "cache_position"}
186+
assert expected & set(
187+
dummies
188+
), f"Unable to find expected inputs {expected} in loaded inputs {set(dummines)}"
159189

160190
inputs = dict(
161191
input_ids=input_ids,

0 commit comments

Comments
 (0)