Skip to content

Commit 37ce239

Browse files
committed
add one more example
1 parent 6bde8a7 commit 37ce239

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

_doc/examples/plot_export_cond.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""
2+
.. _l-plot-export-cond:
3+
4+
Export a model with a control flow (If)
5+
=======================================
6+
7+
Control flow cannot be exported with a change.
8+
The code of the model can be changed or patched
9+
to introduce function :func:`torch.cond`.
10+
11+
A model with a test
12+
+++++++++++++++++++
13+
"""
14+
15+
import torch
16+
17+
18+
# %%
19+
# We define a model with a control flow (-> graph break)
20+
21+
22+
class ForwardWithControlFlowTest(torch.nn.Module):
23+
def forward(self, x):
24+
if x.sum():
25+
return x * 2
26+
return -x
27+
28+
29+
class ModelWithControlFlow(torch.nn.Module):
30+
def __init__(self):
31+
super().__init__()
32+
self.mlp = torch.nn.Sequential(
33+
torch.nn.Linear(3, 2),
34+
torch.nn.Linear(2, 1),
35+
ForwardWithControlFlowTest(),
36+
)
37+
38+
def forward(self, x):
39+
out = self.mlp(x)
40+
return out
41+
42+
43+
model = ModelWithControlFlow()
44+
45+
# %%
46+
# Let's check it runs.
47+
x = torch.randn(1, 3)
48+
model(x)
49+
50+
# %%
51+
# As expected, it does not export.
52+
try:
53+
torch.export.export(model, (x,))
54+
raise AssertionError("This export should failed unless pytorch now supports this model.")
55+
except Exception as e:
56+
print(e)
57+
58+
59+
# %%
60+
# Suggested Patch
61+
# +++++++++++++++
62+
#
63+
# Let's avoid the graph break by replacing the forward.
64+
65+
66+
def new_forward(x):
67+
def identity2(x):
68+
return x * 2
69+
70+
def neg(x):
71+
return -x
72+
73+
return torch.cond(x.sum() > 0, identity2, neg, (x,))
74+
75+
76+
print("the list of submodules")
77+
for name, mod in model.named_modules():
78+
print(name, type(mod))
79+
if isinstance(mod, ForwardWithControlFlowTest):
80+
mod.forward = new_forward
81+
82+
# %%
83+
# Let's see what the fx graph looks like.
84+
85+
ep = torch.export.export(model, (x,))
86+
print(ep.graph)

_doc/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ Source are `sdpython/onnx-diagnostic
4747

4848
**Enlightening Examples**
4949

50+
* :ref:`l-plot-export-cond`
5051
* :ref:`l-plot-sxport-with-dynamio-shapes-auto`
5152
* :ref:`l-plot-tiny-llm-export`
5253
* :ref:`l-plot-failing-model-extract`

0 commit comments

Comments
 (0)