1
1
# -*- coding: utf-8 -*-
2
2
"""
3
- ** Introduction to ONNX** ||
3
+ ` Introduction to ONNX <intro_onnx.html>`_ ||
4
4
`Exporting a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_ ||
5
- `Extending the ONNX Registry <onnx_registry_tutorial.html>`_
5
+ `Extending the ONNX exporter operator support <onnx_registry_tutorial.html>`_ ||
6
+ **`Export a model with control flow to ONNX**
6
7
7
- Export a control flow model to ONNX
8
+ Export a model with control flow to ONNX
8
9
==========================================
9
10
10
11
**Author**: `Xavier Dupré <https://github.com/xadupre>`_.
11
12
12
- Tests cannot be exported into ONNX unless they refactored
13
+ Conditional logic cannot be exported into ONNX unless they refactored
13
14
to use :func:`torch.cond`. Let's start with a simple model
14
15
implementing a test.
15
16
"""
16
17
17
- from onnx .printer import to_text
18
18
import torch
19
19
20
20
class ForwardWithControlFlowTest (torch .nn .Module ):
@@ -48,17 +48,17 @@ def forward(self, x):
48
48
# %%
49
49
# As expected, it does not export.
50
50
try :
51
- torch .export .export (model , (x ,))
51
+ torch .export .export (model , (x ,), strict = False )
52
52
raise AssertionError ("This export should failed unless pytorch now supports this model." )
53
53
except Exception as e :
54
54
print (e )
55
55
56
56
# %%
57
57
# It does export with :func:`torch.onnx.export` because
58
- # it uses JIT to trace the execution .
58
+ # the exporter falls back to use JIT tracing as the graph capturing strategy .
59
59
# But the model is not exactly the same as the initial model.
60
- ep = torch .onnx .export (model , (x ,), dynamo = True )
61
- print (to_text ( ep . model_proto ) )
60
+ onnx_program = torch .onnx .export (model , (x ,), dynamo = True )
61
+ print (onnx_program . model )
62
62
63
63
64
64
# %%
@@ -87,18 +87,18 @@ def neg(x):
87
87
# %%
88
88
# Let's see what the fx graph looks like.
89
89
90
- print (torch .export .export (model , (x ,)). graph )
90
+ print (torch .export .export (model , (x ,), strict = False ))
91
91
92
92
# %%
93
93
# Let's export again.
94
94
95
- ep = torch .onnx .export (model , (x ,), dynamo = True )
96
- print (to_text ( ep . model_proto ))
95
+ onnx_program = torch .onnx .export (model , (x ,), dynamo = True )
96
+ print (onnx_program . model )
97
97
98
98
99
- # %%
100
- # Let's optimize to see a small model.
99
+ # %%
100
+ # We can optimize the model and get rid of the model local functions created to capture the control flow branches.
101
101
102
- ep = torch .onnx .export (model , (x ,), dynamo = True )
103
- ep .optimize ()
104
- print (to_text ( ep . model_proto ))
102
+ onnx_program = torch .onnx .export (model , (x ,), dynamo = True )
103
+ onnx_program .optimize ()
104
+ print (onnx_program . model )
0 commit comments