@@ -671,7 +671,16 @@ def _call_exporter(
671671 do_run ,
672672 output_names ,
673673 exporter_options ,
674+ save_ep ,
674675):
676+ if save_ep and dump_folder :
677+ for name in data :
678+ if name .startswith ("inputs" ):
679+ if verbose :
680+ print (f"[validate_model] -- dump { name !r} " )
681+ filename = os .path .join (dump_folder , f"{ save_ep } .{ name } .pt" )
682+ torch .save (data [name ], filename )
683+
675684 if exporter :
676685 expop = exporter_options or {}
677686 if verbose :
@@ -711,6 +720,7 @@ def _call_exporter(
711720 dump_folder = dump_folder ,
712721 output_names = output_names ,
713722 exporter_options = expop ,
723+ save_ep = save_ep ,
714724 )
715725 else :
716726 data ["inputs_export" ] = data ["inputs" ]
@@ -831,6 +841,7 @@ def validate_model(
831841 output_names : Optional [List [str ]] = None ,
832842 ort_logs : bool = False ,
833843 quiet_input_sets : Optional [Set [str ]] = None ,
844+ save_ep : Optional [str ] = None ,
834845) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
835846 """
836847 Validates a model.
@@ -889,6 +900,8 @@ def validate_model(
889900 :param ort_logs: increases onnxruntime verbosity when creating the session
890901 :param quiet_input_sets: avoid raising an exception if the inputs belongs to that set
891902 even if quiet is False
903+ :param save_ep: if not empty, this can be used to save the input sets and
904+ the exported program
892905 :return: two dictionaries, one with some metrics,
893906 another one with whatever the function produces
894907
@@ -952,6 +965,7 @@ def validate_model(
952965 subfolder = subfolder ,
953966 use_pretrained = use_pretrained ,
954967 same_as_pretrained = same_as_pretrained ,
968+ save_ep = save_ep ,
955969 )
956970 if dump_folder :
957971 with open (dump_stats , "w" ) as f :
@@ -1038,6 +1052,7 @@ def _validate_model_step1(
10381052 subfolder ,
10391053 use_pretrained ,
10401054 same_as_pretrained ,
1055+ save_ep ,
10411056):
10421057 assert not do_same or do_run , (
10431058 f"Discrepancies cannot be measured if the model is not run, "
@@ -1153,6 +1168,7 @@ def _validate_model_step1(
11531168 do_run = do_run ,
11541169 output_names = output_names ,
11551170 exporter_options = exporter_options ,
1171+ save_ep = save_ep ,
11561172 )
11571173
11581174 cont , dump_stats = _dump_onnx_model (
@@ -1426,6 +1442,7 @@ def call_exporter(
14261442 dump_folder : Optional [str ] = None ,
14271443 output_names : Optional [List [str ]] = None ,
14281444 exporter_options : Optional [Dict [str , Any ]] = None ,
1445+ save_ep : Optional [str ] = None ,
14291446) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
14301447 """
14311448 Calls an exporter on a model;
@@ -1440,6 +1457,7 @@ def call_exporter(
14401457 :param dump_folder: to dump additional information
14411458 :param output_names: list of output names to use with the onnx exporter
14421459 :param exporter_options: exporter options
1460+ :param save_ep: saves the exported program
14431461 :return: two dictionaries, one with some metrics,
14441462 another one with whatever the function produces
14451463 """
@@ -1456,6 +1474,8 @@ def call_exporter(
14561474 optimization = optimization ,
14571475 do_run = do_run ,
14581476 exporter_options = exporter_options ,
1477+ save_ep = save_ep ,
1478+ dump_folder = dump_folder ,
14591479 )
14601480 _restore_torch_export_export (summary )
14611481 return summary , data
@@ -1469,6 +1489,8 @@ def call_exporter(
14691489 optimization = optimization ,
14701490 output_names = output_names ,
14711491 exporter_options = exporter_options ,
1492+ dump_folder = dump_folder ,
1493+ save_ep = save_ep ,
14721494 )
14731495 _restore_torch_export_export (summary )
14741496 return summary , data
@@ -1483,6 +1505,7 @@ def call_exporter(
14831505 dump_folder = dump_folder ,
14841506 output_names = output_names ,
14851507 exporter_options = exporter_options ,
1508+ save_ep = save_ep ,
14861509 )
14871510 _restore_torch_export_export (summary )
14881511 return summary , data
@@ -1516,6 +1539,8 @@ def call_torch_export_export(
15161539 optimization : Optional [str ] = None ,
15171540 do_run : bool = False ,
15181541 exporter_options : Optional [Dict [str , Any ]] = None ,
1542+ dump_folder : Optional [str ] = None ,
1543+ save_ep : Optional [str ] = None ,
15191544):
15201545 """
15211546 Exports a model with :func:`torch.export.export`.
@@ -1529,6 +1554,8 @@ def call_torch_export_export(
15291554 :param optimization: optimization to do
15301555 :param do_run: runs and compute discrepancies
15311556 :param exporter_options: additional options given to the exporter
1557+ :param dump_folder: folder where to dump the exported program
1558+ :param save_ep: to save the exported program
15321559 :return: two dictionaries, one with some metrics,
15331560 another one with whatever the function produces
15341561 """
@@ -1604,6 +1631,12 @@ def call_torch_export_export(
16041631 print (ep )
16051632 print ("[call_torch_export_export] -- End of ExportedProgram" )
16061633
1634+ if dump_folder and save_ep :
1635+ fname = f"{ save_ep } .pt2"
1636+ if verbose :
1637+ print (f"[call_torch_export_export] -- save the exported program in { fname !r} " )
1638+ torch .export .save (ep , os .path .join (dump_folder , fname ))
1639+
16071640 if do_run :
16081641 # We check for discrepancies.
16091642 if verbose :
@@ -1880,6 +1913,8 @@ def call_torch_export_onnx(
18801913 optimization : Optional [str ] = None ,
18811914 output_names : Optional [List [str ]] = None ,
18821915 exporter_options : Optional [Dict [str , Any ]] = None ,
1916+ dump_folder : Optional [str ] = None ,
1917+ save_ep : Optional [str ] = None ,
18831918) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
18841919 """
18851920 Exports a model into onnx.
@@ -1893,6 +1928,8 @@ def call_torch_export_onnx(
18931928 :param optimization: optimization to do
18941929 :param output_names: output names to use
18951930 :param exporter_options: additional options to give the exporter
1931+ :param dump_folder: to know where to dump the exported program
1932+ :param save_ep: to save the exported program
18961933 :return: two dictionaries, one with some metrics,
18971934 another one with whatever the function produces
18981935 """
@@ -1986,6 +2023,12 @@ def call_torch_export_onnx(
19862023 return summary , data
19872024
19882025 assert epo is not None , "no onnx export was found"
2026+ if dump_folder and save_ep :
2027+ fname = f"{ save_ep } .pt2"
2028+ if verbose :
2029+ print (f"[call_torch_export_export] -- save the exported program in { fname !r} " )
2030+ torch .export .save (epo .exported_program , os .path .join (dump_folder , fname ))
2031+
19892032 if verbose :
19902033 print ("[call_torch_export_onnx] done (export)" )
19912034 data ["onnx_program" ] = epo
@@ -2219,6 +2262,7 @@ def call_torch_export_custom(
22192262 dump_folder : Optional [str ] = None ,
22202263 output_names : Optional [List [str ]] = None ,
22212264 exporter_options : Optional [Dict [str , Any ]] = None ,
2265+ save_ep : Optional [str ] = None ,
22222266) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
22232267 """
22242268 Exports a model into onnx.
@@ -2233,6 +2277,7 @@ def call_torch_export_custom(
22332277 :param dump_folder: to store additional information
22342278 :param output_names: list of output names to use
22352279 :param exporter_options: additional exporter options
2280+ :param save_ep: to save the exported program
22362281 :return: two dictionaries, one with some metrics,
22372282 another one with whatever the function produces
22382283 """
@@ -2345,7 +2390,11 @@ def call_torch_export_custom(
23452390 export_options = ExportOptions (
23462391 strict = strict ,
23472392 decomposition_table = decomposition_table ,
2348- save_ep = (os .path .join (dump_folder , f"{ exporter } .ep" ) if dump_folder else None ),
2393+ save_ep = (
2394+ (os .path .join (dump_folder , f"{ exporter } .ep" ), 2 ** 35 if save_ep else 2 ** 18 )
2395+ if dump_folder
2396+ else None
2397+ ),
23492398 ** exporter_options ,
23502399 )
23512400 options = OptimizationOptions (patterns = optimization ) if optimization else None
0 commit comments