1212from executorch .backends .arm ._passes .arm_pass_utils import (
1313 create_node ,
1414 get_first_fake_tensor ,
15+ get_node_arg ,
1516 insert_q_dq_pair ,
1617)
1718from executorch .backends .arm .tosa_quant_utils import dq_op , q_op , register_passable_op
@@ -83,14 +84,48 @@ def is_weight_node_for_depthwise_conv2d(self, node: torch.fx.Node):
8384
8485 return False
8586
86- def insert_input_transpose (self , node , input_node , graph_module ):
87+ @staticmethod
88+ def memory_format_differs (shape ):
89+ """Returns true if the shape will have a different memory layout in NCHW and NHWC format"""
90+ if len (shape ) >= 4 :
91+ C = shape [1 ]
92+ H = shape [2 ]
93+ W = shape [3 ]
94+ elif len (shape ) == 3 :
95+ C = shape [0 ]
96+ H = shape [1 ]
97+ W = shape [2 ]
98+ if len (shape ) <= 2 :
99+ return False
100+
101+ return C > 1 and (H > 1 or W > 1 )
102+
103+ @staticmethod
104+ def is_channel_reshape (input_shape , output_shape ):
105+ """Returns true if the reshape changes the channel dimension"""
106+ if not len (input_shape ) == len (output_shape ) == 4 :
107+ return False
108+
109+ C_old = input_shape [1 ]
110+ C_new = output_shape [1 ]
111+
112+ N_new = output_shape [0 ]
113+ N_old = input_shape [0 ]
114+
115+ return (N_old != N_new ) or (C_old != C_new )
116+
117+ @staticmethod
118+ def insert_input_transpose (node , input_node , graph_module ):
87119 quantize = input_node .target == dq_op
88120 q_params = input_node .args [1 :] if quantize else None
89121 with graph_module .graph .inserting_before (node ):
90122 permute_node = create_node (
91123 graph_module .graph ,
92124 torch .ops .passthrough_to_tosa ._transpose ,
93- args = (input_node , list (self .NHWC_inverse_order )),
125+ args = (
126+ input_node ,
127+ list (AnnotateChannelsLastDimOrder .NHWC_inverse_order ),
128+ ),
94129 quantize = quantize ,
95130 q_params = q_params ,
96131 )
@@ -100,14 +135,17 @@ def insert_input_transpose(self, node, input_node, graph_module):
100135 range (len (input_node .meta ["val" ].size ()))
101136 )
102137
103- def insert_output_transpose (self , node , graph_module ):
138+ @staticmethod
139+ def insert_output_transpose (node , graph_module ):
104140 with graph_module .graph .inserting_after (node ):
105141 permute_node = create_node (
106142 graph_module .graph ,
107143 torch .ops .passthrough_to_tosa ._transpose ,
108- args = (node , list (self .NHWC_order )),
144+ args = (node , list (AnnotateChannelsLastDimOrder .NHWC_order )),
145+ )
146+ permute_node .meta ["tosa_dim_order" ] = (
147+ AnnotateChannelsLastDimOrder .NHWC_order
109148 )
110- permute_node .meta ["tosa_dim_order" ] = self .NHWC_order
111149 node .meta ["tosa_dim_order" ] = (0 , 1 , 2 , 3 )
112150 users = [user for user in node .users if user != permute_node ]
113151 for user in users :
@@ -118,54 +156,96 @@ def insert_output_transpose(self, node, graph_module):
118156 q_params = node .args [0 ].args [1 :]
119157 insert_q_dq_pair (graph_module .graph , node , q_params )
120158
159+ @staticmethod
160+ def _insert_squeeze_transpose (
161+ input_shape , output_shape , node , input_node , graph_module
162+ ):
163+ nhwc_to_nhwc = len (input_shape ) == 4 and len (output_shape ) <= 3
164+
165+ if nhwc_to_nhwc and AnnotateChannelsLastDimOrder .memory_format_differs (
166+ input_shape
167+ ):
168+ AnnotateChannelsLastDimOrder .insert_input_transpose (
169+ node , input_node , graph_module
170+ )
171+
172+ @staticmethod
173+ def _insert_unsqueeze_transpose (input_shape , output_shape , node , graph_module ):
174+ nchw_to_nhwc = len (input_shape ) == 3 and len (output_shape ) == 4
175+ if nchw_to_nhwc and AnnotateChannelsLastDimOrder .memory_format_differs (
176+ output_shape
177+ ):
178+ AnnotateChannelsLastDimOrder .insert_output_transpose (node , graph_module )
179+
180+ @staticmethod
181+ def _insert_view_transpose (
182+ input_shape , output_shape , node , input_node , graph_module
183+ ):
184+ nchw_to_nhwc = len (input_shape ) < 4 and len (output_shape ) == 4
185+ nhwc_to_nchw = len (input_shape ) == 4 and len (output_shape ) < 4
186+ channel_reshape = AnnotateChannelsLastDimOrder .is_channel_reshape (
187+ output_shape , input_shape
188+ )
189+
190+ if (
191+ channel_reshape or nhwc_to_nchw
192+ ) and AnnotateChannelsLastDimOrder .memory_format_differs (input_shape ):
193+ AnnotateChannelsLastDimOrder .insert_input_transpose (
194+ node , input_node , graph_module
195+ )
196+ if (
197+ channel_reshape or nchw_to_nhwc
198+ ) and AnnotateChannelsLastDimOrder .memory_format_differs (output_shape ):
199+ AnnotateChannelsLastDimOrder .insert_output_transpose (node , graph_module )
200+
121201 def insert_tosa_transposes (self , graph_module : torch .fx .GraphModule ):
122202 """
123- Reshape operations are not equivalent in NCHW and NHWC.
124- To get around this, transposes need to be added if the previous or new shape
125- fulfil the following condition:
126- C > 1 and (H or W > 1)
127-
128- This is relevant for the following operations;
129- squeeze: 4D -> 3D
130- unsqueeze: <4D -> 4D
131- view: <4D -> 4D
132- view: 4D -> <4D
133- view: 4D -> 4D
134- """
135-
136- def transpose_condition (shape ):
137- if len (shape ) != 4 :
138- return False
139- C = shape [1 ]
140- H = shape [2 ]
141- W = shape [3 ]
142- return C > 1 and (H > 1 or W > 1 )
203+ Transposes are needed for operators transforming the input to a different rank, as 4D-tensors are assumed to be in NHWC-format, whereas all other are in NCHW format.
204+ This is relevant for the following cases:
205+ - squeeze: 4D -> <4D
206+ - unsqueeze: 3D -> 4D
207+ - view: <4D -> 4D
208+ - view: 4D -> <4D
209+ Additionally, a 4D->4D view operation acting on the channel dimension currently needs to be performed in NCHW format, leadning to one extra input and output transpose for this case.
143210
211+ Transposes can be avoided for shapes where there is no difference in actual memory, e.g for
212+ - H == W == 1
213+ - C == 1
214+ - 1D/2D tensors
215+ """
144216 for node in graph_module .graph .nodes :
145217 if node .op != "call_function" :
146218 continue
219+
147220 if node .target == exir_ops .edge .aten .squeeze_copy .dims :
148221 input_node = node .args [0 ]
149222 input_shape = input_node .meta ["val" ].shape
150- if transpose_condition (input_shape ):
151- self .insert_input_transpose (node , input_node , graph_module )
223+ output_shape = node .meta ["val" ].shape
224+
225+ self ._insert_squeeze_transpose (
226+ input_shape , output_shape , node , input_node , graph_module
227+ )
152228
153229 elif node .target == exir_ops .edge .aten .unsqueeze_copy .default :
230+ input_node = get_node_arg (node .args , 0 , default_value = False )
231+ if input_node :
232+ input_shape = input_node .meta ["val" ].shape
233+ else :
234+ input_shape = ()
154235 output_shape = node .meta ["val" ].shape
155- if transpose_condition (output_shape ):
156- self .insert_output_transpose (node , graph_module )
236+
237+ self ._insert_unsqueeze_transpose (
238+ input_shape , output_shape , node , graph_module
239+ )
157240
158241 elif node .target == exir_ops .edge .aten .view_copy .default :
159242 input_node = node .args [0 ]
243+ input_shape = input_node .meta ["val" ].shape
244+ output_shape = node .meta ["val" ].shape
160245
161- old_shape = input_node .meta ["val" ].shape
162- new_shape = node .meta ["val" ].shape
163-
164- if transpose_condition (old_shape ):
165- self .insert_input_transpose (node , input_node , graph_module )
166-
167- if transpose_condition (new_shape ):
168- self .insert_output_transpose (node , graph_module )
246+ self ._insert_view_transpose (
247+ input_shape , output_shape , node , input_node , graph_module
248+ )
169249
170250 def call (self , graph_module : torch .fx .GraphModule ):
171251 for node in graph_module .graph .nodes :
0 commit comments