|
4 | 4 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union |
5 | 5 | import time |
6 | 6 | import onnx |
| 7 | +import onnxscript |
| 8 | +import onnxscript.rewriter.ort_fusions as ort_fusions |
7 | 9 | import torch |
8 | 10 | from ..export import CoupleInputsDynamicShapes |
9 | 11 | from ..helpers import max_diff, string_type, string_diff |
@@ -917,11 +919,10 @@ def call_torch_export_onnx( |
917 | 919 | :return: two dictionaries, one with some metrics, |
918 | 920 | another one with whatever the function produces |
919 | 921 | """ |
920 | | - assert optimization in { |
921 | | - "", |
922 | | - "ir", |
923 | | - None, |
924 | | - }, f"unexpected value for optimization={optimization}" |
| 922 | + available = {"", "ir", "os_ort"} |
| 923 | + assert ( |
| 924 | + optimization in available |
| 925 | + ), f"unexpected value for optimization={optimization}, available={available}" |
925 | 926 | assert exporter in { |
926 | 927 | "onnx-dynamo", |
927 | 928 | "onnx-script", |
@@ -1001,16 +1002,22 @@ def call_torch_export_onnx( |
1001 | 1002 | print(epo) |
1002 | 1003 | print("[call_torch_export_onnx] -- End of ONNXProgram") |
1003 | 1004 |
|
1004 | | - if optimization == "ir": |
| 1005 | + if optimization in {"ir", "os_ort"}: |
1005 | 1006 | if verbose: |
1006 | 1007 | print(f"[call_torch_export_onnx] starts optimization={optimization!r}...") |
1007 | | - _quiet_or_not_quiet( |
1008 | | - quiet, |
1009 | | - "export_onnx_opt_ir", |
1010 | | - summary, |
1011 | | - data, |
1012 | | - (lambda epo=epo: epo.optimize()), |
1013 | | - ) |
| 1008 | + if optimization == "ir": |
| 1009 | + label, f_optim = "export_onnx_opt_ir", (lambda epo=epo: epo.optimize()) |
| 1010 | + else: |
| 1011 | + |
| 1012 | + def _os_ort_optim(epo): |
| 1013 | + onnxscript.optimizer.optimize_ir(epo.model) |
| 1014 | + optimized = ort_fusions.optimize_for_ort(epo.model) |
| 1015 | + epo.model = ( |
| 1016 | + optimized if isinstance(optimized, onnxscript.ir.Model) else optimized[0] |
| 1017 | + ) |
| 1018 | + |
| 1019 | + label, f_optim = "export_onnx_opt_os_ort", (lambda epo=epo: _os_ort_optim(epo)) |
| 1020 | + _quiet_or_not_quiet(quiet, label, summary, data, f_optim) |
1014 | 1021 | if "ERR_export_onnx_opt_ir" in summary: |
1015 | 1022 | return summary, data |
1016 | 1023 | if verbose: |
|
0 commit comments