Skip to content

Commit 94d69b8

Browse files
committed
example
1 parent 4d93786 commit 94d69b8

File tree

2 files changed

+185
-1
lines changed

2 files changed

+185
-1
lines changed
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
"""
2+
Cannot export ``torch.sym_max(x.shape[0], y.shape[0])``
3+
=======================================================
4+
5+
This is related to the following issues:
6+
`Cannot export torch.sym_max(x.shape[0], y.shape[0])
7+
<https://github.com/pytorch/pytorch/issues/150851>`_.
8+
9+
The algorithm trying to automatically infer shapes after every operator
10+
in the exported program is something very aggreessive. Here is a case where
11+
it takes a wrong decision and how to get around it.
12+
13+
Wrong Model
14+
+++++++++++
15+
"""
16+
17+
import torch
18+
from onnx_diagnostic import doc
19+
20+
21+
class Model(torch.nn.Module):
22+
def forward(self, x, y, fact):
23+
s1 = max(x.shape[0], y.shape[0])
24+
s2 = max(x.shape[1], y.shape[1])
25+
# Shapes cannot be known here.
26+
z = torch.zeros((s1, s2), dtype=x.dtype)
27+
z[: x.shape[0], : x.shape[1]] = x
28+
z[: y.shape[0], : y.shape[1]] += y
29+
return z * fact
30+
31+
32+
model = Model()
33+
x = torch.arange(6).reshape((2, 3))
34+
y = torch.arange(6).reshape((3, 2)) * 10
35+
fact = torch.tensor([[1, 2, 3]], dtype=x.dtype)
36+
z = model(x, y, fact)
37+
print(f"x.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}")
38+
39+
# %%
40+
# Export
41+
# ++++++
42+
DYN = torch.export.Dim.DYNAMIC
43+
44+
ep = torch.export.export(
45+
model, (x, y, fact), dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {1: DYN})
46+
)
47+
print(ep)
48+
49+
# %%
50+
# But does it really work? Let's print the shapes.
51+
model_ep = ep.module()
52+
ez = model_ep(x, y, fact)
53+
print("case 1:", z.shape, ez.shape)
54+
55+
# %%
56+
# Case with different shapes.
57+
58+
x = torch.arange(4).reshape((2, 2))
59+
y = torch.arange(9).reshape((3, 3))
60+
try:
61+
ez = model_ep(x, y, fact)
62+
print("case 2:", model(x, y, fact).shape, ez.shape)
63+
except Exception as e:
64+
print("case 2 failed:", e)
65+
66+
# %%
67+
# It does not even compute. The exported program does not get the correct shape.
68+
#
69+
# Rewritten Model
70+
# +++++++++++++++
71+
#
72+
# ``max`` does not get captured, :func:`torch.sym_max` is no better,
73+
# :func:`torch.max` only works on tensors. Nothing really works.
74+
# We use a trick to introduce new shape the shape inference algorithm
75+
# cannot know. This requires to hide the failing logic in a custom operator.
76+
77+
78+
def make_undefined_dimension(i: int) -> torch.SymInt:
79+
"""
80+
Uses for a custom op when a new dimension must be introduced to bypass
81+
some verficiation. The following function creates a dummy output
82+
with a dimension based on the content.
83+
84+
.. code-block:: python
85+
86+
def symbolic_shape(x, y):
87+
return torch.empty(
88+
x.shape[0],
89+
make_undefined_dimension(min(x.shape[1], y[0])),
90+
)
91+
"""
92+
t = torch.ones((i * 2,))
93+
t[:i] = 0
94+
res = torch.nonzero(t).shape[0]
95+
return res
96+
97+
98+
def copy_max_dimensions(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
99+
shape = torch.max(torch.tensor(x.shape), torch.tensor(y.shape))
100+
z = torch.zeros(tuple(shape), dtype=x.dtype)
101+
z[0 : x.shape[0], 0 : x.shape[1]] = x[0 : x.shape[0], 0 : x.shape[1]]
102+
z[0 : y.shape[0], 0 : y.shape[1]] += y[0 : y.shape[0], 0 : y.shape[1]]
103+
return z
104+
105+
106+
def symbolic_shape(x, y):
107+
return torch.empty(
108+
tuple(
109+
make_undefined_dimension(max(x.shape[i], y.shape[i])) for i in range(len(x.shape))
110+
),
111+
dtype=x.dtype,
112+
)
113+
114+
115+
def register(fct, fct_shape, namespace, fname):
116+
schema_str = torch.library.infer_schema(fct, mutates_args=())
117+
custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct)
118+
custom_def.register_kernel("cpu")(fct)
119+
custom_def._abstract_fn = fct_shape
120+
121+
122+
register(
123+
copy_max_dimensions, lambda x, y: symbolic_shape(x, y), "mylib", "copy_max_dimensions"
124+
)
125+
126+
# %%
127+
# Now everything is registered. Let's rewrite the model.
128+
129+
class RewrittenModel(torch.nn.Module):
130+
def forward(self, x, y, fact):
131+
z = torch.ops.mylib.copy_max_dimensions(x, y)
132+
return z * fact
133+
134+
# %%
135+
# And check it works.
136+
137+
rewritten_model = RewrittenModel()
138+
x = torch.arange(6).reshape((2, 3))
139+
y = torch.arange(6).reshape((3, 2)) * 10
140+
z = rewritten_model(x, y, fact)
141+
print(f"x.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}")
142+
143+
# %%
144+
# Export again
145+
# ++++++++++++
146+
147+
ep = torch.export.export(
148+
rewritten_model,
149+
(x, y, fact),
150+
dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {1: DYN}),
151+
)
152+
print(ep)
153+
154+
# %%
155+
# We check it works.
156+
157+
model_ep = ep.module()
158+
ez = model_ep(x, y, fact)
159+
print("case 1:", z.shape, ez.shape)
160+
161+
x = torch.arange(4).reshape((2, 2))
162+
y = torch.arange(9).reshape((3, 3))
163+
try:
164+
ez = model_ep(x, y, fact)
165+
print("case 2:", rewritten_model(x, y, fact).shape, ez.shape)
166+
except Exception as e:
167+
print("case 2 failed:", e)
168+
169+
# %%
170+
# Final Check on very different dimension
171+
# +++++++++++++++++++++++++++++++++++++++
172+
173+
x = torch.arange(6 * 8).reshape((6, 8))
174+
y = torch.arange(10 * 4).reshape((10, 4)) * 10
175+
fact = torch.arange(8).reshape((1, -1))
176+
177+
print("final case:", rewritten_model(x, y, fact).shape, model_ep(x, y, fact).shape)
178+
179+
# %%
180+
# This is not perfect as we get an exported program but some logic
181+
# is hidden in a custom operator.
182+
183+
184+
doc.plot_legend("dynamic shapes\nworkaround\nmax(d1, d2)", "dynamic shapes", "yellow")

_doc/recipes/plot_dynamic_shapes_python_int.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Do not use python int with dynamic shape
2+
Do not use python int with dynamic shapes
33
=========================================
44
55
:func:`torch.export.export` uses :class:`torch.SymInt` to operate on shapes and

0 commit comments

Comments
 (0)