1515import  torch 
1616from  executorch .exir  import  CaptureConfig 
1717from  executorch .exir .passes  import  MemoryPlanningPass 
18+ from  executorch .exir .program ._program  import  ExecutorchProgramManager 
1819from  torch  import  nn 
1920from  torch .export  import  Dim 
2021
@@ -190,7 +191,8 @@ def export_joint():
190191def  export_module_to_program (
191192    module_class : Type [nn .Module ],
192193    skip_type_promotion : bool ,
193- ):
194+     external_constants : bool  =  False ,
195+ ) ->  ExecutorchProgramManager :
194196    """Exports the module and returns the serialized program data.""" 
195197    torch .manual_seed (0 )
196198    # Look for an optional @staticmethod that defines custom trace params. 
@@ -211,9 +213,10 @@ def export_module_to_program(
211213        methods ,
212214        skip_type_promotion = skip_type_promotion ,
213215        export_joint_graph = export_joint ,
216+         external_constants = external_constants ,
214217        ** export_kwargs ,
215218    )
216-     return  module .executorch_program . buffer 
219+     return  module .executorch_program 
217220
218221
219222def  main () ->  None :
@@ -235,7 +238,12 @@ def main() -> None:
235238        "--outdir" ,
236239        type = str ,
237240        required = True ,
238-         help = "Path to the directory to write <classname>.pte files to" ,
241+         help = "Path to the directory to write <classname>.pte files and .ptd files to" ,
242+     )
243+     parser .add_argument (
244+         "--external-constants" ,
245+         action = "store_true" ,
246+         help = "Export the model with external constants" ,
239247    )
240248    args  =  parser .parse_args ()
241249
@@ -257,14 +265,16 @@ def main() -> None:
257265            # Type promotion will convert to fp32. 
258266            skip_type_promotion  =  True 
259267        outfile  =  os .path .join (args .outdir , f"{ module_name }  )
268+         prog  =  export_module_to_program (
269+             module_class ,
270+             skip_type_promotion = skip_type_promotion ,
271+             external_constants = args .external_constants ,
272+         )
260273        with  open (outfile , "wb" ) as  fp :
261-             fp .write (
262-                 export_module_to_program (
263-                     module_class ,
264-                     skip_type_promotion = skip_type_promotion ,
265-                 )
266-             )
267-         print (f"Exported { module_name } { outfile }  )
274+             prog .write_to_file (fp )
275+             print (f"Exported { module_name } { outfile }  )
276+ 
277+         prog .write_tensor_data_to_file (args .outdir )
268278
269279
270280if  __name__  ==  "__main__" :
0 commit comments