Skip to content

Commit 137b16c

Browse files
authored
Support for inputs2, refactoring, fix VitModel (#144)
* support for inputs3 * mypy * assign * rename * fix model builder * update changloes
1 parent 33df61b commit 137b16c

File tree

14 files changed

+305
-216
lines changed

14 files changed

+305
-216
lines changed

CHANGELOGS.rst

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@ Change Logs
44
0.7.0
55
+++++
66

7-
* :pr:`143`: compares intermediate results
7+
* :pr:`144`: support for second inputs with different dimension,
8+
rename test_helper into validate,
9+
support ``interpolate_pos_encoding`` for ``VitModel``,
10+
update model builder helpers for this PR
11+
`Use ONNX IR for model builder
12+
<https://github.com/microsoft/onnxruntime-genai/pull/1416>`_
13+
* :pr:`143`: compares intermediate results,
814

915
0.6.3
1016
+++++

_doc/api/torch_models/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ onnx_diagnostic.torch_models
77

88
hghub/index
99
llms
10-
test_helper
10+
validate
1111

1212
.. automodule:: onnx_diagnostic.torch_models
1313
:members:

_doc/api/torch_models/test_helper.rst

Lines changed: 0 additions & 7 deletions
This file was deleted.

_doc/api/torch_models/validate.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_models.validate
3+
=====================================
4+
5+
.. automodule:: onnx_diagnostic.torch_models.validate
6+
:members:
7+
:no-undoc-members:

_doc/cmds/validate.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
===================================================
55

66
The command line is a wrapper around function
7-
:func:`onnx_diagnostic.torch_models.test_helper.validate_model`.
7+
:func:`onnx_diagnostic.torch_models.validate.validate_model`.
88

99
Description
1010
+++++++++++
@@ -110,7 +110,7 @@ Run onnxruntime fusions
110110

111111
This option runs `transformers optimizations <https://onnxruntime.ai/docs/performance/transformers-optimization.html>`_
112112
implemented in :epkg:`onnxruntime`. The list of supported ``model_type`` can be found in the documentation
113-
of function :func:`onnx_diagnostic.torch_models.test_helper.run_ort_fusion`.
113+
of function :func:`onnx_diagnostic.torch_models.validate.run_ort_fusion`.
114114

115115
.. code-block::
116116

_unittests/ut_helpers/test_doc_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def test_custom_doc_kernels_layer_normalization(self):
5656
)
5757
expected = torch_sess.run(None, feeds)
5858
got = torch_sess_custom.run(None, feeds)
59-
self.assertEqualAny(expected, got)
59+
self.assertEqualAny(expected, got, atol=1e-3)
6060

6161
def test_custom_doc_kernels_matmul(self):
6262
model = oh.make_model(

_unittests/ut_helpers/test_model_builder_helper.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import os
21
import unittest
32
from onnx_diagnostic.ext_test_case import (
43
ExtTestCase,
@@ -48,32 +47,17 @@ def test_model_builder_id(self):
4847
cache_dir=folder,
4948
verbose=1,
5049
)
51-
self.assertGreater(len(onnx_model.nodes), 5)
50+
self.assertGreater(onnx_model.model.graph.num_nodes(), 5)
51+
model_name = save_model_builder(onnx_model, folder, verbose=1)
52+
self.assertExists(model_name)
5253

53-
proto = save_model_builder(onnx_model, verbose=1)
5454
import onnxruntime
5555

56-
onnxruntime.InferenceSession(
57-
proto.SerializeToString(), providers=["CPUExecutionProvider"]
58-
)
59-
60-
# We need to start again.
61-
onnx_model = create_model_builder(
62-
data["configuration"],
63-
data["model"],
64-
precision="fp32",
65-
execution_provider="cpu",
66-
cache_dir=folder,
67-
verbose=1,
68-
)
69-
save_model_builder(onnx_model, folder, verbose=1)
70-
model_name = os.path.join(folder, "model.onnx")
71-
self.assertExists(model_name)
72-
73-
feeds = make_feeds(proto, data["inputs"], use_numpy=True)
56+
sess = onnxruntime.InferenceSession(model_name, providers=["CPUExecutionProvider"])
57+
del data["inputs"]["position_ids"]
58+
feeds = make_feeds([i.name for i in sess.get_inputs()], data["inputs"], use_numpy=True)
7459
expected = data["model"](**data["inputs"])
7560

76-
sess = onnxruntime.InferenceSession(model_name, providers=["CPUExecutionProvider"])
7761
try:
7862
got = sess.run(None, feeds)
7963
except onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument as e:

_unittests/ut_torch_models/test_validate_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
requires_experimental,
99
requires_transformers,
1010
)
11-
from onnx_diagnostic.torch_models.test_helper import validate_model
11+
from onnx_diagnostic.torch_models.validate import validate_model
1212

1313

1414
class TestValidateModel(ExtTestCase):

_unittests/ut_torch_models/test_test_helpers.py renamed to _unittests/ut_torch_models/test_validate_whole_models.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
requires_onnxscript,
1212
requires_transformers,
1313
)
14-
from onnx_diagnostic.torch_models.test_helper import (
14+
from onnx_diagnostic.torch_models.validate import (
1515
get_inputs_for_task,
1616
validate_model,
1717
filter_inputs,
@@ -21,7 +21,7 @@
2121
from onnx_diagnostic.tasks import supported_tasks
2222

2323

24-
class TestTestHelper(ExtTestCase):
24+
class TestValidateWholeModels(ExtTestCase):
2525
def test_get_inputs_for_task(self):
2626
fcts = supported_tasks()
2727
for task in self.subloop(sorted(fcts)):
@@ -221,14 +221,39 @@ def test_validate_model_modelbuilder(self):
221221
do_run=True,
222222
verbose=10,
223223
exporter="modelbuilder",
224-
dump_folder="dump_test_validate_model_onnx_dynamo",
224+
dump_folder="dump_test_validate_model_modelbuilder",
225225
)
226226
self.assertIsInstance(summary, dict)
227227
self.assertIsInstance(data, dict)
228228
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
229229
onnx_filename = data["onnx_filename"]
230230
self.assertExists(onnx_filename)
231231

232+
@requires_torch("2.7")
233+
@hide_stdout()
234+
@ignore_warnings(FutureWarning)
235+
@requires_transformers("4.51")
236+
def test_validate_model_vit_model(self):
237+
mid = "ydshieh/tiny-random-ViTForImageClassification"
238+
summary, data = validate_model(
239+
mid,
240+
do_run=True,
241+
verbose=10,
242+
exporter="onnx-dynamo",
243+
dump_folder="dump_test_validate_model_onnx_dynamo",
244+
inputs2=True,
245+
)
246+
self.assertIsInstance(summary, dict)
247+
self.assertIsInstance(data, dict)
248+
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-3)
249+
self.assertLess(summary["disc_onnx_ort_run2_abs"], 1e-3)
250+
self.assertEqual("dict(pixel_values:A1s2x3x30x30)", summary["run_feeds_inputs"])
251+
self.assertEqual("dict(pixel_values:A1s3x3x31x31)", summary["run_feeds_inputs2"])
252+
self.assertEqual("#1[A1s2x2]", summary["run_output_inputs"])
253+
self.assertEqual("#1[A1s3x2]", summary["run_output_inputs2"])
254+
onnx_filename = data["onnx_filename"]
255+
self.assertExists(onnx_filename)
256+
232257

233258
if __name__ == "__main__":
234259
unittest.main(verbosity=2)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,13 @@ def get_parser_validate() -> ArgumentParser:
373373
action=BooleanOptionalAction,
374374
help="validate the trained model (requires downloading)",
375375
)
376+
parser.add_argument(
377+
"--inputs2",
378+
default=True,
379+
action=BooleanOptionalAction,
380+
help="if run is on, the command lines validates the model on a "
381+
"second set of inputs to check the exported model supports dynamism",
382+
)
376383
parser.add_argument(
377384
"--runtime",
378385
choices=["onnxruntime", "torch", "ref"],
@@ -440,7 +447,7 @@ def get_parser_validate() -> ArgumentParser:
440447

441448
def _cmd_validate(argv: List[Any]):
442449
from .helpers import string_type
443-
from .torch_models.test_helper import get_inputs_for_task, validate_model
450+
from .torch_models.validate import get_inputs_for_task, validate_model
444451
from .tasks import supported_tasks
445452

446453
parser = get_parser_validate()
@@ -492,6 +499,7 @@ def _cmd_validate(argv: List[Any]):
492499
runtime=args.runtime,
493500
repeat=args.repeat,
494501
warmup=args.warmup,
502+
inputs2=args.inputs2,
495503
)
496504
print("")
497505
print("-- summary --")

0 commit comments

Comments
 (0)