3333)
3434from executorch .backends .cadence .aot .quantizer .utils import (
3535 check_out_zero_point_is_min_range ,
36+ copy_node_metadata ,
3637 create_zero_bias_int32 ,
3738 find_sequential_partitions_aten ,
3839 get_conv_args ,
@@ -159,6 +160,8 @@ def get_args_and_kwargs_layer_norm(
159160 ),
160161 {"dtype" : torch .float32 },
161162 )
163+ if len (inputs_inputs ) > 0 :
164+ copy_node_metadata (weight , inputs_inputs [0 ])
162165
163166 bias = other_inputs [2 ] if len (other_inputs ) > 2 else None
164167
@@ -171,6 +174,8 @@ def get_args_and_kwargs_layer_norm(
171174 ),
172175 {"dtype" : torch .float32 },
173176 )
177+ if len (inputs_inputs ) > 0 :
178+ copy_node_metadata (bias , inputs_inputs [0 ])
174179
175180 # Make the args and kwargs for the replacement op
176181 args = tuple (inputs_inputs + [scale , zero_point ])
@@ -346,6 +351,8 @@ def get_args_and_kwargs_softmax(
346351 ),
347352 {"dtype" : torch .int32 },
348353 )
354+ if len (inputs_inputs ) > 0 :
355+ copy_node_metadata (mask_tensor , inputs_inputs [0 ])
349356 # Make the scale and zero_point tensors
350357 in_scale = dequants_inputs [0 ].args [1 ]
351358 in_zero_point = dequants_inputs [0 ].args [2 ]
@@ -395,10 +402,13 @@ def get_args_and_kwargs_mixed_w8a32_conv(
395402 torch .ops .aten .permute .default ,
396403 (other_inputs [0 ], [0 , 2 , 1 ]), # NCL -> NLC
397404 )
405+ copy_node_metadata (transposed_inputs , other_inputs [0 ])
406+
398407 transposed_weights = graph_module .graph .call_function (
399408 torch .ops .aten .permute .default ,
400409 (weights_inputs [0 ], [2 , 0 , 1 ]), # NCL -> LNC
401410 )
411+ copy_node_metadata (transposed_weights , weights_inputs [0 ])
402412
403413 args = (
404414 transposed_inputs ,
@@ -582,6 +592,26 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
582592 torch .ops .aten .transpose .int ,
583593 (weights_inputs [0 ], 0 , 1 ),
584594 )
595+ if "val" in weights_inputs [0 ].meta :
596+ original_val = weights_inputs [0 ].meta ["val" ]
597+ fake_mode = original_val .fake_mode
598+ if fake_mode is not None :
599+ with fake_mode :
600+ transposed_val = torch .ops .aten .transpose .int (
601+ original_val , 0 , 1
602+ )
603+ transposed_weights .meta ["val" ] = transposed_val
604+ else :
605+ transposed_shape = list (original_val .shape )
606+ transposed_shape [0 ], transposed_shape [1 ] = (
607+ transposed_shape [1 ],
608+ transposed_shape [0 ],
609+ )
610+ transposed_weights .meta ["val" ] = torch .zeros (
611+ transposed_shape , dtype = original_val .dtype
612+ )
613+ copy_node_metadata (transposed_weights , weights_inputs [0 ])
614+
585615 # Call linear with transposed weight
586616 args , kwargs = get_args_and_kwargs_linear (
587617 graph_module ,
@@ -654,6 +684,19 @@ def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901
654684
655685 legalize_graph (graph_module )
656686 graph_module .graph .eliminate_dead_code ()
687+ nodes_list = list (graph_module .graph .nodes )
688+
689+ if len (nodes_list ) > 0 and nodes_list [- 1 ].op != "output" :
690+ output_nodes = [n for n in nodes_list if n .op == "output" ]
691+ output_arg = output_nodes [0 ].args [0 ]
692+ original_meta = output_nodes [0 ].meta .copy ()
693+
694+ for out_node in output_nodes :
695+ graph_module .graph .erase_node (out_node )
696+
697+ new_output_node = graph_module .graph .output (output_arg )
698+ new_output_node .meta .update (original_meta )
699+
657700 graph_module .recompile ()
658701 return PassResult (graph_module , True )
659702
0 commit comments