Skip to content

Commit cf08bd9

Browse files
committed
add unit test and example for boradcast_max
1 parent bbce496 commit cf08bd9

File tree

3 files changed

+167
-1
lines changed

3 files changed

+167
-1
lines changed
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
"""
2+
Dynamic Shapes and Broadcasting
3+
===============================
4+
5+
:func:`torch.exoprt.export` makes strict assumption on dynamic shapes
6+
to the generic case. Let's consider two tensors with only one dimension.
7+
``x * y`` allows four configurations:
8+
9+
* ``shape(x) = (1,)`` and ``shape(y) = (1,)``
10+
* ``shape(x) = (1,)`` and ``shape(y) = (p,)``
11+
* ``shape(x) = (q,)`` and ``shape(y) = (1,)``
12+
* ``shape(x) = (p,)`` and ``shape(y) = (p,)``
13+
14+
The expected shape for ``shape(x * y)`` is ``(max(p,q),)``.
15+
16+
Simple Case
17+
+++++++++++
18+
19+
"""
20+
21+
import torch
22+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
23+
from torch._subclasses.fake_tensor import FakeTensorMode
24+
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
25+
from onnx_diagnostic.torch_export_patches import torch_export_patches
26+
from experimental_experiment.torch_interpreter.tracing import CustomTracer
27+
28+
29+
class Model(torch.nn.Module):
30+
def forward(self, x, y):
31+
return x * y
32+
33+
34+
Dim = torch.export.Dim
35+
36+
ep = torch.export.export(
37+
Model(),
38+
(torch.tensor([2, 3], dtype=torch.float32), torch.tensor([2, 3], dtype=torch.float32)),
39+
dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
40+
)
41+
print(ep)
42+
43+
# %%
44+
# We see clearly that the export assummed that ``x`` ad ``y`` had the same shape.
45+
# No other configuration seemed to work at export time,
46+
# including ``with torch.fx.experimental._config.patch(backed_size_oblivious=True):``
47+
# the shape of one tensor equal to ``(1,)``.
48+
49+
output = [n for n in ep.graph.nodes if n.op == "output"][0]
50+
print("output is ", output.name, " arg is", output.args[0])
51+
52+
# %%
53+
# The final shape is:
54+
55+
shape = output.args[0][0].meta["val"].shape
56+
print("output shape is ", shape)
57+
58+
# %%
59+
# Tracing
60+
# +++++++
61+
#
62+
# Let's compare with what a simple tracing would do. :class:`CustomTracer
63+
# <experimental_experiment.torch_interpreter.tracing.CustomTracer>`
64+
# is a customize tracer built on top of :class:`torch.fx.Tracer`.
65+
66+
graph = CustomTracer().trace(Model())
67+
print(graph)
68+
69+
# %%
70+
output = [n for n in graph.nodes if n.op == "output"][0]
71+
print("output is ", output.name, " arg is", output.args[0])
72+
print("The tracer leaves no trace:", output.args[0].__dict__)
73+
74+
# %%
75+
# Shape propagation
76+
# +++++++++++++++++
77+
78+
gm = torch.fx.GraphModule(Model(), graph)
79+
80+
shape_env = ShapeEnv()
81+
fake_mode = FakeTensorMode(shape_env=shape_env)
82+
# d1 = shape_env.create_unbacked_symint()
83+
# d2 = shape_env.create_unbacked_symint()
84+
fake_inputs = fake_mode.from_tensor(
85+
torch.zeros((2,), dtype=torch.float32), static_shapes=False
86+
), fake_mode.from_tensor(torch.zeros((2,), dtype=torch.float32), static_shapes=False)
87+
88+
print("fake_inputs are ", fake_inputs)
89+
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
90+
print("output is", res)
91+
92+
# %%
93+
# Handle Different Shapes
94+
# +++++++++++++++++++++++
95+
96+
fake_inputs = fake_mode.from_tensor(
97+
torch.zeros((2,), dtype=torch.float32), static_shapes=False
98+
), fake_mode.from_tensor(torch.zeros((1,), dtype=torch.float32), static_shapes=False)
99+
100+
print("fake_inputs are ", fake_inputs)
101+
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
102+
print("output is", res)
103+
104+
# %%
105+
# Conclusion
106+
# ++++++++++
107+
#
108+
# We need to give distinct dimensions to get distinct names.
109+
110+
fake_inputs = fake_mode.from_tensor(
111+
torch.zeros((2,), dtype=torch.float32), static_shapes=False
112+
), fake_mode.from_tensor(torch.zeros((3,), dtype=torch.float32), static_shapes=False)
113+
print("fake_inputs are ", fake_inputs)
114+
115+
116+
# %%
117+
try:
118+
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
119+
except Exception as e:
120+
print(e)
121+
122+
# %%
123+
# By applying the patches:
124+
125+
with torch_export_patches():
126+
res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs)
127+
print("output is", res)
128+
129+
# %%
130+
# This is what we want. Let's go back to :func:`torch.export.export`
131+
132+
with torch_export_patches():
133+
ep = torch.export.export(
134+
Model(),
135+
(
136+
torch.tensor([2, 3], dtype=torch.float32),
137+
torch.tensor([2, 3, 4], dtype=torch.float32),
138+
),
139+
dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
140+
)
141+
print(ep)
142+
143+
# %%
144+
output = [n for n in ep.graph.nodes if n.op == "output"][0]
145+
print("output is ", output.name, " arg is", output.args[0])
146+
shape = output.args[0][0].meta["val"].shape
147+
print("output shape is ", shape)

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,25 @@ def forward(self, x, ind1, ind2):
491491
)
492492
self.assertEqualArray(expected, ep.module()(*inputs))
493493

494+
def test_broadcast_max(self):
495+
class Model(torch.nn.Module):
496+
def forward(self, x, y):
497+
return x * y
498+
499+
Dim = torch.export.Dim
500+
with torch_export_patches():
501+
ep = torch.export.export(
502+
Model(),
503+
(
504+
torch.tensor([2, 3], dtype=torch.float32),
505+
torch.tensor([2, 3, 4], dtype=torch.float32),
506+
),
507+
dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}),
508+
)
509+
output = [n for n in ep.graph.nodes if n.op == "output"]
510+
shape = output[0].args[0][0].meta["val"].shape
511+
self.assertEqual(str(shape), "torch.Size([Max(s17, s77)])")
512+
494513

495514
if __name__ == "__main__":
496515
unittest.main(verbosity=2)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ select = [
150150
"_doc/notebooks/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
151151
"_doc/recipes/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
152152
"_scripts/compare_model_execution.py" = ["E402", "F401"]
153-
"_doc/technical/plot_*.py" = ["E402", "B018", "PIE808", "SIM105", "SIM117"]
153+
"_doc/technical/plot_*.py" = ["E402", "B018", "PIE808", "RUF015", "SIM105", "SIM117"]
154154
"_unittests/*/test*.py" = ["B008", "B904", "PIE808", "SIM117", "SIM105", "UP008"]
155155
"onnx_diagnostic/export/__init__.py" = ["F401"]
156156
"onnx_diagnostic/helpers/__init__.py" = ["F401"]

0 commit comments

Comments
 (0)