1717from ..tasks import random_input_kwargs
1818from ..torch_export_patches import torch_export_patches
1919from ..torch_export_patches .patch_inputs import use_dyn_not_str
20+ from ..reference import TorchOnnxEvaluator
2021from .hghub import get_untrained_model_with_inputs
2122
2223
@@ -244,6 +245,7 @@ def validate_model(
244245 model_options : Optional [Dict [str , Any ]] = None ,
245246 subfolder : Optional [str ] = None ,
246247 opset : Optional [int ] = None ,
248+ runtime : str = "onnxruntime" ,
247249) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
248250 """
249251 Validates a model.
@@ -280,6 +282,8 @@ def validate_model(
280282 ``num_hidden_layers`` or ``attn_implementation``
281283 :param subfolder: version or subfolders to uses when retrieving a model id
282284 :param opset: onnx opset to use for the conversion
285+ :param runtime: onnx runtime to use to check about discrepancies,
286+ only if `do_run` is true
283287 :return: two dictionaries, one with some metrics,
284288 another one with whatever the function produces
285289
@@ -308,6 +312,7 @@ def validate_model(
308312 version_ortfusiontype = ortfusiontype or "" ,
309313 version_stop_if_static = str (stop_if_static ),
310314 version_exporter = exporter or "" ,
315+ version_runtime = runtime ,
311316 )
312317 )
313318 if opset :
@@ -633,7 +638,9 @@ def validate_model(
633638 return summary , data
634639
635640 if do_run :
636- summary_valid , data = validate_onnx_model (data = data , quiet = quiet , verbose = verbose )
641+ summary_valid , data = validate_onnx_model (
642+ data = data , quiet = quiet , verbose = verbose , runtime = runtime
643+ )
637644 summary .update (summary_valid )
638645
639646 if ortfusiontype and "onnx_filename" in data :
@@ -686,7 +693,7 @@ def validate_model(
686693
687694 if do_run :
688695 summary_valid , data = validate_onnx_model (
689- data = data , quiet = quiet , verbose = verbose , flavour = flavour
696+ data = data , quiet = quiet , verbose = verbose , flavour = flavour , runtime = runtime
690697 )
691698 summary .update (summary_valid )
692699
@@ -898,6 +905,7 @@ def validate_onnx_model(
898905 quiet : bool = False ,
899906 verbose : int = 0 ,
900907 flavour : Optional [str ] = None ,
908+ runtime : str = "onnxruntime" ,
901909) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
902910 """
903911 Verifies that an onnx model produces the same
@@ -910,6 +918,7 @@ def validate_onnx_model(
910918 :param quiet: catch exception or not
911919 :param verbose: verbosity
912920 :param flavour: use a different version of the inputs
921+ :param runtime: onnx runtime to use, onnxruntime or torch
913922 :return: two dictionaries, one with some metrics,
914923 another one with whatever the function produces
915924 """
@@ -951,16 +960,23 @@ def _mk(key):
951960 f"{ providers } ..., flavour={ flavour !r} "
952961 )
953962
963+ assert runtime == "torch" , f"runtime={ runtime !r} "
964+ cls_runtime = (
965+ (
966+ lambda model , providers : onnxruntime .InferenceSession (
967+ (model .SerializeToString () if isinstance (model , onnx .ModelProto ) else model ),
968+ providers = providers ,
969+ )
970+ )
971+ if runtime == "onnxruntime"
972+ else (lambda model , providers : TorchOnnxEvaluator (model , providers = providers ))
973+ )
954974 sess = _quiet_or_not_quiet (
955975 quiet ,
956976 _mk ("time_onnx_ort_create" ),
957977 summary ,
958978 data ,
959- (
960- lambda source = source , providers = providers : onnxruntime .InferenceSession (
961- source , providers = providers
962- )
963- ),
979+ (lambda source = source , providers = providers : cls_runtime (source , providers )),
964980 )
965981 if f"ERR_{ _mk ('time_onnx_ort_create' )} " in summary :
966982 return summary , data
0 commit comments