Skip to content

Commit 30ef056

Browse files
committed
improve catching
1 parent abbcc6b commit 30ef056

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def bypass_export_some_errors(
211211
patch_torch: bool = True,
212212
patch_transformers: bool = False,
213213
catch_constraints: bool = True,
214-
stop_if_static: bool = False,
214+
stop_if_static: int = 0,
215215
verbose: int = 0,
216216
patch: bool = True,
217217
) -> Callable:
@@ -227,7 +227,9 @@ def bypass_export_some_errors(
227227
can be put to stop at that stage.
228228
:param stop_if_static: see example :ref:`l-plot-export-locale-issue`,
229229
to stop the export as soon as an issue is detected with dynamic shapes
230-
and show a stack trace indicating the exact location of the issue
230+
and show a stack trace indicating the exact location of the issue,
231+
``if stop_if_static > 1``, more methods are replace to catch more
232+
issues
231233
:param patch: if False, disable all patches except the registration of
232234
serialization function
233235
:param verbose: to show which patches is applied
@@ -375,17 +377,24 @@ def bypass_export_some_errors(
375377
)
376378

377379
if stop_if_static:
380+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
381+
from .patches.patch_torch import patched_ShapeEnv
382+
378383
if verbose:
379384
print(
380385
"[bypass_export_some_errors] assert when a dynamic dimension turns static"
381386
)
382-
383-
from torch.fx.experimental.symbolic_shapes import ShapeEnv
384-
from .patches.patch_torch import patched_ShapeEnv
387+
print("[bypass_export_some_errors] replaces ShapeEnv._set_replacement")
385388

386389
f_shape_env__set_replacement = ShapeEnv._set_replacement
387390
ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
388391

392+
if stop_if_static > 1:
393+
if verbose:
394+
print("[bypass_export_some_errors] replaces ShapeEnv._check_frozen")
395+
f_shape_env__check_frozen = ShapeEnv._check_frozen
396+
ShapeEnv._check_frozen = patched_ShapeEnv._check_frozen
397+
389398
####################
390399
# patch transformers
391400
####################
@@ -444,6 +453,10 @@ def bypass_export_some_errors(
444453
print("[bypass_export_some_errors] restored ShapeEnv._set_replacement")
445454

446455
ShapeEnv._set_replacement = f_shape_env__set_replacement
456+
if stop_if_static > 1:
457+
if verbose:
458+
print("[bypass_export_some_errors] restored ShapeEnv._check_frozen")
459+
ShapeEnv._check_frozen = f_shape_env__check_frozen
447460

448461
if catch_constraints:
449462
# to catch or skip dynamic_shapes issues

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,16 @@ def patched__broadcast_shapes(*_shapes):
150150

151151
class patched_ShapeEnv:
152152

153+
def _check_frozen(
154+
self, expr: "sympy.Basic", concrete_val: "sympy.Basic" # noqa: F821
155+
) -> None:
156+
if self.frozen:
157+
self.counter["ignored_backward_guard"] += 1
158+
raise AssertionError(
159+
f"[patched_ShapeEnv] Ignored guard {expr} == {concrete_val}, "
160+
f"this could result in accuracy problems."
161+
)
162+
153163
def _set_replacement(
154164
self, a: "sympy.Symbol", tgt: "sympy.Expr", msg: str # noqa: F821
155165
) -> None:
@@ -314,7 +324,7 @@ def _set_replacement(
314324
# )
315325
# self.log.debug("SPECIALIZATION", stack_info=True)
316326
assert msg != "range_refined_to_singleton", (
317-
f"A dynamic dimension becomes static! "
327+
f"patched_ShapeEnv: A dynamic dimension becomes static! "
318328
f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
319329
)
320330
# log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)

0 commit comments

Comments
 (0)