From 94d69b87b7e0a8b00fb7d314806339b014bec698 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 8 Apr 2025 18:44:49 +0200 Subject: [PATCH 1/3] example --- _doc/recipes/plot_dynamic_shapes_max.py | 184 ++++++++++++++++++ .../recipes/plot_dynamic_shapes_python_int.py | 2 +- 2 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 _doc/recipes/plot_dynamic_shapes_max.py diff --git a/_doc/recipes/plot_dynamic_shapes_max.py b/_doc/recipes/plot_dynamic_shapes_max.py new file mode 100644 index 00000000..4d1a1909 --- /dev/null +++ b/_doc/recipes/plot_dynamic_shapes_max.py @@ -0,0 +1,184 @@ +""" +Cannot export ``torch.sym_max(x.shape[0], y.shape[0])`` +======================================================= + +This is related to the following issues: +`Cannot export torch.sym_max(x.shape[0], y.shape[0]) +`_. + +The algorithm trying to automatically infer shapes after every operator +in the exported program is something very aggreessive. Here is a case where +it takes a wrong decision and how to get around it. + +Wrong Model ++++++++++++ +""" + +import torch +from onnx_diagnostic import doc + + +class Model(torch.nn.Module): + def forward(self, x, y, fact): + s1 = max(x.shape[0], y.shape[0]) + s2 = max(x.shape[1], y.shape[1]) + # Shapes cannot be known here. + z = torch.zeros((s1, s2), dtype=x.dtype) + z[: x.shape[0], : x.shape[1]] = x + z[: y.shape[0], : y.shape[1]] += y + return z * fact + + +model = Model() +x = torch.arange(6).reshape((2, 3)) +y = torch.arange(6).reshape((3, 2)) * 10 +fact = torch.tensor([[1, 2, 3]], dtype=x.dtype) +z = model(x, y, fact) +print(f"x.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}") + +# %% +# Export +# ++++++ +DYN = torch.export.Dim.DYNAMIC + +ep = torch.export.export( + model, (x, y, fact), dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {1: DYN}) +) +print(ep) + +# %% +# But does it really work? Let's print the shapes. +model_ep = ep.module() +ez = model_ep(x, y, fact) +print("case 1:", z.shape, ez.shape) + +# %% +# Case with different shapes. + +x = torch.arange(4).reshape((2, 2)) +y = torch.arange(9).reshape((3, 3)) +try: + ez = model_ep(x, y, fact) + print("case 2:", model(x, y, fact).shape, ez.shape) +except Exception as e: + print("case 2 failed:", e) + +# %% +# It does not even compute. The exported program does not get the correct shape. +# +# Rewritten Model +# +++++++++++++++ +# +# ``max`` does not get captured, :func:`torch.sym_max` is no better, +# :func:`torch.max` only works on tensors. Nothing really works. +# We use a trick to introduce new shape the shape inference algorithm +# cannot know. This requires to hide the failing logic in a custom operator. + + +def make_undefined_dimension(i: int) -> torch.SymInt: + """ + Uses for a custom op when a new dimension must be introduced to bypass + some verficiation. The following function creates a dummy output + with a dimension based on the content. + + .. code-block:: python + + def symbolic_shape(x, y): + return torch.empty( + x.shape[0], + make_undefined_dimension(min(x.shape[1], y[0])), + ) + """ + t = torch.ones((i * 2,)) + t[:i] = 0 + res = torch.nonzero(t).shape[0] + return res + + +def copy_max_dimensions(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + shape = torch.max(torch.tensor(x.shape), torch.tensor(y.shape)) + z = torch.zeros(tuple(shape), dtype=x.dtype) + z[0 : x.shape[0], 0 : x.shape[1]] = x[0 : x.shape[0], 0 : x.shape[1]] + z[0 : y.shape[0], 0 : y.shape[1]] += y[0 : y.shape[0], 0 : y.shape[1]] + return z + + +def symbolic_shape(x, y): + return torch.empty( + tuple( + make_undefined_dimension(max(x.shape[i], y.shape[i])) for i in range(len(x.shape)) + ), + dtype=x.dtype, + ) + + +def register(fct, fct_shape, namespace, fname): + schema_str = torch.library.infer_schema(fct, mutates_args=()) + custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct) + custom_def.register_kernel("cpu")(fct) + custom_def._abstract_fn = fct_shape + + +register( + copy_max_dimensions, lambda x, y: symbolic_shape(x, y), "mylib", "copy_max_dimensions" +) + +# %% +# Now everything is registered. Let's rewrite the model. + +class RewrittenModel(torch.nn.Module): + def forward(self, x, y, fact): + z = torch.ops.mylib.copy_max_dimensions(x, y) + return z * fact + +# %% +# And check it works. + +rewritten_model = RewrittenModel() +x = torch.arange(6).reshape((2, 3)) +y = torch.arange(6).reshape((3, 2)) * 10 +z = rewritten_model(x, y, fact) +print(f"x.shape={x.shape}, y.shape={y.shape}, z.shape={z.shape}") + +# %% +# Export again +# ++++++++++++ + +ep = torch.export.export( + rewritten_model, + (x, y, fact), + dynamic_shapes=({0: DYN, 1: DYN}, {0: DYN, 1: DYN}, {1: DYN}), +) +print(ep) + +# %% +# We check it works. + +model_ep = ep.module() +ez = model_ep(x, y, fact) +print("case 1:", z.shape, ez.shape) + +x = torch.arange(4).reshape((2, 2)) +y = torch.arange(9).reshape((3, 3)) +try: + ez = model_ep(x, y, fact) + print("case 2:", rewritten_model(x, y, fact).shape, ez.shape) +except Exception as e: + print("case 2 failed:", e) + +# %% +# Final Check on very different dimension +# +++++++++++++++++++++++++++++++++++++++ + +x = torch.arange(6 * 8).reshape((6, 8)) +y = torch.arange(10 * 4).reshape((10, 4)) * 10 +fact = torch.arange(8).reshape((1, -1)) + +print("final case:", rewritten_model(x, y, fact).shape, model_ep(x, y, fact).shape) + +# %% +# This is not perfect as we get an exported program but some logic +# is hidden in a custom operator. + + +doc.plot_legend("dynamic shapes\nworkaround\nmax(d1, d2)", "dynamic shapes", "yellow") diff --git a/_doc/recipes/plot_dynamic_shapes_python_int.py b/_doc/recipes/plot_dynamic_shapes_python_int.py index e9c385a0..126bb3c2 100644 --- a/_doc/recipes/plot_dynamic_shapes_python_int.py +++ b/_doc/recipes/plot_dynamic_shapes_python_int.py @@ -1,5 +1,5 @@ """ -Do not use python int with dynamic shape +Do not use python int with dynamic shapes ========================================= :func:`torch.export.export` uses :class:`torch.SymInt` to operate on shapes and From ff916788585378d1ed65143573d9ba3f40231399 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 8 Apr 2025 19:02:06 +0200 Subject: [PATCH 2/3] fix example --- _doc/recipes/plot_dynamic_shapes_max.py | 2 ++ _doc/recipes/plot_dynamic_shapes_nonzero.py | 4 ++-- _doc/recipes/plot_dynamic_shapes_python_int.py | 6 +++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/_doc/recipes/plot_dynamic_shapes_max.py b/_doc/recipes/plot_dynamic_shapes_max.py index 4d1a1909..2d22053c 100644 --- a/_doc/recipes/plot_dynamic_shapes_max.py +++ b/_doc/recipes/plot_dynamic_shapes_max.py @@ -126,11 +126,13 @@ def register(fct, fct_shape, namespace, fname): # %% # Now everything is registered. Let's rewrite the model. + class RewrittenModel(torch.nn.Module): def forward(self, x, y, fact): z = torch.ops.mylib.copy_max_dimensions(x, y) return z * fact + # %% # And check it works. diff --git a/_doc/recipes/plot_dynamic_shapes_nonzero.py b/_doc/recipes/plot_dynamic_shapes_nonzero.py index a8e7cca3..d91c4593 100644 --- a/_doc/recipes/plot_dynamic_shapes_nonzero.py +++ b/_doc/recipes/plot_dynamic_shapes_nonzero.py @@ -2,7 +2,7 @@ Half certain nonzero ==================== -:func:`torch.nonzero` returns the indices or the first zero found +:func:`torch.nonzero` returns the indices of the first zero found in a tensor. The output shape is unknown in the generic case but... If you have a 2D tensor with at least a nonzero value in every row, you can guess the dimension. But :func:`torch.export.export` @@ -49,7 +49,7 @@ def forward(self, x): # ++++++ DYN = torch.export.Dim.DYNAMIC -ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),)) +ep = torch.export.export(model, (x,), dynamic_shapes=(({0: DYN, 1: DYN}),)) print(ep) diff --git a/_doc/recipes/plot_dynamic_shapes_python_int.py b/_doc/recipes/plot_dynamic_shapes_python_int.py index 126bb3c2..054f9a86 100644 --- a/_doc/recipes/plot_dynamic_shapes_python_int.py +++ b/_doc/recipes/plot_dynamic_shapes_python_int.py @@ -36,7 +36,7 @@ def forward(self, x): # ++++++ DYN = torch.export.Dim.DYNAMIC -ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),)) +ep = torch.export.export(model, (x,), dynamic_shapes=(({0: DYN, 1: DYN}),)) print(ep) # %% @@ -65,7 +65,7 @@ def forward(self, x): # Export # ++++++ -ep = torch.export.export(rewritten_model, (x,), dynamic_shapes=((DYN, DYN),)) +ep = torch.export.export(rewritten_model, (x,), dynamic_shapes=({0: DYN, 1: DYN},)) print(ep) @@ -79,7 +79,7 @@ def forward(self, x): with bypass_export_some_errors(stop_if_static=True): - ep = torch.export.export(model, (x,), dynamic_shapes=((DYN, DYN),)) + ep = torch.export.export(model, (x,), dynamic_shapes=({0: DYN, 1: DYN},)) print(ep) # %% From 4fbcf9bcf3b6bcc5349a2568864a274bfba055ea Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 8 Apr 2025 19:20:03 +0200 Subject: [PATCH 3/3] fix batch size --- .github/workflows/ci.yml | 2 +- _doc/recipes/plot_dynamic_shapes_max.py | 2 +- _unittests/ut_torch_export_patches/test_onnx_export_errors.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ab17b1bb..460d6881 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,7 @@ jobs: matrix: os: [ubuntu-latest] python: ['3.11', '3.12'] - transformers: ['4.48.3', '4.50.3', 'main'] + transformers: ['4.48.3', '4.51.1', 'main'] torch: ['2.6', 'main'] steps: diff --git a/_doc/recipes/plot_dynamic_shapes_max.py b/_doc/recipes/plot_dynamic_shapes_max.py index 2d22053c..880200ae 100644 --- a/_doc/recipes/plot_dynamic_shapes_max.py +++ b/_doc/recipes/plot_dynamic_shapes_max.py @@ -78,7 +78,7 @@ def forward(self, x, y, fact): def make_undefined_dimension(i: int) -> torch.SymInt: """ Uses for a custom op when a new dimension must be introduced to bypass - some verficiation. The following function creates a dummy output + some verification. The following function creates a dummy output with a dimension based on the content. .. code-block:: python diff --git a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py index f220783c..435f328e 100644 --- a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py +++ b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py @@ -114,7 +114,7 @@ def forward(self, x: torch.Tensor, cache: MambaCache): DYN = torch.export.Dim.DYNAMIC with bypass_export_some_errors(): - cache = MambaCache(_config(), max_batch_size=1, device="cpu") + cache = MambaCache(_config(), max_batch_size=2, device="cpu") torch.export.export( Model(), (x, cache),