Skip to content

Commit e2b7784

Browse files
committed
fixes
2 parents 71ea535 + f358ccb commit e2b7784

File tree

5 files changed

+19
-7
lines changed

5 files changed

+19
-7
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ jobs:
3737
- python: '3.11' # 3.11
3838
torch: 'main'
3939
- python: '3.11'
40-
torch: '2.9'
40+
torch: '2.8'
4141
- python: '3.11'
4242
transformers: 'main'
4343
- python: '3.11'

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.15
55
++++++
66

7+
* :pr:`261`: updates to support ``transformers>=5.0``
8+
79
0.7.14
810
++++++
911

_unittests/ut_torch_models/test_validate_whole_models.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,14 @@ def test_g_validate_model_onnx_dynamo_os_ort(self):
107107
verbose=10,
108108
exporter="onnx-dynamo",
109109
dump_folder="dump_test/validate_model_onnx_dynamo_os_ort",
110-
patch=True,
110+
patch=dict(patch_torch=False, patch_transformers=True, patch_sympy=True),
111111
stop_if_static=2 if pv.Version(torch.__version__) > pv.Version("2.6.1") else 0,
112112
optimization="os_ort",
113+
quiet_input_sets={"inputs", "inputs22"},
113114
)
114115
self.assertIsInstance(summary, dict)
115116
self.assertIsInstance(data, dict)
116-
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4)
117+
self.assertLess(summary["disc_onnx_ort_run2_batch1_abs"], 1e-4)
117118
onnx_filename = data["onnx_filename"]
118119
self.assertExists(onnx_filename)
119120

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def torch_export_patches(
422422
)
423423
)
424424

425-
if stop_if_static:
425+
if patch_torch and stop_if_static:
426426
ShapeEnv._log_guard_remember = ShapeEnv._log_guard
427427

428428
if verbose:

onnx_diagnostic/torch_models/validate.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import pprint
66
import sys
7-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
7+
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
88
import time
99
import numpy as np
1010
import onnx
@@ -319,6 +319,7 @@ def validate_model(
319319
inputs2: int = 1,
320320
output_names: Optional[List[str]] = None,
321321
ort_logs: bool = False,
322+
quiet_input_sets: Optional[Set[str]] = None,
322323
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
323324
"""
324325
Validates a model.
@@ -373,6 +374,8 @@ def validate_model(
373374
or an empty cache for example
374375
:param output_names: output names the onnx exporter should use
375376
:param ort_logs: increases onnxruntime verbosity when creating the session
377+
:param quiet_input_sets: avoid raising an exception if the inputs belongs to that set
378+
even if quiet is False
376379
:return: two dictionaries, one with some metrics,
377380
another one with whatever the function produces
378381
@@ -842,6 +845,7 @@ def validate_model(
842845
warmup=warmup,
843846
second_input_keys=second_input_keys,
844847
ort_logs=ort_logs,
848+
quiet_input_sets=quiet_input_sets,
845849
)
846850
summary.update(summary_valid)
847851
summary["time_total_validation_onnx"] = time.perf_counter() - validation_begin
@@ -904,6 +908,7 @@ def validate_model(
904908
repeat=repeat,
905909
warmup=warmup,
906910
second_input_keys=second_input_keys,
911+
quiet_input_sets=quiet_input_sets,
907912
)
908913
summary.update(summary_valid)
909914

@@ -1289,6 +1294,7 @@ def validate_onnx_model(
12891294
warmup: int = 0,
12901295
second_input_keys: Optional[List[str]] = None,
12911296
ort_logs: bool = False,
1297+
quiet_input_sets: Optional[Set[str]] = None,
12921298
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
12931299
"""
12941300
Verifies that an onnx model produces the same
@@ -1308,6 +1314,7 @@ def validate_onnx_model(
13081314
to make sure the exported model supports dynamism, the value is
13091315
used as an increment added to the first set of inputs (added to dimensions)
13101316
:param ort_logs: triggers the logs for onnxruntime
1317+
:param quiet_input_sets: avoid raising an exception for these sets of inputs
13111318
:return: two dictionaries, one with some metrics,
13121319
another one with whatever the function produces
13131320
"""
@@ -1455,10 +1462,12 @@ def _mk(key, flavour=flavour):
14551462

14561463
# run ort
14571464
if verbose:
1458-
print("[validate_onnx_model] run session...")
1465+
print(f"[validate_onnx_model] run session on inputs 'inputs{suffix}'...")
1466+
if quiet_input_sets:
1467+
print(f"[validate_onnx_model] quiet_input_sets={quiet_input_sets}")
14591468

14601469
got = _quiet_or_not_quiet(
1461-
quiet,
1470+
quiet or (quiet_input_sets is not None and f"inputs{suffix}" in quiet_input_sets),
14621471
_mk(f"run_onnx_ort{suffix}"),
14631472
summary,
14641473
data,

0 commit comments

Comments
 (0)