|
| 1 | +""" |
| 2 | +Dynamic Shapes and Broadcasting |
| 3 | +=============================== |
| 4 | +
|
| 5 | +:func:`torch.export.export` makes strict assumption on dynamic shapes |
| 6 | +to the generic case. Let's consider two tensors with only one dimension. |
| 7 | +``x * y`` allows four configurations: |
| 8 | +
|
| 9 | +* ``shape(x) = (1,)`` and ``shape(y) = (1,)`` |
| 10 | +* ``shape(x) = (1,)`` and ``shape(y) = (p,)`` |
| 11 | +* ``shape(x) = (q,)`` and ``shape(y) = (1,)`` |
| 12 | +* ``shape(x) = (p,)`` and ``shape(y) = (p,)`` |
| 13 | +
|
| 14 | +The expected shape for ``shape(x * y)`` is ``(max(p,q),)``. |
| 15 | +
|
| 16 | +Simple Case |
| 17 | ++++++++++++ |
| 18 | +
|
| 19 | +""" |
| 20 | + |
| 21 | +import torch |
| 22 | +from torch.fx.experimental.symbolic_shapes import ShapeEnv |
| 23 | +from torch._subclasses.fake_tensor import FakeTensorMode |
| 24 | +from torch.fx.passes.fake_tensor_prop import FakeTensorProp |
| 25 | +from onnx_diagnostic.torch_export_patches import torch_export_patches |
| 26 | +from torch.fx import Tracer |
| 27 | + |
| 28 | + |
| 29 | +class Model(torch.nn.Module): |
| 30 | + def forward(self, x, y): |
| 31 | + return x * y |
| 32 | + |
| 33 | + |
| 34 | +Dim = torch.export.Dim |
| 35 | + |
| 36 | +ep = torch.export.export( |
| 37 | + Model(), |
| 38 | + (torch.tensor([2, 3], dtype=torch.float32), torch.tensor([2, 3], dtype=torch.float32)), |
| 39 | + dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}), |
| 40 | +) |
| 41 | +print(ep) |
| 42 | + |
| 43 | +# %% |
| 44 | +# We see clearly that the export assumed that ``x`` ad ``y`` had the same shape. |
| 45 | +# No other configuration seemed to work at export time, |
| 46 | +# including ``with torch.fx.experimental._config.patch(backed_size_oblivious=True):`` |
| 47 | +# the shape of one tensor equal to ``(1,)``. |
| 48 | + |
| 49 | +output = [n for n in ep.graph.nodes if n.op == "output"][0] |
| 50 | +print("output is ", output.name, " arg is", output.args[0]) |
| 51 | + |
| 52 | +# %% |
| 53 | +# The final shape is: |
| 54 | + |
| 55 | +shape = output.args[0][0].meta["val"].shape |
| 56 | +print("output shape is ", shape) |
| 57 | + |
| 58 | +# %% |
| 59 | +# Tracing |
| 60 | +# +++++++ |
| 61 | +# |
| 62 | +# Let's compare with what a simple tracing would do. Let's use :class:`torch.fx.Tracer`. |
| 63 | + |
| 64 | +graph = Tracer().trace(Model()) |
| 65 | +print(graph) |
| 66 | + |
| 67 | +# %% |
| 68 | +output = [n for n in graph.nodes if n.op == "output"][0] |
| 69 | +print("output is ", output.name, " arg is", output.args[0]) |
| 70 | +print("The tracer leaves no trace:", output.args[0].__dict__) |
| 71 | + |
| 72 | +# %% |
| 73 | +# Shape propagation |
| 74 | +# +++++++++++++++++ |
| 75 | + |
| 76 | +gm = torch.fx.GraphModule(Model(), graph) |
| 77 | + |
| 78 | +shape_env = ShapeEnv() |
| 79 | +fake_mode = FakeTensorMode(shape_env=shape_env) |
| 80 | +# d1 = shape_env.create_unbacked_symint() |
| 81 | +# d2 = shape_env.create_unbacked_symint() |
| 82 | +fake_inputs = fake_mode.from_tensor( |
| 83 | + torch.zeros((2,), dtype=torch.float32), static_shapes=False |
| 84 | +), fake_mode.from_tensor(torch.zeros((2,), dtype=torch.float32), static_shapes=False) |
| 85 | + |
| 86 | +print("fake_inputs are ", fake_inputs) |
| 87 | +res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs) |
| 88 | +print("output is", res) |
| 89 | + |
| 90 | +# %% |
| 91 | +# Handle Different Shapes |
| 92 | +# +++++++++++++++++++++++ |
| 93 | + |
| 94 | +fake_inputs = fake_mode.from_tensor( |
| 95 | + torch.zeros((2,), dtype=torch.float32), static_shapes=False |
| 96 | +), fake_mode.from_tensor(torch.zeros((1,), dtype=torch.float32), static_shapes=False) |
| 97 | + |
| 98 | +print("fake_inputs are ", fake_inputs) |
| 99 | +res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs) |
| 100 | +print("output is", res) |
| 101 | + |
| 102 | +# %% |
| 103 | +# Conclusion |
| 104 | +# ++++++++++ |
| 105 | +# |
| 106 | +# We need to give distinct dimensions to get distinct names. |
| 107 | + |
| 108 | +fake_inputs = fake_mode.from_tensor( |
| 109 | + torch.zeros((2,), dtype=torch.float32), static_shapes=False |
| 110 | +), fake_mode.from_tensor(torch.zeros((3,), dtype=torch.float32), static_shapes=False) |
| 111 | +print("fake_inputs are ", fake_inputs) |
| 112 | + |
| 113 | + |
| 114 | +# %% |
| 115 | +try: |
| 116 | + res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs) |
| 117 | +except Exception as e: |
| 118 | + print(e) |
| 119 | + |
| 120 | +# %% |
| 121 | +# By applying the patches: |
| 122 | + |
| 123 | +with torch_export_patches(): |
| 124 | + res = FakeTensorProp(gm, fake_mode).propagate(*fake_inputs) |
| 125 | + print("output is", res) |
| 126 | + |
| 127 | +# %% |
| 128 | +# This is what we want. Let's go back to :func:`torch.export.export` |
| 129 | + |
| 130 | +with torch_export_patches(): |
| 131 | + ep = torch.export.export( |
| 132 | + Model(), |
| 133 | + ( |
| 134 | + torch.tensor([2, 3], dtype=torch.float32), |
| 135 | + torch.tensor([2, 3, 4], dtype=torch.float32), |
| 136 | + ), |
| 137 | + dynamic_shapes=({0: Dim.DYNAMIC}, {0: Dim.DYNAMIC}), |
| 138 | + ) |
| 139 | + print(ep) |
| 140 | + |
| 141 | +# %% |
| 142 | +output = [n for n in ep.graph.nodes if n.op == "output"][0] |
| 143 | +print("output is ", output.name, " arg is", output.args[0]) |
| 144 | +shape = output.args[0][0].meta["val"].shape |
| 145 | +print("output shape is ", shape) |
0 commit comments