|
| 1 | +""" |
| 2 | +.. _l-plot-export-locale-issue: |
| 3 | +
|
| 4 | +================================================== |
| 5 | +Find and fix an export issue due to dynamic shapes |
| 6 | +================================================== |
| 7 | +
|
| 8 | +LLMs must be exported with dynamic shapes and it is common that |
| 9 | +a static dimension turns into a static ones. The error message from |
| 10 | +:epkg:`pytorch` tells the user to define ``TORCH_LOGS="+dynamic"`` |
| 11 | +but it shows a very long list of messages where we need |
| 12 | +to find the string ``range_refined_to_singleton`` and that |
| 13 | +does not really indicates where it comes from. The example |
| 14 | +shows how to tweak pytorch to get that information until |
| 15 | +it gets better. |
| 16 | +
|
| 17 | +A model with an export issue |
| 18 | +============================ |
| 19 | +
|
| 20 | +The following model implies the first dimension of x is equal to 1 |
| 21 | +or equal to the number of element in the list ``ys``. |
| 22 | +It is not really dynamic. It looks obvious here but |
| 23 | +it is difficult to find deep inside a big model. |
| 24 | +""" |
| 25 | + |
| 26 | +import traceback |
| 27 | +import torch |
| 28 | +from onnx_diagnostic import doc |
| 29 | +from onnx_diagnostic.torch_export_patches import bypass_export_some_errors |
| 30 | + |
| 31 | + |
| 32 | +class ModelWithIssue(torch.nn.Module): |
| 33 | + def forward(self, x: torch.Tensor, ys: list[torch.Tensor]): |
| 34 | + caty = torch.cat([y.unsqueeze(0) for y in ys], axis=0) |
| 35 | + z = x * caty |
| 36 | + return z |
| 37 | + |
| 38 | + |
| 39 | +inputs = (torch.rand(2, 3, 1), [torch.rand(3, 4), torch.rand(3, 4)]) |
| 40 | +model = ModelWithIssue() |
| 41 | +model(*inputs) |
| 42 | + |
| 43 | + |
| 44 | +# %% |
| 45 | +# Let's export. |
| 46 | + |
| 47 | +DYN = torch.export.Dim.DYNAMIC |
| 48 | +dyn_shapes = ({0: DYN, 1: DYN}, [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}]) |
| 49 | +try: |
| 50 | + ep = torch.export.export(model, inputs, dynamic_shapes=dyn_shapes) |
| 51 | + print(ep) |
| 52 | +except Exception as e: |
| 53 | + print("-- ERROR:") |
| 54 | + print(e) |
| 55 | + |
| 56 | +# %% |
| 57 | +# The error shows: |
| 58 | +# |
| 59 | +# .. code-block:: |
| 60 | +# |
| 61 | +# Constraints violated (L['args'][0][0].size()[0])! |
| 62 | +# For more information, run with TORCH_LOGS="+dynamic". |
| 63 | +# - Not all values of RelaxedUnspecConstraint(L['args'][0][0].size()[0]) |
| 64 | +# are valid because L['args'][0][0].size()[0] was inferred to be a constant (2). |
| 65 | +# |
| 66 | +# Where does it happens? That's a tricky question we need to answer. |
| 67 | +# The message is raised from |
| 68 | +# `torch.fx.experimental.symbolic_shapes.ShapeEnv._set_replacement |
| 69 | +# <https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L6239>`_. |
| 70 | +# One way to find the exact location is to retrieve a stack trace |
| 71 | +# by inserting an assert such as the following: |
| 72 | +# |
| 73 | +# .. code-block:: |
| 74 | +# |
| 75 | +# assert msg != "range_refined_to_singleton", ( |
| 76 | +# f"A dynamic dimension becomes static! " |
| 77 | +# f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}" |
| 78 | +# ) |
| 79 | +# |
| 80 | +# Stop when a dynamic dimension turns static |
| 81 | +# ========================================== |
| 82 | +# |
| 83 | +# We use :func:`bypass_export_some_errors |
| 84 | +# <onnx_diagnostic.torch_export_patches.bypass_export_some_errors>` |
| 85 | +# to replace torch implementation by a new one raising the exception |
| 86 | +# mentioned in previous section. |
| 87 | + |
| 88 | +with bypass_export_some_errors(stop_if_static=True, verbose=1): |
| 89 | + try: |
| 90 | + torch.export.export(model, inputs, dynamic_shapes=dyn_shapes) |
| 91 | + except AssertionError: |
| 92 | + print("-- It failed as excepted. Let's print the stack trace.") |
| 93 | + print(traceback.format_exc()) |
| 94 | + |
| 95 | +# The stack trace is quite long but the first line referring to this example |
| 96 | +# is the following one. It points out the line turing a dynamic dimension into |
| 97 | +# static. |
| 98 | +# |
| 99 | +# .. code-block:: |
| 100 | +# |
| 101 | +# File "onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 25, in forward |
| 102 | +# z = x * caty |
| 103 | + |
| 104 | + |
| 105 | +doc.plot_legend("was inferred to be a constant", "torch.export.export", "tomato") |
0 commit comments