44import sys
55from typing import Any , Callable , Dict , List , Optional , Tuple , Union
66import time
7+ import numpy as np
78import onnx
89import onnxscript
910import onnxscript .rewriter .ort_fusions as ort_fusions
@@ -193,20 +194,46 @@ def _quiet_or_not_quiet(
193194 summary : Dict [str , Any ],
194195 data : Optional [Dict [str , Any ]],
195196 fct : Callable ,
197+ repeat : int = 1 ,
198+ warmup : int = 0 ,
196199) -> Any :
197200 begin = time .perf_counter ()
198201 if quiet :
199202 try :
200- return fct ()
203+ res = fct ()
204+ summary [f"time_{ suffix } " ] = time .perf_counter () - begin
205+ if warmup + repeat == 1 :
206+ return res
201207 except Exception as e :
202208 summary [f"ERR_{ suffix } " ] = str (e )
203209 summary [f"time_{ suffix } " ] = time .perf_counter () - begin
204210 if data is None :
205211 return {f"ERR_{ suffix } " : e }
206212 data [f"ERR_{ suffix } " ] = e
207213 return None
208- res = fct ()
214+ else :
215+ res = fct ()
209216 summary [f"time_{ suffix } " ] = time .perf_counter () - begin
217+ if warmup + repeat > 1 :
218+ if suffix == "run" :
219+ res = torch_deepcopy (res )
220+ summary [f"{ suffix } _output" ] = string_type (res , with_shape = True , with_min_max = True )
221+ summary [f"{ suffix } _warmup" ] = warmup
222+ summary [f"{ suffix } _repeat" ] = repeat
223+ for _w in range (max (0 , warmup - 1 )):
224+ t = fct ()
225+ summary [f"io_{ suffix } _{ _w + 1 } " ] = string_type (t , with_shape = True , with_min_max = True )
226+ summary [f"time_{ suffix } _warmup" ] = time .perf_counter () - begin
227+ times = []
228+ for _r in range (repeat ):
229+ begin = time .perf_counter ()
230+ t = fct ()
231+ times .append (time .perf_counter () - begin )
232+ a = np .array (times )
233+ summary [f"time_{ suffix } _latency" ] = a .mean ()
234+ summary [f"time_{ suffix } _latency_std" ] = a .std ()
235+ summary [f"time_{ suffix } _latency_min" ] = a .min ()
236+ summary [f"time_{ suffix } _latency_min" ] = a .max ()
210237 return res
211238
212239
@@ -246,6 +273,8 @@ def validate_model(
246273 subfolder : Optional [str ] = None ,
247274 opset : Optional [int ] = None ,
248275 runtime : str = "onnxruntime" ,
276+ repeat : int = 1 ,
277+ warmup : int = 0 ,
249278) -> Tuple [Dict [str , Union [int , float , str ]], Dict [str , Any ]]:
250279 """
251280 Validates a model.
@@ -284,6 +313,8 @@ def validate_model(
284313 :param opset: onnx opset to use for the conversion
285314 :param runtime: onnx runtime to use to check about discrepancies,
286315 only if `do_run` is true
316+ :param repeat: number of time to measure the model
317+ :param warmup: warmup the model first
287318 :return: two dictionaries, one with some metrics,
288319 another one with whatever the function produces
289320
@@ -480,7 +511,13 @@ def validate_model(
480511 model = data ["model" ]
481512
482513 expected = _quiet_or_not_quiet (
483- quiet , "run" , summary , data , (lambda m = model , inp = inputs : m (** inp ))
514+ quiet ,
515+ "run" ,
516+ summary ,
517+ data ,
518+ (lambda m = model , inp = inputs : m (** inp )),
519+ repeat = repeat ,
520+ warmup = warmup ,
484521 )
485522 if "ERR_run" in summary :
486523 return summary , data
@@ -639,7 +676,12 @@ def validate_model(
639676
640677 if do_run :
641678 summary_valid , data = validate_onnx_model (
642- data = data , quiet = quiet , verbose = verbose , runtime = runtime
679+ data = data ,
680+ quiet = quiet ,
681+ verbose = verbose ,
682+ runtime = runtime ,
683+ repeat = repeat ,
684+ warmup = warmup ,
643685 )
644686 summary .update (summary_valid )
645687
@@ -693,7 +735,13 @@ def validate_model(
693735
694736 if do_run :
695737 summary_valid , data = validate_onnx_model (
696- data = data , quiet = quiet , verbose = verbose , flavour = flavour , runtime = runtime
738+ data = data ,
739+ quiet = quiet ,
740+ verbose = verbose ,
741+ flavour = flavour ,
742+ runtime = runtime ,
743+ repeat = repeat ,
744+ warmup = warmup ,
697745 )
698746 summary .update (summary_valid )
699747
@@ -906,6 +954,8 @@ def validate_onnx_model(
906954 verbose : int = 0 ,
907955 flavour : Optional [str ] = None ,
908956 runtime : str = "onnxruntime" ,
957+ repeat : int = 1 ,
958+ warmup : int = 0 ,
909959) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
910960 """
911961 Verifies that an onnx model produces the same
@@ -919,6 +969,8 @@ def validate_onnx_model(
919969 :param verbose: verbosity
920970 :param flavour: use a different version of the inputs
921971 :param runtime: onnx runtime to use, onnxruntime or torch
972+ :param repeat: run that number of times the model
973+ :param warmup: warmup the model
922974 :return: two dictionaries, one with some metrics,
923975 another one with whatever the function produces
924976 """
@@ -976,12 +1028,12 @@ def _mk(key):
9761028 )
9771029 sess = _quiet_or_not_quiet (
9781030 quiet ,
979- _mk ("time_onnx_ort_create " ),
1031+ _mk ("onnx_ort_create " ),
9801032 summary ,
9811033 data ,
9821034 (lambda source = source , providers = providers : cls_runtime (source , providers )),
9831035 )
984- if f"ERR_{ _mk ('time_onnx_ort_create ' )} " in summary :
1036+ if f"ERR_{ _mk ('onnx_ort_create ' )} " in summary :
9851037 return summary , data
9861038
9871039 data [_mk ("onnx_ort_sess" )] = sess
@@ -1009,6 +1061,8 @@ def _mk(key):
10091061 summary ,
10101062 data ,
10111063 (lambda sess = sess , feeds = feeds : sess .run (None , feeds )),
1064+ repeat = repeat ,
1065+ warmup = warmup ,
10121066 )
10131067 if f"ERR_{ _mk ('time_onnx_ort_run' )} " in summary :
10141068 return summary , data
0 commit comments