diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index e1d5e4bd..bd9ed6ee 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,7 +4,7 @@ Change Logs 0.7.16 ++++++ - +* :pr:`266`: makes ``patch_torch`` an integer in ``torch_export_patches`` to enable more patches 0.7.15 ++++++ diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index a6cb206c..bb8d9c60 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -441,6 +441,56 @@ def _batch1(t): got = ep.module()(**torch_deepcopy(inputs)) self.assertEqualArrayAny(expected, got) + @requires_torch("2.9", "Eq(s3, Max(s10, s3)) is inconsistent!") + def test_patch_tiny_llm_dim_meta_level_1(self): + class Model(torch.nn.Module): + def forward(self, x, ind1, ind2): + return x[ind1, ind2] + + inputs = ( + torch.randn(2, 1024), + torch.tensor([[0, 1]], dtype=torch.int64).T, + torch.arange(1024, dtype=torch.int64), + ) + model = Model() + + with ( + torch_export_patches(patch_torch=1), + torch.fx.experimental._config.patch(backed_size_oblivious=True), + ): + self.assertRaise( + lambda: torch.export.export( + model, + inputs, + dynamic_shapes=use_dyn_not_str(({0: "A", 1: "B"}, {1: "D"}, {0: "E"})), + ), + RuntimeError, + ) + + def test_patch_tiny_llm_dim_meta_level_2(self): + class Model(torch.nn.Module): + def forward(self, x, ind1, ind2): + return x[ind1, ind2] + + inputs = ( + torch.randn(2, 1024), + torch.tensor([[0, 1]], dtype=torch.int64).T, + torch.arange(1024, dtype=torch.int64), + ) + model = Model() + expected = model(*inputs) + + with ( + torch_export_patches(patch_torch=2), + torch.fx.experimental._config.patch(backed_size_oblivious=True), + ): + ep = torch.export.export( + model, + inputs, + dynamic_shapes=use_dyn_not_str(({0: "A", 1: "B"}, {1: "D"}, {0: "E"})), + ) + self.assertEqualArray(expected, ep.module()(*inputs)) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 42b9a394..6eb2b642 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -2,7 +2,7 @@ import importlib import contextlib import re -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from .onnx_export_serialization import ( register_cache_serialization, unregister_cache_serialization, @@ -160,7 +160,7 @@ def register_additional_serialization_functions( @contextlib.contextmanager def torch_export_patches( patch_sympy: bool = True, - patch_torch: bool = True, + patch_torch: Union[bool, int] = True, patch_transformers: bool = False, patch_diffusers: bool = False, catch_constraints: bool = True, @@ -349,6 +349,7 @@ def torch_export_patches( _catch_produce_guards_and_solve_constraints, patch__check_input_constraints_for_graph, patched__broadcast_in_dim_meta, + patched__broadcast_in_dim_meta_level_2, patched__maybe_broadcast, patched_ShapeEnv, ) @@ -390,8 +391,13 @@ def torch_export_patches( # torch._prims._broadcast_in_dim_meta f_broadcast_in_dim = torch._prims.broadcast_in_dim f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta - torch._prims._broadcast_in_dim_meta = patched__broadcast_in_dim_meta - torch._prims.broadcast_in_dim = patched__broadcast_in_dim_meta + _patched_dim_f = ( + patched__broadcast_in_dim_meta_level_2 + if patch_torch == 2 + else patched__broadcast_in_dim_meta + ) + torch._prims._broadcast_in_dim_meta = _patched_dim_f + torch._prims.broadcast_in_dim = _patched_dim_f # torch._refs._maybe_broadcast f__maybe_broadcast = torch._refs._maybe_broadcast diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index e4753cd6..8ce3d72f 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -982,16 +982,21 @@ def _greater_than_reduce(acc, x): elif guard_or_false(a.shape[original_idx] != 1): new_strides.append(a.stride()[original_idx]) else: + # This checks generates the following issue: + # non-broadcasting semantics require s3 == Max(s10, s3), False, + # guard_or_false(a.shape[idx]==1)=False, a.stride()=(1, 2), + # idx=1, a.shape=torch.Size([2, s3]), shape=[2, Max(s10, s3)], + # original_idx=1 torch._check( a.shape[original_idx] == shape[idx], lambda idx=idx, original_idx=original_idx: ( f"non-broadcasting semantics require " f"{a.shape[original_idx]} == {shape[idx]}, " f"{guard_or_false(a.shape[idx] != 1)}, " - f"guard_or_false(a.shape[idx] == 1)=" + f"guard_or_false(a.shape[idx]==1)=" f"{guard_or_false(a.shape[idx] == 1)}, " - f"a.stride()={a.stride()}, idx={idx}, " - f"original_idx={original_idx}" + f"a.stride()={a.stride()}, idx={idx}, a.shape={a.shape}, " + f"shape={shape}, original_idx={original_idx}" ), ) new_strides.append(a.stride()[original_idx]) @@ -1006,3 +1011,77 @@ def _greater_than_reduce(acc, x): new_strides.append(a.stride()[original_idx] * a.size()[original_idx]) return a.as_strided(shape, new_strides, a.storage_offset()) + + +def patched__broadcast_in_dim_meta_level_2( + a: torch._prims_common.TensorLikeType, + shape: torch._prims_common.ShapeType, + broadcast_dimensions: Sequence[int], +): + """Patches ``torch._prims._broadcast_in_dim_meta``.""" + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + sym_or, + ) + + # Type checks + assert isinstance(a, torch._prims_common.TensorLike) + assert isinstance(shape, Sequence) + assert isinstance(broadcast_dimensions, Sequence) + + # every dimension must be accounted for + assert a.ndim == len(broadcast_dimensions) + + # broadcast shape must have weakly more dimensions + assert len(shape) >= a.ndim + + # broadcast_dimensions must be an ascending sequence + # (no relative reordering of dims) of integers and + # each dimension must be within the new shape + def _greater_than_reduce(acc, x): + assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x" + assert x > acc + assert x < len(shape) + + return x + + reduce(_greater_than_reduce, broadcast_dimensions, -1) + + # shape must be broadcastable to + for idx, new_idx in enumerate(broadcast_dimensions): + torch._check( + sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]), + lambda idx=idx, new_idx=new_idx: ( + f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}" + ), + ) + + new_strides = [] + original_idx = 0 + for idx in range(len(shape)): + if idx in broadcast_dimensions: + # Assigns a stride of zero to dimensions + # which were actually broadcast + if guard_or_false(a.shape[original_idx] == 1): + if guard_or_false(a.shape[original_idx] == shape[idx]): + new_strides.append(a.stride()[original_idx]) + else: + new_strides.append(0) + # PATCHED: disabled this check + elif guard_or_false(a.shape[original_idx] != 1): + new_strides.append(a.stride()[original_idx]) + else: + # PATCHED: torch._check was removed + new_strides.append(a.stride()[original_idx]) + original_idx = original_idx + 1 + else: + if guard_or_true(shape[idx] != 1): + # consistent with previous use of guard_size_oblivious + new_strides.append(0) + elif original_idx == a.ndim: + new_strides.append(1) + else: + new_strides.append(a.stride()[original_idx] * a.size()[original_idx]) + + return a.as_strided(shape, new_strides, a.storage_offset())