@@ -482,6 +482,7 @@ def validate_model(
482482 verbose = verbose ,
483483 optimization = optimization ,
484484 do_run = do_run ,
485+ dump_folder = dump_folder ,
485486 )
486487 else :
487488 data ["inputs_export" ] = data ["inputs" ]
@@ -493,6 +494,7 @@ def validate_model(
493494 verbose = verbose ,
494495 optimization = optimization ,
495496 do_run = do_run ,
497+ dump_folder = dump_folder ,
496498 )
497499 summary .update (summary_export )
498500
@@ -618,6 +620,7 @@ def call_exporter(
618620 verbose : int = 0 ,
619621 optimization : Optional [str ] = None ,
620622 do_run : bool = False ,
623+ dump_folder : Optional [None ] = None ,
621624) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
622625 """
623626 Calls an exporter on a model;
@@ -629,6 +632,7 @@ def call_exporter(
629632 :param verbose: verbosity
630633 :param optimization: optimization to do
631634 :param do_run: runs and compute discrepancies
635+ :param dump_folder: to dump additional information
632636 :return: two dictionaries, one with some metrics,
633637 another one with whatever the function produces
634638 """
@@ -661,6 +665,7 @@ def call_exporter(
661665 quiet = quiet ,
662666 verbose = verbose ,
663667 optimization = optimization ,
668+ dump_folder = dump_folder ,
664669 )
665670 return summary , data
666671 raise NotImplementedError (
@@ -1045,6 +1050,7 @@ def call_torch_export_custom(
10451050 quiet : bool = False ,
10461051 verbose : int = 0 ,
10471052 optimization : Optional [str ] = None ,
1053+ dump_folder : Optional [str ] = None ,
10481054) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
10491055 """
10501056 Exports a model into onnx.
@@ -1056,6 +1062,7 @@ def call_torch_export_custom(
10561062 :param quiet: catch exception or not
10571063 :param verbose: verbosity
10581064 :param optimization: optimization to do
1065+ :param dump_folder: to store additional information
10591066 :return: two dictionaries, one with some metrics,
10601067 another one with whatever the function produces
10611068 """
@@ -1113,6 +1120,7 @@ def call_torch_export_custom(
11131120 decomposition_table = (
11141121 "default" if "-default" in exporter else ("all" if "-all" in exporter else None )
11151122 ),
1123+ save_ep = (os .path .join (dump_folder , f"{ exporter } .ep" ) if dump_folder else None ),
11161124 )
11171125 options = OptimizationOptions (patterns = optimization ) if optimization else None
11181126 model = data ["model" ]
0 commit comments