Skip to content

Commit 859338f

Browse files
authored
Add onnxruntime fusion in command line validate (#50)
* add ort fusion * ortfusion in command line * mypy
1 parent bc28e6a commit 859338f

File tree

4 files changed

+217
-36
lines changed

4 files changed

+217
-36
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.4.0
55
+++++
66

7+
* :pr:`50`: add support for onnxruntime fusion
78
* :pr:`48`: add support for EncoderDecoderCache, test with openai/whisper-tiny
89
* :pr:`45`: improve change_dynamic_dimension to fix some dimensions
910

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
get_inputs_for_task,
88
validate_model,
99
filter_inputs,
10+
run_ort_fusion,
1011
)
1112
from onnx_diagnostic.torch_models.hghub.model_inputs import get_get_inputs_function_for_tasks
1213

@@ -69,6 +70,11 @@ def test_validate_model_onnx(self):
6970
self.assertIsInstance(summary, dict)
7071
self.assertIsInstance(data, dict)
7172
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
73+
onnx_filename = data["onnx_filename"]
74+
output_path = f"{onnx_filename}.ortopt.onnx"
75+
run_ort_fusion(
76+
onnx_filename, output_path, num_attention_heads=2, hidden_size=192, verbose=10
77+
)
7278

7379
def test_filter_inputs(self):
7480
inputs, ds = {"a": 1, "b": 2}, {"a": 20, "b": 30}

onnx_diagnostic/_command_lines_parser.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,12 @@ def get_parser_validate() -> ArgumentParser:
287287
help="drops the following inputs names, it should be a list "
288288
"with comma separated values",
289289
)
290+
parser.add_argument(
291+
"--ortfusiontype",
292+
required=False,
293+
help="applies onnxruntime fusion, this parameter should contain the "
294+
"model type or multiple values separated by |",
295+
)
290296
parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity")
291297
parser.add_argument("--dtype", help="changes dtype if necessary")
292298
parser.add_argument("--device", help="changes the device if necessary")
@@ -338,6 +344,7 @@ def _cmd_validate(argv: List[Any]):
338344
exporter=args.export,
339345
dump_folder=args.dump_folder,
340346
drop_inputs=None if not args.drop else args.drop.split(","),
347+
ortfusiontype=args.ortfusiontype,
341348
)
342349
print("")
343350
print("-- summary --")

0 commit comments

Comments
 (0)