2525from executorch .backends .arm .util .arm_model_evaluator import GenericModelEvaluator
2626
2727from executorch .devtools .backend_debug import get_delegation_info
28- from executorch .exir import EdgeCompileConfig , ExecutorchBackendConfig
29- from executorch .extension .export_util .utils import export_to_edge , save_pte_program
28+ from executorch .exir import (
29+ EdgeCompileConfig ,
30+ ExecutorchBackendConfig ,
31+ to_edge_transform_and_lower ,
32+ )
33+ from executorch .extension .export_util .utils import save_pte_program
3034from tabulate import tabulate
3135
3236# Quantize model if required using the standard export quantizaion flow.
@@ -170,7 +174,9 @@ def forward(self, x):
170174]
171175
172176
173- def get_compile_spec (target : str , intermediates : bool ) -> ArmCompileSpecBuilder :
177+ def get_compile_spec (
178+ target : str , intermediates : Optional [str ] = None
179+ ) -> ArmCompileSpecBuilder :
174180 spec_builder = None
175181 if target == "TOSA" :
176182 spec_builder = (
@@ -185,7 +191,7 @@ def get_compile_spec(target: str, intermediates: bool) -> ArmCompileSpecBuilder:
185191 memory_mode = "Shared_Sram" ,
186192 extra_flags = "--debug-force-regor --output-format=raw" ,
187193 )
188- .set_permute_memory_format (args . model_name in MODEL_NAME_TO_MODEL . keys () )
194+ .set_permute_memory_format (True )
189195 .set_quantize_io (True )
190196 )
191197 elif "ethos-u85" in target :
@@ -202,7 +208,7 @@ def get_compile_spec(target: str, intermediates: bool) -> ArmCompileSpecBuilder:
202208 )
203209
204210 if intermediates is not None :
205- spec_builder .dump_intermediate_artifacts_to (args . intermediates )
211+ spec_builder .dump_intermediate_artifacts_to (intermediates )
206212
207213 return spec_builder .build ()
208214
@@ -356,40 +362,42 @@ def get_args():
356362 model , example_inputs = get_model_and_inputs_from_name (args .model_name )
357363 model = model .eval ()
358364
365+ # export_for_training under the assumption we quantize, the exported form also works
366+ # in to_edge if we don't quantize
367+ exported_program = torch .export .export_for_training (model , example_inputs )
368+ model = exported_program .module ()
359369 model_fp32 = model
360370
361- # pre-autograd export. eventually this will become torch.export
362- model = torch .export .export_for_training (model , example_inputs ).module ()
363-
364371 # Quantize if required
365372 model_int8 = None
366373 if args .quantize :
367374 model = quantize (model , example_inputs )
368375 model_int8 = model
376+ # Wrap quantized model back into an exported_program
377+ exported_program = torch .export .export_for_training (model , example_inputs )
378+
379+ if args .delegate :
380+ # As we can target multiple output encodings from ArmBackend, one must
381+ # be specified.
382+ compile_spec = get_compile_spec (args .target , args .intermediates )
383+ edge = to_edge_transform_and_lower (
384+ exported_program ,
385+ partitioner = [ArmPartitioner (compile_spec )],
386+ compile_config = EdgeCompileConfig (
387+ _check_ir_validity = False ,
388+ _skip_dim_order = True ,
389+ ),
390+ )
391+ else :
392+ edge = to_edge_transform_and_lower (
393+ exported_program ,
394+ compile_config = EdgeCompileConfig (
395+ _check_ir_validity = False ,
396+ _skip_dim_order = True ,
397+ ),
398+ )
369399
370- edge = export_to_edge (
371- model ,
372- example_inputs ,
373- edge_compile_config = EdgeCompileConfig (
374- _check_ir_validity = False ,
375- ),
376- )
377-
378- # As we can target multiple output encodings from ArmBackend, one must
379- # be specified.
380- compile_spec = (
381- get_compile_spec (args .target , args .intermediates )
382- if args .delegate is True
383- else None
384- )
385-
386- logging .debug (f"Exported graph:\n { edge .exported_program ().graph } " )
387- if args .delegate is True :
388- edge = edge .to_backend (ArmPartitioner (compile_spec ))
389-
390- dump_delegation_info (edge , args .intermediates )
391-
392- logging .debug (f"Lowered graph:\n { edge .exported_program ().graph } " )
400+ dump_delegation_info (edge , args .intermediates )
393401
394402 try :
395403 exec_prog = edge .to_executorch (
0 commit comments