Skip to content

Commit 77a6db1

Browse files
authored
documentation (#112)
* documentation * fix test case
1 parent ac301d1 commit 77a6db1

File tree

9 files changed

+170
-71
lines changed

9 files changed

+170
-71
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ temp_dump_models/*
4343
dump_dort_bench/*
4444
test_zoo_*/
4545
temp_llama_model_*/*
46+
test_example_transform_method_*.py
4647
test_*.txt
4748
_cache/*
4849
.coverage

_doc/index.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ It also implements tools to investigate, validate exported models (ExportedProgr
3636
:caption: Contents
3737

3838
patches
39-
status/index
4039
api/index
4140
cmds/index
4241
auto_examples/index
@@ -75,7 +74,7 @@ Enlightening Examples
7574
**Torch Export**
7675

7776
* :ref:`l-plot-export-cond`
78-
* :ref:`l-plot-sxport-with-auto`
77+
* :ref:`l-plot-export-with-dynamic`
7978
* :ref:`l-plot-export-with-dynamic-shape`
8079
* :ref:`l-plot-export-locale-issue`
8180
* :ref:`l-plot-tiny-llm-export`

_doc/patches.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@ implements four kinds of patches to make it easier to export a model, usually
99
coming from :epkg:`transformers`.
1010
All patches takes place in :mod:`onnx_diagnostic.torch_export_patches`.
1111

12+
.. toctree::
13+
14+
status/index
15+
16+
Four Kinds of Patches
17+
=====================
18+
1219
.. code-block:: python
1320
1421
with torch_export_patches(...) as f:

_doc/examples/plot_export_cond.py renamed to _doc/recipes/plot_export_cond.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
from onnx_diagnostic import doc
17+
from onnx_diagnostic.torch_export_patches import torch_export_rewrite
1718

1819

1920
# %%
@@ -24,7 +25,8 @@ class ForwardWithControlFlowTest(torch.nn.Module):
2425
def forward(self, x):
2526
if x.sum():
2627
return x * 2
27-
return -x
28+
else:
29+
return -x
2830

2931

3032
class ModelWithControlFlow(torch.nn.Module):
@@ -86,7 +88,28 @@ def neg(x):
8688
ep = torch.export.export(model, (x,))
8789
print(ep.graph)
8890

91+
# %%
92+
# Automated Rewrite of the Control Flow
93+
# +++++++++++++++++++++++++++++++++++++
94+
#
95+
# Functions :func:`torch_export_rewrite
96+
# <onnx_diagnostic.torch_export_patches.torch_export_rewrite>`
97+
# or :func:`torch_export_patches <onnx_diagnostic.torch_export_patches.torch_export_patches>`
98+
# can automatically rewrite a method of a class or a function,
99+
# the method to rewrite is specified parameter ``rewrite``.
100+
# It is experimental. The function contains options to
101+
# rewrite one test but not another one already supported by the exporter.
102+
# It may give a first version of the rewritten code if only a manual
103+
# rewriting can make the model exportable.
104+
105+
with torch_export_rewrite(rewrite=[ForwardWithControlFlowTest.forward], verbose=2) as f:
106+
ep = torch.export.export(model, (x,))
107+
108+
# %%
109+
# This gives:
110+
111+
print(ep.graph)
89112

90113
# %%
91114

92-
doc.plot_legend("If -> torch.cond", "torch.export.export", "tomato")
115+
doc.plot_legend("If -> torch.cond", "torch.export.export", "yellowgreen")

_doc/examples/plot_export_with_auto.py renamed to _doc/recipes/plot_export_with_dynamic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
.. _l-plot-sxport-with-auto:
2+
.. _l-plot-export-with-dynamic:
33
44
Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints
55
====================================================================
@@ -93,4 +93,4 @@ def forward(self, x, y, z):
9393

9494
# %%
9595

96-
doc.plot_legend("torch.export.Dim\nor DYNAMIC\nor AUTO", "torch.export.export", "tomato")
96+
doc.plot_legend("torch.export.Dim\nor DYNAMIC\nor AUTO", "torch.export.export", "green")

_doc/status/exported_program_dynamic.rst

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ to the original model.
2929
for name, cls_model in sorted_cases:
3030
print(f"* :ref:`{name} <led-model-case-export-{name}>`")
3131
print()
32+
print()
3233

3334
obs = []
3435
for name, cls_model in sorted(cases.items()):
@@ -41,14 +42,15 @@ to the original model.
4142
print("forward")
4243
print("+++++++")
4344
print()
44-
print("::")
45+
print(".. code-block:: python")
4546
print()
4647
src = inspect.getsource(cls_model.forward)
4748
if src:
4849
print(textwrap.indent(textwrap.dedent(src), " "))
4950
else:
5051
print(" # code is missing")
5152
print()
53+
print()
5254
for exporter in (
5355
"export-strict",
5456
"export-nostrict",
@@ -67,23 +69,26 @@ to the original model.
6769
if "dynamic_shapes" in res:
6870
print(f"* **shapes:** ``{string_type(res['dynamic_shapes'])}``")
6971
print()
72+
print()
7073
if "exported" in res:
71-
print("::")
74+
print(".. code-block:: text")
7275
print()
7376
print(textwrap.indent(str(res["exported"].graph), " "))
7477
print()
78+
print()
7579
obs.append(dict(case=case_ref, error="", exporter=expo))
7680
else:
7781
print("**FAILED**")
7882
print()
79-
print("::")
83+
print(".. code-block:: text")
8084
print()
8185
err = str(res["error"])
8286
if err:
8387
print(textwrap.indent(err, " "))
8488
else:
8589
print(" # no error found for the failure")
8690
print()
91+
print()
8792
obs.append(dict(case=case_ref, error="FAIL", exporter=expo))
8893

8994
print()

_doc/status/index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,11 @@ what works and what does not with :func:`torch.export.export`.
1010

1111
exported_program_dynamic
1212
patches_coverage
13+
14+
Some PRs in :epkg:`transformers` to keep in mind when it comes to export
15+
a model using a cache or a custom class:
16+
17+
* `Completely rewrite the masking logic for all attentions <https://github.com/huggingface/transformers/pull/37866>`_
18+
* `Fix bugs in DynamicCache <https://github.com/huggingface/transformers/pull/37880>`_
19+
* `Fixes DynamicCache export issues due to control flow and inplace modifications <https://github.com/huggingface/transformers/pull/36652>`_
20+
* `Support tracable dynamicKVcache <https://github.com/huggingface/transformers/pull/36311>`_

onnx_diagnostic/torch_export_patches/eval/__init__.py

Lines changed: 106 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,17 @@ def _make_exporter_export(
185185

186186
if exporter == "export-strict":
187187
try:
188-
exported = torch.export.export(
189-
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
190-
)
188+
if verbose >= 2:
189+
exported = torch.export.export(
190+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
191+
)
192+
else:
193+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
194+
io.StringIO()
195+
):
196+
exported = torch.export.export(
197+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
198+
)
191199
except Exception as e:
192200
if not quiet:
193201
raise
@@ -198,17 +206,33 @@ def _make_exporter_export(
198206
return exported.module()
199207
if exporter in ("export-strict-dec", "export-strict-decall"):
200208
try:
201-
exported = torch.export.export(
202-
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
203-
)
204-
if verbose >= 9:
205-
print("-- graph before decomposition")
206-
print(exported.graph)
207-
exported = (
208-
exported.run_decompositions()
209-
if "decall" in exporter
210-
else exported.run_decompositions({})
211-
)
209+
if verbose >= 2:
210+
exported = torch.export.export(
211+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
212+
)
213+
if verbose >= 9:
214+
print("-- graph before decomposition")
215+
print(exported.graph)
216+
exported = (
217+
exported.run_decompositions()
218+
if "decall" in exporter
219+
else exported.run_decompositions({})
220+
)
221+
else:
222+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
223+
io.StringIO()
224+
):
225+
exported = torch.export.export(
226+
model, inputs, dynamic_shapes=dynamic_shapes, strict=True
227+
)
228+
if verbose >= 9:
229+
print("-- graph before decomposition")
230+
print(exported.graph)
231+
exported = (
232+
exported.run_decompositions()
233+
if "decall" in exporter
234+
else exported.run_decompositions({})
235+
)
212236
except Exception as e:
213237
if not quiet:
214238
raise
@@ -219,9 +243,17 @@ def _make_exporter_export(
219243
return exported.module()
220244
if exporter == "export-nostrict":
221245
try:
222-
exported = torch.export.export(
223-
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
224-
)
246+
if verbose >= 2:
247+
exported = torch.export.export(
248+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
249+
)
250+
else:
251+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
252+
io.StringIO()
253+
):
254+
exported = torch.export.export(
255+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
256+
)
225257
except Exception as e:
226258
if not quiet:
227259
raise
@@ -232,17 +264,33 @@ def _make_exporter_export(
232264
return exported.module()
233265
if exporter in ("export-nostrict-dec", "export-nostrict-decall"):
234266
try:
235-
exported = torch.export.export(
236-
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
237-
)
238-
if verbose >= 9:
239-
print("-- graph before decomposition")
240-
print(exported.graph)
241-
exported = (
242-
exported.run_decompositions()
243-
if "decall" in exporter
244-
else exported.run_decompositions({})
245-
)
267+
if verbose >= 2:
268+
exported = torch.export.export(
269+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
270+
)
271+
if verbose >= 9:
272+
print("-- graph before decomposition")
273+
print(exported.graph)
274+
exported = (
275+
exported.run_decompositions()
276+
if "decall" in exporter
277+
else exported.run_decompositions({})
278+
)
279+
else:
280+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
281+
io.StringIO()
282+
):
283+
exported = torch.export.export(
284+
model, inputs, dynamic_shapes=dynamic_shapes, strict=False
285+
)
286+
if verbose >= 9:
287+
print("-- graph before decomposition")
288+
print(exported.graph)
289+
exported = (
290+
exported.run_decompositions()
291+
if "decall" in exporter
292+
else exported.run_decompositions({})
293+
)
246294
except Exception as e:
247295
if not quiet:
248296
raise
@@ -255,8 +303,15 @@ def _make_exporter_export(
255303
from experimental_experiment.torch_interpreter.tracing import CustomTracer
256304

257305
try:
258-
graph = CustomTracer().trace(model)
259-
mod = torch.fx.GraphModule(model, graph)
306+
if verbose >= 2:
307+
graph = CustomTracer().trace(model)
308+
mod = torch.fx.GraphModule(model, graph)
309+
else:
310+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
311+
io.StringIO()
312+
):
313+
graph = CustomTracer().trace(model)
314+
mod = torch.fx.GraphModule(model, graph)
260315
except Exception as e:
261316
if not quiet:
262317
raise
@@ -289,13 +344,25 @@ def _make_exporter_onnx(
289344
if "-dec" in exporter:
290345
opts["decomposition_table"] = "all" if "-decall" in exporter else "default"
291346
try:
292-
onx, builder = to_onnx(
293-
model,
294-
inputs,
295-
dynamic_shapes=dynamic_shapes,
296-
export_options=ExportOptions(**opts),
297-
return_builder=True,
298-
)
347+
if verbose >= 2:
348+
onx, builder = to_onnx(
349+
model,
350+
inputs,
351+
dynamic_shapes=dynamic_shapes,
352+
export_options=ExportOptions(**opts),
353+
return_builder=True,
354+
)
355+
else:
356+
with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr(
357+
io.StringIO()
358+
):
359+
onx, builder = to_onnx(
360+
model,
361+
inputs,
362+
dynamic_shapes=dynamic_shapes,
363+
export_options=ExportOptions(**opts),
364+
return_builder=True,
365+
)
299366
except Exception as e:
300367
if not quiet:
301368
raise RuntimeError(
@@ -306,6 +373,7 @@ def _make_exporter_onnx(
306373
) from e
307374
return dict(error=str(e), success=0, error_step="export")
308375
return onx, builder
376+
309377
if exporter == "dynamo":
310378
import torch
311379

@@ -338,6 +406,7 @@ def _make_exporter_onnx(
338406
) from e
339407
return dict(error=str(e), success=0, error_step="export")
340408
return onx, None
409+
341410
if exporter == "dynamo-ir":
342411
import torch
343412

0 commit comments

Comments
 (0)