77import time
88import numpy as np
99import onnx
10- import onnxscript
11- import onnxscript .rewriter .ort_fusions as ort_fusions
1210import torch
1311from ..export import CoupleInputsDynamicShapes
1412from ..helpers import max_diff , string_type , string_diff
@@ -249,6 +247,7 @@ def _quiet_or_not_quiet(
249247 summary [f"time_{ suffix } _latency_std" ] = a .std ()
250248 summary [f"time_{ suffix } _latency_min" ] = a .min ()
251249 summary [f"time_{ suffix } _latency_min" ] = a .max ()
250+ summary [f"time_{ suffix } _n" ] = len (a )
252251 return res
253252
254253
@@ -337,7 +336,8 @@ def validate_model(
337336 :param subfolder: version or subfolders to uses when retrieving a model id
338337 :param opset: onnx opset to use for the conversion
339338 :param runtime: onnx runtime to use to check about discrepancies,
340- only if `do_run` is true
339+ possible values ``onnxruntime``, ``torch``, ``orteval``,
340+ ``orteval10``, ``ref`` only if `do_run` is true
341341 :param repeat: number of time to measure the model
342342 :param warmup: warmup the model first
343343 :param inputs2: checks that the second set of inputs is reunning as well,
@@ -364,7 +364,13 @@ def validate_model(
364364
365365 The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
366366 exported model returns the same outputs as the original one, otherwise,
367- :class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
367+ :class:`onnx_diagnostic.reference.TorchOnnxEvaluator`
368+ if ``runtime == 'torch'`` or
369+ :class:`onnx_diagnostic.reference.OnnxruntimeEvaluator`
370+ if ``runtime == 'orteval'`` or
371+ :class:`onnx_diagnostic.reference.ExtendedReferenceEvaluator`
372+ if ``runtime == 'ref'``,
373+ ``orteval10`` increases the verbosity.
368374 """
369375 if isinstance (patch , bool ):
370376 patch_kwargs = (
@@ -846,15 +852,24 @@ def node_iter(proto):
846852 raise NotImplementedError (f"Unexpected type={ type (proto )} " )
847853
848854 counts : Dict [str , Union [float , int ]] = {}
855+ n_nodes = 0
856+ n_nodes_nocst = 0
849857 for proto in node_iter (onx ):
850858 if isinstance (proto , onnx .NodeProto ):
851859 key = f"n_node_{ proto .op_type } "
860+ n_nodes += 1
861+ if proto .op_type != "Constant" :
862+ n_nodes_nocst += 1
852863 else :
853864 key = f"n_node_initializer_{ proto .data_type } "
854865
855866 if key not in counts :
856867 counts [key ] = 0
857868 counts [key ] += 1
869+
870+ counts ["n_node_nodes" ] = n_nodes
871+ counts ["n_node_nodes_nocst" ] = n_nodes_nocst
872+ counts ["n_node_functions" ] = len (onx .functions )
858873 return counts
859874
860875
@@ -1155,7 +1170,7 @@ def validate_onnx_model(
11551170 :param quiet: catch exception or not
11561171 :param verbose: verbosity
11571172 :param flavour: use a different version of the inputs
1158- :param runtime: onnx runtime to use, onnxruntime or torch
1173+ :param runtime: onnx runtime to use, onnxruntime, torch, orteval, ref
11591174 :param repeat: run that number of times the model
11601175 :param warmup: warmup the model
11611176 :param inputs2: to validate the model on the second input set
@@ -1202,23 +1217,66 @@ def _mk(key, flavour=flavour):
12021217 f"{ providers } ..., flavour={ flavour !r} "
12031218 )
12041219
1205- if runtime != "onnxruntime" :
1220+ if runtime == "onnxruntime" :
1221+ if os .environ .get ("DUMPORTOPT" , "" ) in ("1" , "true" , "True" ):
1222+ opts = onnxruntime .SessionOptions ()
1223+ opts .optimized_model_filepath = f"{ data ['onnx_filename' ]} .rtopt.onnx"
1224+ if verbose :
1225+ print (
1226+ f"[validate_onnx_model] saved optimized onnxruntime "
1227+ f"in { opts .optimized_model_filepath !r} "
1228+ )
1229+ onnxruntime .InferenceSession (data ["onnx_filename" ], opts , providers = providers )
1230+ if verbose :
1231+ print ("[validate_onnx_model] -- done" )
1232+
1233+ if verbose :
1234+ print ("[validate_onnx_model] runtime is onnxruntime" )
1235+ cls_runtime = lambda model , providers : onnxruntime .InferenceSession (
1236+ (model .SerializeToString () if isinstance (model , onnx .ModelProto ) else model ),
1237+ providers = providers ,
1238+ )
1239+ elif runtime == "torch" :
12061240 from ..reference import TorchOnnxEvaluator
12071241
1208- cls_runtime = (
1209- (
1210- lambda model , providers : onnxruntime . InferenceSession (
1211- ( model . SerializeToString () if isinstance ( model , onnx . ModelProto ) else model ),
1212- providers = providers ,
1242+ if verbose :
1243+ print ( "[validate_onnx_model] runtime is TorchOnnxEvaluator" )
1244+ cls_runtime = (
1245+ lambda model , providers , _cls_ = TorchOnnxEvaluator : _cls_ ( # type: ignore[misc]
1246+ model , providers = providers , verbose = max ( verbose - 1 , 0 )
12131247 )
12141248 )
1215- if runtime == "onnxruntime"
1216- else (
1217- lambda model , providers , _cls_ = TorchOnnxEvaluator : _cls_ ( # type: ignore[misc]
1249+ elif runtime == "orteval" :
1250+ from ..reference import OnnxruntimeEvaluator
1251+
1252+ if verbose :
1253+ print ("[validate_onnx_model] runtime is OnnxruntimeEvaluator" )
1254+ cls_runtime = (
1255+ lambda model , providers , _cls_ = OnnxruntimeEvaluator : _cls_ ( # type: ignore[misc]
12181256 model , providers = providers , verbose = max (verbose - 1 , 0 )
12191257 )
12201258 )
1221- )
1259+ elif runtime == "orteval10" :
1260+ from ..reference import OnnxruntimeEvaluator
1261+
1262+ if verbose :
1263+ print ("[validate_onnx_model] runtime is OnnxruntimeEvaluator(verbose=10)" )
1264+ cls_runtime = (
1265+ lambda model , providers , _cls_ = OnnxruntimeEvaluator : _cls_ ( # type: ignore[misc]
1266+ model , providers = providers , verbose = 10
1267+ )
1268+ )
1269+ elif runtime == "ref" :
1270+ from ..reference import ExtendedReferenceEvaluator
1271+
1272+ if verbose :
1273+ print ("[validate_onnx_model] runtime is ExtendedReferenceEvaluator" )
1274+ cls_runtime = lambda model , providers , _cls_ = ExtendedReferenceEvaluator : _cls_ ( # type: ignore[misc]
1275+ model , verbose = max (verbose - 1 , 0 )
1276+ )
1277+ else :
1278+ raise ValueError (f"Unexpecteed runtime={ runtime !r} " )
1279+
12221280 sess = _quiet_or_not_quiet (
12231281 quiet ,
12241282 _mk ("create_onnx_ort" ),
@@ -1399,6 +1457,8 @@ def call_torch_export_onnx(
13991457 if optimization == "ir" :
14001458 label , f_optim = "export_onnx_opt_ir" , (lambda epo = epo : epo .optimize ())
14011459 else :
1460+ import onnxscript
1461+ import onnxscript .rewriter .ort_fusions as ort_fusions
14021462
14031463 def _os_ort_optim (epo ):
14041464 onnxscript .optimizer .optimize_ir (epo .model )
@@ -1683,6 +1743,9 @@ def call_torch_export_custom(
16831743 print ("[call_torch_export_custom] done (export)" )
16841744
16851745 if os_ort :
1746+ import onnxscript
1747+ import onnxscript .rewriter .ort_fusions as ort_fusions
1748+
16861749 if verbose :
16871750 print ("[call_torch_export_custom] conversion to IR..." )
16881751 begin = time .perf_counter ()
0 commit comments