44import os
55import pprint
66import sys
7- from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union
7+ from typing import Any , Callable , Dict , List , Optional , Sequence , Set , Tuple , Union
88import time
99import numpy as np
1010import onnx
@@ -319,6 +319,7 @@ def validate_model(
319319 inputs2 : int = 1 ,
320320 output_names : Optional [List [str ]] = None ,
321321 ort_logs : bool = False ,
322+ quiet_input_sets : Optional [Set [str ]] = None ,
322323) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
323324 """
324325 Validates a model.
@@ -373,6 +374,8 @@ def validate_model(
373374 or an empty cache for example
374375 :param output_names: output names the onnx exporter should use
375376 :param ort_logs: increases onnxruntime verbosity when creating the session
377+ :param quiet_input_sets: avoid raising an exception if the inputs belongs to that set
378+ even if quiet is False
376379 :return: two dictionaries, one with some metrics,
377380 another one with whatever the function produces
378381
@@ -842,6 +845,7 @@ def validate_model(
842845 warmup = warmup ,
843846 second_input_keys = second_input_keys ,
844847 ort_logs = ort_logs ,
848+ quiet_input_sets = quiet_input_sets ,
845849 )
846850 summary .update (summary_valid )
847851 summary ["time_total_validation_onnx" ] = time .perf_counter () - validation_begin
@@ -904,6 +908,7 @@ def validate_model(
904908 repeat = repeat ,
905909 warmup = warmup ,
906910 second_input_keys = second_input_keys ,
911+ quiet_input_sets = quiet_input_sets ,
907912 )
908913 summary .update (summary_valid )
909914
@@ -1289,6 +1294,7 @@ def validate_onnx_model(
12891294 warmup : int = 0 ,
12901295 second_input_keys : Optional [List [str ]] = None ,
12911296 ort_logs : bool = False ,
1297+ quiet_input_sets : Optional [Set [str ]] = None ,
12921298) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
12931299 """
12941300 Verifies that an onnx model produces the same
@@ -1308,6 +1314,7 @@ def validate_onnx_model(
13081314 to make sure the exported model supports dynamism, the value is
13091315 used as an increment added to the first set of inputs (added to dimensions)
13101316 :param ort_logs: triggers the logs for onnxruntime
1317+ :param quiet_input_sets: avoid raising an exception for these sets of inputs
13111318 :return: two dictionaries, one with some metrics,
13121319 another one with whatever the function produces
13131320 """
@@ -1455,10 +1462,12 @@ def _mk(key, flavour=flavour):
14551462
14561463 # run ort
14571464 if verbose :
1458- print ("[validate_onnx_model] run session..." )
1465+ print (f"[validate_onnx_model] run session on inputs 'inputs{ suffix } '..." )
1466+ if quiet_input_sets :
1467+ print (f"[validate_onnx_model] quiet_input_sets={ quiet_input_sets } " )
14591468
14601469 got = _quiet_or_not_quiet (
1461- quiet ,
1470+ quiet or ( quiet_input_sets is not None and f"inputs { suffix } " in quiet_input_sets ) ,
14621471 _mk (f"run_onnx_ort{ suffix } " ),
14631472 summary ,
14641473 data ,
0 commit comments