Skip to content

Commit 977c839

Browse files
committed
merge conflicts
2 parents 05b06e9 + 446c956 commit 977c839

24 files changed

+853
-211
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Change Logs
55
++++++
66

77
* :pr:`256`: add a set of inputs checking models works for an empty cache on task text-generation
8+
* :pr:`237`: dummy inputs for google/gemma-3-4b-it
89
* :pr:`244`: add a patch to bypass the exception raised when the dynamic dimension is in {0,1}
910

1011
0.7.12

_unittests/ut_helpers/test_torch_helper.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,39 @@ def forward(self, x, y):
181181
set(restored),
182182
)
183183

184+
@hide_stdout()
185+
def test_steal_forward_dump_file_steal_append_drop(self):
186+
class SubModel(torch.nn.Module):
187+
def forward(self, x):
188+
return x * x
189+
190+
class Model(torch.nn.Module):
191+
def __init__(self):
192+
super().__init__()
193+
self.s1 = SubModel()
194+
self.s2 = SubModel()
195+
196+
def forward(self, x, y):
197+
sx = self.s1(x)
198+
steal_append("sx", sx)
199+
return sx + self.s2(y)
200+
201+
inputs = dict(x=torch.rand(3, 4), y=torch.rand(3, 4))
202+
model = Model()
203+
dump_file = self.get_dump_file("test_steal_forward_dump_file_drop.onnx")
204+
with steal_forward(model, dump_file=dump_file, dump_drop={"x"}):
205+
model(**inputs)
206+
model(**inputs)
207+
self.assertExists(dump_file)
208+
restored = create_input_tensors_from_onnx_model(dump_file)
209+
self.assertEqual(
210+
{("", 1, "I"), ("", 1, "O"), "sx", ("", 0, "O"), "sx_1", ("", 0, "I")},
211+
set(restored),
212+
)
213+
first = restored[("", 0, "I")]
214+
_a, kws = first
215+
self.assertNotIn("x", kws)
216+
184217
@hide_stdout()
185218
def test_steal_forward_submodules(self):
186219
class SubModel(torch.nn.Module):

_unittests/ut_tasks/test_data.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import unittest
2+
from onnx_diagnostic.ext_test_case import ExtTestCase
3+
from onnx_diagnostic.tasks.data import get_data
4+
5+
6+
class TestTasks(ExtTestCase):
7+
def test_get_data(self):
8+
name = "dummies_imagetext2text_generation_gemma3.onnx"
9+
data = get_data(name)
10+
print(data)
11+
12+
13+
if __name__ == "__main__":
14+
unittest.main(verbosity=2)

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def test_image_text_to_text_idefics(self):
2222
self.assertEqual(data["task"], "image-text-to-text")
2323
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
2424
model(**torch_deepcopy(inputs))
25+
print("***", self.string_type(data["inputs2"], with_shape=True))
2526
model(**data["inputs2"])
2627
with torch_export_patches(patch_transformers=True, verbose=10):
2728
torch.export.export(
@@ -31,14 +32,13 @@ def test_image_text_to_text_idefics(self):
3132
@hide_stdout()
3233
@requires_transformers("4.57.99")
3334
@requires_torch("2.7.99")
34-
def test_image_text_to_text_gemma3(self):
35+
def test_image_text_to_text_tiny_gemma3(self):
3536
"""
3637
If the model tails because of
3738
``if inputs_embeds[special_image_mask].numel() != image_features.numel():```,
3839
make sure this PR was merged:
3940
https://github.com/huggingface/transformers/pull/39962.
4041
"""
41-
# mid = "google/gemma-3-4b-it"
4242
mid = "tiny-random/gemma-3"
4343
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
4444
self.assertEqual(data["task"], "image-text-to-text")
@@ -52,6 +52,33 @@ def test_image_text_to_text_gemma3(self):
5252
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
5353
)
5454

55+
@hide_stdout()
56+
@requires_transformers("4.56.99")
57+
@requires_torch("2.8.99")
58+
def test_image_text_to_text_gemma3_4b_it(self):
59+
mid = "google/gemma-3-4b-it"
60+
data = get_untrained_model_with_inputs(
61+
mid,
62+
verbose=1,
63+
add_second_input=False,
64+
# inputs_kwargs={
65+
# "sequence_length": 281,
66+
# "batch_size": 1,
67+
# "max_sequence_length": 580,
68+
# "n_images": 1,
69+
# },
70+
)
71+
self.assertEqual(data["task"], "image-text-to-text")
72+
# self.assertIn((data["size"], data["n_weights"]), [(17248576, 4312144)])
73+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
74+
# inputs.pop("attention_mask")
75+
# ds.pop("attention_mask")
76+
model(**torch_deepcopy(inputs))
77+
with torch_export_patches(patch_transformers=True, verbose=10):
78+
torch.export.export(
79+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
80+
)
81+
5582
@hide_stdout()
5683
@requires_transformers("4.57.99")
5784
@requires_torch("2.7.99")

_unittests/ut_tasks/test_tasks_image_to_video.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@ def test_image_to_video_oblivious(self):
5454
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
5555
model(**inputs)
5656
model(**data["inputs2"])
57-
with torch.fx.experimental._config.patch(
58-
backed_size_oblivious=True
59-
), torch_export_patches(
60-
patch_transformers=True, patch_diffusers=True, verbose=10, stop_if_static=1
57+
with (
58+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
59+
torch_export_patches(
60+
patch_transformers=True, patch_diffusers=True, verbose=10, stop_if_static=1
61+
),
6162
):
6263
torch.export.export(
6364
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False

_unittests/ut_tasks/try_tasks.py

Lines changed: 158 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import unittest
23
import torch
34
from onnx_diagnostic.ext_test_case import ExtTestCase, never_test
@@ -163,9 +164,12 @@ def test_text_generation_tiny_llm(self):
163164

164165
# simply generate a single sequence
165166
print()
166-
with torch_export_patches(
167-
patch_transformers=True, patch_torch=False, patch_sympy=False
168-
), steal_forward(model):
167+
with (
168+
torch_export_patches(
169+
patch_transformers=True, patch_torch=False, patch_sympy=False
170+
),
171+
steal_forward(model),
172+
):
169173
generated_ids = model.generate(
170174
input_ids=input_ids,
171175
max_length=100,
@@ -181,8 +185,9 @@ def test_text_generation_phi4_mini(self):
181185
import torch
182186
from transformers import RobertaTokenizer, T5ForConditionalGeneration
183187

184-
tokenizer = RobertaTokenizer.from_pretrained("microsoft/Phi-4-mini-instruct")
185-
model = T5ForConditionalGeneration.from_pretrained("microsoft/Phi-4-mini-instruct")
188+
model_id = "microsoft/Phi-4-mini-instruct"
189+
tokenizer = RobertaTokenizer.from_pretrained(model_id)
190+
model = T5ForConditionalGeneration.from_pretrained(model_id)
186191

187192
text = "def greet(user): print(f'hello <extra_id_0>!')"
188193
input_ids = tokenizer(text, return_tensors="pt").input_ids
@@ -200,6 +205,41 @@ def test_text_generation_phi4_mini(self):
200205
)
201206
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
202207

208+
@never_test()
209+
def test_text_generation_phi3_mini(self):
210+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k phi3_mini
211+
212+
from transformers import Phi3ForCausalLM, AutoTokenizer
213+
214+
model_id = "microsoft/Phi-3-mini-4k-instruct"
215+
tokenizer = AutoTokenizer.from_pretrained(model_id)
216+
model = Phi3ForCausalLM.from_pretrained(model_id)
217+
218+
messages = [
219+
{
220+
"role": "system",
221+
"content": (
222+
"You are a helpful digital assistant. Please provide safe, "
223+
"ethical and accurate information to the user."
224+
),
225+
},
226+
{
227+
"role": "user",
228+
"content": (
229+
"Can you provide ways to eat combinations of bananas and dragonfruits?"
230+
),
231+
},
232+
]
233+
inputs = tokenizer.apply_chat_template(
234+
messages, add_generation_prompt=True, return_tensors="pt"
235+
)
236+
237+
# simply generate a single sequence
238+
print()
239+
with steal_forward(model):
240+
generated_ids = model.generate(inputs, max_length=100)
241+
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
242+
203243
@never_test()
204244
@unittest.skip(
205245
reason="AttributeError: 'Phi4MMModel' object has no attribute "
@@ -835,6 +875,119 @@ def test_sentence_similary_alibaba_nlp_gte(self):
835875
scores = (embeddings[:1] @ embeddings[1:].T) * 100
836876
print(scores.tolist())
837877

878+
@never_test()
879+
def test_imagetext2text_generation_gemma3_4b_it(self):
880+
"""
881+
clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k gemma3_4b_it
882+
"""
883+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
884+
885+
model_id = "google/gemma-3-4b-it"
886+
if os.environ.get("PRETRAINED", ""):
887+
model = Gemma3ForConditionalGeneration.from_pretrained(
888+
model_id, device_map="cpu"
889+
).eval()
890+
else:
891+
data = get_untrained_model_with_inputs(
892+
model_id,
893+
verbose=1,
894+
add_second_input=False,
895+
# same_as_pretrained=True, #use_pretrained=True
896+
inputs_kwargs={
897+
"sequence_length": 281,
898+
"batch_size": 1,
899+
"max_sequence_length": 580,
900+
"n_images": 1,
901+
},
902+
)
903+
model = data["model"]
904+
905+
print(f"-- model.device={model.device}")
906+
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
907+
print(f"-- processor={type(processor)}")
908+
909+
messages = messages = [
910+
{
911+
"role": "system",
912+
"content": [{"type": "text", "text": "You are a helpful assistant."}],
913+
},
914+
{
915+
"role": "user",
916+
"content": [
917+
{
918+
"type": "image",
919+
"image": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg",
920+
},
921+
{"type": "text", "text": "Describe this image in detail."},
922+
],
923+
},
924+
]
925+
inputs = processor.apply_chat_template(
926+
messages,
927+
tokenize=True,
928+
add_generation_prompt=True,
929+
return_dict=True,
930+
return_tensors="pt",
931+
).to(model.device, dtype=torch.bfloat16)
932+
# if "token_type_ids" in inputs:
933+
# print(
934+
# f"-- remove token_type_ids: "
935+
# f"{self.string_type(inputs['token_type_ids'], with_shape=True)}"
936+
# )
937+
# inputs.pop("token_type_ids", None)
938+
print(f"-- inputs={self.string_type(inputs)}")
939+
940+
# iteration merge = sequence > 1, cache not empty
941+
# iteration 1 = sequence > 1, no cache
942+
# cache_position:T7s281,
943+
# past_key_values:StaticCache(key_cache=#0[], value_cache=#0[]),
944+
# input_ids:T7s1x281,
945+
# inputs_embeds:None,
946+
# token_type_ids:T7s1x281,
947+
# attention_mask:dict(sliding_attention:T9s1x1x281x580,
948+
# full_attention:T9s1x1x281x580),
949+
# position_ids:None,
950+
# use_cache:bool,
951+
# logits_to_keep:None,
952+
# pixel_values:T16s1x3x896x896,
953+
# return_dict:bool)
954+
# iteration 2 = sequence = 1, cache not empty
955+
# cache_position:T7s1,
956+
# past_key_values:StaticCache(key_cache=#34[T1s1x4x580x256,...],
957+
# value_cache=#34[T1s1x4x580x256,...]),
958+
# input_ids:T7s1x1,
959+
# inputs_embeds:None,
960+
# token_type_ids:T7s1x1,
961+
# attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580),
962+
# position_ids:None,
963+
# use_cache:bool,logits_to_keep:None,return_dict:bool)
964+
965+
print()
966+
with (
967+
torch_export_patches(
968+
patch_torch=False, patch_sympy=False, patch_transformers=True
969+
),
970+
steal_forward(
971+
model,
972+
dump_file=self.get_dump_file(
973+
"test_imagetext2text_generation_gemma3_4b_it.onnx"
974+
),
975+
dump_drop={"attention_mask", "past_key_values", "pixel_values"},
976+
save_as_external_data=False,
977+
),
978+
):
979+
generated_ids = model.generate(
980+
**inputs,
981+
# 282 = value high enough to trigger multiple iterations of the model
982+
max_new_tokens=282,
983+
do_sample=False,
984+
cache_implementation="static",
985+
)
986+
output_text = processor.decode(
987+
generated_ids[0][inputs["input_ids"].shape[1] :], skip_special_tokens=False
988+
)
989+
print(output_text)
990+
838991

839992
if __name__ == "__main__":
840993
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,9 @@ def forward(self, x, ind1, ind2):
309309
with self.subTest(
310310
name="patch for 0/1 with oblivious", dynamic_shapes=dynamic_shapes
311311
):
312-
with torch_export_patches(), torch.fx.experimental._config.patch(
313-
backed_size_oblivious=True
312+
with (
313+
torch_export_patches(),
314+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
314315
):
315316
ep = torch.export.export(model, inputs, dynamic_shapes=dynamic_shapes)
316317
got = ep.module()(*inputs)

_unittests/ut_torch_models/test_llm_phi2.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ def test_export_phi2_1_batch_size_1_oblivious(self):
3333
self.assertEqual(
3434
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
3535
)
36-
with torch.fx.experimental._config.patch(
37-
backed_size_oblivious=True
38-
), torch_export_patches(patch_transformers=True):
36+
with (
37+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
38+
torch_export_patches(patch_transformers=True),
39+
):
3940
ep = torch.export.export(
4041
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
4142
)

onnx_diagnostic/export/dynamic_shapes.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,14 @@ def __init__(
5656
self.kwargs = kwargs
5757
self.dynamic_shapes = dynamic_shapes
5858
self.args_names = args_names
59+
if not self.kwargs and isinstance(self.dynamic_shapes, dict):
60+
# This assumes the dictionary for the dynamic shapes is ordered
61+
# the same way the args are. The input names are not known.
62+
assert len(self.dynamic_shapes) == len(self.args), (
63+
f"Length mismatch, kwargs is empty, len(dynamic_shapes)="
64+
f"{len(self.dynamic_shapes)}, len(args)={len(self.args)}"
65+
)
66+
self.dynamic_shapes = tuple(self.dynamic_shapes.values())
5967

6068
def __str__(self) -> str:
6169
return "\n".join(
@@ -232,8 +240,9 @@ def _generic_walker(
232240
"""
233241
if not self.args:
234242
assert isinstance(self.kwargs, dict) and isinstance(self.dynamic_shapes, dict), (
235-
f"Type mismatch, args={string_type(self.args)} and "
236-
f"dynamic_shapes={self.dynamic_shapes} should have the same type."
243+
f"Type mismatch, args={string_type(self.args)}, "
244+
f"kwargs={string_type(self.kwargs)} and dynamic_shapes="
245+
f"{string_type(self.dynamic_shapes)} should have the same type."
237246
)
238247
res = self._generic_walker_step(
239248
processor,

0 commit comments

Comments
 (0)