Skip to content

Commit 283b2cd

Browse files
authored
Adds unit test and example for broadcast rules related to dimension (#269)
* add unit test and example for boradcast_max * documentation
1 parent bbce496 commit 283b2cd

File tree

4 files changed

+166
-1
lines changed

4 files changed

+166
-1
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.16
55
++++++
66

7+
* :pr:`269`: adds one unit test to track a patch fixing broadcast output shape
78
* :pr:`267`: patches ``sdpa_attention_forward`` because of a control flow (``transformers>=5.0``)
89
* :pr:`266`: makes ``patch_torch`` an integer in ``torch_export_patches`` to enable more patches
910

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""
2+
Dynamic Shapes and Broadcasting
3+
===============================
4+
5+
:func:`torch.export.export` makes strict assumption on dynamic shapes
6+
to the generic case. Let's consider two tensors with only one dimension.
7+
``x * y`` allows four configurations:
8+
9+
* ``shape(x) = (1,)`` and ``shape(y) = (1,)``
10+
* ``shape(x) = (1,)`` and ``shape(y) = (p,)``
11+
* ``shape(x) = (q,)`` and ``shape(y) = (1,)``
12+
* ``shape(x) = (p,)`` and ``shape(y) = (p,)``
13+
14+
The expected shape for ``shape(x * y)`` is ``(max(p,q),)``.
15+
16+
Simple Case
17+
+++++++++++
18+
19+
"""
20+
21+
import torch
22+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
23+
from torch._subclasses.fake_tensor import FakeTensorMode
24+
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
25+
from onnx_diagnostic.torch_export_patches import torch_export_patches
26+
from torch.fx import Tracer
27+
28+
29+
class Model(torch.nn.Module):
30+
def forward(self, x, y):
31+
return x * y
32+
33+
34+
Dim = torch.export.Dim
35+
36+
ep = torch.export.export(
37+
Model(),
38+
(torch.tensor([2, 3], dtype=torch.float32), torch.tensor([2, 3], dtype=torch.float32)),
39+
dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
40+
)
41+
print(ep)
42+
43+
# %%
44+
# We see clearly that the export assumed that ``x`` ad ``y`` had the same shape.
45+
# No other configuration seemed to work at export time,
46+
# including ``with torch.fx.experimental._config.patch(backed_size_oblivious=True):``
47+
# the shape of one tensor equal to ``(1,)``.
48+
49+
output = [n for n in ep.graph.nodes if n.op == "output"][0]
50+
print("output is ", output.name, " arg is", output.args[0])
51+
52+
# %%
53+
# The final shape is:
54+
55+
shape = output.args[0][0].meta["val"].shape
56+
print("output shape is ", shape)
57+
58+
# %%
59+
# Tracing
60+
# +++++++
61+
#
62+
# Let's compare with what a simple tracing would do. Let's use :class:`torch.fx.Tracer`.
63+
64+
graph = Tracer().trace(Model())
65+
print(graph)
66+
67+
# %%
68+
output = [n for n in graph.nodes if n.op == "output"][0]
69+
print("output is ", output.name, " arg is", output.args[0])
70+
print("The tracer leaves no trace:", output.args[0].__dict__)
71+
72+
# %%
73+
# Shape propagation
74+
# +++++++++++++++++
75+
76+
gm = torch.fx.GraphModule(Model(), graph)
77+
78+
shape_env = ShapeEnv()
79+
fake_mode = FakeTensorMode(shape_env=shape_env)
80+
# d1 = shape_env.create_unbacked_symint()
81+
# d2 = shape_env.create_unbacked_symint()
82+
fake_inputs = fake_mode.from_tensor(
83+
torch.zeros((2,), dtype=torch.float32), static_shapes=False
84+
), fake_mode.from_tensor(torch.zeros((2,), dtype=torch.float32), static_shapes=False)
85+
86+
print("fake_inputs are ", fake_inputs)
87+
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
88+
print("output is", res)
89+
90+
# %%
91+
# Handle Different Shapes
92+
# +++++++++++++++++++++++
93+
94+
fake_inputs = fake_mode.from_tensor(
95+
torch.zeros((2,), dtype=torch.float32), static_shapes=False
96+
), fake_mode.from_tensor(torch.zeros((1,), dtype=torch.float32), static_shapes=False)
97+
98+
print("fake_inputs are ", fake_inputs)
99+
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
100+
print("output is", res)
101+
102+
# %%
103+
# Conclusion
104+
# ++++++++++
105+
#
106+
# We need to give distinct dimensions to get distinct names.
107+
108+
fake_inputs = fake_mode.from_tensor(
109+
torch.zeros((2,), dtype=torch.float32), static_shapes=False
110+
), fake_mode.from_tensor(torch.zeros((3,), dtype=torch.float32), static_shapes=False)
111+
print("fake_inputs are ", fake_inputs)
112+
113+
114+
# %%
115+
try:
116+
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
117+
except Exception as e:
118+
print(e)
119+
120+
# %%
121+
# By applying the patches:
122+
123+
with torch_export_patches():
124+
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
125+
print("output is", res)
126+
127+
# %%
128+
# This is what we want. Let's go back to :func:`torch.export.export`
129+
130+
with torch_export_patches():
131+
ep = torch.export.export(
132+
Model(),
133+
(
134+
torch.tensor([2, 3], dtype=torch.float32),
135+
torch.tensor([2, 3, 4], dtype=torch.float32),
136+
),
137+
dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
138+
)
139+
print(ep)
140+
141+
# %%
142+
output = [n for n in ep.graph.nodes if n.op == "output"][0]
143+
print("output is ", output.name, " arg is", output.args[0])
144+
shape = output.args[0][0].meta["val"].shape
145+
print("output shape is ", shape)

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,25 @@ def forward(self, x, ind1, ind2):
491491
)
492492
self.assertEqualArray(expected, ep.module()(*inputs))
493493

494+
def test_broadcast_max(self):
495+
class Model(torch.nn.Module):
496+
def forward(self, x, y):
497+
return x * y
498+
499+
Dim = torch.export.Dim
500+
with torch_export_patches():
501+
ep = torch.export.export(
502+
Model(),
503+
(
504+
torch.tensor([2, 3], dtype=torch.float32),
505+
torch.tensor([2, 3, 4], dtype=torch.float32),
506+
),
507+
dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
508+
)
509+
output = [n for n in ep.graph.nodes if n.op == "output"]
510+
shape = output[0].args[0][0].meta["val"].shape
511+
self.assertEqual(str(shape), "torch.Size([Max(s17, s77)])")
512+
494513

495514
if __name__ == "__main__":
496515
unittest.main(verbosity=2)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ select = [
150150
"_doc/notebooks/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
151151
"_doc/recipes/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
152152
"_scripts/compare_model_execution.py" = ["E402", "F401"]
153-
"_doc/technical/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
153+
"_doc/technical/plot_*.py" = ["E402", "B018", "PIE808", "RUF015", "SIM105", "SIM117"]
154154
"_unittests/*/test*.py" = ["B008", "B904", "PIE808", "SIM117", "SIM105", "UP008"]
155155
"onnx_diagnostic/export/__init__.py" = ["F401"]
156156
"onnx_diagnostic/helpers/__init__.py" = ["F401"]

0 commit comments

Comments
 (0)