Skip to content

Commit 6e86cb9

Browse files
committed
add onnx
1 parent 1ada158 commit 6e86cb9

File tree

4 files changed

+274
-8
lines changed

4 files changed

+274
-8
lines changed

_doc/cmds/validate.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,3 +84,19 @@ the same model.
8484
from onnx_diagnostic._command_lines_parser import main
8585

8686
main("validate -m arnir0/Tiny-LLM --run -v 1 --export export-nostrict -o dump_models --patch".split())
87+
88+
Validate ONNX discrepancies
89+
+++++++++++++++++++++++++++
90+
91+
Let's export with ONNX this time and checks for discrepancies.
92+
93+
.. code-block::
94+
95+
python -m onnx_diagnostic validate -m arnir0/Tiny-LLM --run -v 1 --export onnx-dynamo -o dump_models --patch --opt ir
96+
97+
.. runpython::
98+
99+
from onnx_diagnostic._command_lines_parser import main
100+
101+
main("validate -m arnir0/Tiny-LLM --run -v 1 --export onnx-dynamo -o dump_models --patch --opt ir".split())
102+

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import copy
22
import unittest
3-
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings
44
from onnx_diagnostic.torch_models.test_helper import get_inputs_for_task, validate_model
55
from onnx_diagnostic.torch_models.hghub.model_inputs import get_get_inputs_function_for_tasks
66

@@ -47,6 +47,22 @@ def test_validate_model_export(self):
4747
self.assertIsInstance(summary, dict)
4848
self.assertIsInstance(data, dict)
4949

50+
@hide_stdout()
51+
@ignore_warnings(FutureWarning)
52+
def test_validate_model_onnx(self):
53+
mid = "arnir0/Tiny-LLM"
54+
summary, data = validate_model(
55+
mid,
56+
do_run=True,
57+
verbose=10,
58+
exporter="onnx-dynamo",
59+
dump_folder="dump_test_validate_model_onnx",
60+
patch=True,
61+
)
62+
self.assertIsInstance(summary, dict)
63+
self.assertIsInstance(data, dict)
64+
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
65+
5066

5167
if __name__ == "__main__":
5268
unittest.main(verbosity=2)

onnx_diagnostic/helpers/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def string_type(
272272
return "SymFloat"
273273
# Tensors
274274
if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor):
275-
from .helper.onnx_helper import torch_dtype_to_onnx_dtype
275+
from .onnx_helper import torch_dtype_to_onnx_dtype
276276

277277
i = torch_dtype_to_onnx_dtype(obj.dtype)
278278
prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""

onnx_diagnostic/torch_models/test_helper.py

Lines changed: 240 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import time
44
import torch
55
from ..helpers import max_diff, string_type, string_diff
6+
from ..helpers.helper import flatten_object
7+
from ..helpers.ort_session import make_feeds
68
from ..helpers.torch_test_helper import to_any, torch_deepcopy
79
from ..torch_export_patches import bypass_export_some_errors
810
from .hghub import get_untrained_model_with_inputs
@@ -259,6 +261,14 @@ def validate_model(
259261
f.write(str(ep.graph))
260262
if verbose:
261263
print("[validate_model] done (dump ep)")
264+
if "onnx_program" in data:
265+
epo = data["onnx_program"]
266+
if verbose:
267+
print(f"[validate_model] dumps onnx program in {dump_folder!r}...")
268+
onnx_file_name = os.path.join(dump_folder, f"{folder_name}.onnx")
269+
epo.save(onnx_file_name)
270+
if verbose:
271+
print("[validate_model] done (dump onnx)")
262272
if verbose:
263273
print(f"[validate_model] dumps statistics in {dump_folder!r}...")
264274
with open(os.path.join(dump_folder, f"{folder_name}.stats"), "w") as f:
@@ -267,6 +277,15 @@ def validate_model(
267277
if verbose:
268278
print("[validate_model] done (dump)")
269279

280+
if exporter and exporter.startswith("onnx-") and do_run:
281+
summary_valid, data = validate_onnx_model(
282+
data=data,
283+
quiet=quiet,
284+
verbose=verbose,
285+
optimization=optimization,
286+
)
287+
summary.update(summary_valid)
288+
270289
if verbose:
271290
print("[validate_model] done (final)")
272291
return summary, data
@@ -288,7 +307,6 @@ def call_exporter(
288307
:param exporter: exporter to call
289308
:param quiet: catch exception or not
290309
:param verbose: verbosity
291-
:param patch: apply patches
292310
:param optimization: optimization to do
293311
:param do_run: runs and compute discrepancies
294312
:return: two dictionaries, one with some metrics,
@@ -305,6 +323,16 @@ def call_exporter(
305323
do_run=do_run,
306324
)
307325
return summary, data
326+
if exporter.startswith("onnx-"):
327+
# torch export
328+
summary, data = call_torch_export_onnx(
329+
exporter=exporter,
330+
data=data,
331+
quiet=quiet,
332+
verbose=verbose,
333+
optimization=optimization,
334+
)
335+
return summary, data
308336
raise NotImplementedError(
309337
f"export with {exporter!r} and optimization={optimization!r} not implemented yet"
310338
)
@@ -331,19 +359,23 @@ def call_torch_export_export(
331359
do_run: bool = False,
332360
):
333361
"""
334-
Calls an exporter on a model;
362+
Exports a model with :func:`torch.export.export`.
335363
If a patch must be applied, it should be before this functions.
336364
337-
:param data: dictionary with all the necessary inputs
365+
:param data: dictionary with all the necessary inputs, the dictionary must
366+
contains keys ``model`` and ``inputs_export``
338367
:param exporter: exporter to call
339368
:param quiet: catch exception or not
340369
:param verbose: verbosity
341-
:param patch: apply patches
342370
:param optimization: optimization to do
343371
:param do_run: runs and compute discrepancies
344372
:return: two dictionaries, one with some metrics,
345373
another one with whatever the function produces
346374
"""
375+
assert exporter in {
376+
"export-strict",
377+
"export-nostrict",
378+
}, f"Unexpected value for exporter={exporter!r}"
347379
assert "model" in data, f"model is missing from data: {sorted(data)}"
348380
assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
349381
summary: Dict[str, Union[str, int, float]] = {}
@@ -355,8 +387,8 @@ def call_torch_export_export(
355387
f"[call_torch_export_export] exporter={exporter!r}, "
356388
f"strict={strict}, optimization={optimization!r}"
357389
)
358-
print(f"[call_torch_export_export] args={string_type(args)}")
359-
print(f"[call_torch_export_export] kwargs={string_type(kwargs)}")
390+
print(f"[call_torch_export_export] args={string_type(args, with_shape=True)}")
391+
print(f"[call_torch_export_export] kwargs={string_type(kwargs, with_shape=True)}")
360392
print(f"[call_torch_export_export] dynamic_shapes={_ds_clean(ds)}")
361393
print("[call_torch_export_export] export...")
362394
summary["export_exporter"] = exporter
@@ -431,3 +463,205 @@ def call_torch_export_export(
431463
f" after: {string_type(data['inputs_export'], with_shape=True)}"
432464
)
433465
return summary, data
466+
467+
468+
def call_torch_export_onnx(
469+
data: Dict[str, Any],
470+
exporter: str,
471+
quiet: bool = False,
472+
verbose: int = 0,
473+
optimization: Optional[str] = None,
474+
):
475+
"""
476+
Exports a model into onnx.
477+
If a patch must be applied, it should be before this functions.
478+
479+
:param data: dictionary with all the necessary inputs, the dictionary must
480+
contains keys ``model`` and ``inputs_export``
481+
:param exporter: exporter to call
482+
:param quiet: catch exception or not
483+
:param verbose: verbosity
484+
:param optimization: optimization to do
485+
:return: two dictionaries, one with some metrics,
486+
another one with whatever the function produces
487+
"""
488+
assert optimization in {
489+
"",
490+
"ir",
491+
None,
492+
}, f"unexpected value for optimization={optimization}"
493+
assert exporter in {
494+
"onnx-dynamo",
495+
"onnx-script",
496+
}, f"Unexpected value for exporter={exporter!r}"
497+
assert "model" in data, f"model is missing from data: {sorted(data)}"
498+
assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}"
499+
summary: Dict[str, Union[str, int, float]] = {}
500+
dynamo = "nostrict" not in exporter
501+
args, kwargs = split_args_kwargs(data["inputs_export"])
502+
ds = data.get("dynamic_shapes", None)
503+
if verbose:
504+
print(
505+
f"[call_torch_export_onnx] exporter={exporter!r}, "
506+
f"optimization={optimization!r}"
507+
)
508+
print(f"[call_torch_export_onnx] args={string_type(args, with_shape=True)}")
509+
print(f"[call_torch_export_onnx] kwargs={string_type(kwargs, with_shape=True)}")
510+
print(f"[call_torch_export_onnx] dynamic_shapes={_ds_clean(ds)}")
511+
print("[call_torch_export_onnx] export...")
512+
summary["export_exporter"] = exporter
513+
summary["export_optimization"] = optimization or ""
514+
summary["export_dynamo"] = dynamo
515+
summary["export_args"] = string_type(args, with_shape=True)
516+
summary["export_kwargs"] = string_type(kwargs, with_shape=True)
517+
518+
begin = time.perf_counter()
519+
if quiet:
520+
try:
521+
epo = torch.onnx.export(
522+
data["model"],
523+
args,
524+
kwargs=kwargs,
525+
dynamic_shapes=ds,
526+
dynamo=dynamo,
527+
)
528+
except Exception as e:
529+
summary["ERR_export_export"] = str(e)
530+
data["ERR_export_export"] = e
531+
summary["time_export_export"] = time.perf_counter() - begin
532+
return summary, data
533+
else:
534+
epo = torch.onnx.export(
535+
data["model"],
536+
args,
537+
kwargs=kwargs,
538+
dynamic_shapes=ds,
539+
dynamo=dynamo,
540+
)
541+
542+
summary["time_export_export"] = time.perf_counter() - begin
543+
assert epo is not None, "no onnx export was found"
544+
if verbose:
545+
print("[call_torch_export_onnx] done (export)")
546+
data["onnx_program"] = epo
547+
if verbose > 1:
548+
print("[call_torch_export_onnx] -- ONNXProgram")
549+
print(epo)
550+
print("[call_torch_export_onnx] -- End of ONNXProgram")
551+
552+
begin = time.perf_counter()
553+
if optimization == "ir":
554+
if verbose:
555+
print(f"[call_torch_export_onnx] starts optimization={optimization!r}...")
556+
if quiet:
557+
try:
558+
epo.optimize()
559+
except Exception as e:
560+
summary["ERR_export_optimize_ir"] = str(e)
561+
data["ERR_export_optimize_ir"] = e
562+
summary["time_export_optimize_ir"] = time.perf_counter() - begin
563+
return summary, data
564+
else:
565+
epo.optimize()
566+
summary["time_export_optimize_ir"] = time.perf_counter() - begin
567+
if verbose:
568+
print("[call_torch_export_onnx] done (optimization)")
569+
570+
return summary, data
571+
572+
573+
def validate_onnx_model(
574+
data: Dict[str, Any],
575+
quiet: bool = False,
576+
verbose: int = 0,
577+
optimization: Optional[str] = None,
578+
):
579+
"""
580+
Verifies that an onnx model produces the same
581+
expected outputs.
582+
583+
:param data: dictionary with all the necessary inputs, the dictionary must
584+
contains keys ``model`` and ``inputs_export``
585+
:param quiet: catch exception or not
586+
:param verbose: verbosity
587+
:param optimization: optimization to do
588+
:return: two dictionaries, one with some metrics,
589+
another one with whatever the function produces
590+
"""
591+
import onnxruntime
592+
593+
summary = {}
594+
flat_inputs = flatten_object(data["inputs"], drop_keys=True)
595+
d = flat_inputs[0].get_device()
596+
providers = (
597+
["CPUExecutionProvider"]
598+
if d < 0
599+
else ["CUDAExecutionProvider", "CPUExecutionProvider"]
600+
)
601+
if "onnx_file_name" in data:
602+
source = data["onnx_file_name"]
603+
summary["onnx_filename"] = source
604+
summary["onnx_size"] = os.stats(source).st_size
605+
else:
606+
assert (
607+
"onnx_program" in data
608+
), f"onnx_program is missing from data which has {sorted(data)}"
609+
source = data["onnx_program"].model_proto.SerializeToString()
610+
assert len(source) < 2**31, f"The model is highger than 2Gb: {len(source) / 2**30} Gb"
611+
summary["onnx_size"] = len(source)
612+
if verbose:
613+
print(f"[validate_onnx_model] verify onnx model with providers {providers}...")
614+
615+
begin = time.perf_counter()
616+
if quiet:
617+
try:
618+
sess = onnxruntime.InferenceSession(source, providers=providers)
619+
except Exception as e:
620+
summary["ERR_onnx_ort_create"] = str(e)
621+
data["ERR_onnx_ort_create"] = e
622+
summary["time_onnx_ort_create"] = time.perf_counter() - begin
623+
return summary, data
624+
else:
625+
sess = onnxruntime.InferenceSession(source, providers=providers)
626+
627+
summary["time_onnx_ort_create"] = time.perf_counter() - begin
628+
data["onnx_ort_sess"] = sess
629+
if verbose:
630+
print("[validate_onnx_model] done (ort_session)")
631+
632+
# make_feeds
633+
if verbose:
634+
print("[validate_onnx_model] make_feeds...")
635+
print(f"[validate_onnx_model] inputs={string_type(data['inputs'], with_shape=True)}")
636+
feeds = make_feeds([i.name for i in sess.get_inputs()], data["inputs"], use_numpy=True)
637+
if verbose:
638+
print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}")
639+
summary["onnx_ort_inputs"] = string_type(feeds, with_shape=True)
640+
if verbose:
641+
print("[validate_onnx_model] done (make_feeds)")
642+
643+
# run ort
644+
if verbose:
645+
print("[validate_onnx_model] run session...")
646+
begin = time.perf_counter()
647+
if quiet:
648+
try:
649+
got = sess.run(None, feeds)
650+
except Exception as e:
651+
summary["ERR_onnx_ort_run"] = str(e)
652+
data["ERR_onnx_ort_run"] = e
653+
summary["time_onnx_ort_run"] = time.perf_counter() - begin
654+
return summary, data
655+
else:
656+
got = sess.run(None, feeds)
657+
if verbose:
658+
print("[validate_onnx_model] done (run)")
659+
print(f"[validate_onnx_model] got={string_type(got, with_shape=True)}")
660+
661+
# compute discrepancies
662+
disc = max_diff(data["expected"], got, flatten=True)
663+
if verbose:
664+
print(f"[validate_onnx_model] discrepancies={string_diff(disc)}")
665+
for k, v in disc.items():
666+
summary[f"disc_onnx_ort_run_{k}"] = v
667+
return summary, data

0 commit comments

Comments
 (0)