1515
1616class FuseConsecutiveTranspose (ExportPass ):
1717 """
18- This pass fuses consecutive transpose / permute into one to reduce runtime
19- overhead
18+ This pass fuses consecutive transpose / permute into one or none to reduce runtime
19+ overhead.
20+ To simplify the fuse logic, we ensure each permute node's output has at most 1 permute node
21+ by cloning transpose.
22+ Example:
23+ Before clone transpose:
24+ relu -> permute1 ─> permute2
25+ |──────> permute3
26+
27+ After clone transpose:
28+ relu ─> permute1 ──────> permute2
29+ |───> permute4(new) ─> permute3
2030 """
2131
2232 def __init__ (self ):
@@ -27,54 +37,81 @@ def __init__(self):
2737 self .visited = set ()
2838 self .nodes = []
2939
40+ def _clone_transpose (
41+ self , graph_module : torch .fx .GraphModule
42+ ) -> torch .fx .GraphModule :
43+ graph = graph_module .graph
44+ for n in graph_module .graph .nodes :
45+ if n .target in self .op_map :
46+ users = [user for user in list (n .users ) if user .target in self .op_map ]
47+ if len (users ) > 1 :
48+ for i in range (1 , len (users )):
49+ with graph .inserting_after (n ):
50+ clone_permute_node = graph .create_node (
51+ "call_function" ,
52+ exir_ops .edge .aten .permute_copy .default ,
53+ (n .args [0 ], n .args [1 ]),
54+ )
55+ clone_permute_node .meta = n .meta
56+ users [i ].replace_input_with (n , clone_permute_node )
57+
58+ def _is_dispensable (self , axis_order ):
59+ for index , value in enumerate (axis_order ):
60+ if index != value :
61+ return False
62+ return True
63+
3064 def _traverse (self , node ):
3165 if node in self .visited or node .target not in self .op_map :
3266 return
3367
3468 self .nodes .append (node )
3569 self .visited .add (node )
3670 next_users = [n for n in list (node .users ) if n .target in self .op_map ]
71+
72+ assert (
73+ len (next_users ) <= 1
74+ ), "Each permute node should have at most 1 permute output node after _clone_transpose"
3775 if not next_users :
3876 return
39-
40- if len (next_users ) == 1 :
41- self ._traverse (list (node .users )[0 ])
4277 else :
43- raise NotImplementedError (
44- f"Check the node { node } , wich encounter mutilple permute output case"
45- )
78+ self ._traverse (list (node .users )[0 ])
4679
4780 def _fuse (self , graph_module : torch .fx .GraphModule ) -> torch .fx .GraphModule :
4881 graph = graph_module .graph
4982 for n in graph_module .graph .nodes :
5083 self ._traverse (n )
5184 if len (self .nodes ) > 1 :
52- permute_order = []
5385 input_node , output_node = self .nodes [0 ].args [0 ], self .nodes [- 1 ]
5486 input_shape = input_node .meta ["val" ].shape
5587 axis_order = torch .arange (len (input_shape )).tolist ()
5688 for node in self .nodes :
57- permute_order .append (node .args [1 ])
5889 axis_order = [axis_order [i ] for i in node .args [1 ]]
59- with graph .inserting_after (input_node ):
60- permute_op = exir_ops .edge .aten .permute_copy .default
61- permute_node = graph .create_node (
62- "call_function" , permute_op , (input_node , axis_order )
63- )
64- users = output_node .users .copy ()
65- for user in users :
66- user .replace_input_with (output_node , permute_node )
67-
68- # copy metadata
69- permute_node .meta = output_node .meta
70- # Without "qnn_permute", we might obtain wrong input shape
71- if [pn .meta .get (QCOM_INSERTED_PERMUTE ) for pn in self .nodes ]:
72- permute_node .meta [QCOM_INSERTED_PERMUTE ] = True
90+ # If axis order is just [0,1,2,3], we ignore permute node
91+ if self ._is_dispensable (axis_order ):
92+ for user in output_node .users .copy ():
93+ user .replace_input_with (output_node , n .args [0 ])
94+ else :
95+ with graph .inserting_after (input_node ):
96+ permute_op = exir_ops .edge .aten .permute_copy .default
97+ permute_node = graph .create_node (
98+ "call_function" , permute_op , (input_node , axis_order )
99+ )
100+ users = output_node .users .copy ()
101+ for user in users :
102+ user .replace_input_with (output_node , permute_node )
103+
104+ # copy metadata
105+ permute_node .meta = output_node .meta
106+ # Without "qnn_permute", we might obtain wrong input shape
107+ if [pn .meta .get (QCOM_INSERTED_PERMUTE ) for pn in self .nodes ]:
108+ permute_node .meta [QCOM_INSERTED_PERMUTE ] = True
73109
74110 # clear current stack
75111 self .nodes = []
76112
77113 def call (self , graph_module : torch .fx .GraphModule ):
114+ self ._clone_transpose (graph_module )
78115 self ._fuse (graph_module )
79116 graph_module .recompile ()
80117 dead_code_elimination_pass (graph_module )
0 commit comments