Skip to content

Commit 92dbf02

Browse files
committed
gemma
1 parent 37c01b9 commit 92dbf02

File tree

4 files changed

+111
-71
lines changed

4 files changed

+111
-71
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.13
55
++++++
66

7+
* :pr:`237`: dummy inputs for gemma-3-4b-it
78
* :pr:`244`: add a patch to bypass the exception raised when the dynamic dimension is in {0,1}
89

910
0.7.12

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,13 @@ def test_image_text_to_text_idefics(self):
3131
@hide_stdout()
3232
@requires_transformers("4.57.99")
3333
@requires_torch("2.7.99")
34-
def test_image_text_to_text_gemma3(self):
34+
def test_image_text_to_text_tiny_gemma3(self):
3535
"""
3636
If the model tails because of
3737
``if inputs_embeds[special_image_mask].numel() != image_features.numel():```,
3838
make sure this PR was merged:
3939
https://github.com/huggingface/transformers/pull/39962.
4040
"""
41-
# mid = "google/gemma-3-4b-it"
4241
mid = "tiny-random/gemma-3"
4342
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
4443
self.assertEqual(data["task"], "image-text-to-text")
@@ -52,6 +51,33 @@ def test_image_text_to_text_gemma3(self):
5251
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
5352
)
5453

54+
@hide_stdout()
55+
@requires_transformers("4.56.2")
56+
@requires_torch("2.7.99")
57+
def test_image_text_to_text_gemma3_4b_it(self):
58+
mid = "google/gemma-3-4b-it"
59+
data = get_untrained_model_with_inputs(
60+
mid,
61+
verbose=1,
62+
add_second_input=False,
63+
# inputs_kwargs={
64+
# "sequence_length": 281,
65+
# "batch_size": 1,
66+
# "max_sequence_length": 580,
67+
# "n_images": 1,
68+
# },
69+
)
70+
self.assertEqual(data["task"], "image-text-to-text")
71+
# self.assertIn((data["size"], data["n_weights"]), [(17248576, 4312144)])
72+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
73+
# inputs.pop("attention_mask")
74+
# ds.pop("attention_mask")
75+
model(**torch_deepcopy(inputs))
76+
with torch_export_patches(patch_transformers=True, verbose=10):
77+
torch.export.export(
78+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
79+
)
80+
5581
@hide_stdout()
5682
@requires_transformers("4.57.99")
5783
@requires_torch("2.7.99")

_unittests/ut_tasks/try_tasks.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,8 +845,14 @@ def test_imagetext2text_generation_gemma3_4b_it(self):
845845
data = get_untrained_model_with_inputs(
846846
model_id,
847847
verbose=1,
848-
add_second_input=True,
848+
add_second_input=False,
849849
# same_as_pretrained=True, #use_pretrained=True
850+
inputs_kwargs={
851+
"sequence_length": 281,
852+
"batch_size": 1,
853+
"max_sequence_length": 580,
854+
"n_images": 1,
855+
},
850856
)
851857
model = data["model"]
852858

@@ -921,6 +927,7 @@ def test_imagetext2text_generation_gemma3_4b_it(self):
921927
):
922928
generated_ids = model.generate(
923929
**inputs,
930+
# 282 = value high enough to trigger multiple iterations of the model
924931
max_new_tokens=282,
925932
do_sample=False,
926933
cache_implementation="static",

onnx_diagnostic/tasks/image_text_to_text.py

Lines changed: 74 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
_pick,
88
default_num_hidden_layers as nhl,
99
)
10-
from ..helpers.mini_onnx_builder import create_input_tensors_from_onnx_model
1110
from .data import get_data
1211

1312
__TASK__ = "image-text-to-text"
@@ -95,37 +94,15 @@ def _get_inputs_gemma3(
9594
width: int,
9695
height: int,
9796
num_channels: int,
98-
batch_size: int = 2,
99-
sequence_length: int = 43,
100-
sequence_length2: int = 43,
101-
n_images: int = 2,
102-
dynamic_rope: bool = False,
103-
max_sequence_length: int = 380,
97+
batch_size: int = 1,
98+
sequence_length: int = 281,
99+
n_images: int = 1,
100+
max_sequence_length: int = 580,
101+
total_sequence_length: int = 860,
104102
**kwargs, # unused
105103
):
106104
"""
107-
::
108-
109-
dict(input_ids:T7s1x281,
110-
pixel_values:T16s1x3x896x896,
111-
attention_mask:dict(full_attention:T9s1x1x281x380,sliding_attention:T9s1x1x281x380),
112-
position_ids:T7s1x281,
113-
past_key_values:HybridCache(
114-
key_cache=#34[T1s1x4x380x256,...],
115-
value_cache=#34[T1s1x4x380x256,...]),
116-
token_type_ids:T7s1x281,
117-
cache_position:T7s281,
118-
logits_to_keep:1)
119-
dict(input_ids:T7s1x1,
120-
pixel_values:None,
121-
attention_mask:dict(full_attention:T9s1x1x1x380,sliding_attention:T9s1x1x1x380),
122-
position_ids:T7s1x1,
123-
past_key_values:HybridCache(
124-
key_cache=#34[T1s1x4x380x256,...],
125-
value_cache=#34[T1s1x4x380x256,...]),
126-
token_type_ids:T7s1x1,
127-
cache_position:T7s1,
128-
logits_to_keep:1)
105+
The functions uses predefined values for input_ids and token_type_ids.
129106
130107
**google/gemma-3-4b-it**
131108
@@ -151,21 +128,20 @@ def _get_inputs_gemma3(
151128
token_type_ids:T7s1x1,
152129
attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580),
153130
position_ids:None,
154-
use_cache:bool,logits_to_keep:None,return_dict:bool)
155-
156131
"""
157132
assert (
158133
"cls_cache" not in kwargs
159134
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
160135
batch = "batch"
161136
seq_length = "seq_length"
137+
tot_length = "total_length"
162138

163139
shapes = {
164140
"input_ids": {0: batch, 1: seq_length},
165141
"token_type_ids": {0: batch, 1: seq_length},
166142
"attention_mask": {
167-
"full_attention": {0: batch, 2: seq_length},
168-
"sliding_attention": {0: batch, 2: seq_length},
143+
"full_attention": {0: batch, 2: seq_length, 3: tot_length},
144+
"sliding_attention": {0: batch, 2: seq_length, 3: tot_length},
169145
},
170146
"position_ids": {0: batch, 1: seq_length},
171147
"cache_position": {1: seq_length},
@@ -177,22 +153,46 @@ def _get_inputs_gemma3(
177153
"use_cache": None,
178154
}
179155

180-
# first iteration
181-
dummies = create_input_tensors_from_onnx_model(
182-
get_data("dummies_imagetext2text_generation_gemma3.onnx")
183-
)
156+
# retrieve specific inputs to keep the consistency between
157+
# ids and images
158+
dummies = get_data("dummies_imagetext2text_generation_gemma3.onnx")
159+
dummies = dummies[("", 0, "I")][1]
184160
dummies = {k: v for k, v in dummies.items() if k in shapes}
185161
expected = {"input_ids", "token_type_ids", "position_ids", "cache_position"}
186162
assert expected & set(
187163
dummies
188164
), f"Unable to find expected inputs {expected} in loaded inputs {set(dummies)}"
165+
assert sequence_length == dummies["input_ids"].shape[-1], (
166+
f"sequence_length={sequence_length} != {dummies['input_ids'].shape[-1]} for "
167+
f"model class {model.__class__.__name__}"
168+
)
169+
assert batch_size == dummies["input_ids"].shape[0], (
170+
f"batch_size={batch_size} != {dummies['input_ids'].shape[0]} for "
171+
f"model class {model.__class__.__name__}"
172+
)
173+
assert max_sequence_length == 580, (
174+
f"max_sequence_length={max_sequence_length} != 580 "
175+
f"for model {model.__class__.__name__}"
176+
)
177+
assert total_sequence_length == 860, (
178+
f"total_sequence_length={total_sequence_length} != 860 "
179+
f"for model {model.__class__.__name__}"
180+
)
181+
assert head_dim == 256, f"head_dim={head_dim} != 256 for model {model.__class__.__name__}"
182+
assert n_images == 1, f"n_images={n_images} != 1 for model {model.__class__.__name__}"
183+
assert num_key_value_heads == 4, (
184+
f"num_key_value_heads={num_key_value_heads} != 256 "
185+
f"for this model {model.__class__.__name__}"
186+
)
189187

190188
inputs = dict(
191-
input_ids=input_ids,
192-
token_type_ids=token_type_ids,
189+
input_ids=dummies["input_ids"],
190+
token_type_ids=dummies["token_type_ids"],
193191
attention_mask=dict(
194-
full_attention=torch.randn(batch_size, 1, sequence_length, max_sequence_length),
195-
sliding_attention=torch.randn(batch_size, 1, sequence_length, max_sequence_length),
192+
full_attention=torch.randn(batch_size, 1, sequence_length, total_sequence_length),
193+
sliding_attention=torch.randn(
194+
batch_size, 1, sequence_length, total_sequence_length
195+
),
196196
),
197197
cache_position=torch.arange(0, sequence_length).to(torch.int64),
198198
position_ids=torch.arange(0, sequence_length).to(torch.int64).expand((batch_size, -1)),
@@ -210,9 +210,9 @@ def _get_inputs_gemma3(
210210
]
211211
),
212212
pixel_values=torch.randn(n_images, num_channels, width, height).clamp(-1, 1),
213-
image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
214-
torch.int64
215-
),
213+
# image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
214+
# torch.int64
215+
# ),
216216
use_cache=True, # Gemma3 does not set this value to true when a cache is provided
217217
)
218218
return dict(inputs=inputs, dynamic_shapes=shapes)
@@ -230,12 +230,12 @@ def get_inputs(
230230
width: int,
231231
height: int,
232232
num_channels: int,
233-
batch_size: int = 2,
234-
sequence_length: int = 43,
235-
sequence_length2: int = 43,
236-
n_images: int = 2,
237-
dynamic_rope: bool = False,
238-
add_second_input: int = 1,
233+
batch_size: int = 1,
234+
sequence_length: int = 281,
235+
n_images: int = 1,
236+
max_sequence_length: int = 580,
237+
total_sequence_length: int = 860,
238+
add_second_input: int = 0,
239239
**kwargs, # unused
240240
):
241241
"""
@@ -249,13 +249,19 @@ def get_inputs(
249249
:param image_token_index: image_token_index
250250
:param batch_size: batch size
251251
:param sequence_length: sequence length
252-
:param sequence_length2: new sequence length
252+
:param max_sequence_length: for the cache
253+
:param total_sequence_length: for the mask
253254
:param n_images: number of images
254255
:param width: width of the image
255256
:param height: height of the image
256257
:param num_channels: number of channels
257-
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
258258
:return: dictionary
259+
260+
.. note::
261+
262+
The content of the input_ids and its shape is correlated to the images.
263+
The function uses a predefined values. The function raises an exception
264+
if dimension are not the expected ones.
259265
"""
260266
if model.__class__.__name__.startswith("Gemma3"):
261267
res = _get_inputs_gemma3(
@@ -272,9 +278,9 @@ def get_inputs(
272278
num_channels=num_channels,
273279
batch_size=batch_size,
274280
sequence_length=sequence_length,
275-
sequence_length2=sequence_length2,
281+
max_sequence_length=max_sequence_length,
282+
total_sequence_length=total_sequence_length,
276283
n_images=n_images,
277-
dynamic_rope=dynamic_rope,
278284
**kwargs,
279285
)
280286
else:
@@ -306,9 +312,9 @@ def get_inputs(
306312
"use_cache": None,
307313
}
308314

309-
input_ids = torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to(
310-
torch.int64
311-
)
315+
input_ids = torch.randint(
316+
0, dummy_max_token_id, (batch_size, total_sequence_length)
317+
).to(torch.int64)
312318
input_ids[0, 0] = image_token_index
313319
input_ids[1, 1] = image_token_index
314320
# input_ids[input_ids == image_token_index] = pad_token_id
@@ -329,7 +335,7 @@ def get_inputs(
329335
],
330336
axis=-1,
331337
),
332-
position_ids=torch.arange(0, sequence_length2)
338+
position_ids=torch.arange(0, total_sequence_length)
333339
.to(torch.int64)
334340
.expand((batch_size, -1)),
335341
past_key_values=make_dynamic_cache(
@@ -350,9 +356,9 @@ def get_inputs(
350356
if model.__class__.__name__ == "IdeficsForVisionText2Text"
351357
else torch.randn(n_images, num_channels, width, height).clamp(-1, 1)
352358
),
353-
image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
354-
torch.int64
355-
),
359+
# image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
360+
# torch.int64
361+
# ),
356362
token_type_ids=token_type_ids,
357363
image_grid_thw=image_grid_thw,
358364
use_cache=True, # Gemma3 does not set this value to true when a cache is provided
@@ -373,10 +379,10 @@ def get_inputs(
373379
height=height,
374380
num_channels=num_channels,
375381
batch_size=batch_size + 1,
376-
sequence_length=sequence_length + add_second_input,
377-
sequence_length2=sequence_length2 + 1,
378-
n_images=n_images + 1,
379-
dynamic_rope=dynamic_rope,
382+
sequence_length=0,
383+
max_sequence_length=0,
384+
total_sequence_length=0,
385+
n_images=0,
380386
pad_token_id=pad_token_id,
381387
image_token_index=image_token_index,
382388
add_second_input=0,
@@ -419,9 +425,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
419425
text_config = False
420426
check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
421427
kwargs = dict(
422-
batch_size=2,
423-
sequence_length=43,
424-
sequence_length2=43,
428+
sequence_length=281,
429+
max_sequence_length=580,
430+
total_sequence_length=860,
425431
head_dim=(
426432
16
427433
if config is None

0 commit comments

Comments
 (0)