@@ -420,6 +420,7 @@ def validate_model(
420420 )
421421 summary .update (summary_export )
422422
423+ dump_stats = None
423424 if dump_folder :
424425 if "exported_program" in data :
425426 ep = data ["exported_program" ]
@@ -435,22 +436,27 @@ def validate_model(
435436 epo = data ["onnx_program" ]
436437 if verbose :
437438 print (f"[validate_model] dumps onnx program in { dump_folder !r} ..." )
438- onnx_file_name = os .path .join (dump_folder , f"{ folder_name } .onnx" )
439+ onnx_filename = os .path .join (dump_folder , f"{ folder_name } .onnx" )
440+ begin = time .perf_counter ()
439441 if isinstance (epo , onnx .model_container .ModelContainer ):
440- epo .save (onnx_file_name , all_tensors_to_one_file = True )
442+ epo .save (onnx_filename , all_tensors_to_one_file = True )
441443 else :
442- epo .save (onnx_file_name , external_data = True )
444+ epo .save (onnx_filename , external_data = True )
445+ duration = time .perf_counter () - begin
443446 if verbose :
444- print ("[validate_model] done (dump onnx)" )
447+ print (f"[validate_model] done (dump onnx) in { duration } " )
448+ data ["onnx_filename" ] = onnx_filename
449+ summary ["time_onnx_save" ] = duration
445450 if verbose :
446451 print (f"[validate_model] dumps statistics in { dump_folder !r} ..." )
447- with open (os .path .join (dump_folder , f"{ folder_name } .stats" ), "w" ) as f :
452+ dump_stats = os .path .join (dump_folder , f"{ folder_name } .stats" )
453+ with open (dump_stats , "w" ) as f :
448454 for k , v in sorted (summary .items ()):
449455 f .write (f":{ k } :{ v } ;\n " )
450456 if verbose :
451457 print ("[validate_model] done (dump)" )
452458
453- if exporter and exporter .startswith ("onnx-" ) and do_run :
459+ if exporter and exporter .startswith (( "onnx-" , "custom-" ) ) and do_run :
454460 summary_valid , data = validate_onnx_model (
455461 data = data ,
456462 quiet = quiet ,
@@ -461,6 +467,10 @@ def validate_model(
461467
462468 if verbose :
463469 print ("[validate_model] -- done (final)" )
470+ if dump_stats :
471+ with open (dump_stats , "w" ) as f :
472+ for k , v in sorted (summary .items ()):
473+ f .write (f":{ k } :{ v } ;\n " )
464474 return summary , data
465475
466476
@@ -642,7 +652,7 @@ def validate_onnx_model(
642652 quiet : bool = False ,
643653 verbose : int = 0 ,
644654 optimization : Optional [str ] = None ,
645- ):
655+ ) -> Tuple [ Dict [ str , Any ], Dict [ str , Any ]] :
646656 """
647657 Verifies that an onnx model produces the same
648658 expected outputs.
@@ -665,10 +675,10 @@ def validate_onnx_model(
665675 if d < 0
666676 else ["CUDAExecutionProvider" , "CPUExecutionProvider" ]
667677 )
668- if "onnx_file_name " in data :
669- source = data ["onnx_file_name " ]
678+ if "onnx_filename " in data :
679+ source = data ["onnx_filename " ]
670680 summary ["onnx_filename" ] = source
671- summary ["onnx_size" ] = os .stats (source ).st_size
681+ summary ["onnx_size" ] = os .stat (source ).st_size
672682 else :
673683 assert (
674684 "onnx_program" in data
@@ -745,7 +755,7 @@ def call_torch_export_onnx(
745755 quiet : bool = False ,
746756 verbose : int = 0 ,
747757 optimization : Optional [str ] = None ,
748- ):
758+ ) -> Tuple [ Dict [ str , Any ], Dict [ str , Any ]] :
749759 """
750760 Exports a model into onnx.
751761 If a patch must be applied, it should be before this functions.
@@ -818,7 +828,7 @@ def call_torch_export_onnx(
818828 if verbose :
819829 print ("[call_torch_export_onnx] done (export)" )
820830 data ["onnx_program" ] = epo
821- if verbose > 1 :
831+ if verbose > 5 :
822832 print ("[call_torch_export_onnx] -- ONNXProgram" )
823833 print (epo )
824834 print ("[call_torch_export_onnx] -- End of ONNXProgram" )
@@ -850,7 +860,7 @@ def call_torch_export_custom(
850860 quiet : bool = False ,
851861 verbose : int = 0 ,
852862 optimization : Optional [str ] = None ,
853- ):
863+ ) -> Tuple [ Dict [ str , Any ], Dict [ str , Any ]] :
854864 """
855865 Exports a model into onnx.
856866 If a patch must be applied, it should be before this functions.
@@ -1011,3 +1021,75 @@ def call_torch_export_custom(
10111021 print ("[call_torch_export_custom] done (export)" )
10121022 data ["onnx_program" ] = epo
10131023 return summary , data
1024+
1025+
1026+ def run_ort_fusion (
1027+ model_or_path : Union [str , onnx .ModelProto ],
1028+ output_path : str ,
1029+ num_attention_heads : int ,
1030+ hidden_size : int ,
1031+ model_type : str = "bert" ,
1032+ verbose : int = 0 ,
1033+ ) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
1034+ """
1035+ Runs :epkg:`onnxruntime` fusion optimizer.
1036+
1037+ :param model_or_path: path to the ModelProto or the ModelProto itself
1038+ :param output_path: the model to save
1039+ :param num_attention_heads: number of heads, usually ``config.num_attention_heads``
1040+ :param hidden_size: hidden size, usually ``config.hidden_size``
1041+ :param model_type: type of optimization, see below
1042+ :param verbose: verbosity
1043+ :return: two dictionaries, summary and data
1044+
1045+ Supported values for ``model_type``:
1046+
1047+ .. runpython::
1048+ :showcode:
1049+
1050+ import pprint
1051+ from onnxruntime.transformers.optimizer import MODEL_TYPES
1052+
1053+ pprint.pprint(sorted(MODEL_TYPES))
1054+ """
1055+ from onnxruntime .transformers .optimizer import optimize_by_fusion
1056+ from onnxruntime .transformers .fusion_options import FusionOptions
1057+
1058+ opts = FusionOptions (model_type )
1059+
1060+ if isinstance (model_or_path , str ):
1061+ if verbose :
1062+ print (f"[run_ort_fusion] loads { model_or_path !r} " )
1063+ onx = onnx .load (model_or_path )
1064+ else :
1065+ onx = model_or_path
1066+ begin = time .perf_counter ()
1067+ n_nodes = len (onx .graph .node )
1068+ if verbose :
1069+ print (
1070+ f"[run_ort_fusion] starts optimization for "
1071+ f"model_type={ model_type !r} with { n_nodes } nodes"
1072+ )
1073+ new_onx = optimize_by_fusion (
1074+ onx ,
1075+ model_type = model_type ,
1076+ num_heads = num_attention_heads ,
1077+ hidden_size = hidden_size ,
1078+ optimization_options = opts ,
1079+ )
1080+ duration = {time .perf_counter () - begin }
1081+ delta = len (new_onx .model .graph .node )
1082+ if verbose :
1083+ print (f"[run_ort_fusion] done in { duration } with { delta } nodes" )
1084+ print (f"[run_ort_fusion] save to { output_path !r} " )
1085+ begin = time .perf_counter ()
1086+ new_onx .save_model_to_file (output_path , use_external_data_format = True )
1087+ d = time .perf_counter () - begin
1088+ if verbose :
1089+ print (f"[run_ort_fusion] done in { d } " )
1090+ return {
1091+ f"opt_ort_{ model_type } _n_nodes1" : n_nodes ,
1092+ f"opt_ort_{ model_type } _n_nodes2" : delta ,
1093+ f"opt_ort_{ model_type } _duration" : duration ,
1094+ f"opt_ort_{ model_type } _duration_save" : d ,
1095+ }, {f"opt_ort_{ model_type } " : output_path }
0 commit comments