@@ -313,7 +313,7 @@ def export_to_edge(
313313            core_aten_ep , edge_constant_methods , edge_compile_config , verbose = verbose 
314314        )
315315
316-     def  export_for_et (model , device , output_path ) ->  str :
316+     def  export_for_et (model , device , output_path ,  edge_constant_methods ) ->  str :
317317
318318        input  =  (
319319            torch .tensor ([[1 ]], dtype = torch .long , device = device ),
@@ -344,12 +344,15 @@ def export_for_et(model, device, output_path) -> str:
344344        with  torch .nn .attention .sdpa_kernel (
345345            [torch .nn .attention .SDPBackend .MATH ]
346346        ), torch .no_grad ():
347-             m  =  export_for_training (model , input , dynamic_shapes = dynamic_shapes ).module ()
347+             m  =  export_for_training (
348+                 model , input , dynamic_shapes = dynamic_shapes 
349+             ).module ()
348350
349351            edge_manager  =  export_to_edge (
350352                m ,
351353                input ,
352354                dynamic_shapes = dynamic_shapes ,
355+                 edge_constant_methods = edge_constant_methods ,
353356                edge_compile_config = edge_config ,
354357            )
355358        edge_manager  =  edge_manager .to_backend (XnnpackDynamicallyQuantizedPartitioner ())
@@ -365,6 +368,7 @@ def export_for_et(model, device, output_path) -> str:
365368        )
366369
367370        print ("The methods are: " , export_program .methods )
371+         print ("The config methods are: " , export_program .config_methods )
368372        with  open (output_path , "wb" ) as  f :
369373            export_program .write_to_file (f )
370374
@@ -407,7 +411,9 @@ def main(args):
407411            f"Warning! ExecuTorch export target is controlled by export recipe, not device setting. Ignoring device={ builder_args .device }  
408412        )
409413        builder_args .device  =  "cpu" 
410-     elif  (output_pte_path  or  output_dso_path  or  output_aoti_package_path ) and  "mps"  in  builder_args .device :
414+     elif  (
415+         output_pte_path  or  output_dso_path  or  output_aoti_package_path 
416+     ) and  "mps"  in  builder_args .device :
411417        print ("Warning! Device MPS not supported for export. Exporting for device CPU." )
412418        builder_args .device  =  "cpu" 
413419
@@ -473,13 +479,26 @@ def main(args):
473479                support_tensor_subclass = False ,
474480            )
475481            _unset_gguf_kwargs (builder_args )
476-  
482+ 
483+     if  tokenizer_args  is  None :
484+         tokenizer_type  =  "0" 
485+     elif  tokenizer_args .is_sentencepiece :
486+         tokenizer_type  =  "2"   # Corresponding to llama2 
487+     else :
488+         tokenizer_type  =  "3"   # Corresponding to llama3 
489+ 
477490    with  torch .no_grad ():
478491        if  output_pte_path :
479492            output_pte_path  =  str (os .path .abspath (output_pte_path ))
480493            if  executorch_export_available :
481494                print (f"Exporting model using ExecuTorch to { output_pte_path }  )
482-                 export_for_et (model_to_pte , builder_args .device , args .output_pte_path )
495+                 print (f"Tokenizer type is { tokenizer_type }  )
496+                 export_for_et (
497+                     model_to_pte ,
498+                     builder_args .device ,
499+                     args .output_pte_path ,
500+                     {"tokenizer_type" : tokenizer_type },
501+                 )
483502            else :
484503                print (
485504                    "Export with executorch requested but ExecuTorch could not be loaded" 
@@ -503,13 +522,6 @@ def main(args):
503522        if  output_aoti_package_path :
504523            output_aoti_package_path  =  str (os .path .abspath (output_aoti_package_path ))
505524
506-             if  tokenizer_args  is  None :
507-                 tokenizer_type  =  "0" 
508-             elif  tokenizer_args .is_sentencepiece :
509-                 tokenizer_type  =  "2"   # Corresponding to llama2 
510-             else :
511-                 tokenizer_type  =  "3"   # Corresponding to llama3 
512- 
513525            metadata  =  {"tokenizer_type" : tokenizer_type }
514526            print (
515527                "Exporting model using AOT Inductor to "  f"{ output_aoti_package_path }  
0 commit comments