|
| 1 | +""" |
| 2 | +Cannot export ``torch.sym_max(x.shape[0], y.shape[0])`` |
| 3 | +======================================================= |
| 4 | +
|
| 5 | +This is related to the following issues: |
| 6 | +`Cannot export torch.sym_max(x.shape[0], y.shape[0]) |
| 7 | +<https://github.com/pytorch/pytorch/issues/150851>`_. |
| 8 | +
|
| 9 | +The algorithm trying to automatically infer shapes after every operator |
| 10 | +in the exported program is something very aggreessive. Here is a case where |
| 11 | +it takes a wrong decision and how to get around it. |
| 12 | +
|
| 13 | +Wrong Model |
| 14 | ++++++++++++ |
| 15 | +""" |
| 16 | + |
| 17 | +import torch |
| 18 | +from onnx_diagnostic import doc |
| 19 | + |
| 20 | + |
| 21 | +class Model(torch.nn.Module): |
| 22 | + def forward(self, x, y, fact): |
| 23 | + s1 = max(x.shape[0], y.shape[0]) |
| 24 | + s2 = max(x.shape[1], y.shape[1]) |
| 25 | + # Shapes cannot be known here. |
| 26 | + z = torch.zeros((s1, s2), dtype=x.dtype) |
| 27 | + z[: x.shape[0], : x.shape[1]] = x |
| 28 | + z[: y.shape[0], : y.shape[1]] += y |
| 29 | + return z * fact |
| 30 | + |
| 31 | + |
| 32 | +model = Model() |
| 33 | +x = torch.arange(6).reshape((2, 3)) |
| 34 | +y = torch.arange(6).reshape((3, 2)) * 10 |
| 35 | +fact = torch.tensor([[1, 2, 3]], dtype=x.dtype) |
| 36 | +z = model(x, y, fact) |
| 37 | +print(f"x.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}") |
| 38 | + |
| 39 | +# %% |
| 40 | +# Export |
| 41 | +# ++++++ |
| 42 | +DYN = torch.export.Dim.DYNAMIC |
| 43 | + |
| 44 | +ep = torch.export.export( |
| 45 | + model, (x, y, fact), dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {1: DYN}) |
| 46 | +) |
| 47 | +print(ep) |
| 48 | + |
| 49 | +# %% |
| 50 | +# But does it really work? Let's print the shapes. |
| 51 | +model_ep = ep.module() |
| 52 | +ez = model_ep(x, y, fact) |
| 53 | +print("case 1:", z.shape, ez.shape) |
| 54 | + |
| 55 | +# %% |
| 56 | +# Case with different shapes. |
| 57 | + |
| 58 | +x = torch.arange(4).reshape((2, 2)) |
| 59 | +y = torch.arange(9).reshape((3, 3)) |
| 60 | +try: |
| 61 | + ez = model_ep(x, y, fact) |
| 62 | + print("case 2:", model(x, y, fact).shape, ez.shape) |
| 63 | +except Exception as e: |
| 64 | + print("case 2 failed:", e) |
| 65 | + |
| 66 | +# %% |
| 67 | +# It does not even compute. The exported program does not get the correct shape. |
| 68 | +# |
| 69 | +# Rewritten Model |
| 70 | +# +++++++++++++++ |
| 71 | +# |
| 72 | +# ``max`` does not get captured, :func:`torch.sym_max` is no better, |
| 73 | +# :func:`torch.max` only works on tensors. Nothing really works. |
| 74 | +# We use a trick to introduce new shape the shape inference algorithm |
| 75 | +# cannot know. This requires to hide the failing logic in a custom operator. |
| 76 | + |
| 77 | + |
| 78 | +def make_undefined_dimension(i: int) -> torch.SymInt: |
| 79 | + """ |
| 80 | + Uses for a custom op when a new dimension must be introduced to bypass |
| 81 | + some verification. The following function creates a dummy output |
| 82 | + with a dimension based on the content. |
| 83 | +
|
| 84 | + .. code-block:: python |
| 85 | +
|
| 86 | + def symbolic_shape(x, y): |
| 87 | + return torch.empty( |
| 88 | + x.shape[0], |
| 89 | + make_undefined_dimension(min(x.shape[1], y[0])), |
| 90 | + ) |
| 91 | + """ |
| 92 | + t = torch.ones((i * 2,)) |
| 93 | + t[:i] = 0 |
| 94 | + res = torch.nonzero(t).shape[0] |
| 95 | + return res |
| 96 | + |
| 97 | + |
| 98 | +def copy_max_dimensions(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
| 99 | + shape = torch.max(torch.tensor(x.shape), torch.tensor(y.shape)) |
| 100 | + z = torch.zeros(tuple(shape), dtype=x.dtype) |
| 101 | + z[0 : x.shape[0], 0 : x.shape[1]] = x[0 : x.shape[0], 0 : x.shape[1]] |
| 102 | + z[0 : y.shape[0], 0 : y.shape[1]] += y[0 : y.shape[0], 0 : y.shape[1]] |
| 103 | + return z |
| 104 | + |
| 105 | + |
| 106 | +def symbolic_shape(x, y): |
| 107 | + return torch.empty( |
| 108 | + tuple( |
| 109 | + make_undefined_dimension(max(x.shape[i], y.shape[i])) for i in range(len(x.shape)) |
| 110 | + ), |
| 111 | + dtype=x.dtype, |
| 112 | + ) |
| 113 | + |
| 114 | + |
| 115 | +def register(fct, fct_shape, namespace, fname): |
| 116 | + schema_str = torch.library.infer_schema(fct, mutates_args=()) |
| 117 | + custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct) |
| 118 | + custom_def.register_kernel("cpu")(fct) |
| 119 | + custom_def._abstract_fn = fct_shape |
| 120 | + |
| 121 | + |
| 122 | +register( |
| 123 | + copy_max_dimensions, lambda x, y: symbolic_shape(x, y), "mylib", "copy_max_dimensions" |
| 124 | +) |
| 125 | + |
| 126 | +# %% |
| 127 | +# Now everything is registered. Let's rewrite the model. |
| 128 | + |
| 129 | + |
| 130 | +class RewrittenModel(torch.nn.Module): |
| 131 | + def forward(self, x, y, fact): |
| 132 | + z = torch.ops.mylib.copy_max_dimensions(x, y) |
| 133 | + return z * fact |
| 134 | + |
| 135 | + |
| 136 | +# %% |
| 137 | +# And check it works. |
| 138 | + |
| 139 | +rewritten_model = RewrittenModel() |
| 140 | +x = torch.arange(6).reshape((2, 3)) |
| 141 | +y = torch.arange(6).reshape((3, 2)) * 10 |
| 142 | +z = rewritten_model(x, y, fact) |
| 143 | +print(f"x.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}") |
| 144 | + |
| 145 | +# %% |
| 146 | +# Export again |
| 147 | +# ++++++++++++ |
| 148 | + |
| 149 | +ep = torch.export.export( |
| 150 | + rewritten_model, |
| 151 | + (x, y, fact), |
| 152 | + dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {1: DYN}), |
| 153 | +) |
| 154 | +print(ep) |
| 155 | + |
| 156 | +# %% |
| 157 | +# We check it works. |
| 158 | + |
| 159 | +model_ep = ep.module() |
| 160 | +ez = model_ep(x, y, fact) |
| 161 | +print("case 1:", z.shape, ez.shape) |
| 162 | + |
| 163 | +x = torch.arange(4).reshape((2, 2)) |
| 164 | +y = torch.arange(9).reshape((3, 3)) |
| 165 | +try: |
| 166 | + ez = model_ep(x, y, fact) |
| 167 | + print("case 2:", rewritten_model(x, y, fact).shape, ez.shape) |
| 168 | +except Exception as e: |
| 169 | + print("case 2 failed:", e) |
| 170 | + |
| 171 | +# %% |
| 172 | +# Final Check on very different dimension |
| 173 | +# +++++++++++++++++++++++++++++++++++++++ |
| 174 | + |
| 175 | +x = torch.arange(6 * 8).reshape((6, 8)) |
| 176 | +y = torch.arange(10 * 4).reshape((10, 4)) * 10 |
| 177 | +fact = torch.arange(8).reshape((1, -1)) |
| 178 | + |
| 179 | +print("final case:", rewritten_model(x, y, fact).shape, model_ep(x, y, fact).shape) |
| 180 | + |
| 181 | +# %% |
| 182 | +# This is not perfect as we get an exported program but some logic |
| 183 | +# is hidden in a custom operator. |
| 184 | + |
| 185 | + |
| 186 | +doc.plot_legend("dynamic shapes\nworkaround\nmax(d1, d2)", "dynamic shapes", "yellow") |
0 commit comments