1111from ..export import CoupleInputsDynamicShapes
1212from ..helpers import max_diff , string_type , string_diff
1313from ..helpers .helper import flatten_object
14- from ..helpers .rt_helper import make_feeds
14+ from ..helpers .rt_helper import make_feeds , reorder_modelbuilder_cache_to_torch
1515from ..helpers .torch_helper import to_any , torch_deepcopy
1616from ..helpers .cache_helper import flatten_unflatten_for_dynamic_shapes
1717from ..tasks import random_input_kwargs
@@ -264,14 +264,18 @@ def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
264264 return new_cfg
265265
266266
267- def _preprocess_model_id (model_id , subfolder ):
267+ def _preprocess_model_id (
268+ model_id : str , subfolder : Optional [str ], same_as_pretrained : bool , use_pretrained : bool
269+ ) -> Tuple [str , Optional [str ], bool , bool ]:
268270 if subfolder or "//" not in model_id :
269- return model_id , subfolder
271+ return model_id , subfolder , same_as_pretrained , use_pretrained
270272 spl = model_id .split ("//" )
273+ if spl [- 1 ] == "pretrained" :
274+ return _preprocess_model_id ("//" .join (spl [:- 1 ]), "" , True , True )
271275 if spl [- 1 ] in {"transformer" , "vae" }:
272276 # known subfolder
273- return "//" .join (spl [:- 1 ]), spl [- 1 ]
274- return model_id , subfolder
277+ return "//" .join (spl [:- 1 ]), spl [- 1 ], same_as_pretrained , use_pretrained
278+ return model_id , subfolder , same_as_pretrained , use_pretrained
275279
276280
277281def validate_model (
@@ -384,7 +388,12 @@ def validate_model(
384388 if ``runtime == 'ref'``,
385389 ``orteval10`` increases the verbosity.
386390 """
387- model_id , subfolder = _preprocess_model_id (model_id , subfolder )
391+ model_id , subfolder , same_as_pretrained , use_pretrained = _preprocess_model_id (
392+ model_id ,
393+ subfolder ,
394+ same_as_pretrained = same_as_pretrained ,
395+ use_pretrained = use_pretrained ,
396+ )
388397 if isinstance (patch , bool ):
389398 patch_kwargs = (
390399 dict (patch_transformers = True , patch_diffusers = True , patch = True )
@@ -812,6 +821,8 @@ def validate_model(
812821 )
813822 summary .update (summary_valid )
814823
824+ _compute_final_statistics (summary )
825+
815826 if verbose :
816827 print ("[validate_model] -- done (final)" )
817828 if dump_stats :
@@ -824,15 +835,24 @@ def validate_model(
824835def compute_statistics (onnx_filename : str ) -> Dict [str , Union [float , int ]]:
825836 """Computes some statistics on the model itself."""
826837 onx = onnx .load (onnx_filename , load_external_data = False )
838+ cache_functions = {(f .domain , f .name ): f for f in onx .functions }
839+ local_domains = set (f .domain for f in onx .functions )
827840
828841 def node_iter (proto ):
829842 if isinstance (proto , onnx .ModelProto ):
830- yield from node_iter (proto .graph )
831843 for f in proto .functions :
832844 yield from node_iter (f )
845+ yield from node_iter (proto .graph )
833846 elif isinstance (proto , (onnx .FunctionProto , onnx .GraphProto )):
834847 for node in proto .node :
835848 yield node
849+
850+ # Let's inline the function
851+ key = node .domain , node .op_type
852+ if key in cache_functions :
853+ yield from node_iter (cache_functions [key ])
854+
855+ # Let's continue
836856 for att in node .attribute :
837857 if att .type == onnx .AttributeProto .GRAPH :
838858 yield from node_iter (att .g )
@@ -850,6 +870,11 @@ def node_iter(proto):
850870 n_nodes += 1
851871 if proto .op_type != "Constant" :
852872 n_nodes_nocst += 1
873+ if proto .domain in local_domains :
874+ key = "n_node_local_function"
875+ if key not in counts :
876+ counts [key ] = 0
877+ counts [key ] += 1
853878 else :
854879 key = f"n_node_initializer_{ proto .data_type } "
855880
@@ -1298,7 +1323,13 @@ def _mk(key, flavour=flavour):
12981323 print (
12991324 f"[validate_onnx_model] inputs={ string_type (data [k_input ], with_shape = True )} "
13001325 )
1301- feeds = make_feeds (sess , data [k_input ], use_numpy = True , check_flatten = False )
1326+ feeds = make_feeds (
1327+ sess ,
1328+ data [k_input ],
1329+ use_numpy = True ,
1330+ check_flatten = False ,
1331+ is_modelbuilder = data ["exporter" ] == "modelbuilder" ,
1332+ )
13021333 if verbose :
13031334 print (f"[validate_onnx_model] ort inputs={ string_type (feeds , with_shape = True )} " )
13041335 summary [_mk (f"onnx_ort_inputs{ suffix } " )] = string_type (feeds , with_shape = True )
@@ -1318,6 +1349,13 @@ def _mk(key, flavour=flavour):
13181349 repeat = repeat ,
13191350 warmup = warmup ,
13201351 )
1352+ # NOTE: modelbuilder has different order on past_kv outputs
1353+ if data ["exporter" ] == "modelbuilder" :
1354+ logits = got [:1 ]
1355+ past_key_values = got [1 :]
1356+ reorder_past_key_values = reorder_modelbuilder_cache_to_torch (past_key_values )
1357+ got = logits + reorder_past_key_values
1358+
13211359 if f"ERR_{ _mk (f'time_onnx_ort_run{ suffix } ' )} " in summary :
13221360 return summary , data
13231361
@@ -1358,7 +1396,7 @@ def call_torch_export_onnx(
13581396 :return: two dictionaries, one with some metrics,
13591397 another one with whatever the function produces
13601398 """
1361- available = {None , "" , "ir" , "os_ort" }
1399+ available = {None , "" , "ir" , "os_ort" , "ir+default" }
13621400 assert (
13631401 optimization in available
13641402 ), f"unexpected value for optimization={ optimization } , available={ available } "
@@ -1448,11 +1486,31 @@ def call_torch_export_onnx(
14481486 print (epo )
14491487 print ("[call_torch_export_onnx] -- End of ONNXProgram" )
14501488
1451- if optimization in {"ir" , "os_ort" }:
1489+ if optimization in {"ir" , "os_ort" , "ir+default" }:
14521490 if verbose :
14531491 print (f"[call_torch_export_onnx] starts optimization={ optimization !r} ..." )
14541492 if optimization == "ir" :
14551493 label , f_optim = "export_onnx_opt_ir" , (lambda epo = epo : epo .optimize ())
1494+ elif optimization == "ir+default" :
1495+ import onnxscript
1496+ from experimental_experiment .xbuilder import GraphBuilder , OptimizationOptions
1497+
1498+ def _ir_default_opt (epo ):
1499+ onnxscript .optimizer .optimize_ir (epo .model )
1500+ onx = epo .model_proto
1501+ # not very efficient
1502+ gr = GraphBuilder (
1503+ onx ,
1504+ infer_shapes_options = True ,
1505+ optimization_options = OptimizationOptions (patterns = "default" ),
1506+ )
1507+ cont = gr .to_onnx (large_model = True )
1508+ epo .model = cont .to_ir ()
1509+
1510+ label , f_optim = "export_onnx_opt_ir_default" , (
1511+ lambda epo = epo : _ir_default_opt (epo )
1512+ )
1513+
14561514 else :
14571515 import onnxscript
14581516 import onnxscript .rewriter .ort_fusions as ort_fusions
@@ -1851,3 +1909,21 @@ def run_ort_fusion(
18511909 f"opt_ort_{ model_type } _duration" : duration ,
18521910 f"opt_ort_{ model_type } _duration_save" : d ,
18531911 }, {f"opt_ort_{ model_type } " : output_path }
1912+
1913+
1914+ def _compute_final_statistics (summary : Dict [str , Any ]):
1915+ """
1916+ Updates inline the list of statistics. It adds:
1917+
1918+ - speedup
1919+ """
1920+ stats = {}
1921+ if (
1922+ "time_run_latency" in summary
1923+ and "time_run_onnx_ort_latency" in summary
1924+ and summary ["time_run_onnx_ort_latency" ] > 0
1925+ ):
1926+ stats ["stat_estimated_speedup_ort" ] = (
1927+ summary ["time_run_latency" ] / summary ["time_run_onnx_ort_latency" ]
1928+ )
1929+ summary .update (stats )
0 commit comments