Skip to content

Commit 2534137

Browse files
committed
reformatting
1 parent 0afd428 commit 2534137

File tree

2 files changed

+90
-24
lines changed

2 files changed

+90
-24
lines changed

beginner_source/onnx/export_control_flow_model_to_onnx_tutorial.py

Lines changed: 90 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,36 @@
99
==========================================
1010
1111
**Author**: `Xavier Dupré <https://github.com/xadupre>`_.
12-
13-
Conditional logic cannot be exported into ONNX unless they refactored
14-
to use :func:`torch.cond`. Let's start with a simple model
15-
implementing a test.
1612
"""
1713

14+
15+
###############################################################################
16+
# Overview
17+
# --------
18+
#
19+
# This tutorial demonstrates how to handle control flow logic while exporting
20+
# a PyTorch model to ONNX. It highlights the challenges of exporting
21+
# conditional statements directly and provides solutions to circumvent them.
22+
#
23+
# Conditional logic cannot be exported into ONNX unless they refactored
24+
# to use :func:`torch.cond`. Let's start with a simple model
25+
# implementing a test.
26+
1827
import torch
1928

29+
###############################################################################
30+
# Define the Models
31+
# --------
32+
#
33+
# Two models are defined:
34+
#
35+
# ForwardWithControlFlowTest: A model with a forward method containing an
36+
# if-else conditional.
37+
#
38+
# ModelWithControlFlowTest: A model that incorporates ForwardWithControlFlowTest
39+
# as part of a simple multi-layer perceptron (MLP). The models are tested with
40+
# a random input tensor to confirm they execute as expected.
41+
2042
class ForwardWithControlFlowTest(torch.nn.Module):
2143
def forward(self, x):
2244
if x.sum():
@@ -40,33 +62,55 @@ def forward(self, x):
4062

4163
model = ModelWithControlFlowTest()
4264

43-
# %%
44-
# Let's check it runs.
65+
66+
###############################################################################
67+
# Exporting the Model: First Attempt
68+
# --------
69+
#
70+
# Exporting this model using torch.export.export fails because the control
71+
# flow logic in the forward pass creates a graph break that the exporter cannot
72+
# handle. This behavior is expected, as conditional logic not written using
73+
# torch.cond is unsupported.
74+
#
75+
# A try-except block is used to capture the expected failure during the export
76+
# process. If the export unexpectedly succeeds, an AssertionError is raised.
77+
4578
x = torch.randn(3)
4679
model(x)
4780

48-
# %%
49-
# As expected, it does not export.
5081
try:
5182
torch.export.export(model, (x,), strict=False)
52-
raise AssertionError("This export should failed unless pytorch now supports this model.")
83+
raise AssertionError("This export should failed unless PyTorch now supports this model.")
5384
except Exception as e:
5485
print(e)
5586

56-
# %%
57-
# It does export with :func:`torch.onnx.export` because
58-
# the exporter falls back to use JIT tracing as the graph capturing strategy.
59-
# But the model is not exactly the same as the initial model.
87+
###############################################################################
88+
# Using torch.onnx.export with JIT Tracing
89+
# --------
90+
#
91+
# When exporting the model using torch.onnx.export with the dynamo=True
92+
# argument, the exporter defaults to using JIT tracing. This fallback allows
93+
# the model to export, but the resulting ONNX graph may not faithfully represent
94+
# the original model logic due to the limitations of tracing.
95+
96+
6097
onnx_program = torch.onnx.export(model, (x,), dynamo=True)
6198
print(onnx_program.model)
6299

63100

64-
# %%
65-
# Suggested Patch
66-
# +++++++++++++++
101+
###############################################################################
102+
# Suggested Patch: Refactoring with torch.cond
103+
# --------
67104
#
68-
# Let's avoid the graph break by replacing the forward.
69-
105+
# To make the control flow exportable, the tutorial demonstrates replacing the
106+
# forward method in ForwardWithControlFlowTest with a refactored version that
107+
# uses torch.cond.
108+
#
109+
# Details of the Refactoring:
110+
#
111+
# Two helper functions (identity2 and neg) represent the branches of the conditional logic:
112+
# * torch.cond is used to specify the condition and the two branches along with the input arguments.
113+
# * The updated forward method is then dynamically assigned to the ForwardWithControlFlowTest instance within the model. A list of submodules is printed to confirm the replacement.
70114

71115
def new_forward(x):
72116
def identity2(x):
@@ -84,21 +128,44 @@ def neg(x):
84128
if isinstance(mod, ForwardWithControlFlowTest):
85129
mod.forward = new_forward
86130

87-
# %%
131+
###############################################################################
88132
# Let's see what the fx graph looks like.
89133

90134
print(torch.export.export(model, (x,), strict=False))
91135

92-
# %%
136+
###############################################################################
93137
# Let's export again.
94138

95139
onnx_program = torch.onnx.export(model, (x,), dynamo=True)
96140
print(onnx_program.model)
97141

98142

99-
# %%
143+
###############################################################################
100144
# We can optimize the model and get rid of the model local functions created to capture the control flow branches.
101145

102-
onnx_program = torch.onnx.export(model, (x,), dynamo=True)
103146
onnx_program.optimize()
104-
print(onnx_program.model)
147+
print(onnx_program.model)
148+
149+
###############################################################################
150+
# Conclusion
151+
# --------
152+
# This tutorial demonstrates the challenges of exporting models with conditional
153+
# logic to ONNX and presents a practical solution using torch.cond.
154+
# While the default exporters may fail or produce imperfect graphs, refactoring the
155+
# model's logic ensures compatibility and generates a faithful ONNX representation.
156+
#
157+
# By understanding these techniques, we can overcome common pitfalls when
158+
# working with control flow in PyTorch models and ensure smooth integration with ONNX workflows.
159+
#
160+
# Further reading
161+
# ---------------
162+
#
163+
# The list below refers to tutorials that ranges from basic examples to advanced scenarios,
164+
# not necessarily in the order they are listed.
165+
# Feel free to jump directly to specific topics of your interest or
166+
# sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
167+
#
168+
# .. include:: /beginner_source/onnx/onnx_toc.txt
169+
#
170+
# .. toctree::
171+
# :hidden:

beginner_source/onnx/onnx_registry_tutorial.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
# before proceeding.
5959

6060
import torch
61-
import onnxruntime
6261
import onnxscript
6362

6463
# Opset 18 is the standard supported version as of PyTorch 2.6

0 commit comments

Comments
 (0)