Skip to content

Commit 63d3f93

Browse files
committed
anotehr example
1 parent 417e7d9 commit 63d3f93

File tree

5 files changed

+127
-1
lines changed

5 files changed

+127
-1
lines changed

_doc/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,12 +152,14 @@
152152
"within_subsection_order": "ExampleTitleSortKey",
153153
# errors
154154
"abort_on_example_error": True,
155+
"expected_failing_examples": ["examples/plot_export_locate_issue.py"],
155156
# recommendation
156157
"recommender": {"enable": True, "n_examples": 3, "min_df": 3, "max_df": 0.9},
157158
# ignore capture for matplotib axes
158159
"ignore_repr_types": "matplotlib\\.(text|axes)",
159160
# robubstness
160161
"reset_modules_order": "both",
162+
"reset_modules": ("matplotlib", "onnx_diagnostic.reset_torch_transformers"),
161163
}
162164

163165
if int(os.environ.get("UNITTEST_GOING", "0")):
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""
2+
.. _l-plot-export-locale-issue:
3+
4+
==================================================
5+
Find and fix an export issue due to dynamic shapes
6+
==================================================
7+
8+
9+
A model with an export issue
10+
============================
11+
12+
The following model implies the first dimension of x is equal to 1
13+
or equal to the number of element in the list ``ys``.
14+
It is not really dynamic. It looks obvious here but
15+
it is difficult to find deep inside a big model.
16+
"""
17+
18+
import torch
19+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
20+
21+
22+
class ModelWithIssue(torch.nn.Module):
23+
def forward(self, x: torch.Tensor, ys: list[torch.Tensor]):
24+
caty = torch.cat([y.unsqueeze(0) for y in ys], axis=0)
25+
z = x * caty
26+
return z
27+
28+
29+
inputs = (torch.rand(2, 3, 1), [torch.rand(3, 4), torch.rand(3, 4)])
30+
model = ModelWithIssue()
31+
model(*inputs)
32+
33+
34+
# %%
35+
# Let's export.
36+
37+
DYN = torch.export.Dim.DYNAMIC
38+
dyn_shapes = ({0: DYN, 1: DYN}, [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}])
39+
try:
40+
ep = torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
41+
print(ep)
42+
except Exception as e:
43+
print("-- ERROR:")
44+
print(e)
45+
46+
# %%
47+
# The error shows:
48+
#
49+
# ::
50+
# Constraints violated (L['args'][0][0].size()[0])!
51+
# For more information, run with TORCH_LOGS="+dynamic".
52+
# - Not all values of RelaxedUnspecConstraint(L['args'][0][0].size()[0])
53+
# are valid because L['args'][0][0].size()[0] was inferred to be a constant (2).
54+
#
55+
# Where does it happens? That's a tricky question we need to answer.
56+
# The message is raised from
57+
# `torch.fx.experimental.symbolic_shapes.ShapeEnv._set_replacement
58+
# <https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L6239>`_.
59+
# One way to find the exact location is to retrieve a stack trace
60+
# by inserting an assert such as the following:
61+
#
62+
# ::
63+
#
64+
# assert msg != "range_refined_to_singleton", (
65+
# f"A dynamic dimension becomes static! "
66+
# f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
67+
# )
68+
#
69+
# Stop when a dynamic dimension turns static
70+
# ==========================================
71+
#
72+
#
73+
74+
with bypass_export_some_errors(stop_if_static=True, verbose=1):
75+
torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
76+
77+
# The stack trace is quite long but the first line referring to this example
78+
# is the following one. It points out the line turing a dynamic dimension into
79+
# static.
80+
#
81+
# ::
82+
#
83+
# File "onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 25, in forward
84+
# z = x * caty

onnx_diagnostic/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,12 @@
55

66
__version__ = "0.3.0"
77
__author__ = "Xavier Dupré"
8+
9+
10+
def reset_torch_transformers(gallery_conf, fname):
11+
"Resets torch dynamo for :epkg:`sphinx-gallery`."
12+
import matplotlib.pyplot as plt
13+
import torch
14+
15+
plt.style.use("ggplot")
16+
torch._dynamo.reset()

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,11 @@ def _unregister(cls: type, verbose: int = 0):
145145
# torch >= 2.7
146146
torch.utils._pytree._deregister_pytree_node(cls)
147147
optree.unregister_pytree_node(cls, namespace="torch")
148+
if cls in torch.utils._pytree.SUPPORTED_NODES:
149+
import packaging.version as pv
150+
151+
if pv.Version(torch.__version__) < pv.Version("2.7.0"):
152+
del torch.utils._pytree.SUPPORTED_NODES[cls]
148153
assert cls not in torch.utils._pytree.SUPPORTED_NODES, (
149154
f"{cls} was not successful unregistered "
150155
f"from torch.utils._pytree.SUPPORTED_NODES="
@@ -190,6 +195,7 @@ def bypass_export_some_errors(
190195
patch_torch: bool = True,
191196
patch_transformers: bool = False,
192197
catch_constraints: bool = True,
198+
stop_if_static: bool = False,
193199
verbose: int = 0,
194200
patch: bool = True,
195201
) -> Callable:
@@ -203,8 +209,12 @@ def bypass_export_some_errors(
203209
as a result, some dynamic dimension may turn into static ones,
204210
the environment variable ``SKIP_SOLVE_CONSTRAINTS=0``
205211
can be put to stop at that stage.
212+
:param stop_if_static: see example :ref:`l-plot-export-locale-issue`,
213+
to stop the export as soon as an issue is detected with dyanmic shapes
214+
and show a stack trace indicating the exact location of the issue
206215
:param patch: if False, disable all patches except the registration of
207216
serialization function
217+
:param verbose: to show which patches is applied
208218
209219
The list of available patches.
210220
@@ -348,6 +358,18 @@ def bypass_export_some_errors(
348358
)
349359
)
350360

361+
if stop_if_static:
362+
if verbose:
363+
print(
364+
"[bypass_export_some_errors] assert when a dynamic dimnension turns static"
365+
)
366+
367+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
368+
from .patches.patch_torch import patched_ShapeEnv
369+
370+
f_shape_env__set_replacement = ShapeEnv._set_replacement
371+
ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
372+
351373
####################
352374
# patch transformers
353375
####################
@@ -401,6 +423,12 @@ def bypass_export_some_errors(
401423
if verbose:
402424
print("[bypass_export_some_errors] restored pytorch functions")
403425

426+
if stop_if_static:
427+
if verbose:
428+
print("[bypass_export_some_errors] restored ShapeEnv._set_replacement")
429+
430+
ShapeEnv._set_replacement = f_shape_env__set_replacement
431+
404432
if catch_constraints:
405433
# to catch or skip dynamic_shapes issues
406434
torch._export.non_strict_utils.produce_guards_and_solve_constraints = (

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,10 @@ def _set_replacement(
313313
# "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt
314314
# )
315315
# self.log.debug("SPECIALIZATION", stack_info=True)
316-
assert msg != "range_refined_to_singleton", f"{[a, tgt, msg, tgt_bound]}"
316+
assert msg != "range_refined_to_singleton", (
317+
f"A dynamic dimension becomes static! "
318+
f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
319+
)
317320
# log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound)
318321
self.replacements[a] = tgt
319322
# NB: the replacement may get refined, but the user will find the

0 commit comments

Comments
 (0)