Skip to content

Commit 61102ef

Browse files
committed
support for inputs3
1 parent 33df61b commit 61102ef

File tree

11 files changed

+240
-119
lines changed

11 files changed

+240
-119
lines changed

CHANGELOGS.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ Change Logs
44
0.7.0
55
+++++
66

7-
* :pr:`143`: compares intermediate results
7+
* :pr:`143`: compares intermediate results,
8+
support for second inputs with different dimension,
9+
rename test_helper into validate,
10+
support ``interpolate_pos_encoding`` for ``VitModel``
811

912
0.6.3
1013
+++++

_doc/api/torch_models/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ onnx_diagnostic.torch_models
77

88
hghub/index
99
llms
10-
test_helper
10+
validate
1111

1212
.. automodule:: onnx_diagnostic.torch_models
1313
:members:

_doc/api/torch_models/test_helper.rst

Lines changed: 0 additions & 7 deletions
This file was deleted.

_doc/api/torch_models/validate.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.torch_models.validate
3+
=====================================
4+
5+
.. automodule:: onnx_diagnostic.torch_models.validate
6+
:members:
7+
:no-undoc-members:

_doc/cmds/validate.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
===================================================
55

66
The command line is a wrapper around function
7-
:func:`onnx_diagnostic.torch_models.test_helper.validate_model`.
7+
:func:`onnx_diagnostic.torch_models.validate.validate_model`.
88

99
Description
1010
+++++++++++
@@ -110,7 +110,7 @@ Run onnxruntime fusions
110110

111111
This option runs `transformers optimizations <https://onnxruntime.ai/docs/performance/transformers-optimization.html>`_
112112
implemented in :epkg:`onnxruntime`. The list of supported ``model_type`` can be found in the documentation
113-
of function :func:`onnx_diagnostic.torch_models.test_helper.run_ort_fusion`.
113+
of function :func:`onnx_diagnostic.torch_models.validate.run_ort_fusion`.
114114

115115
.. code-block::
116116

_unittests/ut_torch_models/test_test_helpers.py renamed to _unittests/ut_torch_models/test_validate.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
requires_onnxscript,
1212
requires_transformers,
1313
)
14-
from onnx_diagnostic.torch_models.test_helper import (
14+
from onnx_diagnostic.torch_models.validate import (
1515
get_inputs_for_task,
1616
validate_model,
1717
filter_inputs,
@@ -21,7 +21,7 @@
2121
from onnx_diagnostic.tasks import supported_tasks
2222

2323

24-
class TestTestHelper(ExtTestCase):
24+
class TestValidate(ExtTestCase):
2525
def test_get_inputs_for_task(self):
2626
fcts = supported_tasks()
2727
for task in self.subloop(sorted(fcts)):
@@ -221,14 +221,39 @@ def test_validate_model_modelbuilder(self):
221221
do_run=True,
222222
verbose=10,
223223
exporter="modelbuilder",
224-
dump_folder="dump_test_validate_model_onnx_dynamo",
224+
dump_folder="dump_test_validate_model_modelbuilder",
225225
)
226226
self.assertIsInstance(summary, dict)
227227
self.assertIsInstance(data, dict)
228228
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
229229
onnx_filename = data["onnx_filename"]
230230
self.assertExists(onnx_filename)
231231

232+
@requires_torch("2.7")
233+
@hide_stdout()
234+
@ignore_warnings(FutureWarning)
235+
@requires_transformers("4.51")
236+
def test_validate_model_vit_model(self):
237+
mid = "ydshieh/tiny-random-ViTForImageClassification"
238+
summary, data = validate_model(
239+
mid,
240+
do_run=True,
241+
verbose=10,
242+
exporter="onnx-dynamo",
243+
dump_folder="dump_test_validate_model_onnx_dynamo",
244+
inputs2=True,
245+
)
246+
self.assertIsInstance(summary, dict)
247+
self.assertIsInstance(data, dict)
248+
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-3)
249+
self.assertLess(summary["disc_onnx_ort_run2_abs"], 1e-3)
250+
self.assertEqual("dict(pixel_values:A1s2x3x30x30)", summary["run_feeds_inputs"])
251+
self.assertEqual("dict(pixel_values:A1s3x3x31x31)", summary["run_feeds_inputs2"])
252+
self.assertEqual("#1[A1s2x2]", summary["run_output_inputs"])
253+
self.assertEqual("#1[A1s3x2]", summary["run_output_inputs2"])
254+
onnx_filename = data["onnx_filename"]
255+
self.assertExists(onnx_filename)
256+
232257

233258
if __name__ == "__main__":
234259
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_validate_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
requires_experimental,
99
requires_transformers,
1010
)
11-
from onnx_diagnostic.torch_models.test_helper import validate_model
11+
from onnx_diagnostic.torch_models.validate import validate_model
1212

1313

1414
class TestValidateModel(ExtTestCase):

onnx_diagnostic/_command_lines_parser.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,13 @@ def get_parser_validate() -> ArgumentParser:
373373
action=BooleanOptionalAction,
374374
help="validate the trained model (requires downloading)",
375375
)
376+
parser.add_argument(
377+
"--inputs2",
378+
default=True,
379+
action=BooleanOptionalAction,
380+
help="if run is on, the command lines validates the model on a "
381+
"second set of inputs to check the exported model supports dynamism",
382+
)
376383
parser.add_argument(
377384
"--runtime",
378385
choices=["onnxruntime", "torch", "ref"],
@@ -440,7 +447,7 @@ def get_parser_validate() -> ArgumentParser:
440447

441448
def _cmd_validate(argv: List[Any]):
442449
from .helpers import string_type
443-
from .torch_models.test_helper import get_inputs_for_task, validate_model
450+
from .torch_models.validate import get_inputs_for_task, validate_model
444451
from .tasks import supported_tasks
445452

446453
parser = get_parser_validate()
@@ -492,6 +499,7 @@ def _cmd_validate(argv: List[Any]):
492499
runtime=args.runtime,
493500
repeat=args.repeat,
494501
warmup=args.warmup,
502+
inputs2=args.inputs2,
495503
)
496504
print("")
497505
print("-- summary --")

onnx_diagnostic/tasks/image_classification.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_inputs(
5252
input_width, int
5353
), f"Unexpected type for input_width {type(input_width)}{config}"
5454
assert isinstance(
55-
input_width, int
55+
input_height, int
5656
), f"Unexpected type for input_height {type(input_height)}{config}"
5757

5858
shapes = {
@@ -67,6 +67,9 @@ def get_inputs(
6767
-1, 1
6868
),
6969
)
70+
if model.__class__.__name__ == "ViTForImageClassification":
71+
inputs["interpolate_pos_encoding"] = True
72+
shapes["interpolate_pos_encoding"] = None
7073
res = dict(inputs=inputs, dynamic_shapes=shapes)
7174
if add_second_input:
7275
res["inputs2"] = get_inputs(

onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4093,3 +4093,30 @@ def _ccached_microsoft_phi4_reasoning():
40934093
"vocab_size": 200064,
40944094
}
40954095
)
4096+
4097+
4098+
def _ccached_ydshieh_tiny_random_vit_for_image_classification():
4099+
"ydshieh/tiny-random-ViTForImageClassification"
4100+
return transformers.Phi3Config(
4101+
**{
4102+
"_name_or_path": ".temp/dummy/vit/ViTForImageClassification",
4103+
"architectures": ["ViTForImageClassification"],
4104+
"attention_probs_dropout_prob": 0.1,
4105+
"encoder_stride": 2,
4106+
"hidden_act": "gelu",
4107+
"hidden_dropout_prob": 0.1,
4108+
"hidden_size": 32,
4109+
"image_size": 30,
4110+
"initializer_range": 0.02,
4111+
"intermediate_size": 37,
4112+
"layer_norm_eps": 1e-12,
4113+
"model_type": "vit",
4114+
"num_attention_heads": 4,
4115+
"num_channels": 3,
4116+
"num_hidden_layers": 5,
4117+
"patch_size": 2,
4118+
"qkv_bias": true,
4119+
"torch_dtype": "float32",
4120+
"transformers_version": "4.24.0.dev0",
4121+
}
4122+
)

0 commit comments

Comments
 (0)