77import time
88import numpy as np
99import onnx
10- import onnxscript
11- import onnxscript .rewriter .ort_fusions as ort_fusions
1210import torch
1311from ..export import CoupleInputsDynamicShapes
1412from ..helpers import max_diff , string_type , string_diff
@@ -249,6 +247,7 @@ def _quiet_or_not_quiet(
249247 summary [f"time_{ suffix } _latency_std" ] = a .std ()
250248 summary [f"time_{ suffix } _latency_min" ] = a .min ()
251249 summary [f"time_{ suffix } _latency_min" ] = a .max ()
250+ summary [f"time_{ suffix } _n" ] = len (a )
252251 return res
253252
254253
@@ -337,7 +336,8 @@ def validate_model(
337336 :param subfolder: version or subfolders to uses when retrieving a model id
338337 :param opset: onnx opset to use for the conversion
339338 :param runtime: onnx runtime to use to check about discrepancies,
340- only if `do_run` is true
339+ possible values ``onnxruntime``, ``torch``, ``orteval``,
340+ ``orteval10``, ``ref`` only if `do_run` is true
341341 :param repeat: number of time to measure the model
342342 :param warmup: warmup the model first
343343 :param inputs2: checks that the second set of inputs is reunning as well,
@@ -364,7 +364,13 @@ def validate_model(
364364
365365 The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
366366 exported model returns the same outputs as the original one, otherwise,
367- :class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
367+ :class:`onnx_diagnostic.reference.TorchOnnxEvaluator`
368+ if ``runtime == 'torch'`` or
369+ :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`
370+ if ``runtime == 'orteval'`` or
371+ :class:`onnx_diagnostic.reference.ExtendedReferenceEvaluator`
372+ if ``runtime == 'ref'``,
373+ ``orteval10`` increases the verbosity.
368374 """
369375 if isinstance (patch , bool ):
370376 patch_kwargs = (
@@ -1155,7 +1161,7 @@ def validate_onnx_model(
11551161 :param quiet: catch exception or not
11561162 :param verbose: verbosity
11571163 :param flavour: use a different version of the inputs
1158- :param runtime: onnx runtime to use, onnxruntime or torch
1164+ :param runtime: onnx runtime to use, onnxruntime, torch, orteval, ref
11591165 :param repeat: run that number of times the model
11601166 :param warmup: warmup the model
11611167 :param inputs2: to validate the model on the second input set
@@ -1202,23 +1208,66 @@ def _mk(key, flavour=flavour):
12021208 f"{ providers } ..., flavour={ flavour !r} "
12031209 )
12041210
1205- if runtime != "onnxruntime" :
1211+ if runtime == "onnxruntime" :
1212+ if os .environ .get ("DUMPORTOPT" , "" ) in ("1" , "true" , "True" ):
1213+ opts = onnxruntime .SessionOptions ()
1214+ opts .optimized_model_filepath = f"{ data ['onnx_filename' ]} .rtopt.onnx"
1215+ if verbose :
1216+ print (
1217+ f"[validate_onnx_model] saved optimized onnxruntime "
1218+ f"in { opts .optimized_model_filepath !r} "
1219+ )
1220+ onnxruntime .InferenceSession (data ["onnx_filename" ], opts , providers = providers )
1221+ if verbose :
1222+ print ("[validate_onnx_model] -- done" )
1223+
1224+ if verbose :
1225+ print ("[validate_onnx_model] runtime is onnxruntime" )
1226+ cls_runtime = lambda model , providers : onnxruntime .InferenceSession (
1227+ (model .SerializeToString () if isinstance (model , onnx .ModelProto ) else model ),
1228+ providers = providers ,
1229+ )
1230+ elif runtime == "torch" :
12061231 from ..reference import TorchOnnxEvaluator
12071232
1208- cls_runtime = (
1209- (
1210- lambda model , providers : onnxruntime . InferenceSession (
1211- ( model . SerializeToString () if isinstance ( model , onnx . ModelProto ) else model ),
1212- providers = providers ,
1233+ if verbose :
1234+ print ( "[validate_onnx_model] runtime is TorchOnnxEvaluator" )
1235+ cls_runtime = (
1236+ lambda model , providers , _cls_ = TorchOnnxEvaluator : _cls_ ( # type: ignore[misc]
1237+ model , providers = providers , verbose = max ( verbose - 1 , 0 )
12131238 )
12141239 )
1215- if runtime == "onnxruntime"
1216- else (
1217- lambda model , providers , _cls_ = TorchOnnxEvaluator : _cls_ ( # type: ignore[misc]
1240+ elif runtime == "orteval" :
1241+ from ..reference import OnnxruntimeEvaluator
1242+
1243+ if verbose :
1244+ print ("[validate_onnx_model] runtime is OnnxruntimeEvaluator" )
1245+ cls_runtime = (
1246+ lambda model , providers , _cls_ = OnnxruntimeEvaluator : _cls_ ( # type: ignore[misc]
12181247 model , providers = providers , verbose = max (verbose - 1 , 0 )
12191248 )
12201249 )
1221- )
1250+ elif runtime == "orteval10" :
1251+ from ..reference import OnnxruntimeEvaluator
1252+
1253+ if verbose :
1254+ print ("[validate_onnx_model] runtime is OnnxruntimeEvaluator(verbose=10)" )
1255+ cls_runtime = (
1256+ lambda model , providers , _cls_ = OnnxruntimeEvaluator : _cls_ ( # type: ignore[misc]
1257+ model , providers = providers , verbose = 10
1258+ )
1259+ )
1260+ elif runtime == "ref" :
1261+ from ..reference import ExtendedReferenceEvaluator
1262+
1263+ if verbose :
1264+ print ("[validate_onnx_model] runtime is ExtendedReferenceEvaluator" )
1265+ cls_runtime = lambda model , providers , _cls_ = ExtendedReferenceEvaluator : _cls_ ( # type: ignore[misc]
1266+ model , verbose = max (verbose - 1 , 0 )
1267+ )
1268+ else :
1269+ raise ValueError (f"Unexpecteed runtime={ runtime !r} " )
1270+
12221271 sess = _quiet_or_not_quiet (
12231272 quiet ,
12241273 _mk ("create_onnx_ort" ),
@@ -1399,6 +1448,8 @@ def call_torch_export_onnx(
13991448 if optimization == "ir" :
14001449 label , f_optim = "export_onnx_opt_ir" , (lambda epo = epo : epo .optimize ())
14011450 else :
1451+ import onnxscript
1452+ import onnxscript .rewriter .ort_fusions as ort_fusions
14021453
14031454 def _os_ort_optim (epo ):
14041455 onnxscript .optimizer .optimize_ir (epo .model )
@@ -1683,6 +1734,9 @@ def call_torch_export_custom(
16831734 print ("[call_torch_export_custom] done (export)" )
16841735
16851736 if os_ort :
1737+ import onnxscript
1738+ import onnxscript .rewriter .ort_fusions as ort_fusions
1739+
16861740 if verbose :
16871741 print ("[call_torch_export_custom] conversion to IR..." )
16881742 begin = time .perf_counter ()
0 commit comments