Skip to content

Commit 80e2ebf

Browse files
committed
fix issues
1 parent c2e416a commit 80e2ebf

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,17 @@ def torch_export_patches(
254254
may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``.
255255
It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`.
256256
"""
257+
if verbose:
258+
print(f"[torch_export_patches] patch_sympy={patch_sympy!r}")
259+
print(f" . patch_torch={patch_torch!r}")
260+
print(f" . patch_transformers={patch_transformers!r}")
261+
print(f" . patch_diffusers={patch_diffusers!r}")
262+
print(f" . catch_constraints={catch_constraints!r}")
263+
print(f" . stop_if_static={stop_if_static!r}")
264+
print(f" . patch={patch!r}")
265+
print(f" . custom_patches={custom_patches!r}")
266+
print(f"[torch_export_patches] dump_rewriting={dump_rewriting!r}")
267+
257268
if rewrite:
258269
from .patch_module import torch_export_rewrite
259270

onnx_diagnostic/torch_models/validate.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -394,12 +394,9 @@ def validate_model(
394394
same_as_pretrained=same_as_pretrained,
395395
use_pretrained=use_pretrained,
396396
)
397+
default_patch = dict(patch_transformers=True, patch_diffusers=True, patch=True)
397398
if isinstance(patch, bool):
398-
patch_kwargs = (
399-
dict(patch_transformers=True, patch_diffusers=True, patch=True)
400-
if patch
401-
else dict(patch=False)
402-
)
399+
patch_kwargs = default_patch if patch else dict(patch=False)
403400
elif isinstance(patch, str):
404401
patch_kwargs = {"patch": True, **{p: True for p in patch.split(",")}} # noqa: C420
405402
else:
@@ -408,6 +405,8 @@ def validate_model(
408405
if "patch" not in patch_kwargs:
409406
if any(patch_kwargs.values()):
410407
patch_kwargs["patch"] = True
408+
elif len(patch) == 1 and patch.get("patch", False):
409+
patch_kwargs.update(default_patch)
411410

412411
assert not rewrite or patch_kwargs.get("patch", False), (
413412
f"rewrite={rewrite}, patch={patch}, patch_kwargs={patch_kwargs} "

0 commit comments

Comments
 (0)