@@ -197,6 +197,7 @@ def validate_model(
197197 stop_if_static : int = 1 ,
198198 dump_folder : Optional [str ] = None ,
199199 drop_inputs : Optional [List [str ]] = None ,
200+ ortfusiontype : Optional [str ] = None ,
200201) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
201202 """
202203 Validates a model.
@@ -222,11 +223,33 @@ def validate_model(
222223 see :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
223224 :param dump_folder: dumps everything in a subfolder of this one
224225 :param drop_inputs: drops this list of inputs (given their names)
226+ :param ortfusiontype: runs ort fusion, the parameters defines the fusion type,
227+ it accepts multiple values separated by ``|``,
228+ see :func:`onnx_diagnostic.torch_models.test_helper.run_ort_fusion`
225229 :return: two dictionaries, one with some metrics,
226230 another one with whatever the function produces
227231 """
228232 assert not trained , f"trained={ trained } not supported yet"
229233 summary = version_summary ()
234+
235+ summary .update (
236+ dict (
237+ version_model_id = model_id ,
238+ version_do_run = str (do_run ),
239+ version_dtype = str (dtype or "" ),
240+ version_device = str (device or "" ),
241+ version_trained = str (trained ),
242+ version_optimization = optimization or "" ,
243+ version_quiet = str (quiet ),
244+ version_patch = str (patch ),
245+ version_dump_folder = dump_folder or "" ,
246+ version_drop_inputs = str (list (drop_inputs or "" )),
247+ version_ortfusiontype = ortfusiontype or "" ,
248+ version_stop_if_static = str (stop_if_static ),
249+ version_exporter = exporter ,
250+ )
251+ )
252+
230253 folder_name = None
231254 if dump_folder :
232255 folder_name = _make_folder_name (
@@ -456,15 +479,66 @@ def validate_model(
456479 if verbose :
457480 print ("[validate_model] done (dump)" )
458481
459- if exporter and exporter .startswith (("onnx-" , "custom-" )) and do_run :
460- summary_valid , data = validate_onnx_model (
461- data = data ,
462- quiet = quiet ,
463- verbose = verbose ,
464- optimization = optimization ,
465- )
482+ if not exporter or not exporter .startswith (("onnx-" , "custom-" )):
483+ if verbose :
484+ print ("[validate_model] -- done (final)" )
485+ if dump_stats :
486+ with open (dump_stats , "w" ) as f :
487+ for k , v in sorted (summary .items ()):
488+ f .write (f":{ k } :{ v } ;\n " )
489+ return summary , data
490+
491+ if do_run :
492+ summary_valid , data = validate_onnx_model (data = data , quiet = quiet , verbose = verbose )
466493 summary .update (summary_valid )
467494
495+ if ortfusiontype and "onnx_filename" in data :
496+ assert (
497+ "configuration" in data
498+ ), f"missing configuration in data, cannot run ort fusion for model_id={ model_id } "
499+ config = data ["configuration" ]
500+ assert hasattr (
501+ config , "hidden_size"
502+ ), f"Missing attribute hidden_size in configuration { config } "
503+ hidden_size = config .hidden_size
504+ assert hasattr (
505+ config , "num_attention_heads"
506+ ), f"Missing attribute num_attention_heads in configuration { config } "
507+ num_attention_heads = config .num_attention_heads
508+
509+ model_types = ortfusiontype .split ("|" )
510+ for model_type in model_types :
511+ flavour = f"ort{ model_type } "
512+ summary [f"version_{ flavour } _hidden_size" ] = hidden_size
513+ summary [f"version_{ flavour } _num_attention_heads" ] = num_attention_heads
514+
515+ begin = time .perf_counter ()
516+ if verbose :
517+ print (f"[validate_model] run onnxruntime fusion for { model_type !r} " )
518+ input_filename = data ["onnx_filename" ]
519+ output_path = f"{ os .path .splitext (input_filename )[0 ]} .ort.{ model_type } .onnx"
520+ run_ort_fusion (
521+ input_filename ,
522+ output_path ,
523+ model_type = model_type ,
524+ num_attention_heads = num_attention_heads ,
525+ hidden_size = hidden_size ,
526+ )
527+ data [f"onnx_filename_{ flavour } " ] = output_path
528+ duration = time .perf_counter () - begin
529+ summary [f"time_ortfusion_{ flavour } " ] = duration
530+ if verbose :
531+ print (
532+ f"[validate_model] done { model_type !r} in { duration } , "
533+ f"saved into { output_path !r} "
534+ )
535+
536+ if do_run :
537+ summary_valid , data = validate_onnx_model (
538+ data = data , quiet = quiet , verbose = verbose , flavour = flavour
539+ )
540+ summary .update (summary_valid )
541+
468542 if verbose :
469543 print ("[validate_model] -- done (final)" )
470544 if dump_stats :
@@ -651,22 +725,27 @@ def validate_onnx_model(
651725 data : Dict [str , Any ],
652726 quiet : bool = False ,
653727 verbose : int = 0 ,
654- optimization : Optional [str ] = None ,
728+ flavour : Optional [str ] = None ,
655729) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
656730 """
657731 Verifies that an onnx model produces the same
658- expected outputs.
732+ expected outputs. It uses ``data["onnx_filename]`` as the input
733+ onnx filename or ``data["onnx_filename_{flavour}]`` if *flavour*
734+ is specified.
659735
660736 :param data: dictionary with all the necessary inputs, the dictionary must
661737 contains keys ``model`` and ``inputs_export``
662738 :param quiet: catch exception or not
663739 :param verbose: verbosity
664- :param optimization: optimization to do
740+ :param flavour: use a different version of the inputs
665741 :return: two dictionaries, one with some metrics,
666742 another one with whatever the function produces
667743 """
668744 import onnxruntime
669745
746+ def _mk (key ):
747+ return f"{ key } _{ flavour } " if flavour else key
748+
670749 summary = {}
671750 flat_inputs = flatten_object (data ["inputs" ], drop_keys = True )
672751 d = flat_inputs [0 ].get_device ()
@@ -675,36 +754,42 @@ def validate_onnx_model(
675754 if d < 0
676755 else ["CUDAExecutionProvider" , "CPUExecutionProvider" ]
677756 )
678- if "onnx_filename" in data :
679- source = data ["onnx_filename" ]
680- summary ["onnx_filename" ] = source
681- summary ["onnx_size" ] = os .stat (source ).st_size
757+ input_data_key = f"onnx_filename_{ flavour } " if flavour else "onnx_filename"
758+
759+ if input_data_key in data :
760+ source = data [input_data_key ]
761+ summary [input_data_key ] = source
762+ summary [_mk ("onnx_size" )] = os .stat (source ).st_size
682763 else :
764+ assert not flavour , f"flavour={ flavour !r} , the filename must be saved."
683765 assert (
684766 "onnx_program" in data
685767 ), f"onnx_program is missing from data which has { sorted (data )} "
686768 source = data ["onnx_program" ].model_proto .SerializeToString ()
687769 assert len (source ) < 2 ** 31 , f"The model is highger than 2Gb: { len (source ) / 2 ** 30 } Gb"
688- summary ["onnx_size" ] = len (source )
770+ summary [_mk ( "onnx_size" ) ] = len (source )
689771 if verbose :
690- print (f"[validate_onnx_model] verify onnx model with providers { providers } ..." )
772+ print (
773+ f"[validate_onnx_model] verify onnx model with providers "
774+ f"{ providers } ..., flavour={ flavour !r} "
775+ )
691776
692777 begin = time .perf_counter ()
693778 if quiet :
694779 try :
695780 sess = onnxruntime .InferenceSession (source , providers = providers )
696781 except Exception as e :
697- summary ["ERR_onnx_ort_create" ] = str (e )
698- data ["ERR_onnx_ort_create" ] = e
699- summary ["time_onnx_ort_create" ] = time .perf_counter () - begin
782+ summary [_mk ( "ERR_onnx_ort_create" ) ] = str (e )
783+ data [_mk ( "ERR_onnx_ort_create" ) ] = e
784+ summary [_mk ( "time_onnx_ort_create" ) ] = time .perf_counter () - begin
700785 return summary , data
701786 else :
702787 sess = onnxruntime .InferenceSession (source , providers = providers )
703788
704- summary ["time_onnx_ort_create" ] = time .perf_counter () - begin
705- data ["onnx_ort_sess" ] = sess
789+ summary [_mk ( "time_onnx_ort_create" ) ] = time .perf_counter () - begin
790+ data [_mk ( "onnx_ort_sess" ) ] = sess
706791 if verbose :
707- print ("[validate_onnx_model] done (ort_session)" )
792+ print (f "[validate_onnx_model] done (ort_session) flavour= { flavour !r } " )
708793
709794 # make_feeds
710795 if verbose :
@@ -718,7 +803,7 @@ def validate_onnx_model(
718803 )
719804 if verbose :
720805 print (f"[validate_onnx_model] ort inputs={ string_type (feeds , with_shape = True )} " )
721- summary ["onnx_ort_inputs" ] = string_type (feeds , with_shape = True )
806+ summary [_mk ( "onnx_ort_inputs" ) ] = string_type (feeds , with_shape = True )
722807 if verbose :
723808 print ("[validate_onnx_model] done (make_feeds)" )
724809
@@ -730,9 +815,9 @@ def validate_onnx_model(
730815 try :
731816 got = sess .run (None , feeds )
732817 except Exception as e :
733- summary ["ERR_onnx_ort_run" ] = str (e )
734- data ["ERR_onnx_ort_run" ] = e
735- summary ["time_onnx_ort_run" ] = time .perf_counter () - begin
818+ summary [_mk ( "ERR_onnx_ort_run" ) ] = str (e )
819+ data [_mk ( "ERR_onnx_ort_run" ) ] = e
820+ summary [_mk ( "time_onnx_ort_run" ) ] = time .perf_counter () - begin
736821 return summary , data
737822 else :
738823 got = sess .run (None , feeds )
@@ -745,7 +830,7 @@ def validate_onnx_model(
745830 if verbose :
746831 print (f"[validate_onnx_model] discrepancies={ string_diff (disc )} " )
747832 for k , v in disc .items ():
748- summary [f"disc_onnx_ort_run_{ k } " ] = v
833+ summary [_mk ( f"disc_onnx_ort_run_{ k } " ) ] = v
749834 return summary , data
750835
751836
0 commit comments