1111import io
1212import logging
1313import os
14+ from collections import defaultdict
1415from typing import Any , Dict , List , Optional , Sequence , Set , TextIO , Type , Union
1516
1617import torch
@@ -1136,25 +1137,16 @@ def keep(op):
11361137
11371138
11381139def _can_skip_using_EDGE_DO_NOT_DECOMP (
1139- partitioner : Dict [ str , List [ Partitioner ]], aten_programs : Dict [ str , ExportedProgram ]
1140+ partitioner : Partitioner , program : ExportedProgram
11401141) -> bool :
11411142 # THe current design of using EDGE_DO_NOT_DECOMP to prevent decomposition
11421143 # has long standing issues. _remove_invalid_ops_for_not_decompose was a band-aid to
11431144 # fix some of the issues, but more issues are coming up over time, including a new issue with SDPA
11441145 # and contiguous views: https://fb.workplace.com/groups/pytorch.edge.users/permalink/1796069037930048/
11451146 # EDGE_DO_NOT_DECOMP is only needed by partitioners that specify check_op_support
11461147 # As a temp fix, we give a more reliable path for backends that do not specify check_op_support
1147- can_skip_using_EDGE_DO_NOT_DECOMP = True
1148- for name , program in aten_programs .items ():
1149- if partitioner is not None :
1150- for curr_partitioner in partitioner .get (name , []):
1151- (
1152- curr_ops_no_decomp ,
1153- check_op_support ,
1154- ) = curr_partitioner .ops_to_not_decompose (program )
1155- if check_op_support is not None :
1156- can_skip_using_EDGE_DO_NOT_DECOMP = False
1157- return can_skip_using_EDGE_DO_NOT_DECOMP
1148+ _ , check_op_support = partitioner .ops_to_not_decompose (program )
1149+ return check_op_support is None
11581150
11591151
11601152def _gen_edge_manager_for_partitioners (
@@ -1177,60 +1169,75 @@ def _gen_edge_manager_for_partitioners(
11771169 on nodes with preserved aten targets. They are then replaces with transformed ops to
11781170 keep them through the second pass of decompositions
11791171 """
1180- can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP (
1181- partitioner , aten_programs
1182- )
1183- ops_set_to_not_decompose_by_program = {}
1172+ ops_set_to_not_decompose_by_program = defaultdict (list )
11841173 edge_programs : Dict [str , ExportedProgram ] = {}
11851174 for name , program in aten_programs .items ():
11861175 # Functionalize program before asking partitioners to preserve ops
11871176 program = program .run_decompositions ({})
11881177
11891178 if partitioner is not None :
1190- # preserve all ops listed by all partitioners first
1191- all_ops_no_decomp = set ()
1192- all_ops_no_decomp_needing_preservation = []
1193- for curr_partitioner in partitioner .get (name , []):
1179+ partitioners_for_program = partitioner .get (name , [])
1180+ final_ops_to_preserve = set ()
1181+
1182+ # Decompose by default if there are no partitioners for the method
1183+ if not partitioners_for_program :
1184+ program = program .run_decompositions (_default_decomposition_table ())
1185+
1186+ # Process each partitioner individually using their specific requirements
1187+ for curr_partitioner in partitioners_for_program :
11941188 curr_ops_no_decomp , _ = curr_partitioner .ops_to_not_decompose (program )
1195- all_ops_no_decomp |= set (curr_ops_no_decomp )
11961189
1197- # If not using the can_skip_using_EDGE_DO_NOT_DECOMP path, we need to remove invalid ops
1198- # Otherwise there will be issues
1199- if not can_skip_using_EDGE_DO_NOT_DECOMP :
1200- all_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1201- list (all_ops_no_decomp )
1202- )
1203- all_ops_no_decomp = set (all_ops_no_decomp )
1204-
1205- # Run default decompositions, except for those in all_ops_no_decomp
1206- table = _default_decomposition_table ()
1207- for op in all_ops_no_decomp :
1208- if table .pop (op , None ) is not None :
1209- all_ops_no_decomp_needing_preservation .append (op )
1210- program = program .run_decompositions (table )
1211-
1212- # Among all the preserved aten ops, use the check_op_fn to do an additional
1213- # check on which ops need to be preserved and which ops need to be decomposed
1214- # Those which are truly preserved will be replaced with transformed ops
1215- if can_skip_using_EDGE_DO_NOT_DECOMP :
1216- ops_set_to_not_decompose_by_program [name ] = (
1217- all_ops_no_decomp_needing_preservation
1218- )
1219- else :
1220- ops_set_to_not_decompose_by_program [name ] = (
1221- _replace_aten_ops_with_transformed_ops (name , program , partitioner )
1222- or []
1190+ # Check if this partitioner can skip using EDGE_DO_NOT_DECOMP
1191+ can_skip_using_edge_do_not_decomp = _can_skip_using_EDGE_DO_NOT_DECOMP (
1192+ curr_partitioner , program
12231193 )
12241194
1225- if not can_skip_using_EDGE_DO_NOT_DECOMP :
1226- program = program .run_decompositions (_default_decomposition_table ())
1227- _restore_transformed_ops_to_aten_ops (program )
1195+ if can_skip_using_edge_do_not_decomp :
1196+ # Preserve all ops in curr_ops_no_decomp from decomposition
1197+ table = _default_decomposition_table ()
1198+ ops_needing_preservation = []
1199+
1200+ for op in curr_ops_no_decomp :
1201+ if table .pop (op , None ) is not None :
1202+ ops_needing_preservation .append (op )
1203+
1204+ program = program .run_decompositions (table )
1205+ final_ops_to_preserve .update (ops_needing_preservation )
1206+ else :
1207+ # EDGE_DO_NOT_DECOMP path for the partitioner
1208+ curr_ops_no_decomp = _remove_invalid_ops_for_not_decompose (
1209+ curr_ops_no_decomp
1210+ )
1211+
1212+ # Apply decompositions with this partitioner's preserved ops
1213+ table = _default_decomposition_table ()
1214+ for op in curr_ops_no_decomp :
1215+ table .pop (op , None )
1216+
1217+ # First pass of decompositions with this partitioner's preserved ops
1218+ program = program .run_decompositions (table )
1219+
1220+ # Filter ops using EDGE_DO_NOT_DECOMP
1221+ temp_partitioner_dict = {name : [curr_partitioner ]}
1222+ preserved_ops = (
1223+ _replace_aten_ops_with_transformed_ops (
1224+ name , program , temp_partitioner_dict
1225+ )
1226+ or []
1227+ )
1228+ final_ops_to_preserve .update (preserved_ops )
1229+
1230+ # Second pass of decompositions with this partitioner's preserved ops after filtering
1231+ program = program .run_decompositions (_default_decomposition_table ())
1232+
1233+ # Restore ops from edge_no_decomp_namespace to aten ops
1234+ _restore_transformed_ops_to_aten_ops (program )
1235+ ops_set_to_not_decompose_by_program [name ].extend (final_ops_to_preserve )
12281236
1229- edge_programs [name ] = program
12301237 edge_programs [name ] = _generate_edge_program (
12311238 config ,
12321239 program ,
1233- preserve_ops = list ( ops_set_to_not_decompose_by_program .get (name , []) ),
1240+ preserve_ops = ops_set_to_not_decompose_by_program .get (name , []),
12341241 )
12351242
12361243 edge_manager = EdgeProgramManager (
@@ -1349,9 +1356,6 @@ def to_edge_transform_and_lower( # noqa: C901
13491356 elif partitioner is None :
13501357 partitioner = {name : [] for name in aten_programs .keys ()}
13511358
1352- can_skip_using_EDGE_DO_NOT_DECOMP = _can_skip_using_EDGE_DO_NOT_DECOMP (
1353- partitioner , aten_programs
1354- )
13551359 edge_manager = _gen_edge_manager_for_partitioners (
13561360 partitioner , aten_programs , config , constant_methods , generate_etrecord
13571361 )
@@ -1377,7 +1381,8 @@ def to_edge_transform_and_lower( # noqa: C901
13771381 curr_op_set , check_op_support = curr_partitioner .ops_to_not_decompose (
13781382 program
13791383 )
1380- if not can_skip_using_EDGE_DO_NOT_DECOMP :
1384+
1385+ if not _can_skip_using_EDGE_DO_NOT_DECOMP (curr_partitioner , program ):
13811386 curr_op_set = _remove_invalid_ops_for_not_decompose (curr_op_set )
13821387 ops_set_to_not_decompose = ops_set_to_not_decompose .union (curr_op_set )
13831388 _sanity_check_graph_for_non_decomp_ops (
0 commit comments