Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions _unittests/ut_xrun_doc/test_check_ort_float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ExtTestCase,
ignore_warnings,
requires_cuda,
requires_onnxruntime,
)


Expand Down Expand Up @@ -130,6 +131,7 @@ def common_scatter(self, opset, providers, dtype, reduction, expected_names):

@requires_cuda()
@ignore_warnings(DeprecationWarning)
@requires_onnxruntime("1.23")
def test_scatterels_cuda(self):
default_value = [
"Cast",
Expand All @@ -143,6 +145,10 @@ def test_scatterels_cuda(self):
(np.float16, "none"): default_value,
(np.float32, "add"): default_value,
(np.float16, "add"): default_value,
(np.float32, "min"): default_value,
(np.float16, "min"): default_value,
(np.float32, "max"): default_value,
(np.float16, "max"): default_value,
}
for opset, dtype, reduction in itertools.product(
[16, 18], [np.float32, np.float16], ["none", "add", "min", "max"]
Expand Down Expand Up @@ -185,14 +191,14 @@ def test_scatternd_cuda(self):
)

@ignore_warnings(DeprecationWarning)
@requires_onnxruntime("1.23")
def test_scatterels_cpu(self):
default_value = [
"Cast",
"ScatterElements",
"Sub",
]
default_value_16 = [
"Cast",
"Cast",
"ScatterElements",
"Cast",
Expand All @@ -218,14 +224,14 @@ def test_scatterels_cpu(self):
)

@ignore_warnings(DeprecationWarning)
@requires_onnxruntime("1.23")
def test_scatternd_cpu(self):
default_value = [
"Cast",
"ScatterND",
"Sub",
]
default_value_16 = [
"Cast",
"Cast",
"ScatterND",
"Cast",
Expand Down
11 changes: 11 additions & 0 deletions onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,17 @@ def torch_export_patches(
may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``.
It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`.
"""
if verbose:
print(f"[torch_export_patches] patch_sympy={patch_sympy!r}")
print(f" . patch_torch={patch_torch!r}")
print(f" . patch_transformers={patch_transformers!r}")
print(f" . patch_diffusers={patch_diffusers!r}")
print(f" . catch_constraints={catch_constraints!r}")
print(f" . stop_if_static={stop_if_static!r}")
print(f" . patch={patch!r}")
print(f" . custom_patches={custom_patches!r}")
print(f"[torch_export_patches] dump_rewriting={dump_rewriting!r}")

if rewrite:
from .patch_module import torch_export_rewrite

Expand Down
14 changes: 9 additions & 5 deletions onnx_diagnostic/torch_models/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def _make_folder_name(
device: Optional[Union[str, torch.device]] = None,
subfolder: Optional[str] = None,
opset: Optional[int] = None,
drop_inputs: Optional[List[str]] = None,
) -> str:
"Creates a filename unique based on the given options."
els = [model_id.replace("/", "_")]
Expand All @@ -137,6 +138,9 @@ def _make_folder_name(
els.append(sdev)
if opset is not None:
els.append(f"op{opset}")
if drop_inputs:
ii = "-".join(f"{s[0]}{s[-1]}" for s in drop_inputs)
els.append(f"I-{ii.upper()}")
return "-".join(els)


Expand Down Expand Up @@ -394,12 +398,9 @@ def validate_model(
same_as_pretrained=same_as_pretrained,
use_pretrained=use_pretrained,
)
default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True)
if isinstance(patch, bool):
patch_kwargs = (
dict(patch_transformers=True, patch_diffusers=True, patch=True)
if patch
else dict(patch=False)
)
patch_kwargs = default_patch if patch else dict(patch=False)
elif isinstance(patch, str):
patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
else:
Expand All @@ -408,6 +409,8 @@ def validate_model(
if "patch" not in patch_kwargs:
if any(patch_kwargs.values()):
patch_kwargs["patch"] = True
elif len(patch) == 1 and patch.get("patch", False):
patch_kwargs.update(default_patch)

assert not rewrite or patch_kwargs.get("patch", False), (
f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "
Expand Down Expand Up @@ -450,6 +453,7 @@ def validate_model(
device=device,
subfolder=subfolder,
opset=opset,
drop_inputs=drop_inputs,
)
dump_folder = os.path.join(dump_folder, folder_name)
if not os.path.exists(dump_folder):
Expand Down
1 change: 0 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ huggingface_hub
matplotlib
onnx-array-api>=0.3.1
onnx
git+https://github.com/onnx/ir-py.git
onnxscript
openpyxl
packaging
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy
onnx>=1.16.0
onnxruntime>=1.21
onnxruntime>=1.23
optree
torch>=2.8
torch_geometric
Loading