6161from executorch .backends .arm .tosa_partitioner import TOSAPartitioner
6262from executorch .backends .arm .tosa_specification import TosaSpecification
6363
64+ from executorch .backends .test .harness .stages import Stage , StageType
6465from executorch .backends .xnnpack .test .tester import Tester
6566from executorch .devtools .backend_debug import get_delegation_info
6667
@@ -259,10 +260,13 @@ def wrapped_ep_pass(ep: ExportedProgram) -> ExportedProgram:
259260 super ().run (artifact , inputs )
260261
261262
262- class InitialModel (tester . Stage ):
263+ class InitialModel (Stage ):
263264 def __init__ (self , model : torch .nn .Module ):
264265 self .model = model
265266
267+ def stage_type (self ) -> StageType :
268+ return StageType .INITIAL_MODEL
269+
266270 def run (self , artifact , inputs = None ) -> None :
267271 pass
268272
@@ -305,13 +309,13 @@ def __init__(
305309 self .constant_methods = constant_methods
306310 self .compile_spec = compile_spec
307311 super ().__init__ (model , example_inputs , dynamic_shapes )
308- self .pipeline [self . stage_name ( InitialModel ) ] = [
309- self . stage_name ( tester . Quantize ) ,
310- self . stage_name ( tester . Export ) ,
312+ self .pipeline [StageType . INITIAL_MODEL ] = [
313+ StageType . QUANTIZE ,
314+ StageType . EXPORT ,
311315 ]
312316
313317 # Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry.
314- self .stages [self . stage_name ( InitialModel ) ] = None
318+ self .stages [StageType . INITIAL_MODEL ] = None
315319 self ._run_stage (InitialModel (self .original_module ))
316320
317321 def quantize (
@@ -413,7 +417,7 @@ def serialize(
413417 return super ().serialize (serialize_stage )
414418
415419 def is_quantized (self ) -> bool :
416- return self .stages [self . stage_name ( tester . Quantize ) ] is not None
420+ return self .stages [StageType . QUANTIZE ] is not None
417421
418422 def run_method_and_compare_outputs (
419423 self ,
@@ -442,18 +446,16 @@ def run_method_and_compare_outputs(
442446 """
443447
444448 if not run_eager_mode :
445- edge_stage = self .stages [self . stage_name ( tester . ToEdge ) ]
449+ edge_stage = self .stages [StageType . TO_EDGE ]
446450 if edge_stage is None :
447- edge_stage = self .stages [
448- self .stage_name (tester .ToEdgeTransformAndLower )
449- ]
451+ edge_stage = self .stages [StageType .TO_EDGE_TRANSFORM_AND_LOWER ]
450452 assert (
451453 edge_stage is not None
452454 ), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run."
453455 else :
454456 # Run models in eager mode. We do this when we want to check that the passes
455457 # are numerically accurate and the exported graph is correct.
456- export_stage = self .stages [self . stage_name ( tester . Export ) ]
458+ export_stage = self .stages [StageType . EXPORT ]
457459 assert (
458460 export_stage is not None
459461 ), "To compare outputs in eager mode, the model must be at Export stage"
@@ -463,11 +465,11 @@ def run_method_and_compare_outputs(
463465 is_quantized = self .is_quantized ()
464466
465467 if is_quantized :
466- reference_stage = self .stages [self . stage_name ( tester . Quantize ) ]
468+ reference_stage = self .stages [StageType . QUANTIZE ]
467469 else :
468- reference_stage = self .stages [self . stage_name ( InitialModel ) ]
470+ reference_stage = self .stages [StageType . INITIAL_MODEL ]
469471
470- exported_program = self .stages [self . stage_name ( tester . Export ) ].artifact
472+ exported_program = self .stages [StageType . EXPORT ].artifact
471473 output_nodes = get_output_nodes (exported_program )
472474
473475 output_qparams = get_output_quantization_params (output_nodes )
@@ -477,7 +479,7 @@ def run_method_and_compare_outputs(
477479 quantization_scales .append (getattr (output_qparams [node ], "scale" , None ))
478480
479481 logger .info (
480- f"Comparing Stage '{ self . stage_name ( test_stage )} ' with Stage '{ self . stage_name ( reference_stage )} '"
482+ f"Comparing Stage '{ test_stage . stage_type ( )} ' with Stage '{ reference_stage . stage_type ( )} '"
481483 )
482484
483485 # Loop inputs and compare reference stage with the compared stage.
@@ -528,14 +530,12 @@ def get_graph(self, stage: str | None = None) -> Graph:
528530 stage = self .cur
529531 artifact = self .get_artifact (stage )
530532 if (
531- self .cur == self . stage_name ( tester . ToEdge )
532- or self .cur == self . stage_name ( Partition )
533- or self .cur == self . stage_name ( ToEdgeTransformAndLower )
533+ self .cur == StageType . TO_EDGE
534+ or self .cur == StageType . PARTITION
535+ or self .cur == StageType . TO_EDGE_TRANSFORM_AND_LOWER
534536 ):
535537 graph = artifact .exported_program ().graph
536- elif self .cur == self .stage_name (tester .Export ) or self .cur == self .stage_name (
537- tester .Quantize
538- ):
538+ elif self .cur == StageType .EXPORT or self .cur == StageType .QUANTIZE :
539539 graph = artifact .graph
540540 else :
541541 raise RuntimeError (
@@ -556,13 +556,13 @@ def dump_operator_distribution(
556556 Returns self for daisy-chaining.
557557 """
558558 line = "#" * 10
559- to_print = f"{ line } { self .cur . capitalize () } Operator Distribution { line } \n "
559+ to_print = f"{ line } { self .cur } Operator Distribution { line } \n "
560560
561561 if (
562562 self .cur
563563 in (
564- self . stage_name ( tester . Partition ) ,
565- self . stage_name ( ToEdgeTransformAndLower ) ,
564+ StageType . PARTITION ,
565+ StageType . TO_EDGE_TRANSFORM_AND_LOWER ,
566566 )
567567 and print_table
568568 ):
@@ -602,9 +602,7 @@ def dump_dtype_distribution(
602602 """
603603
604604 line = "#" * 10
605- to_print = (
606- f"{ line } { self .cur .capitalize ()} Placeholder Dtype Distribution { line } \n "
607- )
605+ to_print = f"{ line } { self .cur } Placeholder Dtype Distribution { line } \n "
608606
609607 graph = self .get_graph (self .cur )
610608 tosa_spec = get_tosa_spec (self .compile_spec )
@@ -653,7 +651,7 @@ def run_transform_for_annotation_pipeline(
653651 stage = self .cur
654652 # We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run.
655653 artifact = self .get_artifact (stage )
656- if self .cur == self . stage_name ( tester . Export ) :
654+ if self .cur == StageType . EXPORT :
657655 new_gm = ArmPassManager (get_tosa_spec (self .compile_spec )).transform_for_annotation_pipeline ( # type: ignore[arg-type]
658656 graph_module = artifact .graph_module
659657 )
0 commit comments