11import datetime
22import inspect
33import os
4+ import sys
45from typing import Any , Callable , Dict , List , Optional , Tuple , Union
56import time
67import onnx
@@ -375,8 +376,11 @@ def validate_model(
375376 summary [f"model_{ k .replace ('_' ,'' )} " ] = data [k ]
376377 summary ["model_inputs_opionts" ] = str (input_options or "" )
377378 summary ["model_inputs" ] = string_type (data ["inputs" ], with_shape = True )
378- summary ["model_shapes" ] = string_type (str ( data ["dynamic_shapes" ]) )
379+ summary ["model_shapes" ] = string_type (data ["dynamic_shapes" ])
379380 summary ["model_class" ] = data ["model" ].__class__ .__name__
381+ summary ["model_module" ] = str (data ["model" ].__class__ .__module__ )
382+ if summary ["model_module" ] in sys .modules :
383+ summary ["model_file" ] = str (sys .modules [summary ["model_module" ]].__file__ ) # type: ignore[index]
380384 summary ["model_config_class" ] = data ["configuration" ].__class__ .__name__
381385 summary ["model_config" ] = str (data ["configuration" ].to_dict ()).replace (" " , "" )
382386 summary ["model_id" ] = model_id
@@ -482,6 +486,7 @@ def validate_model(
482486 verbose = verbose ,
483487 optimization = optimization ,
484488 do_run = do_run ,
489+ dump_folder = dump_folder ,
485490 )
486491 else :
487492 data ["inputs_export" ] = data ["inputs" ]
@@ -493,6 +498,7 @@ def validate_model(
493498 verbose = verbose ,
494499 optimization = optimization ,
495500 do_run = do_run ,
501+ dump_folder = dump_folder ,
496502 )
497503 summary .update (summary_export )
498504
@@ -618,6 +624,7 @@ def call_exporter(
618624 verbose : int = 0 ,
619625 optimization : Optional [str ] = None ,
620626 do_run : bool = False ,
627+ dump_folder : Optional [str ] = None ,
621628) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
622629 """
623630 Calls an exporter on a model;
@@ -629,6 +636,7 @@ def call_exporter(
629636 :param verbose: verbosity
630637 :param optimization: optimization to do
631638 :param do_run: runs and compute discrepancies
639+ :param dump_folder: to dump additional information
632640 :return: two dictionaries, one with some metrics,
633641 another one with whatever the function produces
634642 """
@@ -661,6 +669,7 @@ def call_exporter(
661669 quiet = quiet ,
662670 verbose = verbose ,
663671 optimization = optimization ,
672+ dump_folder = dump_folder ,
664673 )
665674 return summary , data
666675 raise NotImplementedError (
@@ -1045,6 +1054,7 @@ def call_torch_export_custom(
10451054 quiet : bool = False ,
10461055 verbose : int = 0 ,
10471056 optimization : Optional [str ] = None ,
1057+ dump_folder : Optional [str ] = None ,
10481058) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
10491059 """
10501060 Exports a model into onnx.
@@ -1056,6 +1066,7 @@ def call_torch_export_custom(
10561066 :param quiet: catch exception or not
10571067 :param verbose: verbosity
10581068 :param optimization: optimization to do
1069+ :param dump_folder: to store additional information
10591070 :return: two dictionaries, one with some metrics,
10601071 another one with whatever the function produces
10611072 """
@@ -1113,6 +1124,7 @@ def call_torch_export_custom(
11131124 decomposition_table = (
11141125 "default" if "-default" in exporter else ("all" if "-all" in exporter else None )
11151126 ),
1127+ save_ep = (os .path .join (dump_folder , f"{ exporter } .ep" ) if dump_folder else None ),
11161128 )
11171129 options = OptimizationOptions (patterns = optimization ) if optimization else None
11181130 model = data ["model" ]
0 commit comments