Skip to content

Commit 82ad2e9

Browse files
committed
improves documentation
1 parent 6fafaa6 commit 82ad2e9

File tree

5 files changed

+149
-6
lines changed

5 files changed

+149
-6
lines changed

_doc/status/exported_program_dynamic.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22
Exported Programs with Dynamic Shapes
33
=====================================
44

5-
The following script shows the exported program for many short cases
6-
and various l-plot-export-with-dynamic-shape to retrieve an ONNX model equivalent
7-
to the original model.
5+
The following script shows the exported program for many short cases exported
6+
with different options. This steps happens before converting into ONNX.
87

98
.. runpython::
109
:showcode:
@@ -55,6 +54,7 @@ to the original model.
5554
"export-strict",
5655
"export-nostrict",
5756
"export-nostrict-decall",
57+
"export-tracing",
5858
):
5959
expname = exporter.replace("export-", "")
6060
print()

_doc/status/exporter_dynamic.rst

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
=================================
2+
Exported ONNX with Dynamic Shapes
3+
=================================
4+
5+
The following script shows the exported program for many short cases
6+
and various l-plot-export-with-dynamic-shape to retrieve an ONNX model equivalent
7+
to the original model.
8+
9+
.. runpython::
10+
:showcode:
11+
:rst:
12+
:toggle: code
13+
:warningout: UserWarning
14+
15+
import inspect
16+
import textwrap
17+
import pandas
18+
from onnx_diagnostic.helpers import string_type
19+
from onnx_diagnostic.helpers.onnx_helper import pretty_onnx
20+
from onnx_diagnostic.torch_export_patches.eval import discover, run_exporter
21+
from onnx_diagnostic.ext_test_case import unit_test_going
22+
23+
cases = discover()
24+
print()
25+
print(":ref:`Summary <ledx-summary-exported-program>`")
26+
print()
27+
sorted_cases = sorted(cases.items())
28+
if unit_test_going():
29+
sorted_cases = sorted_cases[:3]
30+
for name, cls_model in sorted_cases:
31+
print(f"* :ref:`{name} <ledx-model-case-export-{name}>`")
32+
print()
33+
print()
34+
35+
obs = []
36+
for name, cls_model in sorted(cases.items()):
37+
print()
38+
print(f".. _ledx-model-case-export-{name}:")
39+
print()
40+
print(name)
41+
print("=" * len(name))
42+
print()
43+
print("forward")
44+
print("+++++++")
45+
print()
46+
print(".. code-block:: python")
47+
print()
48+
src = inspect.getsource(cls_model.forward)
49+
if src:
50+
print(textwrap.indent(textwrap.dedent(src), " "))
51+
else:
52+
print(" # code is missing")
53+
print()
54+
print()
55+
for exporter in ("custom", "dynamo-ir"):
56+
expname = exporter.replace("export-", "")
57+
print()
58+
print(expname)
59+
print("+" * len(expname))
60+
print()
61+
res = run_exporter(exporter, cls_model, True, quiet=True)
62+
case_ref = f":ref:`{name} <ledx-model-case-export-{name}>`"
63+
expo = exporter.split("-", maxsplit=1)[-1]
64+
if "inputs" in res:
65+
print(f"* **inputs:** ``{string_type(res['inputs'], with_shape=True)}``")
66+
if "dynamic_shapes" in res:
67+
print(f"* **shapes:** ``{string_type(res['dynamic_shapes'])}``")
68+
print()
69+
print()
70+
if "onx" in res:
71+
print(".. code-block:: text")
72+
print()
73+
print(textwrap.indent(pretty_onnx(res["onx"]), " "))
74+
print()
75+
print()
76+
obs.append(dict(case=case_ref, error="", exporter=expo))
77+
if "error" in res:
78+
print("**FAILED**")
79+
print()
80+
print(".. code-block:: text")
81+
print()
82+
err = str(res["error"])
83+
if err:
84+
print(textwrap.indent(err, " "))
85+
else:
86+
print(" # no error found for the failure")
87+
print()
88+
print()
89+
obs.append(dict(case=case_ref, error="FAIL", exporter=expo))
90+
91+
print()
92+
print(".. _ledx-summary-exported-program:")
93+
print()
94+
print("Summary")
95+
print("+++++++")
96+
print()
97+
df = pandas.DataFrame(obs)
98+
piv = df.pivot(index="case", columns="exporter", values="error")
99+
print(piv.to_markdown(tablefmt="rst"))
100+
print()

_doc/status/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ what works and what does not with :func:`torch.export.export`.
99
:maxdepth: 1
1010

1111
exported_program_dynamic
12+
exporter_dynamic
1213
patches_coverage
1314

1415
Some PRs in :epkg:`transformers` to keep in mind when it comes to export

_unittests/ut_torch_export_patches/test_eval.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,57 @@ def test_eval(self):
3535
self.assertIsInstance(ev, list)
3636
self.assertIsInstance(ev[0], dict)
3737

38-
def test_run_exporter(self):
38+
def test_run_exporter_custom(self):
3939
evaluation(
4040
cases="SignatureListFixedLength",
41-
exporters="custom-strict",
41+
exporters="custom",
42+
quiet=False,
43+
dynamic=False,
44+
)
45+
46+
def test_run_exporter_dynamo(self):
47+
evaluation(
48+
cases="SignatureListFixedLength",
49+
exporters="dynamo",
50+
quiet=False,
51+
dynamic=False,
52+
)
53+
54+
def test_run_exporter_dynamo_ir(self):
55+
evaluation(
56+
cases="SignatureListFixedLength",
57+
exporters="dynamo-ir",
58+
quiet=False,
59+
dynamic=False,
60+
)
61+
62+
def test_run_exporter_nostrict(self):
63+
evaluation(
64+
cases="SignatureListFixedLength",
65+
exporters="export-nostrict",
66+
quiet=False,
67+
dynamic=False,
68+
)
69+
70+
def test_run_exporter_tracing(self):
71+
evaluation(
72+
cases="SignatureListFixedLength",
73+
exporters="export-tracing",
4274
quiet=False,
4375
dynamic=False,
4476
)
4577

4678
def test_run_exporter_regex(self):
4779
evaluation(cases=".*Aten.*", exporters="custom-strict", quiet=False, dynamic=False)
4880

81+
def test_run_exporter_custom_nested_cond(self):
82+
evaluation(
83+
cases="ControlFlowNestCond",
84+
exporters="custom",
85+
quiet=False,
86+
dynamic=False,
87+
)
88+
4989

5090
if __name__ == "__main__":
5191
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/eval/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def _make_exporter_onnx(
337337
from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
338338

339339
opts = {}
340-
opts["strict"] = "-nostrict" not in exporter
340+
opts["strict"] = "-strict" in exporter
341341
opts["fallback"] = "-fallback" in exporter
342342
opts["tracing"] = "-tracing" in exporter
343343
opts["jit"] = "-jit" in exporter
@@ -520,6 +520,8 @@ def run_exporter(
520520
return res
521521

522522
onx, builder = res
523+
base["onx"] = onx
524+
base["builder"] = builder
523525
if verbose >= 9:
524526
print("[run_exporter] onnx model")
525527
print(

0 commit comments

Comments
 (0)