Skip to content

Commit b476c7e

Browse files
committed
add torch.cond
1 parent d9660ce commit b476c7e

File tree

3 files changed

+109
-0
lines changed

3 files changed

+109
-0
lines changed

beginner_source/onnx/README.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@ ONNX
1212
3. onnx_registry_tutorial.py
1313
Extending the ONNX Registry
1414
https://pytorch.org/tutorials/beginner/onnx/onnx_registry_tutorial.html
15+
16+
4. export_control_flow_model_to_onnx_tutorial.py
17+
Export a Pytorch model with a test to ONNX
18+
https://pytorch.org/tutorials/beginner/onnx/export_control_flow_model_to_onnx_tutorial.html
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
**Introduction to ONNX** ||
4+
`Exporting a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_ ||
5+
`Extending the ONNX Registry <onnx_registry_tutorial.html>`_
6+
7+
Export a control flow model to ONNX
8+
==========================================
9+
10+
**Author**: `Xavier Dupré <https://github.com/xadupre>`_.
11+
12+
Tests cannot be exported into ONNX unless they refactored
13+
to use :func:`torch.cond`. Let's start with a simple model
14+
implementing a test.
15+
"""
16+
17+
from onnx.printer import to_text
18+
import torch
19+
20+
class ForwardWithControlFlowTest(torch.nn.Module):
21+
def forward(self, x):
22+
if x.sum():
23+
return x * 2
24+
return -x
25+
26+
27+
class ModelWithControlFlowTest(torch.nn.Module):
28+
def __init__(self):
29+
super().__init__()
30+
self.mlp = torch.nn.Sequential(
31+
torch.nn.Linear(3, 2),
32+
torch.nn.Linear(2, 1),
33+
ForwardWithControlFlowTest(),
34+
)
35+
36+
def forward(self, x):
37+
out = self.mlp(x)
38+
return out
39+
40+
41+
model = ModelWithControlFlowTest()
42+
43+
# %%
44+
# Let's check it runs.
45+
x = torch.randn(3)
46+
model(x)
47+
48+
# %%
49+
# As expected, it does not export.
50+
try:
51+
torch.export.export(model, (x,))
52+
raise AssertionError("This export should failed unless pytorch now supports this model.")
53+
except Exception as e:
54+
print(e)
55+
56+
# %%
57+
# It does export with :func:`torch.onnx.export` because
58+
# it uses JIT to trace the execution.
59+
# But the model is not exactly the same as the initial model.
60+
ep = torch.onnx.export(model, (x,), dynamo=True)
61+
print(to_text(ep.model_proto))
62+
63+
64+
# %%
65+
# Suggested Patch
66+
# +++++++++++++++
67+
#
68+
# Let's avoid the graph break by replacing the forward.
69+
70+
71+
def new_forward(x):
72+
def identity2(x):
73+
return x * 2
74+
75+
def neg(x):
76+
return -x
77+
78+
return torch.cond(x.sum() > 0, identity2, neg, (x,))
79+
80+
81+
print("the list of submodules")
82+
for name, mod in model.named_modules():
83+
print(name, type(mod))
84+
if isinstance(mod, ForwardWithControlFlowTest):
85+
mod.forward = new_forward
86+
87+
# %%
88+
# Let's see what the fx graph looks like.
89+
90+
print(torch.export.export(model, (x,)).graph)
91+
92+
# %%
93+
# Let's export again.
94+
95+
ep = torch.onnx.export(model, (x,), dynamo=True)
96+
print(to_text(ep.model_proto))
97+
98+
99+
# %%
100+
# Let's optimize to see a small model.
101+
102+
ep = torch.onnx.export(model, (x,), dynamo=True)
103+
ep.optimize()
104+
print(to_text(ep.model_proto))

beginner_source/onnx/onnx_toc.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
| 1. `Exporting a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_
22
| 2. `Extending the ONNX registry <onnx_registry_tutorial.html>`_
3+
| 3. `Export a Pytorch model with a test to ONNX <export_control_flow_model_to_onnx_tutorial>`_

0 commit comments

Comments
 (0)