Skip to content

Commit 23a6901

Browse files
authored
Add phi (#19)
* Add phi * import * add position_ids * fix position ids * add missing patched method * fix seq_length * payvhrd' * ci * torch * torch * fix contiguous * fix contiguous * fix ort_session * anotehr example * fix issues * documentation * doc * less ambitious goal
1 parent dec843b commit 23a6901

30 files changed

+785
-123
lines changed

.github/workflows/ci.yml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@ on:
1010

1111
jobs:
1212
run:
13-
name: tr-${{ matrix.transformers }}-ci ${{ matrix.os }}-${{ matrix.python }}
13+
name: to-${{ matrix.torch }}-tr-${{ matrix.transformers }}-ci ${{ matrix.os }}-${{ matrix.python }}
1414
runs-on: ${{ matrix.os }}
1515
strategy:
1616
matrix:
1717
os: [ubuntu-latest]
1818
python: ['3.11', '3.12']
1919
transformers: ['4.48', '4.50', 'main']
20+
torch: ['main']
2021

2122
steps:
2223
- uses: actions/checkout@v3
@@ -26,7 +27,13 @@ jobs:
2627
python-version: ${{ matrix.python }}
2728

2829
- name: Install pytorch
29-
run: python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
30+
run: |
31+
if [[ "${{ matrix.torch }}" == "main" ]]; then
32+
python -m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu
33+
else
34+
echo "install torch==${{ matrix.torch }}"
35+
pip install torch==${{ matrix.torch }}
36+
fi
3037
3138
- name: Install transformers ${{ matrix.transformers }}
3239
run: |

README.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ Enlightening Examples
5050

5151
* `Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints
5252
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_shapes_auto.html>`_
53+
* `Find and fix an export issue due to dynamic shapes
54+
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_locate_issue.html>`_
5355
* `Export with DynamicCache and dynamic shapes
5456
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_cache.html>`_
5557
* `Steel method forward to guess the dynamic shapes (with Tiny-LLM)

_doc/api/torch_models/llms.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@ onnx_diagnostic.torch_models.llms
33
=================================
44

55
.. automodule:: onnx_diagnostic.torch_models.llms
6-
:members:
7-
:no-undoc-members:
6+
:members: get_phi2, get_tiny_llm

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@
158158
"ignore_repr_types": "matplotlib\\.(text|axes)",
159159
# robubstness
160160
"reset_modules_order": "both",
161+
"reset_modules": ("matplotlib", "onnx_diagnostic.doc.reset_torch_transformers"),
161162
}
162163

163164
if int(os.environ.get("UNITTEST_GOING", "0")):

_doc/examples/plot_export_cond.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"""
1414

1515
import torch
16+
from onnx_diagnostic import doc
1617

1718

1819
# %%
@@ -84,3 +85,8 @@ def neg(x):
8485

8586
ep = torch.export.export(model, (x,))
8687
print(ep.graph)
88+
89+
90+
# %%
91+
92+
doc.plot_legend("If -> torch.cond", "torch.export.export", "tomato")
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""
2+
.. _l-plot-export-locale-issue:
3+
4+
==================================================
5+
Find and fix an export issue due to dynamic shapes
6+
==================================================
7+
8+
LLMs must be exported with dynamic shapes and it is common that
9+
a static dimension turns into a static ones. The error message from
10+
:epkg:`pytorch` tells the user to define ``TORCH_LOGS="+dynamic"``
11+
but it shows a very long list of messages where we need
12+
to find the string ``range_refined_to_singleton`` and that
13+
does not really indicates where it comes from. The example
14+
shows how to tweak pytorch to get that information until
15+
it gets better.
16+
17+
A model with an export issue
18+
============================
19+
20+
The following model implies the first dimension of x is equal to 1
21+
or equal to the number of element in the list ``ys``.
22+
It is not really dynamic. It looks obvious here but
23+
it is difficult to find deep inside a big model.
24+
"""
25+
26+
import traceback
27+
import torch
28+
from onnx_diagnostic import doc
29+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
30+
31+
32+
class ModelWithIssue(torch.nn.Module):
33+
def forward(self, x: torch.Tensor, ys: list[torch.Tensor]):
34+
caty = torch.cat([y.unsqueeze(0) for y in ys], axis=0)
35+
z = x * caty
36+
return z
37+
38+
39+
inputs = (torch.rand(2, 3, 1), [torch.rand(3, 4), torch.rand(3, 4)])
40+
model = ModelWithIssue()
41+
model(*inputs)
42+
43+
44+
# %%
45+
# Let's export.
46+
47+
DYN = torch.export.Dim.DYNAMIC
48+
dyn_shapes = ({0: DYN, 1: DYN}, [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}])
49+
try:
50+
ep = torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
51+
print(ep)
52+
except Exception as e:
53+
print("-- ERROR:")
54+
print(e)
55+
56+
# %%
57+
# The error shows:
58+
#
59+
# .. code-block::
60+
#
61+
# Constraints violated (L['args'][0][0].size()[0])!
62+
# For more information, run with TORCH_LOGS="+dynamic".
63+
# - Not all values of RelaxedUnspecConstraint(L['args'][0][0].size()[0])
64+
# are valid because L['args'][0][0].size()[0] was inferred to be a constant (2).
65+
#
66+
# Where does it happens? That's a tricky question we need to answer.
67+
# The message is raised from
68+
# `torch.fx.experimental.symbolic_shapes.ShapeEnv._set_replacement
69+
# <https://github.com/pytorch/pytorch/blob/main/torch/fx/experimental/symbolic_shapes.py#L6239>`_.
70+
# One way to find the exact location is to retrieve a stack trace
71+
# by inserting an assert such as the following:
72+
#
73+
# .. code-block::
74+
#
75+
# assert msg != "range_refined_to_singleton", (
76+
# f"A dynamic dimension becomes static! "
77+
# f"a={a!r}, tgt={tgt!r}, msg={msg!r}, tgt_bound={tgt_bound}"
78+
# )
79+
#
80+
# Stop when a dynamic dimension turns static
81+
# ==========================================
82+
#
83+
# We use :func:`bypass_export_some_errors
84+
# <onnx_diagnostic.torch_export_patches.bypass_export_some_errors>`
85+
# to replace torch implementation by a new one raising the exception
86+
# mentioned in previous section.
87+
88+
with bypass_export_some_errors(stop_if_static=True, verbose=1):
89+
try:
90+
torch.export.export(model, inputs, dynamic_shapes=dyn_shapes)
91+
except AssertionError:
92+
print("-- It failed as excepted. Let's print the stack trace.")
93+
print(traceback.format_exc())
94+
95+
# The stack trace is quite long but the first line referring to this example
96+
# is the following one. It points out the line turing a dynamic dimension into
97+
# static.
98+
#
99+
# .. code-block::
100+
#
101+
# File "onnx-diagnostic/_doc/examples/plot_export_locate_issue.py", line 25, in forward
102+
# z = x * caty
103+
104+
105+
doc.plot_legend("was inferred to be a constant", "torch.export.export", "tomato")

_doc/examples/plot_export_tiny_llm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import pprint
3131
import torch
3232
import transformers
33+
from onnx_diagnostic import doc
3334
from onnx_diagnostic.helpers import string_type
3435
from onnx_diagnostic.torch_models.llms import get_tiny_llm
3536

@@ -44,10 +45,11 @@
4445

4546
def _forward_(*args, _f=None, **kwargs):
4647
assert _f is not None
47-
if not torch.compiler.is_exporting():
48+
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
49+
# torch.compiler.is_exporting requires torch>=2.7
4850
print("<-", string_type((args, kwargs), with_shape=True, with_min_max=True))
4951
res = _f(*args, **kwargs)
50-
if not torch.compiler.is_exporting():
52+
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
5153
print("->", string_type((args, kwargs), with_shape=True, with_min_max=True))
5254
return res
5355

@@ -67,7 +69,8 @@ def _forward_(*args, _f=None, **kwargs):
6769
)
6870

6971
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
70-
print(generated_text)
72+
print("-- prompt", prompt)
73+
print("-- answer", generated_text)
7174

7275
# %%
7376
# Let's restore the forward as it was.
@@ -168,3 +171,5 @@ def _forward_(*args, _f=None, **kwargs):
168171
# %%
169172
# If you have any error, then look at example
170173
# :ref:`l-plot-tiny-llm-export-patched`.
174+
175+
doc.plot_legend("Tiny-LLM fails", "torch.export.export", "tomato")

_doc/examples/plot_export_tiny_llm_patched.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
import pprint
6767
import torch
6868
import transformers
69+
from onnx_diagnostic import doc
6970
from onnx_diagnostic.helpers import string_type
7071
from onnx_diagnostic.torch_export_patches.onnx_export_errors import bypass_export_some_errors
7172
from onnx_diagnostic.torch_models.llms import get_tiny_llm
@@ -122,3 +123,6 @@
122123
)
123124
print("It worked:")
124125
print(ep)
126+
127+
# %%
128+
doc.plot_legend("Tiny-LLM patched", "torch.export.export", "green")

_doc/examples/plot_export_with_dynamic_cache.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import pprint
2525
import torch
26+
from onnx_diagnostic import doc
2627
from onnx_diagnostic.cache_helpers import make_dynamic_cache
2728
from onnx_diagnostic.helpers import string_type
2829
from onnx_diagnostic.export import ModelInputs
@@ -221,3 +222,7 @@ def forward(self, cache, z):
221222
model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False
222223
)
223224
print(ep)
225+
226+
# %%
227+
228+
doc.plot_legend("dynamic shapes", "torch.export.export", "tomato")

_doc/examples/plot_export_with_dynamic_shapes_auto.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"""
1313

1414
import torch
15+
from onnx_diagnostic import doc
1516

1617

1718
class Model(torch.nn.Module):
@@ -57,7 +58,7 @@ def forward(self, x, y, z):
5758
},
5859
)
5960
print(ep)
60-
raise AssertionError("able to export this moel, please update the tutorial")
61+
raise AssertionError("able to export this model, please update the tutorial")
6162
except torch._dynamo.exc.UserError as e:
6263
print(f"unable to use Dim('dz') because {type(e)}, {e}")
6364

@@ -90,3 +91,7 @@ def forward(self, x, y, z):
9091
dynamic_shapes=({0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}, {0: AUTO, 1: AUTO}),
9192
)
9293
)
94+
95+
# %%
96+
97+
doc.plot_legend("dynamic shapes inferred", "torch.export.export", "tomato")

0 commit comments

Comments
 (0)