Skip to content

Commit 194ffc2

Browse files
committed
pass
1 parent 05b04f7 commit 194ffc2

File tree

5 files changed

+14
-277
lines changed

5 files changed

+14
-277
lines changed

_doc/api/torch_export_patches/index.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ onnx_diagnostic.torch_export_patches
77

88
patches/index
99
patch_inputs
10-
patch_module
1110

1211

1312
.. automodule:: onnx_diagnostic.torch_export_patches

_doc/api/torch_export_patches/patch_module.rst

Lines changed: 0 additions & 7 deletions
This file was deleted.

_unittests/ut_torch_export_patches/test_patch_module.py

Lines changed: 0 additions & 33 deletions
This file was deleted.

onnx_diagnostic/torch_export_patches/patch_module.py

Lines changed: 0 additions & 234 deletions
This file was deleted.

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,12 +1046,17 @@ def call_torch_export_custom(
10461046
:return: two dictionaries, one with some metrics,
10471047
another one with whatever the function produces
10481048
"""
1049-
assert optimization in {
1049+
available = {
10501050
"",
10511051
"default",
10521052
"default+onnxruntime",
1053+
"default+os_ort",
1054+
"default+onnxruntime+os_ort",
10531055
None,
1054-
}, f"unexpected value for optimization={optimization}"
1056+
}
1057+
assert (
1058+
optimization in available
1059+
), f"unexpected value for optimization={optimization}, available={available}"
10551060
assert exporter in {
10561061
"custom",
10571062
"custom-strict",
@@ -1085,6 +1090,10 @@ def call_torch_export_custom(
10851090
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
10861091
from experimental_experiment.xbuilder import OptimizationOptions
10871092

1093+
spl = optimization.split("+") if optimization else []
1094+
os_ort = "os_ort" in spl
1095+
optimization = "+".join(_ for _ in spl if _ != "os_ort")
1096+
10881097
export_options = ExportOptions(
10891098
strict=strict,
10901099
decomposition_table=(
@@ -1188,6 +1197,9 @@ def call_torch_export_custom(
11881197
assert epo is not None, "no onnx export was found"
11891198
if verbose:
11901199
print("[call_torch_export_custom] done (export)")
1200+
1201+
if os_ort:
1202+
pass
11911203
data["onnx_program"] = epo
11921204
return summary, data
11931205

0 commit comments

Comments
 (0)