@@ -264,14 +264,18 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
264264 return new_cfg
265265
266266
267- def _preprocess_model_id (model_id , subfolder ):
267+ def _preprocess_model_id (
268+ model_id : str , subfolder : Optional [str ], same_as_pretrained : bool , use_pretrained : bool
269+ ) -> Tuple [str , Optional [str ], bool , bool ]:
268270 if subfolder or "//" not in model_id :
269- return model_id , subfolder
271+ return model_id , subfolder , same_as_pretrained , use_pretrained
270272 spl = model_id .split ("//" )
273+ if spl [- 1 ] == "pretrained" :
274+ return _preprocess_model_id ("//" .join (spl [:- 1 ]), "" , True , True )
271275 if spl [- 1 ] in {"transformer" , "vae" }:
272276 # known subfolder
273- return "//" .join (spl [:- 1 ]), spl [- 1 ]
274- return model_id , subfolder
277+ return "//" .join (spl [:- 1 ]), spl [- 1 ], same_as_pretrained , use_pretrained
278+ return model_id , subfolder , same_as_pretrained , use_pretrained
275279
276280
277281def validate_model (
@@ -384,7 +388,12 @@ def validate_model(
384388 if ``runtime == 'ref'``,
385389 ``orteval10`` increases the verbosity.
386390 """
387- model_id , subfolder = _preprocess_model_id (model_id , subfolder )
391+ model_id , subfolder , same_as_pretrained , use_pretrained = _preprocess_model_id (
392+ model_id ,
393+ subfolder ,
394+ same_as_pretrained = same_as_pretrained ,
395+ use_pretrained = use_pretrained ,
396+ )
388397 if isinstance (patch , bool ):
389398 patch_kwargs = (
390399 dict (patch_transformers = True , patch_diffusers = True , patch = True )
@@ -841,6 +850,8 @@ def validate_model(
841850 )
842851 summary .update (summary_valid )
843852
853+ _compute_final_statistics (summary )
854+
844855 if verbose :
845856 print ("[validate_model] -- done (final)" )
846857 if dump_stats :
@@ -853,15 +864,24 @@ def validate_model(
853864def compute_statistics (onnx_filename : str ) -> Dict [str , Union [float , int ]]:
854865 """Computes some statistics on the model itself."""
855866 onx = onnx .load (onnx_filename , load_external_data = False )
867+ cache_functions = {(f .domain , f .name ): f for f in onx .functions }
868+ local_domains = set (f .domain for f in onx .functions )
856869
857870 def node_iter (proto ):
858871 if isinstance (proto , onnx .ModelProto ):
859- yield from node_iter (proto .graph )
860872 for f in proto .functions :
861873 yield from node_iter (f )
874+ yield from node_iter (proto .graph )
862875 elif isinstance (proto , (onnx .FunctionProto , onnx .GraphProto )):
863876 for node in proto .node :
864877 yield node
878+
879+ # Let's inline the function
880+ key = node .domain , node .op_type
881+ if key in cache_functions :
882+ yield from node_iter (cache_functions [key ])
883+
884+ # Let's continue
865885 for att in node .attribute :
866886 if att .type == onnx .AttributeProto .GRAPH :
867887 yield from node_iter (att .g )
@@ -879,6 +899,11 @@ def node_iter(proto):
879899 n_nodes += 1
880900 if proto .op_type != "Constant" :
881901 n_nodes_nocst += 1
902+ if proto .domain in local_domains :
903+ key = "n_node_local_function"
904+ if key not in counts :
905+ counts [key ] = 0
906+ counts [key ] += 1
882907 else :
883908 key = f"n_node_initializer_{ proto .data_type } "
884909
@@ -1400,7 +1425,7 @@ def call_torch_export_onnx(
14001425 :return: two dictionaries, one with some metrics,
14011426 another one with whatever the function produces
14021427 """
1403- available = {None , "" , "ir" , "os_ort" }
1428+ available = {None , "" , "ir" , "os_ort" , "ir+default" }
14041429 assert (
14051430 optimization in available
14061431 ), f"unexpected value for optimization={ optimization } , available={ available } "
@@ -1490,11 +1515,31 @@ def call_torch_export_onnx(
14901515 print (epo )
14911516 print ("[call_torch_export_onnx] -- End of ONNXProgram" )
14921517
1493- if optimization in {"ir" , "os_ort" }:
1518+ if optimization in {"ir" , "os_ort" , "ir+default" }:
14941519 if verbose :
14951520 print (f"[call_torch_export_onnx] starts optimization={ optimization !r} ..." )
14961521 if optimization == "ir" :
14971522 label , f_optim = "export_onnx_opt_ir" , (lambda epo = epo : epo .optimize ())
1523+ elif optimization == "ir+default" :
1524+ import onnxscript
1525+ from experimental_experiment .xbuilder import GraphBuilder , OptimizationOptions
1526+
1527+ def _ir_default_opt (epo ):
1528+ onnxscript .optimizer .optimize_ir (epo .model )
1529+ onx = epo .model_proto
1530+ # not very efficient
1531+ gr = GraphBuilder (
1532+ onx ,
1533+ infer_shapes_options = True ,
1534+ optimization_options = OptimizationOptions (patterns = "default" ),
1535+ )
1536+ cont = gr .to_onnx (large_model = True )
1537+ epo .model = cont .to_ir ()
1538+
1539+ label , f_optim = "export_onnx_opt_ir_default" , (
1540+ lambda epo = epo : _ir_default_opt (epo )
1541+ )
1542+
14981543 else :
14991544 import onnxscript
15001545 import onnxscript .rewriter .ort_fusions as ort_fusions
@@ -1893,3 +1938,21 @@ def run_ort_fusion(
18931938 f"opt_ort_{ model_type } _duration" : duration ,
18941939 f"opt_ort_{ model_type } _duration_save" : d ,
18951940 }, {f"opt_ort_{ model_type } " : output_path }
1941+
1942+
1943+ def _compute_final_statistics (summary : Dict [str , Any ]):
1944+ """
1945+ Updates inline the list of statistics. It adds:
1946+
1947+ - speedup
1948+ """
1949+ stats = {}
1950+ if (
1951+ "time_run_latency" in summary
1952+ and "time_run_onnx_ort_latency" in summary
1953+ and summary ["time_run_onnx_ort_latency" ] > 0
1954+ ):
1955+ stats ["stat_estimated_speedup_ort" ] = (
1956+ summary ["time_run_latency" ] / summary ["time_run_onnx_ort_latency" ]
1957+ )
1958+ summary .update (stats )
0 commit comments