99from executorch .exir .pass_base import ExportPass , PassResult
1010from executorch .exir .passes import dead_code_elimination_pass
1111
12- from torch .fx .passes .utils .source_matcher_utils import get_source_partitions
13-
1412
1513class FixedLinearKeepDim (ExportPass ):
1614 """
@@ -24,61 +22,58 @@ def __init__(self):
2422 super (FixedLinearKeepDim , self ).__init__ ()
2523
2624 def _fixed_keep_dim (self , graph_module : torch .fx .GraphModule ):
27- partitions = get_source_partitions (
28- graph_module .graph , [torch .nn .Linear , torch .ops .aten .linear .default ]
29- )
30- for _ , src_partitions in partitions .items ():
31- for src_partition in src_partitions :
32- linear_node = [
33- n for n in src_partition .nodes if n .target == self .linear
34- ][0 ]
35- input_node = linear_node .args [0 ]
36- # Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node
37- # TODO: Find a more general conditional statement.
38- linear_output = linear_node .meta ["val" ]
39- if linear_output .dim () >= 3 :
40- with graph_module .graph .inserting_after (input_node ):
41- input_users = list (input_node .users .keys ())
42- input_tensor = input_node .meta ["val" ]
43- squeeze_dim = (- 1 , input_tensor .shape [- 1 ])
44- squeeze_node = graph_module .graph .create_node (
45- "call_function" ,
46- self .view_copy ,
47- (
48- input_node ,
49- squeeze_dim ,
50- ),
51- )
52- # meta needs to be copied elementwisely for fake-tensor
53- # to be updated correctly and not affect meta of input_node
54- for k , v in input_node .meta .items ():
55- squeeze_node .meta [k ] = v
56- squeeze_node .meta ["val" ] = input_tensor .reshape (squeeze_dim )
57- for user in input_users :
58- if user == linear_node :
59- user .replace_input_with (input_node , squeeze_node )
25+ for node in graph_module .graph .nodes :
26+ if node .target != self .linear :
27+ continue
28+
29+ linear_node = node
30+ input_node = linear_node .args [0 ]
31+ # Since QNN has no keep dims for linear op, we will need to add squeeze and unsqueeze around linear node
32+ # TODO: Find a more general conditional statement.
33+ linear_output = linear_node .meta ["val" ]
34+ if linear_output .dim () >= 3 :
35+ with graph_module .graph .inserting_after (input_node ):
36+ input_users = list (input_node .users .keys ())
37+ input_tensor = input_node .meta ["val" ]
38+ squeeze_dim = (- 1 , input_tensor .shape [- 1 ])
39+ squeeze_node = graph_module .graph .create_node (
40+ "call_function" ,
41+ self .view_copy ,
42+ (
43+ input_node ,
44+ squeeze_dim ,
45+ ),
46+ )
47+ # meta needs to be copied elementwisely for fake-tensor
48+ # to be updated correctly and not affect meta of input_node
49+ for k , v in input_node .meta .items ():
50+ squeeze_node .meta [k ] = v
51+ squeeze_node .meta ["val" ] = input_tensor .reshape (squeeze_dim )
52+ for user in input_users :
53+ if user == linear_node :
54+ user .replace_input_with (input_node , squeeze_node )
6055
61- with graph_module .graph .inserting_after (linear_node ):
62- output_users = list (linear_node .users .keys ())
63- unsqueeze_dim = linear_output .shape
64- unsqueeze_node = graph_module .graph .create_node (
65- "call_function" ,
66- self .view_copy ,
67- (
68- linear_node ,
69- unsqueeze_dim ,
70- ),
71- )
72- # meta needs to be copied elementwisely for fake-tensor
73- # to be updated correctly and not affect meta of unsqueeze_node
74- for k , v in linear_node .meta .items ():
75- unsqueeze_node .meta [k ] = v
76- # update linear node's shape
77- linear_node .meta ["val" ] = linear_output .reshape (
78- (squeeze_node .meta ["val" ].shape [0 ], linear_output .shape [- 1 ])
79- )
80- for user in output_users :
81- user .replace_input_with (linear_node , unsqueeze_node )
56+ with graph_module .graph .inserting_after (linear_node ):
57+ output_users = list (linear_node .users .keys ())
58+ unsqueeze_dim = linear_output .shape
59+ unsqueeze_node = graph_module .graph .create_node (
60+ "call_function" ,
61+ self .view_copy ,
62+ (
63+ linear_node ,
64+ unsqueeze_dim ,
65+ ),
66+ )
67+ # meta needs to be copied elementwisely for fake-tensor
68+ # to be updated correctly and not affect meta of unsqueeze_node
69+ for k , v in linear_node .meta .items ():
70+ unsqueeze_node .meta [k ] = v
71+ # update linear node's shape
72+ linear_node .meta ["val" ] = linear_output .reshape (
73+ (squeeze_node .meta ["val" ].shape [0 ], linear_output .shape [- 1 ])
74+ )
75+ for user in output_users :
76+ user .replace_input_with (linear_node , unsqueeze_node )
8277
8378 def call (self , graph_module : torch .fx .GraphModule ):
8479 self ._fixed_keep_dim (graph_module )
0 commit comments