@@ -841,6 +841,8 @@ def validate_model(
841841 )
842842 summary .update (summary_valid )
843843
844+ _compute_final_statistics (summary )
845+
844846 if verbose :
845847 print ("[validate_model] -- done (final)" )
846848 if dump_stats :
@@ -853,15 +855,24 @@ def validate_model(
853855def compute_statistics (onnx_filename : str ) -> Dict [str , Union [float , int ]]:
854856 """Computes some statistics on the model itself."""
855857 onx = onnx .load (onnx_filename , load_external_data = False )
858+ cache_functions = {(f .domain , f .name ): f for f in onx .functions }
859+ local_domains = set (f .domain for f in onx .functions )
856860
857861 def node_iter (proto ):
858862 if isinstance (proto , onnx .ModelProto ):
859- yield from node_iter (proto .graph )
860863 for f in proto .functions :
861864 yield from node_iter (f )
865+ yield from node_iter (proto .graph )
862866 elif isinstance (proto , (onnx .FunctionProto , onnx .GraphProto )):
863867 for node in proto .node :
864868 yield node
869+
870+ # Let's inline the function
871+ key = node .domain , node .op_type
872+ if key in cache_functions :
873+ yield from node_iter (cache_functions [key ])
874+
875+ # Let's continue
865876 for att in node .attribute :
866877 if att .type == onnx .AttributeProto .GRAPH :
867878 yield from node_iter (att .g )
@@ -879,6 +890,11 @@ def node_iter(proto):
879890 n_nodes += 1
880891 if proto .op_type != "Constant" :
881892 n_nodes_nocst += 1
893+ if proto .domain in local_domains :
894+ key = "n_node_local_function"
895+ if key not in counts :
896+ counts [key ] = 0
897+ counts [key ] += 1
882898 else :
883899 key = f"n_node_initializer_{ proto .data_type } "
884900
@@ -1400,7 +1416,7 @@ def call_torch_export_onnx(
14001416 :return: two dictionaries, one with some metrics,
14011417 another one with whatever the function produces
14021418 """
1403- available = {None , "" , "ir" , "os_ort" }
1419+ available = {None , "" , "ir" , "os_ort" , "ir+default" }
14041420 assert (
14051421 optimization in available
14061422 ), f"unexpected value for optimization={ optimization } , available={ available } "
@@ -1490,11 +1506,31 @@ def call_torch_export_onnx(
14901506 print (epo )
14911507 print ("[call_torch_export_onnx] -- End of ONNXProgram" )
14921508
1493- if optimization in {"ir" , "os_ort" }:
1509+ if optimization in {"ir" , "os_ort" , "ir+default" }:
14941510 if verbose :
14951511 print (f"[call_torch_export_onnx] starts optimization={ optimization !r} ..." )
14961512 if optimization == "ir" :
14971513 label , f_optim = "export_onnx_opt_ir" , (lambda epo = epo : epo .optimize ())
1514+ elif optimization == "ir+default" :
1515+ import onnxscript
1516+ from experimental_experiment .xbuilder import GraphBuilder , OptimizationOptions
1517+
1518+ def _ir_default_opt (epo ):
1519+ onnxscript .optimizer .optimize_ir (epo .model )
1520+ onx = epo .model_proto
1521+ # not very efficient
1522+ gr = GraphBuilder (
1523+ onx ,
1524+ infer_shapes_options = True ,
1525+ optimization_options = OptimizationOptions (patterns = "default" ),
1526+ )
1527+ cont = gr .to_onnx (large_model = True )
1528+ epo .model = cont .to_ir ()
1529+
1530+ label , f_optim = "export_onnx_opt_ir_default" , (
1531+ lambda epo = epo : _ir_default_opt (epo )
1532+ )
1533+
14981534 else :
14991535 import onnxscript
15001536 import onnxscript .rewriter .ort_fusions as ort_fusions
@@ -1893,3 +1929,21 @@ def run_ort_fusion(
18931929 f"opt_ort_{ model_type } _duration" : duration ,
18941930 f"opt_ort_{ model_type } _duration_save" : d ,
18951931 }, {f"opt_ort_{ model_type } " : output_path }
1932+
1933+
1934+ def _compute_final_statistics (summary : Dict [str , Any ]):
1935+ """
1936+ Updates inline the list of statistics. It adds:
1937+
1938+ - speedup
1939+ """
1940+ stats = {}
1941+ if (
1942+ "time_run_latency" in summary
1943+ and "time_run_onnx_ort_latency" in summary
1944+ and summary ["time_run_onnx_ort_latency" ] > 0
1945+ ):
1946+ stats ["stat_estimated_speedup_ort" ] = (
1947+ summary ["time_run_latency" ] / summary ["time_run_onnx_ort_latency" ]
1948+ )
1949+ summary .update (stats )
0 commit comments