Skip to content

Commit 821f9be

Browse files
committed
torch-onnx
1 parent 10afdd3 commit 821f9be

File tree

5 files changed

+62
-8
lines changed

5 files changed

+62
-8
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ Change Logs
44
0.6.1
55
+++++
66

7-
* :pr:`115`, :pr:`116`, :pr:`117`, :pr:`118`: first steps for TorchOnnxEvaluator
7+
* :pr:`115`, :pr:`116`, :pr:`117`, :pr:`118`, :pr:`119`:
8+
first steps for TorchOnnxEvaluator
89
* :pr:`114`: extends the list of known rewritings
910
* :pr:`113`: fixes a couple of issues with ModelBuilder
1011

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,34 @@ def test_validate_model_custom(self):
164164
onnx_filename, output_path, num_attention_heads=2, hidden_size=192, verbose=10
165165
)
166166

167+
@requires_torch("2.7")
168+
@hide_stdout()
169+
@ignore_warnings(FutureWarning)
170+
@requires_experimental()
171+
def test_validate_model_custom_torch(self):
172+
mid = "arnir0/Tiny-LLM"
173+
summary, data = validate_model(
174+
mid,
175+
do_run=True,
176+
verbose=10,
177+
exporter="custom",
178+
dump_folder="dump_test_validate_model_custom_torch",
179+
patch=True,
180+
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
181+
optimization="default",
182+
quiet=False,
183+
runtime="torch",
184+
)
185+
self.assertIsInstance(summary, dict)
186+
self.assertIsInstance(data, dict)
187+
self.assertIn("disc_onnx_ort_run_abs", summary)
188+
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
189+
onnx_filename = data["onnx_filename"]
190+
output_path = f"{onnx_filename}.ortopt.onnx"
191+
run_ort_fusion(
192+
onnx_filename, output_path, num_attention_heads=2, hidden_size=192, verbose=10
193+
)
194+
167195
def test_filter_inputs(self):
168196
inputs, ds = {"a": 1, "b": 2}, {"a": 20, "b": 30}
169197
ni, nd = filter_inputs(inputs, dynamic_shapes=ds, drop_names=["a"])

onnx_diagnostic/_command_lines_parser.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,12 @@ def get_parser_validate() -> ArgumentParser:
352352
action=BooleanOptionalAction,
353353
help="validate the trained model (requires downloading)",
354354
)
355+
parser.add_argument(
356+
"--runtime",
357+
choices=["onnxruntime", "torch", "ref"],
358+
default="onnxruntime",
359+
help="onnx runtime to use, ",
360+
)
355361
parser.add_argument(
356362
"-o",
357363
"--dump-folder",
@@ -453,6 +459,7 @@ def _cmd_validate(argv: List[Any]):
453459
model_options=args.mop,
454460
subfolder=args.subfolder,
455461
opset=args.opset,
462+
runtime=args.runtime,
456463
)
457464
print("")
458465
print("-- summary --")

onnx_diagnostic/reference/torch_evaluator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def __init__(
6060
else:
6161
self.default_device = self.CPU
6262

63+
if isinstance(proto, str):
64+
proto = onnx.load(proto)
6365
if isinstance(proto, onnx.ModelProto):
6466
assert opsets is None, "proto is a model, opsets must be None in that case"
6567
assert not proto.graph.sparse_initializer, "sparse_initializer not support yet"

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ..tasks import random_input_kwargs
1818
from ..torch_export_patches import torch_export_patches
1919
from ..torch_export_patches.patch_inputs import use_dyn_not_str
20+
from ..reference import TorchOnnxEvaluator
2021
from .hghub import get_untrained_model_with_inputs
2122

2223

@@ -244,6 +245,7 @@ def validate_model(
244245
model_options: Optional[Dict[str, Any]] = None,
245246
subfolder: Optional[str] = None,
246247
opset: Optional[int] = None,
248+
runtime: str = "onnxruntime",
247249
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
248250
"""
249251
Validates a model.
@@ -280,6 +282,8 @@ def validate_model(
280282
``num_hidden_layers`` or ``attn_implementation``
281283
:param subfolder: version or subfolders to uses when retrieving a model id
282284
:param opset: onnx opset to use for the conversion
285+
:param runtime: onnx runtime to use to check about discrepancies,
286+
only if `do_run` is true
283287
:return: two dictionaries, one with some metrics,
284288
another one with whatever the function produces
285289
@@ -308,6 +312,7 @@ def validate_model(
308312
version_ortfusiontype=ortfusiontype or "",
309313
version_stop_if_static=str(stop_if_static),
310314
version_exporter=exporter or "",
315+
version_runtime=runtime,
311316
)
312317
)
313318
if opset:
@@ -633,7 +638,9 @@ def validate_model(
633638
return summary, data
634639

635640
if do_run:
636-
summary_valid, data = validate_onnx_model(data=data, quiet=quiet, verbose=verbose)
641+
summary_valid, data = validate_onnx_model(
642+
data=data, quiet=quiet, verbose=verbose, runtime=runtime
643+
)
637644
summary.update(summary_valid)
638645

639646
if ortfusiontype and "onnx_filename" in data:
@@ -686,7 +693,7 @@ def validate_model(
686693

687694
if do_run:
688695
summary_valid, data = validate_onnx_model(
689-
data=data, quiet=quiet, verbose=verbose, flavour=flavour
696+
data=data, quiet=quiet, verbose=verbose, flavour=flavour, runtime=runtime
690697
)
691698
summary.update(summary_valid)
692699

@@ -898,6 +905,7 @@ def validate_onnx_model(
898905
quiet: bool = False,
899906
verbose: int = 0,
900907
flavour: Optional[str] = None,
908+
runtime: str = "onnxruntime",
901909
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
902910
"""
903911
Verifies that an onnx model produces the same
@@ -910,6 +918,7 @@ def validate_onnx_model(
910918
:param quiet: catch exception or not
911919
:param verbose: verbosity
912920
:param flavour: use a different version of the inputs
921+
:param runtime: onnx runtime to use, onnxruntime or torch
913922
:return: two dictionaries, one with some metrics,
914923
another one with whatever the function produces
915924
"""
@@ -951,16 +960,23 @@ def _mk(key):
951960
f"{providers}..., flavour={flavour!r}"
952961
)
953962

963+
assert runtime == "torch", f"runtime={runtime!r}"
964+
cls_runtime = (
965+
(
966+
lambda model, providers: onnxruntime.InferenceSession(
967+
(model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
968+
providers=providers,
969+
)
970+
)
971+
if runtime == "onnxruntime"
972+
else (lambda model, providers: TorchOnnxEvaluator(model, providers=providers))
973+
)
954974
sess = _quiet_or_not_quiet(
955975
quiet,
956976
_mk("time_onnx_ort_create"),
957977
summary,
958978
data,
959-
(
960-
lambda source=source, providers=providers: onnxruntime.InferenceSession(
961-
source, providers=providers
962-
)
963-
),
979+
(lambda source=source, providers=providers: cls_runtime(source, providers)),
964980
)
965981
if f"ERR_{_mk('time_onnx_ort_create')}" in summary:
966982
return summary, data

0 commit comments

Comments
 (0)