Skip to content

Commit 54245d1

Browse files
authored
Check custom (#64)
* Check custom * doc
1 parent e96d7ec commit 54245d1

File tree

3 files changed

+49
-2
lines changed

3 files changed

+49
-2
lines changed

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@
187187
"ExecuTorch": "https://pytorch.org/executorch/stable/intro-overview.html",
188188
"ExecuTorch Runtime Python API Reference": "https://pytorch.org/executorch/stable/runtime-python-api-reference.html",
189189
"ExecuTorch Tutorial": "https://pytorch.org/executorch/stable/tutorials/export-to-executorch-tutorial.html",
190+
"experimental-experiment": "https://sdpython.github.io/doc/experimental-experiment/dev/",
190191
"JIT": "https://en.wikipedia.org/wiki/Just-in-time_compilation",
191192
"FunctionProto": "https://onnx.ai/onnx/api/classes.html#functionproto",
192193
"graph break": "https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks",

_unittests/ut_torch_models/test_test_helpers.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
hide_stdout,
88
ignore_warnings,
99
requires_torch,
10+
requires_experimental,
1011
)
1112
from onnx_diagnostic.torch_models.test_helper import (
1213
get_inputs_for_task,
@@ -62,16 +63,42 @@ def test_validate_model_export(self):
6263
@requires_torch("2.7")
6364
@hide_stdout()
6465
@ignore_warnings(FutureWarning)
65-
def test_validate_model_onnx(self):
66+
def test_validate_model_onnx_dynamo(self):
6667
mid = "arnir0/Tiny-LLM"
6768
summary, data = validate_model(
6869
mid,
6970
do_run=True,
7071
verbose=10,
7172
exporter="onnx-dynamo",
72-
dump_folder="dump_test_validate_model_onnx",
73+
dump_folder="dump_test_validate_model_onnx_dynamo",
7374
patch=True,
7475
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
76+
optimization="ir",
77+
)
78+
self.assertIsInstance(summary, dict)
79+
self.assertIsInstance(data, dict)
80+
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
81+
onnx_filename = data["onnx_filename"]
82+
output_path = f"{onnx_filename}.ortopt.onnx"
83+
run_ort_fusion(
84+
onnx_filename, output_path, num_attention_heads=2, hidden_size=192, verbose=10
85+
)
86+
87+
@requires_torch("2.7")
88+
@hide_stdout()
89+
@ignore_warnings(FutureWarning)
90+
@requires_experimental()
91+
def test_validate_model_custom(self):
92+
mid = "arnir0/Tiny-LLM"
93+
summary, data = validate_model(
94+
mid,
95+
do_run=True,
96+
verbose=10,
97+
exporter="custom",
98+
dump_folder="dump_test_validate_model_custom",
99+
patch=True,
100+
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
101+
optimization="default",
75102
)
76103
self.assertIsInstance(summary, dict)
77104
self.assertIsInstance(data, dict)

onnx_diagnostic/ext_test_case.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,25 @@ def requires_sklearn(version: str, msg: str = "") -> Callable:
461461
return lambda x: x
462462

463463

464+
def requires_experimental(version: str = "", msg: str = "") -> Callable:
465+
"""Skips a unit test if :epkg:`experimental-experiment` is not recent enough."""
466+
import packaging.version as pv
467+
468+
try:
469+
import experimental_experiment
470+
except ImportError:
471+
msg = f"experimental-experiment not installed: {msg}"
472+
return unittest.skip(msg)
473+
474+
if pv.Version(experimental_experiment.__version__) < pv.Version(version):
475+
msg = (
476+
f"experimental-experiment version "
477+
f"{experimental_experiment.__version__} < {version}: {msg}"
478+
)
479+
return unittest.skip(msg)
480+
return lambda x: x
481+
482+
464483
def has_torch(version: str) -> bool:
465484
"Returns True if torch transformers is higher."
466485
import packaging.version as pv

0 commit comments

Comments
 (0)