Skip to content

Commit 446c956

Browse files
authored
Improves export settings for image-text-to-text (#237)
* gemma * change * first step for gemma3 * data * clean * patch * another step for gemma * used inputs * better stats * style * better stat * gemma * fix a few things * validate * one fix * tiny changes * fix a few things * fix * add patch * fix patch * fix test * add pr link * install
1 parent 943e44b commit 446c956

24 files changed

+1021
-201
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.13
55
++++++
66

7+
* :pr:`237`: dummy inputs for google/gemma-3-4b-it
78
* :pr:`244`: add a patch to bypass the exception raised when the dynamic dimension is in {0,1}
89

910
0.7.12

_unittests/ut_helpers/test_torch_helper.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,39 @@ def forward(self, x, y):
181181
set(restored),
182182
)
183183

184+
@hide_stdout()
185+
def test_steal_forward_dump_file_steal_append_drop(self):
186+
class SubModel(torch.nn.Module):
187+
def forward(self, x):
188+
return x * x
189+
190+
class Model(torch.nn.Module):
191+
def __init__(self):
192+
super().__init__()
193+
self.s1 = SubModel()
194+
self.s2 = SubModel()
195+
196+
def forward(self, x, y):
197+
sx = self.s1(x)
198+
steal_append("sx", sx)
199+
return sx + self.s2(y)
200+
201+
inputs = dict(x=torch.rand(3, 4), y=torch.rand(3, 4))
202+
model = Model()
203+
dump_file = self.get_dump_file("test_steal_forward_dump_file_drop.onnx")
204+
with steal_forward(model, dump_file=dump_file, dump_drop={"x"}):
205+
model(**inputs)
206+
model(**inputs)
207+
self.assertExists(dump_file)
208+
restored = create_input_tensors_from_onnx_model(dump_file)
209+
self.assertEqual(
210+
{("", 1, "I"), ("", 1, "O"), "sx", ("", 0, "O"), "sx_1", ("", 0, "I")},
211+
set(restored),
212+
)
213+
first = restored[("", 0, "I")]
214+
_a, kws = first
215+
self.assertNotIn("x", kws)
216+
184217
@hide_stdout()
185218
def test_steal_forward_submodules(self):
186219
class SubModel(torch.nn.Module):

_unittests/ut_tasks/test_data.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.tasks.data import get_data
4+
5+
6+
class TestTasks(ExtTestCase):
7+
def test_get_data(self):
8+
name = "dummies_imagetext2text_generation_gemma3.onnx"
9+
data = get_data(name)
10+
print(data)
11+
12+
13+
if __name__ == "__main__":
14+
unittest.main(verbosity=2)

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def test_image_text_to_text_idefics(self):
2222
self.assertEqual(data["task"], "image-text-to-text")
2323
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
2424
model(**torch_deepcopy(inputs))
25+
print("***", self.string_type(data["inputs2"], with_shape=True))
2526
model(**data["inputs2"])
2627
with torch_export_patches(patch_transformers=True, verbose=10):
2728
torch.export.export(
@@ -31,14 +32,13 @@ def test_image_text_to_text_idefics(self):
3132
@hide_stdout()
3233
@requires_transformers("4.57.99")
3334
@requires_torch("2.7.99")
34-
def test_image_text_to_text_gemma3(self):
35+
def test_image_text_to_text_tiny_gemma3(self):
3536
"""
3637
If the model tails because of
3738
``if inputs_embeds[special_image_mask].numel() != image_features.numel():```,
3839
make sure this PR was merged:
3940
https://github.com/huggingface/transformers/pull/39962.
4041
"""
41-
# mid = "google/gemma-3-4b-it"
4242
mid = "tiny-random/gemma-3"
4343
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
4444
self.assertEqual(data["task"], "image-text-to-text")
@@ -52,6 +52,33 @@ def test_image_text_to_text_gemma3(self):
5252
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
5353
)
5454

55+
@hide_stdout()
56+
@requires_transformers("4.56.99")
57+
@requires_torch("2.8.99")
58+
def test_image_text_to_text_gemma3_4b_it(self):
59+
mid = "google/gemma-3-4b-it"
60+
data = get_untrained_model_with_inputs(
61+
mid,
62+
verbose=1,
63+
add_second_input=False,
64+
# inputs_kwargs={
65+
# "sequence_length": 281,
66+
# "batch_size": 1,
67+
# "max_sequence_length": 580,
68+
# "n_images": 1,
69+
# },
70+
)
71+
self.assertEqual(data["task"], "image-text-to-text")
72+
# self.assertIn((data["size"], data["n_weights"]), [(17248576, 4312144)])
73+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
74+
# inputs.pop("attention_mask")
75+
# ds.pop("attention_mask")
76+
model(**torch_deepcopy(inputs))
77+
with torch_export_patches(patch_transformers=True, verbose=10):
78+
torch.export.export(
79+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
80+
)
81+
5582
@hide_stdout()
5683
@requires_transformers("4.57.99")
5784
@requires_torch("2.7.99")

_unittests/ut_tasks/test_tasks_image_to_video.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@ def test_image_to_video_oblivious(self):
5454
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
5555
model(**inputs)
5656
model(**data["inputs2"])
57-
with torch.fx.experimental._config.patch(
58-
backed_size_oblivious=True
59-
), torch_export_patches(
60-
patch_transformers=True, patch_diffusers=True, verbose=10, stop_if_static=1
57+
with (
58+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
59+
torch_export_patches(
60+
patch_transformers=True, patch_diffusers=True, verbose=10, stop_if_static=1
61+
),
6162
):
6263
torch.export.export(
6364
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False

_unittests/ut_tasks/try_tasks.py

Lines changed: 153 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import os
12
import unittest
23
import torch
34
from onnx_diagnostic.ext_test_case import ExtTestCase, never_test
45
from onnx_diagnostic.helpers import string_type
56
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
67
from onnx_diagnostic.helpers.torch_helper import steal_forward
78
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
9+
from onnx_diagnostic.torch_export_patches import torch_export_patches
810

911

1012
class TestHuggingFaceHubModel(ExtTestCase):
@@ -137,8 +139,9 @@ def test_text_generation_phi4_mini(self):
137139
import torch
138140
from transformers import RobertaTokenizer, T5ForConditionalGeneration
139141

140-
tokenizer = RobertaTokenizer.from_pretrained("microsoft/Phi-4-mini-instruct")
141-
model = T5ForConditionalGeneration.from_pretrained("microsoft/Phi-4-mini-instruct")
142+
model_id = "microsoft/Phi-4-mini-instruct"
143+
tokenizer = RobertaTokenizer.from_pretrained(model_id)
144+
model = T5ForConditionalGeneration.from_pretrained(model_id)
142145

143146
text = "def greet(user): print(f'hello <extra_id_0>!')"
144147
input_ids = tokenizer(text, return_tensors="pt").input_ids
@@ -156,6 +159,41 @@ def test_text_generation_phi4_mini(self):
156159
)
157160
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
158161

162+
@never_test()
163+
def test_text_generation_phi3_mini(self):
164+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k phi3_mini
165+
166+
from transformers import Phi3ForCausalLM, AutoTokenizer
167+
168+
model_id = "microsoft/Phi-3-mini-4k-instruct"
169+
tokenizer = AutoTokenizer.from_pretrained(model_id)
170+
model = Phi3ForCausalLM.from_pretrained(model_id)
171+
172+
messages = [
173+
{
174+
"role": "system",
175+
"content": (
176+
"You are a helpful digital assistant. Please provide safe, "
177+
"ethical and accurate information to the user."
178+
),
179+
},
180+
{
181+
"role": "user",
182+
"content": (
183+
"Can you provide ways to eat combinations of bananas and dragonfruits?"
184+
),
185+
},
186+
]
187+
inputs = tokenizer.apply_chat_template(
188+
messages, add_generation_prompt=True, return_tensors="pt"
189+
)
190+
191+
# simply generate a single sequence
192+
print()
193+
with steal_forward(model):
194+
generated_ids = model.generate(inputs, max_length=100)
195+
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
196+
159197
@never_test()
160198
@unittest.skip(
161199
reason="AttributeError: 'Phi4MMModel' object has no attribute "
@@ -791,6 +829,119 @@ def test_sentence_similary_alibaba_nlp_gte(self):
791829
scores = (embeddings[:1] @ embeddings[1:].T) * 100
792830
print(scores.tolist())
793831

832+
@never_test()
833+
def test_imagetext2text_generation_gemma3_4b_it(self):
834+
"""
835+
clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k gemma3_4b_it
836+
"""
837+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
838+
839+
model_id = "google/gemma-3-4b-it"
840+
if os.environ.get("PRETRAINED", ""):
841+
model = Gemma3ForConditionalGeneration.from_pretrained(
842+
model_id, device_map="cpu"
843+
).eval()
844+
else:
845+
data = get_untrained_model_with_inputs(
846+
model_id,
847+
verbose=1,
848+
add_second_input=False,
849+
# same_as_pretrained=True, #use_pretrained=True
850+
inputs_kwargs={
851+
"sequence_length": 281,
852+
"batch_size": 1,
853+
"max_sequence_length": 580,
854+
"n_images": 1,
855+
},
856+
)
857+
model = data["model"]
858+
859+
print(f"-- model.device={model.device}")
860+
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
861+
print(f"-- processor={type(processor)}")
862+
863+
messages = messages = [
864+
{
865+
"role": "system",
866+
"content": [{"type": "text", "text": "You are a helpful assistant."}],
867+
},
868+
{
869+
"role": "user",
870+
"content": [
871+
{
872+
"type": "image",
873+
"image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
874+
},
875+
{"type": "text", "text": "Describe this image in detail."},
876+
],
877+
},
878+
]
879+
inputs = processor.apply_chat_template(
880+
messages,
881+
tokenize=True,
882+
add_generation_prompt=True,
883+
return_dict=True,
884+
return_tensors="pt",
885+
).to(model.device, dtype=torch.bfloat16)
886+
# if "token_type_ids" in inputs:
887+
# print(
888+
# f"-- remove token_type_ids: "
889+
# f"{self.string_type(inputs['token_type_ids'], with_shape=True)}"
890+
# )
891+
# inputs.pop("token_type_ids", None)
892+
print(f"-- inputs={self.string_type(inputs)}")
893+
894+
# iteration merge = sequence > 1, cache not empty
895+
# iteration 1 = sequence > 1, no cache
896+
# cache_position:T7s281,
897+
# past_key_values:StaticCache(key_cache=#0[], value_cache=#0[]),
898+
# input_ids:T7s1x281,
899+
# inputs_embeds:None,
900+
# token_type_ids:T7s1x281,
901+
# attention_mask:dict(sliding_attention:T9s1x1x281x580,
902+
# full_attention:T9s1x1x281x580),
903+
# position_ids:None,
904+
# use_cache:bool,
905+
# logits_to_keep:None,
906+
# pixel_values:T16s1x3x896x896,
907+
# return_dict:bool)
908+
# iteration 2 = sequence = 1, cache not empty
909+
# cache_position:T7s1,
910+
# past_key_values:StaticCache(key_cache=#34[T1s1x4x580x256,...],
911+
# value_cache=#34[T1s1x4x580x256,...]),
912+
# input_ids:T7s1x1,
913+
# inputs_embeds:None,
914+
# token_type_ids:T7s1x1,
915+
# attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580),
916+
# position_ids:None,
917+
# use_cache:bool,logits_to_keep:None,return_dict:bool)
918+
919+
print()
920+
with (
921+
torch_export_patches(
922+
patch_torch=False, patch_sympy=False, patch_transformers=True
923+
),
924+
steal_forward(
925+
model,
926+
dump_file=self.get_dump_file(
927+
"test_imagetext2text_generation_gemma3_4b_it.onnx"
928+
),
929+
dump_drop={"attention_mask", "past_key_values", "pixel_values"},
930+
save_as_external_data=False,
931+
),
932+
):
933+
generated_ids = model.generate(
934+
**inputs,
935+
# 282 = value high enough to trigger multiple iterations of the model
936+
max_new_tokens=282,
937+
do_sample=False,
938+
cache_implementation="static",
939+
)
940+
output_text = processor.decode(
941+
generated_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=False
942+
)
943+
print(output_text)
944+
794945

795946
if __name__ == "__main__":
796947
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,9 @@ def forward(self, x, ind1, ind2):
309309
with self.subTest(
310310
name="patch for 0/1 with oblivious", dynamic_shapes=dynamic_shapes
311311
):
312-
with torch_export_patches(), torch.fx.experimental._config.patch(
313-
backed_size_oblivious=True
312+
with (
313+
torch_export_patches(),
314+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
314315
):
315316
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
316317
got = ep.module()(*inputs)

_unittests/ut_torch_models/test_llm_phi2.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ def test_export_phi2_1_batch_size_1_oblivious(self):
3333
self.assertEqual(
3434
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
3535
)
36-
with torch.fx.experimental._config.patch(
37-
backed_size_oblivious=True
38-
), torch_export_patches(patch_transformers=True):
36+
with (
37+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
38+
torch_export_patches(patch_transformers=True),
39+
):
3940
ep = torch.export.export(
4041
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
4142
)

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ def __init__(
5656
self.kwargs = kwargs
5757
self.dynamic_shapes = dynamic_shapes
5858
self.args_names = args_names
59+
if not self.kwargs and isinstance(self.dynamic_shapes, dict):
60+
# This assumes the dictionary for the dynamic shapes is ordered
61+
# the same way the args are. The input names are not known.
62+
assert len(self.dynamic_shapes) == len(self.args), (
63+
f"Length mismatch, kwargs is empty, len(dynamic_shapes)="
64+
f"{len(self.dynamic_shapes)}, len(args)={len(self.args)}"
65+
)
66+
self.dynamic_shapes = tuple(self.dynamic_shapes.values())
5967

6068
def __str__(self) -> str:
6169
return "\n".join(
@@ -232,8 +240,9 @@ def _generic_walker(
232240
"""
233241
if not self.args:
234242
assert isinstance(self.kwargs, dict) and isinstance(self.dynamic_shapes, dict), (
235-
f"Type mismatch, args={string_type(self.args)} and "
236-
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
243+
f"Type mismatch, args={string_type(self.args)}, "
244+
f"kwargs={string_type(self.kwargs)} and dynamic_shapes="
245+
f"{string_type(self.dynamic_shapes)} should have the same type."
237246
)
238247
res = self._generic_walker_step(
239248
processor,

onnx_diagnostic/helpers/helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def string_type(
397397
return "AUTO"
398398
if verbose:
399399
print(f"[string_type] Y7:{type(obj)}")
400-
return str(obj)
400+
return str(obj).replace("DimHint(DYNAMIC)", "DYNAMIC").replace("DimHint(AUTO)", "AUTO")
401401

402402
if isinstance(obj, bool):
403403
if with_min_max:
@@ -939,7 +939,7 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
939939
return flatten_object(list(x.values()), drop_keys=drop_keys)
940940
return flatten_object(list(x.items()), drop_keys=drop_keys)
941941

942-
if x.__class__.__name__ in {"DynamicCache", "StaticCache"}:
942+
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
943943
from .cache_helper import CacheKeyValue
944944

945945
kc = CacheKeyValue(x)

0 commit comments

Comments
 (0)