Skip to content

Commit 18aead7

Browse files
committed
add ort fusion
1 parent bc28e6a commit 18aead7

File tree

2 files changed

+101
-13
lines changed

2 files changed

+101
-13
lines changed

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
get_inputs_for_task,
88
validate_model,
99
filter_inputs,
10+
run_ort_fusion,
1011
)
1112
from onnx_diagnostic.torch_models.hghub.model_inputs import get_get_inputs_function_for_tasks
1213

@@ -69,6 +70,11 @@ def test_validate_model_onnx(self):
6970
self.assertIsInstance(summary, dict)
7071
self.assertIsInstance(data, dict)
7172
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
73+
onnx_filename = data["onnx_filename"]
74+
output_path = f"{onnx_filename}.ortopt.onnx"
75+
run_ort_fusion(
76+
onnx_filename, output_path, num_attention_heads=2, hidden_size=192, verbose=10
77+
)
7278

7379
def test_filter_inputs(self):
7480
inputs, ds = {"a": 1, "b": 2}, {"a": 20, "b": 30}

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 95 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ def validate_model(
420420
)
421421
summary.update(summary_export)
422422

423+
dump_stats = None
423424
if dump_folder:
424425
if "exported_program" in data:
425426
ep = data["exported_program"]
@@ -435,22 +436,27 @@ def validate_model(
435436
epo = data["onnx_program"]
436437
if verbose:
437438
print(f"[validate_model] dumps onnx program in {dump_folder!r}...")
438-
onnx_file_name = os.path.join(dump_folder, f"{folder_name}.onnx")
439+
onnx_filename = os.path.join(dump_folder, f"{folder_name}.onnx")
440+
begin = time.perf_counter()
439441
if isinstance(epo, onnx.model_container.ModelContainer):
440-
epo.save(onnx_file_name, all_tensors_to_one_file=True)
442+
epo.save(onnx_filename, all_tensors_to_one_file=True)
441443
else:
442-
epo.save(onnx_file_name, external_data=True)
444+
epo.save(onnx_filename, external_data=True)
445+
duration = time.perf_counter() - begin
443446
if verbose:
444-
print("[validate_model] done (dump onnx)")
447+
print(f"[validate_model] done (dump onnx) in {duration}")
448+
data["onnx_filename"] = onnx_filename
449+
summary["time_onnx_save"] = duration
445450
if verbose:
446451
print(f"[validate_model] dumps statistics in {dump_folder!r}...")
447-
with open(os.path.join(dump_folder, f"{folder_name}.stats"), "w") as f:
452+
dump_stats = os.path.join(dump_folder, f"{folder_name}.stats")
453+
with open(dump_stats, "w") as f:
448454
for k, v in sorted(summary.items()):
449455
f.write(f":{k}:{v};\n")
450456
if verbose:
451457
print("[validate_model] done (dump)")
452458

453-
if exporter and exporter.startswith("onnx-") and do_run:
459+
if exporter and exporter.startswith(("onnx-", "custom-")) and do_run:
454460
summary_valid, data = validate_onnx_model(
455461
data=data,
456462
quiet=quiet,
@@ -461,6 +467,10 @@ def validate_model(
461467

462468
if verbose:
463469
print("[validate_model] -- done (final)")
470+
if dump_stats:
471+
with open(dump_stats, "w") as f:
472+
for k, v in sorted(summary.items()):
473+
f.write(f":{k}:{v};\n")
464474
return summary, data
465475

466476

@@ -642,7 +652,7 @@ def validate_onnx_model(
642652
quiet: bool = False,
643653
verbose: int = 0,
644654
optimization: Optional[str] = None,
645-
):
655+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
646656
"""
647657
Verifies that an onnx model produces the same
648658
expected outputs.
@@ -665,10 +675,10 @@ def validate_onnx_model(
665675
if d < 0
666676
else ["CUDAExecutionProvider", "CPUExecutionProvider"]
667677
)
668-
if "onnx_file_name" in data:
669-
source = data["onnx_file_name"]
678+
if "onnx_filename" in data:
679+
source = data["onnx_filename"]
670680
summary["onnx_filename"] = source
671-
summary["onnx_size"] = os.stats(source).st_size
681+
summary["onnx_size"] = os.stat(source).st_size
672682
else:
673683
assert (
674684
"onnx_program" in data
@@ -745,7 +755,7 @@ def call_torch_export_onnx(
745755
quiet: bool = False,
746756
verbose: int = 0,
747757
optimization: Optional[str] = None,
748-
):
758+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
749759
"""
750760
Exports a model into onnx.
751761
If a patch must be applied, it should be before this functions.
@@ -818,7 +828,7 @@ def call_torch_export_onnx(
818828
if verbose:
819829
print("[call_torch_export_onnx] done (export)")
820830
data["onnx_program"] = epo
821-
if verbose > 1:
831+
if verbose > 5:
822832
print("[call_torch_export_onnx] -- ONNXProgram")
823833
print(epo)
824834
print("[call_torch_export_onnx] -- End of ONNXProgram")
@@ -850,7 +860,7 @@ def call_torch_export_custom(
850860
quiet: bool = False,
851861
verbose: int = 0,
852862
optimization: Optional[str] = None,
853-
):
863+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
854864
"""
855865
Exports a model into onnx.
856866
If a patch must be applied, it should be before this functions.
@@ -1011,3 +1021,75 @@ def call_torch_export_custom(
10111021
print("[call_torch_export_custom] done (export)")
10121022
data["onnx_program"] = epo
10131023
return summary, data
1024+
1025+
1026+
def run_ort_fusion(
1027+
model_or_path: Union[str, onnx.ModelProto],
1028+
output_path: str,
1029+
num_attention_heads: int,
1030+
hidden_size: int,
1031+
model_type: str = "bert",
1032+
verbose: int = 0,
1033+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
1034+
"""
1035+
Runs :epkg:`onnxruntime` fusion optimizer.
1036+
1037+
:param model_or_path: path to the ModelProto or the ModelProto itself
1038+
:param output_path: the model to save
1039+
:param num_attention_heads: number of heads, usually ``config.num_attention_heads``
1040+
:param hidden_size: hidden size, usually ``config.hidden_size``
1041+
:param model_type: type of optimization, see below
1042+
:param verbose: verbosity
1043+
:return: two dictionaries, summary and data
1044+
1045+
Supported values for ``model_type``:
1046+
1047+
.. runpython::
1048+
:showcode:
1049+
1050+
import pprint
1051+
from onnxruntime.transformers.optimizer import MODEL_TYPES
1052+
1053+
pprint.pprint(sorted(MODEL_TYPES))
1054+
"""
1055+
from onnxruntime.transformers.optimizer import optimize_by_fusion
1056+
from onnxruntime.transformers.fusion_options import FusionOptions
1057+
1058+
opts = FusionOptions(model_type)
1059+
1060+
if isinstance(model_or_path, str):
1061+
if verbose:
1062+
print(f"[run_ort_fusion] loads {model_or_path!r}")
1063+
onx = onnx.load(model_or_path)
1064+
else:
1065+
onx = model_or_path
1066+
begin = time.perf_counter()
1067+
n_nodes = len(onx.graph.node)
1068+
if verbose:
1069+
print(
1070+
f"[run_ort_fusion] starts optimization for "
1071+
f"model_type={model_type!r} with {n_nodes} nodes"
1072+
)
1073+
new_onx = optimize_by_fusion(
1074+
onx,
1075+
model_type=model_type,
1076+
num_heads=num_attention_heads,
1077+
hidden_size=hidden_size,
1078+
optimization_options=opts,
1079+
)
1080+
duration = {time.perf_counter() - begin}
1081+
delta = len(new_onx.model.graph.node)
1082+
if verbose:
1083+
print(f"[run_ort_fusion] done in {duration} with {delta} nodes")
1084+
print(f"[run_ort_fusion] save to {output_path!r}")
1085+
begin = time.perf_counter()
1086+
new_onx.save_model_to_file(output_path, use_external_data_format=True)
1087+
d = time.perf_counter() - begin
1088+
if verbose:
1089+
print(f"[run_ort_fusion] done in {d}")
1090+
return {
1091+
f"opt_ort_{model_type}_n_nodes1": n_nodes,
1092+
f"opt_ort_{model_type}_n_nodes2": delta,
1093+
f"opt_ort_{model_type}_duration": duration,
1094+
f"opt_ort_{model_type}_duration_save": d,
1095+
}, {f"opt_ort_{model_type}": output_path}

0 commit comments

Comments
 (0)