Skip to content

Commit b7d5dd7

Browse files
committed
more tests
1 parent bfac0a1 commit b7d5dd7

File tree

8 files changed

+121
-14
lines changed

8 files changed

+121
-14
lines changed

_doc/examples/plot_export_hub_codellama.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222
from onnx_diagnostic import doc
2323
from onnx_diagnostic.ext_test_case import unit_test_going
2424
from onnx_diagnostic.helpers import string_type
25-
from onnx_diagnostic.torch_models.hghub import (
26-
get_untrained_model_with_inputs,
27-
)
25+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
2826
from onnx_diagnostic.torch_models.hghub.hub_api import (
2927
get_model_info,
3028
get_pretrained_config,

_doc/examples/plot_export_tiny_phi2.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,7 @@
3333
from onnx_diagnostic.helpers.rt_helper import make_feeds
3434
from onnx_diagnostic.torch_export_patches import torch_export_patches
3535
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
36-
from onnx_diagnostic.torch_models.hghub import (
37-
get_untrained_model_with_inputs,
38-
)
36+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
3937

4038
warnings.simplefilter("ignore")
4139

_unittests/ut_export/test_api.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
import unittest
22
import torch
33
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
4+
from onnx_diagnostic.helpers import max_diff
5+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
6+
from onnx_diagnostic.helpers.rt_helper import make_feeds
7+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
8+
from onnx_diagnostic.torch_export_patches import torch_export_patches
49
from onnx_diagnostic.export.api import to_onnx
510

611

@@ -19,16 +24,66 @@ def forward(self, x, y):
1924
(x, y),
2025
dynamic_shapes=ds,
2126
exporter="custom",
22-
filename=self.get_dump_file("custom.onnx"),
27+
filename=self.get_dump_file("to_onnx_custom.onnx"),
2328
)
2429
to_onnx(
2530
Model(),
2631
(x, y),
2732
dynamic_shapes=ds,
2833
exporter="onnx-dynamo",
29-
filename=self.get_dump_file("onnx-dynamo.onnx"),
34+
filename=self.get_dump_file("to_onnx_onnx-dynamo.onnx"),
3035
)
3136

37+
@hide_stdout()
38+
def test_tiny_llm_to_onnx(self):
39+
import onnxruntime
40+
41+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM")
42+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
43+
b1 = data["inputs_batch1"]
44+
filenames = {
45+
"custom": self.get_dump_file("test_tiny_llm_to_onnx-custom.onnx"),
46+
"onnx-dynamo": self.get_dump_file("test_tiny_llm_to_onnx-dynamo.onnx"),
47+
"modelbuilder": self.get_dump_file("model.onnx"),
48+
}
49+
del inputs["position_ids"]
50+
del ds["position_ids"]
51+
del b1["position_ids"]
52+
53+
expected = model(**torch_deepcopy(b1))
54+
55+
with torch_export_patches(patch_transformers=True):
56+
for exporter, filename in filenames.items():
57+
with self.subTest(exporter=exporter):
58+
to_onnx(
59+
model,
60+
kwargs=inputs,
61+
dynamic_shapes=ds,
62+
exporter=exporter,
63+
filename=filename,
64+
)
65+
for exporter, filename in filenames.items():
66+
with self.subTest(exporter=f"validate-{exporter}"):
67+
sess = onnxruntime.InferenceSession(
68+
filename, providers=["CPUExecutionProvider"]
69+
)
70+
feeds = make_feeds(sess, b1, use_numpy=True)
71+
got = sess.run(None, feeds)
72+
diff = max_diff(expected, got)
73+
assert diff["abs"] <= 1e-5, f"diff={diff}"
74+
75+
b1["attention_mask"][:, :] = 1
76+
expected = model(**torch_deepcopy(b1))
77+
for exporter, filename in filenames.items():
78+
with self.subTest(exporter=f"full-mask-{exporter}"):
79+
sess = onnxruntime.InferenceSession(
80+
filename, providers=["CPUExecutionProvider"]
81+
)
82+
feeds = make_feeds(sess, b1, use_numpy=True)
83+
got = sess.run(None, feeds)
84+
diff = max_diff(expected, got)
85+
assert diff["abs"] <= 1e-5, f"diff={diff}"
86+
3287

3388
if __name__ == "__main__":
3489
unittest.main(verbosity=2)

_unittests/ut_export/test_dynamic_shapes.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,27 @@ def forward(self, cache, z):
916916
ds,
917917
)
918918

919+
def test_invalid_dimensions_for_export(self):
920+
ags = []
921+
kws = dict(
922+
input_ids=torch.randint(0, 10, (2, 3)),
923+
attention_mask=torch.randint(0, 1, (2, 33)),
924+
position_ids=torch.randint(0, 10, (2, 3)),
925+
past_key_values=make_dynamic_cache(
926+
[torch.rand((2, 1, 30, 96)), torch.rand((2, 1, 30, 96))]
927+
),
928+
)
929+
ds = dict(
930+
input_ids={0: "batch", 1: "seq_length"},
931+
attention_mask={0: "batch", 1: "seq_length"},
932+
position_ids={0: "batch", 1: "seq_length"},
933+
past_key_values=[{0: "batch", 2: "cache_length"}, {0: "batch", 2: "cache_length"}],
934+
)
935+
with torch_export_patches(patch_transformers=True):
936+
cpl = CoupleInputsDynamicShapes(ags, kws, ds)
937+
backed_size_oblivious = cpl.invalid_dimensions_for_export()
938+
self.assertFalse(backed_size_oblivious)
939+
919940

920941
if __name__ == "__main__":
921942
unittest.main(verbosity=2)

_unittests/ut_helpers/test_model_builder_helper.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
create_model_builder,
1313
save_model_builder,
1414
)
15-
from onnx_diagnostic.torch_models.hghub import (
16-
get_untrained_model_with_inputs,
17-
)
15+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
1816
from onnx_diagnostic.helpers.rt_helper import make_feeds
1917

2018

_unittests/ut_helpers/test_rt_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_onnx_generate(self):
113113
kwargs=inputs,
114114
dynamic_shapes=ds,
115115
filename=model_name,
116-
exporter="custom",
116+
exporter="modelbuilder",
117117
)
118118

119119
print("-- test_onnx_generate: generate")

_unittests/ut_xrun_doc/test_documentation_examples.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def add_test_methods(cls):
102102

103103
if (
104104
not reason
105-
and name in {"plot_export_with_dynamic_cache.py", "plot_export_tiny_phi2.py"}
105+
and name in {"plot_export_tiny_phi2.py", "plot_export_with_dynamic_cache.py"}
106106
and not has_transformers("4.55")
107107
):
108108
reason = "transformers<4.55"
@@ -117,6 +117,7 @@ def add_test_methods(cls):
117117
"plot_export_locate_issue.py",
118118
"plot_export_with_auto.py",
119119
"plot_export_tiny_llm.py",
120+
"plot_export_with_dynamic_cache.py",
120121
}
121122
and not has_torch("2.8")
122123
):

onnx_diagnostic/export/api.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def to_onnx(
4242
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
4343
epo = torch.onnx.export(
4444
mod,
45-
args=args,
45+
args=args or tuple(),
4646
kwargs=kwargs,
4747
input_names=input_names,
4848
output_names=output_names,
@@ -54,4 +54,40 @@ def to_onnx(
5454
epo.save(filename)
5555
return epo
5656

57+
if exporter == "modelbuilder":
58+
import os
59+
from ..helpers import flatten_object, string_type
60+
from ..helpers.model_builder_helper import create_model_builder, save_model_builder
61+
62+
assert filename, f"filename must be specified for exporter={exporter!r}"
63+
assert (
64+
not output_dynamic_shapes
65+
), f"output_dynamic_shapes not supported for exporter={exporter!r}"
66+
assert hasattr(mod, "config"), f"configuration is missing in model class {type(mod)}"
67+
assert not args, f"only kwargs can be defined with exporter={exporter!r}"
68+
assert list(kwargs) == ["input_ids", "attention_mask", "past_key_values"], (
69+
f"Only a specified set of inputs is supported for exporter={exporter!r}, "
70+
f"but it is {list(kwargs)}"
71+
)
72+
flat_inputs = flatten_object(kwargs, drop_keys=True)
73+
first = flat_inputs[0]
74+
first_float = [
75+
t
76+
for t in flat_inputs
77+
if t.dtype in {torch.float32, torch.double, torch.float16, torch.bfloat16}
78+
]
79+
assert first_float, (
80+
f"Unable to find a float tensor in the inputs "
81+
f"{string_type(kwargs, with_shape=True)}"
82+
)
83+
onx = create_model_builder(
84+
mod.config,
85+
mod,
86+
precision=str(first_float[0].dtype).split(".")[-1],
87+
execution_provider="cuda" if first.is_cuda else "cpu",
88+
cache_dir=os.path.dirname(filename),
89+
)
90+
save_model_builder(onx, os.path.dirname(filename))
91+
return onx
92+
5793
raise ValueError(f"Unknown exporter={exporter!r}")

0 commit comments

Comments
 (0)