9
9
==========================================
10
10
11
11
**Author**: `Xavier Dupré <https://github.com/xadupre>`_.
12
-
13
- Conditional logic cannot be exported into ONNX unless they refactored
14
- to use :func:`torch.cond`. Let's start with a simple model
15
- implementing a test.
16
12
"""
17
13
14
+
15
+ ###############################################################################
16
+ # Overview
17
+ # --------
18
+ #
19
+ # This tutorial demonstrates how to handle control flow logic while exporting
20
+ # a PyTorch model to ONNX. It highlights the challenges of exporting
21
+ # conditional statements directly and provides solutions to circumvent them.
22
+ #
23
+ # Conditional logic cannot be exported into ONNX unless they refactored
24
+ # to use :func:`torch.cond`. Let's start with a simple model
25
+ # implementing a test.
26
+
18
27
import torch
19
28
29
+ ###############################################################################
30
+ # Define the Models
31
+ # --------
32
+ #
33
+ # Two models are defined:
34
+ #
35
+ # ForwardWithControlFlowTest: A model with a forward method containing an
36
+ # if-else conditional.
37
+ #
38
+ # ModelWithControlFlowTest: A model that incorporates ForwardWithControlFlowTest
39
+ # as part of a simple multi-layer perceptron (MLP). The models are tested with
40
+ # a random input tensor to confirm they execute as expected.
41
+
20
42
class ForwardWithControlFlowTest (torch .nn .Module ):
21
43
def forward (self , x ):
22
44
if x .sum ():
@@ -40,33 +62,55 @@ def forward(self, x):
40
62
41
63
model = ModelWithControlFlowTest ()
42
64
43
- # %%
44
- # Let's check it runs.
65
+
66
+ ###############################################################################
67
+ # Exporting the Model: First Attempt
68
+ # --------
69
+ #
70
+ # Exporting this model using torch.export.export fails because the control
71
+ # flow logic in the forward pass creates a graph break that the exporter cannot
72
+ # handle. This behavior is expected, as conditional logic not written using
73
+ # torch.cond is unsupported.
74
+ #
75
+ # A try-except block is used to capture the expected failure during the export
76
+ # process. If the export unexpectedly succeeds, an AssertionError is raised.
77
+
45
78
x = torch .randn (3 )
46
79
model (x )
47
80
48
- # %%
49
- # As expected, it does not export.
50
81
try :
51
82
torch .export .export (model , (x ,), strict = False )
52
- raise AssertionError ("This export should failed unless pytorch now supports this model." )
83
+ raise AssertionError ("This export should failed unless PyTorch now supports this model." )
53
84
except Exception as e :
54
85
print (e )
55
86
56
- # %%
57
- # It does export with :func:`torch.onnx.export` because
58
- # the exporter falls back to use JIT tracing as the graph capturing strategy.
59
- # But the model is not exactly the same as the initial model.
87
+ ###############################################################################
88
+ # Using torch.onnx.export with JIT Tracing
89
+ # --------
90
+ #
91
+ # When exporting the model using torch.onnx.export with the dynamo=True
92
+ # argument, the exporter defaults to using JIT tracing. This fallback allows
93
+ # the model to export, but the resulting ONNX graph may not faithfully represent
94
+ # the original model logic due to the limitations of tracing.
95
+
96
+
60
97
onnx_program = torch .onnx .export (model , (x ,), dynamo = True )
61
98
print (onnx_program .model )
62
99
63
100
64
- # %%
65
- # Suggested Patch
66
- # +++++++++++++++
101
+ ###############################################################################
102
+ # Suggested Patch: Refactoring with torch.cond
103
+ # --------
67
104
#
68
- # Let's avoid the graph break by replacing the forward.
69
-
105
+ # To make the control flow exportable, the tutorial demonstrates replacing the
106
+ # forward method in ForwardWithControlFlowTest with a refactored version that
107
+ # uses torch.cond.
108
+ #
109
+ # Details of the Refactoring:
110
+ #
111
+ # Two helper functions (identity2 and neg) represent the branches of the conditional logic:
112
+ # * torch.cond is used to specify the condition and the two branches along with the input arguments.
113
+ # * The updated forward method is then dynamically assigned to the ForwardWithControlFlowTest instance within the model. A list of submodules is printed to confirm the replacement.
70
114
71
115
def new_forward (x ):
72
116
def identity2 (x ):
@@ -84,21 +128,44 @@ def neg(x):
84
128
if isinstance (mod , ForwardWithControlFlowTest ):
85
129
mod .forward = new_forward
86
130
87
- # %%
131
+ ###############################################################################
88
132
# Let's see what the fx graph looks like.
89
133
90
134
print (torch .export .export (model , (x ,), strict = False ))
91
135
92
- # %%
136
+ ###############################################################################
93
137
# Let's export again.
94
138
95
139
onnx_program = torch .onnx .export (model , (x ,), dynamo = True )
96
140
print (onnx_program .model )
97
141
98
142
99
- # %%
143
+ ###############################################################################
100
144
# We can optimize the model and get rid of the model local functions created to capture the control flow branches.
101
145
102
- onnx_program = torch .onnx .export (model , (x ,), dynamo = True )
103
146
onnx_program .optimize ()
104
- print (onnx_program .model )
147
+ print (onnx_program .model )
148
+
149
+ ###############################################################################
150
+ # Conclusion
151
+ # --------
152
+ # This tutorial demonstrates the challenges of exporting models with conditional
153
+ # logic to ONNX and presents a practical solution using torch.cond.
154
+ # While the default exporters may fail or produce imperfect graphs, refactoring the
155
+ # model's logic ensures compatibility and generates a faithful ONNX representation.
156
+ #
157
+ # By understanding these techniques, we can overcome common pitfalls when
158
+ # working with control flow in PyTorch models and ensure smooth integration with ONNX workflows.
159
+ #
160
+ # Further reading
161
+ # ---------------
162
+ #
163
+ # The list below refers to tutorials that ranges from basic examples to advanced scenarios,
164
+ # not necessarily in the order they are listed.
165
+ # Feel free to jump directly to specific topics of your interest or
166
+ # sit tight and have fun going through all of them to learn all there is about the ONNX exporter.
167
+ #
168
+ # .. include:: /beginner_source/onnx/onnx_toc.txt
169
+ #
170
+ # .. toctree::
171
+ # :hidden:
0 commit comments