Skip to content

Commit 3f08c5f

Browse files
authored
Enables export with fake tensors (#273)
* enables export with fake tensors * improves tests * fix * fix fake export * fix * fix * push * doc * doc * foc * fix index
1 parent 68d71cf commit 3f08c5f

18 files changed

+172
-81
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.16
55
++++++
66

7+
* :pr:`273`: enables export with FakeTensor
78
* :pr:`272`: makes patches work with FakeTensor
89
* :pr:`270`: add export sample code to export a specific model id with the appropriate inputs
910
* :pr:`269`: adds one unit test to track a patch fixing broadcast output shape

_doc/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ def linkcode_resolve(domain, info):
119119
("py:class", "ast.Node"),
120120
("py:class", "dtype"),
121121
("py:class", "False"),
122+
("py:class", "FakeTensor"),
123+
("py:class", "FakeTensorMode"),
122124
("py:class", "True"),
123125
("py:class", "Argument"),
124126
("py:class", "CacheProcessor"),

_doc/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ See :func:`onnx_diagnostic.torch_export_patches.torch_export_rewrite`.
135135
# ...
136136
137137
all_dynamic_shapes_from_inputs
138-
+++++++++++++++++++++++++++++
138+
++++++++++++++++++++++++++++++
139139

140140
See :func:`onnx_diagnostic.export.shape_helper.all_dynamic_shapes_from_inputs`.
141141

_unittests/ut_tasks/test_tasks_image_classification.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@ def test_image_classification(self):
1414
self.assertEqual(data["task"], "image-classification")
1515
self.assertIn((data["size"], data["n_weights"]), [(56880, 14220)])
1616
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
17-
model(**inputs)
17+
expected = model(**inputs)
1818
model(**data["inputs2"])
1919
if not has_transformers("4.52.999"):
2020
raise unittest.SkipTest("Requires transformers>=4.52")
2121
with torch_export_patches(patch_transformers=True, verbose=10):
22-
torch.export.export(
22+
ep = torch.export.export(
2323
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
2424
)
25+
self.assertEqualAny(expected, ep.module()(**inputs))
2526

2627

2728
if __name__ == "__main__":

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,21 @@
1414

1515
class TestTasksImageTextToText(ExtTestCase):
1616
@hide_stdout()
17-
@requires_transformers("4.53")
17+
@requires_transformers("4.56")
1818
@requires_torch("2.7.99")
1919
def test_image_text_to_text_idefics(self):
2020
mid = "HuggingFaceM4/tiny-random-idefics"
2121
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
2222
self.assertEqual(data["task"], "image-text-to-text")
2323
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
24-
model(**torch_deepcopy(inputs))
24+
expected = model(**torch_deepcopy(inputs))
2525
model(**data["inputs2"])
2626
with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False):
27-
torch.export.export(
27+
ep = torch.export.export(
2828
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
2929
)
30+
# The conversion does not work. Tolerance is set to 1.
31+
self.assertEqualAny(expected, ep.module()(**inputs), atol=1)
3032

3133
@hide_stdout()
3234
@requires_transformers("5.0.99")
@@ -44,12 +46,13 @@ def test_image_text_to_text_tiny_gemma3(self):
4446
# self.assertIn((data["size"], data["n_weights"]), [(17248576, 4312144)])
4547
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
4648
print("--", self.string_type(data["inputs"], with_shape=True))
47-
model(**torch_deepcopy(inputs))
49+
expected = model(**torch_deepcopy(inputs))
4850
model(**data["inputs2"])
4951
with torch_export_patches(patch_transformers=True, verbose=10):
50-
torch.export.export(
52+
ep = torch.export.export(
5153
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
5254
)
55+
self.assertEqualAny(expected, ep.module()(**inputs))
5356

5457
@hide_stdout()
5558
@requires_transformers("4.56.99")
@@ -72,11 +75,13 @@ def test_image_text_to_text_gemma3_4b_it(self):
7275
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
7376
# inputs.pop("attention_mask")
7477
# ds.pop("attention_mask")
75-
model(**torch_deepcopy(inputs))
78+
expected = model(**torch_deepcopy(inputs))
7679
with torch_export_patches(patch_transformers=True, verbose=10):
77-
torch.export.export(
80+
ep = torch.export.export(
7881
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
7982
)
83+
# The conversion does not work. Tolerance is set to 1.
84+
self.assertEqualAny(expected, ep.module()(**inputs))
8085

8186
@hide_stdout()
8287
@requires_transformers("5.0.99")
@@ -93,12 +98,13 @@ def test_image_text_to_text_zai_glm(self):
9398
self.assertEqual(data["task"], "image-text-to-text")
9499
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
95100
print("--", self.string_type(data["inputs"], with_shape=True))
96-
model(**torch_deepcopy(inputs))
101+
expected = model(**torch_deepcopy(inputs))
97102
model(**data["inputs2"])
98103
with torch_export_patches(patch_transformers=True, verbose=10):
99-
torch.export.export(
104+
ep = torch.export.export(
100105
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
101106
)
107+
self.assertEqualAny(expected, ep.module()(**inputs))
102108

103109

104110
if __name__ == "__main__":

_unittests/ut_tasks/test_tasks_mask_generation.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ def test_mask_generation(self):
2121
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
2222
self.assertEqual(data["task"], "mask-generation")
2323
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
24-
model(**torch_deepcopy(inputs))
24+
expected = model(**torch_deepcopy(inputs))
2525
model(**data["inputs2"])
2626
with torch_export_patches(patch_torch=False, patch_transformers=True, verbose=1):
27-
torch.export.export(
27+
ep = torch.export.export(
2828
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
2929
)
30+
self.assertEqualAny(expected, ep.module()(**inputs))
3031

3132
@hide_stdout()
3233
@requires_transformers("4.53")
@@ -36,14 +37,15 @@ def test_mask_generation_with_torch_patches(self):
3637
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
3738
self.assertEqual(data["task"], "mask-generation")
3839
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
39-
model(**torch_deepcopy(inputs))
40+
expected = model(**torch_deepcopy(inputs))
4041
model(**data["inputs2"])
4142
with torch_export_patches(
4243
patch_torch=True, patch_sympy=True, patch_transformers=True, verbose=1
4344
):
44-
torch.export.export(
45+
ep = torch.export.export(
4546
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
4647
)
48+
self.assertEqualAny(expected, ep.module()(**inputs))
4749

4850

4951
if __name__ == "__main__":

_unittests/ut_tasks/test_tasks_object_detection.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@ def test_object_detection(self):
1414
self.assertEqual(data["task"], "object-detection")
1515
self.assertIn((data["size"], data["n_weights"]), [(8160384, 2040096)])
1616
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
17-
model(**inputs)
17+
expected = model(**inputs)
1818
model(**data["inputs2"])
1919
if not has_transformers("4.51.999"):
2020
raise unittest.SkipTest("Requires transformers>=4.52")
2121
with torch_export_patches(patch_transformers=True, verbose=10):
22-
torch.export.export(
22+
ep = torch.export.export(
2323
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
2424
)
25+
self.assertEqualAny(expected, ep.module()(**inputs))
2526

2627

2728
if __name__ == "__main__":

_unittests/ut_tasks/test_tasks_text_generation.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,62 @@
1010
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
1111
from onnx_diagnostic.torch_export_patches import torch_export_patches
1212
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
13+
from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
1314

1415

1516
class TestTasksTextGeneration(ExtTestCase):
1617
@hide_stdout()
1718
@requires_transformers("4.53")
1819
@requires_torch("2.7.99")
19-
def test_tet_generation_gemma3_for_causallm(self):
20+
def test_text_generation_gemma3_for_causallm(self):
2021
mid = "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
2122
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
2223
self.assertEqual(data["task"], "text-generation")
2324
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
24-
model(**torch_deepcopy(inputs))
25+
expected = model(**torch_deepcopy(inputs))
2526
model(**data["inputs2"])
2627
with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False):
27-
torch.export.export(
28+
ep = torch.export.export(
2829
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
2930
)
31+
self.assertEqualAny(expected, ep.module()(**inputs))
3032

3133
@hide_stdout()
3234
@requires_transformers("4.53")
3335
@requires_torch("2.7.99")
34-
def test_itext_generation_phi_3_mini_128k_instruct(self):
36+
def test_text_generation_phi_3_mini_128k_instruct(self):
3537
mid = "microsoft/Phi-3-mini-128k-instruct"
3638
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
3739
self.assertEqual(data["task"], "text-generation")
3840
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
39-
print("--", self.string_type(inputs, with_shape=True))
40-
print("--", self.string_type(ds))
41-
model(**torch_deepcopy(inputs))
41+
expected = model(**torch_deepcopy(inputs))
4242
model(**data["inputs2"])
4343
with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False):
44-
torch.export.export(
44+
ep = torch.export.export(
4545
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
4646
)
47+
self.assertEqualAny(expected, ep.module()(**inputs))
48+
49+
@hide_stdout()
50+
@requires_transformers("4.53")
51+
@requires_torch("2.7.99")
52+
def test_text_generation_tiny_llm(self):
53+
mid = "arnir0/Tiny-LLM"
54+
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
55+
self.assertEqual(data["task"], "text-generation")
56+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
57+
inputs_copied = torch_deepcopy(inputs)
58+
expected = model(**torch_deepcopy(inputs))
59+
model(**data["inputs2"])
60+
fake = make_fake_with_dynamic_dimensions(inputs, dynamic_shapes=ds)[0]
61+
with torch_export_patches(patch_transformers=True, verbose=10, patch_torch=False):
62+
ep = torch.export.export(
63+
model, (), kwargs=fake, dynamic_shapes=use_dyn_not_str(ds), strict=False
64+
)
65+
# print(ep)
66+
got = ep.module()(**inputs_copied)
67+
self.assertEqualAny(expected.past_key_values, got.past_key_values)
68+
self.assertEqualArray(expected.logits, got.logits)
4769

4870

4971
if __name__ == "__main__":

_unittests/ut_tasks/test_tasks_text_to_image.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@ def test_text_to_image(self):
2323
self.assertEqual(data["task"], "text-to-image")
2424
self.assertIn((data["size"], data["n_weights"]), [(5708048, 1427012)])
2525
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
26-
model(**inputs)
26+
expected = model(**inputs)
2727
model(**data["inputs2"])
2828
with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1):
29-
torch.export.export(
29+
ep = torch.export.export(
3030
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
3131
)
32+
self.assertEqualAny(expected, ep.module()(**inputs))
3233

3334

3435
if __name__ == "__main__":

_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@ def test_zero_shot_image_classification(self):
1515
self.assertEqual(data["task"], "zero-shot-image-classification")
1616
self.assertIn((data["size"], data["n_weights"]), [(188872708, 47218177)])
1717
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
18-
model(**inputs)
18+
expected = model(**inputs)
1919
model(**data["inputs2"])
2020
with torch_export_patches(patch_transformers=True, verbose=10):
21-
torch.export.export(
21+
ep = torch.export.export(
2222
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False
2323
)
24+
self.assertEqualAny(expected, ep.module()(**inputs))
2425

2526

2627
if __name__ == "__main__":

0 commit comments

Comments
 (0)