Skip to content

Commit 54afb59

Browse files
authored
checks it is working with the latest onnxruntime (#234)
* checks it is working with the latest onnxruntime * remove unnecessary mentionned dependencies * fix a couple of unittests * fix issues * fix dump folder name
1 parent 2aeb3ff commit 54afb59

File tree

5 files changed

+29
-9
lines changed

5 files changed

+29
-9
lines changed

_unittests/ut_xrun_doc/test_check_ort_float16.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ExtTestCase,
1111
ignore_warnings,
1212
requires_cuda,
13+
requires_onnxruntime,
1314
)
1415

1516

@@ -130,6 +131,7 @@ def common_scatter(self, opset, providers, dtype, reduction, expected_names):
130131

131132
@requires_cuda()
132133
@ignore_warnings(DeprecationWarning)
134+
@requires_onnxruntime("1.23")
133135
def test_scatterels_cuda(self):
134136
default_value = [
135137
"Cast",
@@ -143,6 +145,10 @@ def test_scatterels_cuda(self):
143145
(np.float16, "none"): default_value,
144146
(np.float32, "add"): default_value,
145147
(np.float16, "add"): default_value,
148+
(np.float32, "min"): default_value,
149+
(np.float16, "min"): default_value,
150+
(np.float32, "max"): default_value,
151+
(np.float16, "max"): default_value,
146152
}
147153
for opset, dtype, reduction in itertools.product(
148154
[16, 18], [np.float32, np.float16], ["none", "add", "min", "max"]
@@ -185,14 +191,14 @@ def test_scatternd_cuda(self):
185191
)
186192

187193
@ignore_warnings(DeprecationWarning)
194+
@requires_onnxruntime("1.23")
188195
def test_scatterels_cpu(self):
189196
default_value = [
190197
"Cast",
191198
"ScatterElements",
192199
"Sub",
193200
]
194201
default_value_16 = [
195-
"Cast",
196202
"Cast",
197203
"ScatterElements",
198204
"Cast",
@@ -218,14 +224,14 @@ def test_scatterels_cpu(self):
218224
)
219225

220226
@ignore_warnings(DeprecationWarning)
227+
@requires_onnxruntime("1.23")
221228
def test_scatternd_cpu(self):
222229
default_value = [
223230
"Cast",
224231
"ScatterND",
225232
"Sub",
226233
]
227234
default_value_16 = [
228-
"Cast",
229235
"Cast",
230236
"ScatterND",
231237
"Cast",

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,17 @@ def torch_export_patches(
254254
may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``.
255255
It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`.
256256
"""
257+
if verbose:
258+
print(f"[torch_export_patches] patch_sympy={patch_sympy!r}")
259+
print(f" . patch_torch={patch_torch!r}")
260+
print(f" . patch_transformers={patch_transformers!r}")
261+
print(f" . patch_diffusers={patch_diffusers!r}")
262+
print(f" . catch_constraints={catch_constraints!r}")
263+
print(f" . stop_if_static={stop_if_static!r}")
264+
print(f" . patch={patch!r}")
265+
print(f" . custom_patches={custom_patches!r}")
266+
print(f"[torch_export_patches] dump_rewriting={dump_rewriting!r}")
267+
257268
if rewrite:
258269
from .patch_module import torch_export_rewrite
259270

onnx_diagnostic/torch_models/validate.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def _make_folder_name(
112112
device: Optional[Union[str, torch.device]] = None,
113113
subfolder: Optional[str] = None,
114114
opset: Optional[int] = None,
115+
drop_inputs: Optional[List[str]] = None,
115116
) -> str:
116117
"Creates a filename unique based on the given options."
117118
els = [model_id.replace("/", "_")]
@@ -137,6 +138,9 @@ def _make_folder_name(
137138
els.append(sdev)
138139
if opset is not None:
139140
els.append(f"op{opset}")
141+
if drop_inputs:
142+
ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
143+
els.append(f"I-{ii.upper()}")
140144
return "-".join(els)
141145

142146

@@ -394,12 +398,9 @@ def validate_model(
394398
same_as_pretrained=same_as_pretrained,
395399
use_pretrained=use_pretrained,
396400
)
401+
default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True)
397402
if isinstance(patch, bool):
398-
patch_kwargs = (
399-
dict(patch_transformers=True, patch_diffusers=True, patch=True)
400-
if patch
401-
else dict(patch=False)
402-
)
403+
patch_kwargs = default_patch if patch else dict(patch=False)
403404
elif isinstance(patch, str):
404405
patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
405406
else:
@@ -408,6 +409,8 @@ def validate_model(
408409
if "patch" not in patch_kwargs:
409410
if any(patch_kwargs.values()):
410411
patch_kwargs["patch"] = True
412+
elif len(patch) == 1 and patch.get("patch", False):
413+
patch_kwargs.update(default_patch)
411414

412415
assert not rewrite or patch_kwargs.get("patch", False), (
413416
f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
@@ -450,6 +453,7 @@ def validate_model(
450453
device=device,
451454
subfolder=subfolder,
452455
opset=opset,
456+
drop_inputs=drop_inputs,
453457
)
454458
dump_folder = os.path.join(dump_folder, folder_name)
455459
if not os.path.exists(dump_folder):

requirements-dev.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ huggingface_hub
77
matplotlib
88
onnx-array-api>=0.3.1
99
onnx
10-
git+https://github.com/onnx/ir-py.git
1110
onnxscript
1211
openpyxl
1312
packaging

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
numpy
22
onnx>=1.16.0
3-
onnxruntime>=1.21
3+
onnxruntime>=1.23
44
optree
55
torch>=2.8
66
torch_geometric

0 commit comments

Comments
 (0)