Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ jobs:
- name: Install requirements dev
run: python -m pip install -r requirements-dev.txt

- name: Uninstall onnx-diagnostic
run: python -m pip uninstall -y onnx-diagnostic

- name: Uninstall onnx and install onnx-weekly
run: |
python -m pip uninstall -y onnx
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ jobs:
- name: Install requirements dev
run: python -m pip install -r requirements-dev.txt

- name: Uninstall onnx-diagnostic
run: python -m pip uninstall -y onnx-diagnostic

- name: Uninstall onnx and install onnx-weekly
run: |
python -m pip uninstall -y onnx
Expand Down
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.6.1
+++++

* :pr:`126`: add repeat and warmup to command line validate
* :pr:`125`: handles sequences in TorchOnnxEvaluator
* :pr:`123`: add subgraphs to TorchOnnxEvaluator
* :pr:`122`: add local functions to TorchOnnxEvaluator
Expand Down
2 changes: 2 additions & 0 deletions _unittests/ut_torch_models/test_test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ def test_validate_model_custom(self):
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
optimization="default",
quiet=False,
repeat=2,
warmup=1,
)
self.assertIsInstance(summary, dict)
self.assertIsInstance(data, dict)
Expand Down
11 changes: 11 additions & 0 deletions onnx_diagnostic/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,15 @@ def get_parser_validate() -> ArgumentParser:
"example: --mop attn_implementation=eager",
action=_ParseDict,
)
parser.add_argument(
"--repeat",
default=1,
type=int,
help="number of times to run the model to measures inference time",
)
parser.add_argument(
"--warmup", default=0, type=int, help="number of times to run the model to do warmup"
)
return parser


Expand Down Expand Up @@ -460,6 +469,8 @@ def _cmd_validate(argv: List[Any]):
subfolder=args.subfolder,
opset=args.opset,
runtime=args.runtime,
repeat=args.repeat,
warmup=args.warmup,
)
print("")
print("-- summary --")
Expand Down
3 changes: 2 additions & 1 deletion onnx_diagnostic/helpers/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,7 +698,8 @@ def string_type(
print(f"[string_type] CONFIG:{type(obj)}")
s = str(obj.to_diff_dict()).replace("\n", "").replace(" ", "")
return f"{obj.__class__.__name__}(**{s})"

if obj.__class__.__name__ in {"TorchModelContainer", "InferenceSession"}:
return f"{obj.__class__.__name__}(...)"
if verbose:
print(f"[string_type] END:{type(obj)}")
raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")
Expand Down
68 changes: 61 additions & 7 deletions onnx_diagnostic/torch_models/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import time
import numpy as np
import onnx
import onnxscript
import onnxscript.rewriter.ort_fusions as ort_fusions
Expand Down Expand Up @@ -193,20 +194,46 @@ def _quiet_or_not_quiet(
summary: Dict[str, Any],
data: Optional[Dict[str, Any]],
fct: Callable,
repeat: int = 1,
warmup: int = 0,
) -> Any:
begin = time.perf_counter()
if quiet:
try:
return fct()
res = fct()
summary[f"time_{suffix}"] = time.perf_counter() - begin
if warmup + repeat == 1:
return res
except Exception as e:
summary[f"ERR_{suffix}"] = str(e)
summary[f"time_{suffix}"] = time.perf_counter() - begin
if data is None:
return {f"ERR_{suffix}": e}
data[f"ERR_{suffix}"] = e
return None
res = fct()
else:
res = fct()
summary[f"time_{suffix}"] = time.perf_counter() - begin
if warmup + repeat > 1:
if suffix == "run":
res = torch_deepcopy(res)
summary[f"{suffix}_output"] = string_type(res, with_shape=True, with_min_max=True)
summary[f"{suffix}_warmup"] = warmup
summary[f"{suffix}_repeat"] = repeat
for _w in range(max(0, warmup - 1)):
t = fct()
summary[f"io_{suffix}_{_w+1}"] = string_type(t, with_shape=True, with_min_max=True)
summary[f"time_{suffix}_warmup"] = time.perf_counter() - begin
times = []
for _r in range(repeat):
begin = time.perf_counter()
t = fct()
times.append(time.perf_counter() - begin)
a = np.array(times)
summary[f"time_{suffix}_latency"] = a.mean()
summary[f"time_{suffix}_latency_std"] = a.std()
summary[f"time_{suffix}_latency_min"] = a.min()
summary[f"time_{suffix}_latency_min"] = a.max()
return res


Expand Down Expand Up @@ -246,6 +273,8 @@ def validate_model(
subfolder: Optional[str] = None,
opset: Optional[int] = None,
runtime: str = "onnxruntime",
repeat: int = 1,
warmup: int = 0,
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
"""
Validates a model.
Expand Down Expand Up @@ -284,6 +313,8 @@ def validate_model(
:param opset: onnx opset to use for the conversion
:param runtime: onnx runtime to use to check about discrepancies,
only if `do_run` is true
:param repeat: number of time to measure the model
:param warmup: warmup the model first
:return: two dictionaries, one with some metrics,
another one with whatever the function produces

Expand Down Expand Up @@ -480,7 +511,13 @@ def validate_model(
model = data["model"]

expected = _quiet_or_not_quiet(
quiet, "run", summary, data, (lambda m=model, inp=inputs: m(**inp))
quiet,
"run",
summary,
data,
(lambda m=model, inp=inputs: m(**torch_deepcopy(inp))),
repeat=repeat,
warmup=warmup,
)
if "ERR_run" in summary:
return summary, data
Expand Down Expand Up @@ -639,7 +676,12 @@ def validate_model(

if do_run:
summary_valid, data = validate_onnx_model(
data=data, quiet=quiet, verbose=verbose, runtime=runtime
data=data,
quiet=quiet,
verbose=verbose,
runtime=runtime,
repeat=repeat,
warmup=warmup,
)
summary.update(summary_valid)

Expand Down Expand Up @@ -693,7 +735,13 @@ def validate_model(

if do_run:
summary_valid, data = validate_onnx_model(
data=data, quiet=quiet, verbose=verbose, flavour=flavour, runtime=runtime
data=data,
quiet=quiet,
verbose=verbose,
flavour=flavour,
runtime=runtime,
repeat=repeat,
warmup=warmup,
)
summary.update(summary_valid)

Expand Down Expand Up @@ -906,6 +954,8 @@ def validate_onnx_model(
verbose: int = 0,
flavour: Optional[str] = None,
runtime: str = "onnxruntime",
repeat: int = 1,
warmup: int = 0,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Verifies that an onnx model produces the same
Expand All @@ -919,6 +969,8 @@ def validate_onnx_model(
:param verbose: verbosity
:param flavour: use a different version of the inputs
:param runtime: onnx runtime to use, onnxruntime or torch
:param repeat: run that number of times the model
:param warmup: warmup the model
:return: two dictionaries, one with some metrics,
another one with whatever the function produces
"""
Expand Down Expand Up @@ -976,12 +1028,12 @@ def _mk(key):
)
sess = _quiet_or_not_quiet(
quiet,
_mk("time_onnx_ort_create"),
_mk("onnx_ort_create"),
summary,
data,
(lambda source=source, providers=providers: cls_runtime(source, providers)),
)
if f"ERR_{_mk('time_onnx_ort_create')}" in summary:
if f"ERR_{_mk('onnx_ort_create')}" in summary:
return summary, data

data[_mk("onnx_ort_sess")] = sess
Expand Down Expand Up @@ -1009,6 +1061,8 @@ def _mk(key):
summary,
data,
(lambda sess=sess, feeds=feeds: sess.run(None, feeds)),
repeat=repeat,
warmup=warmup,
)
if f"ERR_{_mk('time_onnx_ort_run')}" in summary:
return summary, data
Expand Down
Loading