Skip to content

Commit 53336a1

Browse files
committed
Add repeat and warmup
1 parent 00480d1 commit 53336a1

File tree

4 files changed

+76
-8
lines changed

4 files changed

+76
-8
lines changed

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ def test_validate_model_custom(self):
153153
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
154154
optimization="default",
155155
quiet=False,
156+
repeat=2,
157+
warmup=1,
156158
)
157159
self.assertIsInstance(summary, dict)
158160
self.assertIsInstance(data, dict)

onnx_diagnostic/_command_lines_parser.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,15 @@ def get_parser_validate() -> ArgumentParser:
405405
"example: --mop attn_implementation=eager",
406406
action=_ParseDict,
407407
)
408+
parser.add_argument(
409+
"--repeat",
410+
default=1,
411+
type=int,
412+
help="number of times to run the model to measures inference time",
413+
)
414+
parser.add_argument(
415+
"--warmup", default=0, type=int, help="number of times to run the model to do warmup"
416+
)
408417
return parser
409418

410419

@@ -460,6 +469,8 @@ def _cmd_validate(argv: List[Any]):
460469
subfolder=args.subfolder,
461470
opset=args.opset,
462471
runtime=args.runtime,
472+
repeat=args.repeat,
473+
warmup=args.warmup,
463474
)
464475
print("")
465476
print("-- summary --")

onnx_diagnostic/helpers/helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,8 @@ def string_type(
698698
print(f"[string_type] CONFIG:{type(obj)}")
699699
s = str(obj.to_diff_dict()).replace("\n", "").replace(" ", "")
700700
return f"{obj.__class__.__name__}(**{s})"
701-
701+
if obj.__class__.__name__ in {"TorchModelContainer", "InferenceSession"}:
702+
return f"{obj.__class__.__name__}(...)"
702703
if verbose:
703704
print(f"[string_type] END:{type(obj)}")
704705
raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import sys
55
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
66
import time
7+
import numpy as np
78
import onnx
89
import onnxscript
910
import onnxscript.rewriter.ort_fusions as ort_fusions
@@ -193,20 +194,46 @@ def _quiet_or_not_quiet(
193194
summary: Dict[str, Any],
194195
data: Optional[Dict[str, Any]],
195196
fct: Callable,
197+
repeat: int = 1,
198+
warmup: int = 0,
196199
) -> Any:
197200
begin = time.perf_counter()
198201
if quiet:
199202
try:
200-
return fct()
203+
res = fct()
204+
summary[f"time_{suffix}"] = time.perf_counter() - begin
205+
if warmup + repeat == 1:
206+
return res
201207
except Exception as e:
202208
summary[f"ERR_{suffix}"] = str(e)
203209
summary[f"time_{suffix}"] = time.perf_counter() - begin
204210
if data is None:
205211
return {f"ERR_{suffix}": e}
206212
data[f"ERR_{suffix}"] = e
207213
return None
208-
res = fct()
214+
else:
215+
res = fct()
209216
summary[f"time_{suffix}"] = time.perf_counter() - begin
217+
if warmup + repeat > 1:
218+
if suffix == "run":
219+
res = torch_deepcopy(res)
220+
summary[f"{suffix}_output"] = string_type(res, with_shape=True, with_min_max=True)
221+
summary[f"{suffix}_warmup"] = warmup
222+
summary[f"{suffix}_repeat"] = repeat
223+
for _w in range(max(0, warmup - 1)):
224+
t = fct()
225+
summary[f"io_{suffix}_{_w+1}"] = string_type(t, with_shape=True, with_min_max=True)
226+
summary[f"time_{suffix}_warmup"] = time.perf_counter() - begin
227+
times = []
228+
for _r in range(repeat):
229+
begin = time.perf_counter()
230+
t = fct()
231+
times.append(time.perf_counter() - begin)
232+
a = np.array(times)
233+
summary[f"time_{suffix}_latency"] = a.mean()
234+
summary[f"time_{suffix}_latency_std"] = a.std()
235+
summary[f"time_{suffix}_latency_min"] = a.min()
236+
summary[f"time_{suffix}_latency_min"] = a.max()
210237
return res
211238

212239

@@ -246,6 +273,8 @@ def validate_model(
246273
subfolder: Optional[str] = None,
247274
opset: Optional[int] = None,
248275
runtime: str = "onnxruntime",
276+
repeat: int = 1,
277+
warmup: int = 0,
249278
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
250279
"""
251280
Validates a model.
@@ -284,6 +313,8 @@ def validate_model(
284313
:param opset: onnx opset to use for the conversion
285314
:param runtime: onnx runtime to use to check about discrepancies,
286315
only if `do_run` is true
316+
:param repeat: number of time to measure the model
317+
:param warmup: warmup the model first
287318
:return: two dictionaries, one with some metrics,
288319
another one with whatever the function produces
289320
@@ -480,7 +511,13 @@ def validate_model(
480511
model = data["model"]
481512

482513
expected = _quiet_or_not_quiet(
483-
quiet, "run", summary, data, (lambda m=model, inp=inputs: m(**inp))
514+
quiet,
515+
"run",
516+
summary,
517+
data,
518+
(lambda m=model, inp=inputs: m(**inp)),
519+
repeat=repeat,
520+
warmup=warmup,
484521
)
485522
if "ERR_run" in summary:
486523
return summary, data
@@ -639,7 +676,12 @@ def validate_model(
639676

640677
if do_run:
641678
summary_valid, data = validate_onnx_model(
642-
data=data, quiet=quiet, verbose=verbose, runtime=runtime
679+
data=data,
680+
quiet=quiet,
681+
verbose=verbose,
682+
runtime=runtime,
683+
repeat=repeat,
684+
warmup=warmup,
643685
)
644686
summary.update(summary_valid)
645687

@@ -693,7 +735,13 @@ def validate_model(
693735

694736
if do_run:
695737
summary_valid, data = validate_onnx_model(
696-
data=data, quiet=quiet, verbose=verbose, flavour=flavour, runtime=runtime
738+
data=data,
739+
quiet=quiet,
740+
verbose=verbose,
741+
flavour=flavour,
742+
runtime=runtime,
743+
repeat=repeat,
744+
warmup=warmup,
697745
)
698746
summary.update(summary_valid)
699747

@@ -906,6 +954,8 @@ def validate_onnx_model(
906954
verbose: int = 0,
907955
flavour: Optional[str] = None,
908956
runtime: str = "onnxruntime",
957+
repeat: int = 1,
958+
warmup: int = 0,
909959
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
910960
"""
911961
Verifies that an onnx model produces the same
@@ -919,6 +969,8 @@ def validate_onnx_model(
919969
:param verbose: verbosity
920970
:param flavour: use a different version of the inputs
921971
:param runtime: onnx runtime to use, onnxruntime or torch
972+
:param repeat: run that number of times the model
973+
:param warmup: warmup the model
922974
:return: two dictionaries, one with some metrics,
923975
another one with whatever the function produces
924976
"""
@@ -976,12 +1028,12 @@ def _mk(key):
9761028
)
9771029
sess = _quiet_or_not_quiet(
9781030
quiet,
979-
_mk("time_onnx_ort_create"),
1031+
_mk("onnx_ort_create"),
9801032
summary,
9811033
data,
9821034
(lambda source=source, providers=providers: cls_runtime(source, providers)),
9831035
)
984-
if f"ERR_{_mk('time_onnx_ort_create')}" in summary:
1036+
if f"ERR_{_mk('onnx_ort_create')}" in summary:
9851037
return summary, data
9861038

9871039
data[_mk("onnx_ort_sess")] = sess
@@ -1009,6 +1061,8 @@ def _mk(key):
10091061
summary,
10101062
data,
10111063
(lambda sess=sess, feeds=feeds: sess.run(None, feeds)),
1064+
repeat=repeat,
1065+
warmup=warmup,
10121066
)
10131067
if f"ERR_{_mk('time_onnx_ort_run')}" in summary:
10141068
return summary, data

0 commit comments

Comments
 (0)