Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Change Logs
0.7.16
++++++


* :pr:`266`: makes ``patch_torch`` an integer in ``torch_export_patches`` to enable more patches

0.7.15
++++++
Expand Down
50 changes: 50 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,56 @@ def _batch1(t):
got = ep.module()(**torch_deepcopy(inputs))
self.assertEqualArrayAny(expected, got)

@requires_torch("2.9", "Eq(s3, Max(s10, s3)) is inconsistent!")
def test_patch_tiny_llm_dim_meta_level_1(self):
class Model(torch.nn.Module):
def forward(self, x, ind1, ind2):
return x[ind1, ind2]

inputs = (
torch.randn(2, 1024),
torch.tensor([[0, 1]], dtype=torch.int64).T,
torch.arange(1024, dtype=torch.int64),
)
model = Model()

with (
torch_export_patches(patch_torch=1),
torch.fx.experimental._config.patch(backed_size_oblivious=True),
):
self.assertRaise(
lambda: torch.export.export(
model,
inputs,
dynamic_shapes=use_dyn_not_str(({0: "A", 1: "B"}, {1: "D"}, {0: "E"})),
),
RuntimeError,
)

def test_patch_tiny_llm_dim_meta_level_2(self):
class Model(torch.nn.Module):
def forward(self, x, ind1, ind2):
return x[ind1, ind2]

inputs = (
torch.randn(2, 1024),
torch.tensor([[0, 1]], dtype=torch.int64).T,
torch.arange(1024, dtype=torch.int64),
)
model = Model()
expected = model(*inputs)

with (
torch_export_patches(patch_torch=2),
torch.fx.experimental._config.patch(backed_size_oblivious=True),
):
ep = torch.export.export(
model,
inputs,
dynamic_shapes=use_dyn_not_str(({0: "A", 1: "B"}, {1: "D"}, {0: "E"})),
)
self.assertEqualArray(expected, ep.module()(*inputs))


if __name__ == "__main__":
unittest.main(verbosity=2)
14 changes: 10 additions & 4 deletions onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import importlib
import contextlib
import re
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from .onnx_export_serialization import (
register_cache_serialization,
unregister_cache_serialization,
Expand Down Expand Up @@ -160,7 +160,7 @@ def register_additional_serialization_functions(
@contextlib.contextmanager
def torch_export_patches(
patch_sympy: bool = True,
patch_torch: bool = True,
patch_torch: Union[bool, int] = True,
patch_transformers: bool = False,
patch_diffusers: bool = False,
catch_constraints: bool = True,
Expand Down Expand Up @@ -349,6 +349,7 @@ def torch_export_patches(
_catch_produce_guards_and_solve_constraints,
patch__check_input_constraints_for_graph,
patched__broadcast_in_dim_meta,
patched__broadcast_in_dim_meta_level_2,
patched__maybe_broadcast,
patched_ShapeEnv,
)
Expand Down Expand Up @@ -390,8 +391,13 @@ def torch_export_patches(
# torch._prims._broadcast_in_dim_meta
f_broadcast_in_dim = torch._prims.broadcast_in_dim
f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
torch._prims._broadcast_in_dim_meta = patched__broadcast_in_dim_meta
torch._prims.broadcast_in_dim = patched__broadcast_in_dim_meta
_patched_dim_f = (
patched__broadcast_in_dim_meta_level_2
if patch_torch == 2
else patched__broadcast_in_dim_meta
)
torch._prims._broadcast_in_dim_meta = _patched_dim_f
torch._prims.broadcast_in_dim = _patched_dim_f

# torch._refs._maybe_broadcast
f__maybe_broadcast = torch._refs._maybe_broadcast
Expand Down
85 changes: 82 additions & 3 deletions onnx_diagnostic/torch_export_patches/patches/patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,16 +982,21 @@ def _greater_than_reduce(acc, x):
elif guard_or_false(a.shape[original_idx] != 1):
new_strides.append(a.stride()[original_idx])
else:
# This checks generates the following issue:
# non-broadcasting semantics require s3 == Max(s10, s3), False,
# guard_or_false(a.shape[idx]==1)=False, a.stride()=(1, 2),
# idx=1, a.shape=torch.Size([2, s3]), shape=[2, Max(s10, s3)],
# original_idx=1
torch._check(
a.shape[original_idx] == shape[idx],
lambda idx=idx, original_idx=original_idx: (
f"non-broadcasting semantics require "
f"{a.shape[original_idx]} == {shape[idx]}, "
f"{guard_or_false(a.shape[idx] != 1)}, "
f"guard_or_false(a.shape[idx] == 1)="
f"guard_or_false(a.shape[idx]==1)="
f"{guard_or_false(a.shape[idx] == 1)}, "
f"a.stride()={a.stride()}, idx={idx}, "
f"original_idx={original_idx}"
f"a.stride()={a.stride()}, idx={idx}, a.shape={a.shape}, "
f"shape={shape}, original_idx={original_idx}"
),
)
new_strides.append(a.stride()[original_idx])
Expand All @@ -1006,3 +1011,77 @@ def _greater_than_reduce(acc, x):
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])

return a.as_strided(shape, new_strides, a.storage_offset())


def patched__broadcast_in_dim_meta_level_2(
a: torch._prims_common.TensorLikeType,
shape: torch._prims_common.ShapeType,
broadcast_dimensions: Sequence[int],
):
"""Patches ``torch._prims._broadcast_in_dim_meta``."""
from torch.fx.experimental.symbolic_shapes import (
guard_or_false,
guard_or_true,
sym_or,
)

# Type checks
assert isinstance(a, torch._prims_common.TensorLike)
assert isinstance(shape, Sequence)
assert isinstance(broadcast_dimensions, Sequence)

# every dimension must be accounted for
assert a.ndim == len(broadcast_dimensions)

# broadcast shape must have weakly more dimensions
assert len(shape) >= a.ndim

# broadcast_dimensions must be an ascending sequence
# (no relative reordering of dims) of integers and
# each dimension must be within the new shape
def _greater_than_reduce(acc, x):
assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
assert x > acc
assert x < len(shape)

return x

reduce(_greater_than_reduce, broadcast_dimensions, -1)

# shape must be broadcastable to
for idx, new_idx in enumerate(broadcast_dimensions):
torch._check(
sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
lambda idx=idx, new_idx=new_idx: (
f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
),
)

new_strides = []
original_idx = 0
for idx in range(len(shape)):
if idx in broadcast_dimensions:
# Assigns a stride of zero to dimensions
# which were actually broadcast
if guard_or_false(a.shape[original_idx] == 1):
if guard_or_false(a.shape[original_idx] == shape[idx]):
new_strides.append(a.stride()[original_idx])
else:
new_strides.append(0)
# PATCHED: disabled this check
elif guard_or_false(a.shape[original_idx] != 1):
new_strides.append(a.stride()[original_idx])
else:
# PATCHED: torch._check was removed
new_strides.append(a.stride()[original_idx])
original_idx = original_idx + 1
else:
if guard_or_true(shape[idx] != 1):
# consistent with previous use of guard_size_oblivious
new_strides.append(0)
elif original_idx == a.ndim:
new_strides.append(1)
else:
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])

return a.as_strided(shape, new_strides, a.storage_offset())
Loading