|
4 | 4 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
5 | 5 | import time |
6 | 6 | import onnx |
| 7 | +import onnxscript |
| 8 | +import onnxscript.rewriter.ort_fusions as ort_fusions |
7 | 9 | import torch |
8 | 10 | from ..export import CoupleInputsDynamicShapes |
9 | 11 | from ..helpers import max_diff, string_type, string_diff |
@@ -917,11 +919,10 @@ def call_torch_export_onnx( |
917 | 919 | :return: two dictionaries, one with some metrics, |
918 | 920 | another one with whatever the function produces |
919 | 921 | """ |
920 | | - assert optimization in { |
921 | | - "", |
922 | | - "ir", |
923 | | - None, |
924 | | - }, f"unexpected value for optimization={optimization}" |
| 922 | + available = {"", "ir", "os_ort"} |
| 923 | + assert ( |
| 924 | + optimization in available |
| 925 | + ), f"unexpected value for optimization={optimization}, available={available}" |
925 | 926 | assert exporter in { |
926 | 927 | "onnx-dynamo", |
927 | 928 | "onnx-script", |
@@ -1001,16 +1002,25 @@ def call_torch_export_onnx( |
1001 | 1002 | print(epo) |
1002 | 1003 | print("[call_torch_export_onnx] -- End of ONNXProgram") |
1003 | 1004 |
|
1004 | | - if optimization == "ir": |
| 1005 | + if optimization in {"ir", "os_ort"}: |
1005 | 1006 | if verbose: |
1006 | 1007 | print(f"[call_torch_export_onnx] starts optimization={optimization!r}...") |
1007 | | - _quiet_or_not_quiet( |
1008 | | - quiet, |
1009 | | - "export_onnx_opt_ir", |
1010 | | - summary, |
1011 | | - data, |
1012 | | - (lambda epo=epo: epo.optimize()), |
1013 | | - ) |
| 1008 | + if optimization == "ir": |
| 1009 | + label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize()) |
| 1010 | + else: |
| 1011 | + |
| 1012 | + def _os_ort_optim(epo): |
| 1013 | + onnxscript.optimizer.optimize_ir(epo.model) |
| 1014 | + optimized = ort_fusions.optimize_for_ort(epo.model) |
| 1015 | + if isinstance(optimized, tuple): |
| 1016 | + for k, v in optimized[1].items(): |
| 1017 | + summary[f"op_opt_fused_{k}"] = v |
| 1018 | + epo.model = optimized[0] |
| 1019 | + else: |
| 1020 | + epo.model = optimized |
| 1021 | + |
| 1022 | + label, f_optim = "export_onnx_opt_os_ort", (lambda epo=epo: _os_ort_optim(epo)) |
| 1023 | + _quiet_or_not_quiet(quiet, label, summary, data, f_optim) |
1014 | 1024 | if "ERR_export_onnx_opt_ir" in summary: |
1015 | 1025 | return summary, data |
1016 | 1026 | if verbose: |
@@ -1039,12 +1049,17 @@ def call_torch_export_custom( |
1039 | 1049 | :return: two dictionaries, one with some metrics, |
1040 | 1050 | another one with whatever the function produces |
1041 | 1051 | """ |
1042 | | - assert optimization in { |
| 1052 | + available = { |
1043 | 1053 | "", |
1044 | 1054 | "default", |
1045 | 1055 | "default+onnxruntime", |
| 1056 | + "default+os_ort", |
| 1057 | + "default+onnxruntime+os_ort", |
1046 | 1058 | None, |
1047 | | - }, f"unexpected value for optimization={optimization}" |
| 1059 | + } |
| 1060 | + assert ( |
| 1061 | + optimization in available |
| 1062 | + ), f"unexpected value for optimization={optimization}, available={available}" |
1048 | 1063 | assert exporter in { |
1049 | 1064 | "custom", |
1050 | 1065 | "custom-strict", |
@@ -1078,6 +1093,10 @@ def call_torch_export_custom( |
1078 | 1093 | from experimental_experiment.torch_interpreter import to_onnx, ExportOptions |
1079 | 1094 | from experimental_experiment.xbuilder import OptimizationOptions |
1080 | 1095 |
|
| 1096 | + spl = optimization.split("+") if optimization else [] |
| 1097 | + os_ort = "os_ort" in spl |
| 1098 | + optimization = "+".join(_ for _ in spl if _ != "os_ort") |
| 1099 | + |
1081 | 1100 | export_options = ExportOptions( |
1082 | 1101 | strict=strict, |
1083 | 1102 | decomposition_table=( |
@@ -1181,6 +1200,31 @@ def call_torch_export_custom( |
1181 | 1200 | assert epo is not None, "no onnx export was found" |
1182 | 1201 | if verbose: |
1183 | 1202 | print("[call_torch_export_custom] done (export)") |
| 1203 | + |
| 1204 | + if os_ort: |
| 1205 | + if verbose: |
| 1206 | + print("[call_torch_export_custom] conversion to IR...") |
| 1207 | + begin = time.perf_counter() |
| 1208 | + ir_model = epo.to_ir() |
| 1209 | + duration = time.perf_counter() - begin |
| 1210 | + summary["time_optim_to_ir"] = duration |
| 1211 | + if verbose: |
| 1212 | + print(f"[call_torch_export_custom] done in {duration}") |
| 1213 | + print("[call_torch_export_custom] start optimization...") |
| 1214 | + begin = time.perf_counter() |
| 1215 | + onnxscript.optimizer.optimize_ir(ir_model) |
| 1216 | + ir_optimized = ort_fusions.optimize_for_ort(ir_model) |
| 1217 | + if isinstance(ir_optimized, tuple): |
| 1218 | + report = ir_optimized[1] |
| 1219 | + for k, v in report.items(): |
| 1220 | + summary[f"op_opt_fused_{k}"] = v |
| 1221 | + ir_optimized = ir_optimized[0] |
| 1222 | + epo.model = ir_optimized |
| 1223 | + duration = time.perf_counter() - begin |
| 1224 | + summary["time_optim_os_ort"] = duration |
| 1225 | + if verbose: |
| 1226 | + print(f"[call_torch_export_custom] done in {duration}") |
| 1227 | + |
1184 | 1228 | data["onnx_program"] = epo |
1185 | 1229 | return summary, data |
1186 | 1230 |
|
|
0 commit comments