1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ **Introduction to ONNX** ||
4
+ `Exporting a PyTorch model to ONNX <export_simple_model_to_onnx_tutorial.html>`_ ||
5
+ `Extending the ONNX Registry <onnx_registry_tutorial.html>`_
6
+
7
+ Export a control flow model to ONNX
8
+ ==========================================
9
+
10
+ **Author**: `Xavier Dupré <https://github.com/xadupre>`_.
11
+
12
+ Tests cannot be exported into ONNX unless they refactored
13
+ to use :func:`torch.cond`. Let's start with a simple model
14
+ implementing a test.
15
+ """
16
+
17
+ from onnx .printer import to_text
18
+ import torch
19
+
20
+ class ForwardWithControlFlowTest (torch .nn .Module ):
21
+ def forward (self , x ):
22
+ if x .sum ():
23
+ return x * 2
24
+ return - x
25
+
26
+
27
+ class ModelWithControlFlowTest (torch .nn .Module ):
28
+ def __init__ (self ):
29
+ super ().__init__ ()
30
+ self .mlp = torch .nn .Sequential (
31
+ torch .nn .Linear (3 , 2 ),
32
+ torch .nn .Linear (2 , 1 ),
33
+ ForwardWithControlFlowTest (),
34
+ )
35
+
36
+ def forward (self , x ):
37
+ out = self .mlp (x )
38
+ return out
39
+
40
+
41
+ model = ModelWithControlFlowTest ()
42
+
43
+ # %%
44
+ # Let's check it runs.
45
+ x = torch .randn (3 )
46
+ model (x )
47
+
48
+ # %%
49
+ # As expected, it does not export.
50
+ try :
51
+ torch .export .export (model , (x ,))
52
+ raise AssertionError ("This export should failed unless pytorch now supports this model." )
53
+ except Exception as e :
54
+ print (e )
55
+
56
+ # %%
57
+ # It does export with :func:`torch.onnx.export` because
58
+ # it uses JIT to trace the execution.
59
+ # But the model is not exactly the same as the initial model.
60
+ ep = torch .onnx .export (model , (x ,), dynamo = True )
61
+ print (to_text (ep .model_proto ))
62
+
63
+
64
+ # %%
65
+ # Suggested Patch
66
+ # +++++++++++++++
67
+ #
68
+ # Let's avoid the graph break by replacing the forward.
69
+
70
+
71
+ def new_forward (x ):
72
+ def identity2 (x ):
73
+ return x * 2
74
+
75
+ def neg (x ):
76
+ return - x
77
+
78
+ return torch .cond (x .sum () > 0 , identity2 , neg , (x ,))
79
+
80
+
81
+ print ("the list of submodules" )
82
+ for name , mod in model .named_modules ():
83
+ print (name , type (mod ))
84
+ if isinstance (mod , ForwardWithControlFlowTest ):
85
+ mod .forward = new_forward
86
+
87
+ # %%
88
+ # Let's see what the fx graph looks like.
89
+
90
+ print (torch .export .export (model , (x ,)).graph )
91
+
92
+ # %%
93
+ # Let's export again.
94
+
95
+ ep = torch .onnx .export (model , (x ,), dynamo = True )
96
+ print (to_text (ep .model_proto ))
97
+
98
+
99
+ # %%
100
+ # Let's optimize to see a small model.
101
+
102
+ ep = torch .onnx .export (model , (x ,), dynamo = True )
103
+ ep .optimize ()
104
+ print (to_text (ep .model_proto ))
0 commit comments