Skip to content

Commit 0e8155c

Browse files
authored
Add example around sym_max (#47)
* example * fix example * fix batch size
1 parent 4d93786 commit 0e8155c

File tree

5 files changed

+194
-8
lines changed

5 files changed

+194
-8
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
matrix:
1717
os: [ubuntu-latest]
1818
python: ['3.11', '3.12']
19-
transformers: ['4.48.3', '4.50.3', 'main']
19+
transformers: ['4.48.3', '4.51.1', 'main']
2020
torch: ['2.6', 'main']
2121

2222
steps:
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
"""
2+
Cannot export ``torch.sym_max(x.shape[0], y.shape[0])``
3+
=======================================================
4+
5+
This is related to the following issues:
6+
`Cannot export torch.sym_max(x.shape[0], y.shape[0])
7+
<https://github.com/pytorch/pytorch/issues/150851>`_.
8+
9+
The algorithm trying to automatically infer shapes after every operator
10+
in the exported program is something very aggreessive. Here is a case where
11+
it takes a wrong decision and how to get around it.
12+
13+
Wrong Model
14+
+++++++++++
15+
"""
16+
17+
import torch
18+
from onnx_diagnostic import doc
19+
20+
21+
class Model(torch.nn.Module):
22+
def forward(self, x, y, fact):
23+
s1 = max(x.shape[0], y.shape[0])
24+
s2 = max(x.shape[1], y.shape[1])
25+
# Shapes cannot be known here.
26+
z = torch.zeros((s1, s2), dtype=x.dtype)
27+
z[: x.shape[0], : x.shape[1]] = x
28+
z[: y.shape[0], : y.shape[1]] += y
29+
return z * fact
30+
31+
32+
model = Model()
33+
x = torch.arange(6).reshape((2, 3))
34+
y = torch.arange(6).reshape((3, 2)) * 10
35+
fact = torch.tensor([[1, 2, 3]], dtype=x.dtype)
36+
z = model(x, y, fact)
37+
print(f"x.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}")
38+
39+
# %%
40+
# Export
41+
# ++++++
42+
DYN = torch.export.Dim.DYNAMIC
43+
44+
ep = torch.export.export(
45+
model, (x, y, fact), dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {1: DYN})
46+
)
47+
print(ep)
48+
49+
# %%
50+
# But does it really work? Let's print the shapes.
51+
model_ep = ep.module()
52+
ez = model_ep(x, y, fact)
53+
print("case 1:", z.shape, ez.shape)
54+
55+
# %%
56+
# Case with different shapes.
57+
58+
x = torch.arange(4).reshape((2, 2))
59+
y = torch.arange(9).reshape((3, 3))
60+
try:
61+
ez = model_ep(x, y, fact)
62+
print("case 2:", model(x, y, fact).shape, ez.shape)
63+
except Exception as e:
64+
print("case 2 failed:", e)
65+
66+
# %%
67+
# It does not even compute. The exported program does not get the correct shape.
68+
#
69+
# Rewritten Model
70+
# +++++++++++++++
71+
#
72+
# ``max`` does not get captured, :func:`torch.sym_max` is no better,
73+
# :func:`torch.max` only works on tensors. Nothing really works.
74+
# We use a trick to introduce new shape the shape inference algorithm
75+
# cannot know. This requires to hide the failing logic in a custom operator.
76+
77+
78+
def make_undefined_dimension(i: int) -> torch.SymInt:
79+
"""
80+
Uses for a custom op when a new dimension must be introduced to bypass
81+
some verification. The following function creates a dummy output
82+
with a dimension based on the content.
83+
84+
.. code-block:: python
85+
86+
def symbolic_shape(x, y):
87+
return torch.empty(
88+
x.shape[0],
89+
make_undefined_dimension(min(x.shape[1], y[0])),
90+
)
91+
"""
92+
t = torch.ones((i * 2,))
93+
t[:i] = 0
94+
res = torch.nonzero(t).shape[0]
95+
return res
96+
97+
98+
def copy_max_dimensions(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
99+
shape = torch.max(torch.tensor(x.shape), torch.tensor(y.shape))
100+
z = torch.zeros(tuple(shape), dtype=x.dtype)
101+
z[0 : x.shape[0], 0 : x.shape[1]] = x[0 : x.shape[0], 0 : x.shape[1]]
102+
z[0 : y.shape[0], 0 : y.shape[1]] += y[0 : y.shape[0], 0 : y.shape[1]]
103+
return z
104+
105+
106+
def symbolic_shape(x, y):
107+
return torch.empty(
108+
tuple(
109+
make_undefined_dimension(max(x.shape[i], y.shape[i])) for i in range(len(x.shape))
110+
),
111+
dtype=x.dtype,
112+
)
113+
114+
115+
def register(fct, fct_shape, namespace, fname):
116+
schema_str = torch.library.infer_schema(fct, mutates_args=())
117+
custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct)
118+
custom_def.register_kernel("cpu")(fct)
119+
custom_def._abstract_fn = fct_shape
120+
121+
122+
register(
123+
copy_max_dimensions, lambda x, y: symbolic_shape(x, y), "mylib", "copy_max_dimensions"
124+
)
125+
126+
# %%
127+
# Now everything is registered. Let's rewrite the model.
128+
129+
130+
class RewrittenModel(torch.nn.Module):
131+
def forward(self, x, y, fact):
132+
z = torch.ops.mylib.copy_max_dimensions(x, y)
133+
return z * fact
134+
135+
136+
# %%
137+
# And check it works.
138+
139+
rewritten_model = RewrittenModel()
140+
x = torch.arange(6).reshape((2, 3))
141+
y = torch.arange(6).reshape((3, 2)) * 10
142+
z = rewritten_model(x, y, fact)
143+
print(f"x.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}")
144+
145+
# %%
146+
# Export again
147+
# ++++++++++++
148+
149+
ep = torch.export.export(
150+
rewritten_model,
151+
(x, y, fact),
152+
dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {1: DYN}),
153+
)
154+
print(ep)
155+
156+
# %%
157+
# We check it works.
158+
159+
model_ep = ep.module()
160+
ez = model_ep(x, y, fact)
161+
print("case 1:", z.shape, ez.shape)
162+
163+
x = torch.arange(4).reshape((2, 2))
164+
y = torch.arange(9).reshape((3, 3))
165+
try:
166+
ez = model_ep(x, y, fact)
167+
print("case 2:", rewritten_model(x, y, fact).shape, ez.shape)
168+
except Exception as e:
169+
print("case 2 failed:", e)
170+
171+
# %%
172+
# Final Check on very different dimension
173+
# +++++++++++++++++++++++++++++++++++++++
174+
175+
x = torch.arange(6 * 8).reshape((6, 8))
176+
y = torch.arange(10 * 4).reshape((10, 4)) * 10
177+
fact = torch.arange(8).reshape((1, -1))
178+
179+
print("final case:", rewritten_model(x, y, fact).shape, model_ep(x, y, fact).shape)
180+
181+
# %%
182+
# This is not perfect as we get an exported program but some logic
183+
# is hidden in a custom operator.
184+
185+
186+
doc.plot_legend("dynamic shapes\nworkaround\nmax(d1, d2)", "dynamic shapes", "yellow")

_doc/recipes/plot_dynamic_shapes_nonzero.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Half certain nonzero
33
====================
44
5-
:func:`torch.nonzero` returns the indices or the first zero found
5+
:func:`torch.nonzero` returns the indices of the first zero found
66
in a tensor. The output shape is unknown in the generic case
77
but... If you have a 2D tensor with at least a nonzero value
88
in every row, you can guess the dimension. But :func:`torch.export.export`
@@ -49,7 +49,7 @@ def forward(self, x):
4949
# ++++++
5050

5151
DYN = torch.export.Dim.DYNAMIC
52-
ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),))
52+
ep = torch.export.export(model, (x,), dynamic_shapes=(({0: DYN, 1: DYN}),))
5353
print(ep)
5454

5555

_doc/recipes/plot_dynamic_shapes_python_int.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Do not use python int with dynamic shape
2+
Do not use python int with dynamic shapes
33
=========================================
44
55
:func:`torch.export.export` uses :class:`torch.SymInt` to operate on shapes and
@@ -36,7 +36,7 @@ def forward(self, x):
3636
# ++++++
3737

3838
DYN = torch.export.Dim.DYNAMIC
39-
ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),))
39+
ep = torch.export.export(model, (x,), dynamic_shapes=(({0: DYN, 1: DYN}),))
4040
print(ep)
4141

4242
# %%
@@ -65,7 +65,7 @@ def forward(self, x):
6565
# Export
6666
# ++++++
6767

68-
ep = torch.export.export(rewritten_model, (x,), dynamic_shapes=((DYN, DYN),))
68+
ep = torch.export.export(rewritten_model, (x,), dynamic_shapes=({0: DYN, 1: DYN},))
6969
print(ep)
7070

7171

@@ -79,7 +79,7 @@ def forward(self, x):
7979

8080

8181
with bypass_export_some_errors(stop_if_static=True):
82-
ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),))
82+
ep = torch.export.export(model, (x,), dynamic_shapes=({0: DYN, 1: DYN},))
8383
print(ep)
8484

8585
# %%

_unittests/ut_torch_export_patches/test_onnx_export_errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def forward(self, x: torch.Tensor, cache: MambaCache):
114114
DYN = torch.export.Dim.DYNAMIC
115115

116116
with bypass_export_some_errors():
117-
cache = MambaCache(_config(), max_batch_size=1, device="cpu")
117+
cache = MambaCache(_config(), max_batch_size=2, device="cpu")
118118
torch.export.export(
119119
Model(),
120120
(x, cache),

0 commit comments

Comments
 (0)