Skip to content

Commit b206ad4

Browse files
authored
check a patch with and without, make patch_torch an int to select more patches (#266)
* check a patch with and without * doc * disable a test for torch 2.8
1 parent 4498981 commit b206ad4

File tree

4 files changed

+143
-8
lines changed

4 files changed

+143
-8
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Change Logs
44
0.7.16
55
++++++
66

7-
7+
* :pr:`266`: makes ``patch_torch`` an integer in ``torch_export_patches`` to enable more patches
88

99
0.7.15
1010
++++++

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,56 @@ def _batch1(t):
441441
got = ep.module()(**torch_deepcopy(inputs))
442442
self.assertEqualArrayAny(expected, got)
443443

444+
@requires_torch("2.9", "Eq(s3, Max(s10, s3)) is inconsistent!")
445+
def test_patch_tiny_llm_dim_meta_level_1(self):
446+
class Model(torch.nn.Module):
447+
def forward(self, x, ind1, ind2):
448+
return x[ind1, ind2]
449+
450+
inputs = (
451+
torch.randn(2, 1024),
452+
torch.tensor([[0, 1]], dtype=torch.int64).T,
453+
torch.arange(1024, dtype=torch.int64),
454+
)
455+
model = Model()
456+
457+
with (
458+
torch_export_patches(patch_torch=1),
459+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
460+
):
461+
self.assertRaise(
462+
lambda: torch.export.export(
463+
model,
464+
inputs,
465+
dynamic_shapes=use_dyn_not_str(({0: "A", 1: "B"}, {1: "D"}, {0: "E"})),
466+
),
467+
RuntimeError,
468+
)
469+
470+
def test_patch_tiny_llm_dim_meta_level_2(self):
471+
class Model(torch.nn.Module):
472+
def forward(self, x, ind1, ind2):
473+
return x[ind1, ind2]
474+
475+
inputs = (
476+
torch.randn(2, 1024),
477+
torch.tensor([[0, 1]], dtype=torch.int64).T,
478+
torch.arange(1024, dtype=torch.int64),
479+
)
480+
model = Model()
481+
expected = model(*inputs)
482+
483+
with (
484+
torch_export_patches(patch_torch=2),
485+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
486+
):
487+
ep = torch.export.export(
488+
model,
489+
inputs,
490+
dynamic_shapes=use_dyn_not_str(({0: "A", 1: "B"}, {1: "D"}, {0: "E"})),
491+
)
492+
self.assertEqualArray(expected, ep.module()(*inputs))
493+
444494

445495
if __name__ == "__main__":
446496
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import importlib
33
import contextlib
44
import re
5-
from typing import Any, Callable, Dict, List, Optional, Tuple
5+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
66
from .onnx_export_serialization import (
77
register_cache_serialization,
88
unregister_cache_serialization,
@@ -160,7 +160,7 @@ def register_additional_serialization_functions(
160160
@contextlib.contextmanager
161161
def torch_export_patches(
162162
patch_sympy: bool = True,
163-
patch_torch: bool = True,
163+
patch_torch: Union[bool, int] = True,
164164
patch_transformers: bool = False,
165165
patch_diffusers: bool = False,
166166
catch_constraints: bool = True,
@@ -349,6 +349,7 @@ def torch_export_patches(
349349
_catch_produce_guards_and_solve_constraints,
350350
patch__check_input_constraints_for_graph,
351351
patched__broadcast_in_dim_meta,
352+
patched__broadcast_in_dim_meta_level_2,
352353
patched__maybe_broadcast,
353354
patched_ShapeEnv,
354355
)
@@ -390,8 +391,13 @@ def torch_export_patches(
390391
# torch._prims._broadcast_in_dim_meta
391392
f_broadcast_in_dim = torch._prims.broadcast_in_dim
392393
f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
393-
torch._prims._broadcast_in_dim_meta = patched__broadcast_in_dim_meta
394-
torch._prims.broadcast_in_dim = patched__broadcast_in_dim_meta
394+
_patched_dim_f = (
395+
patched__broadcast_in_dim_meta_level_2
396+
if patch_torch == 2
397+
else patched__broadcast_in_dim_meta
398+
)
399+
torch._prims._broadcast_in_dim_meta = _patched_dim_f
400+
torch._prims.broadcast_in_dim = _patched_dim_f
395401

396402
# torch._refs._maybe_broadcast
397403
f__maybe_broadcast = torch._refs._maybe_broadcast

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -982,16 +982,21 @@ def _greater_than_reduce(acc, x):
982982
elif guard_or_false(a.shape[original_idx] != 1):
983983
new_strides.append(a.stride()[original_idx])
984984
else:
985+
# This checks generates the following issue:
986+
# non-broadcasting semantics require s3 == Max(s10, s3), False,
987+
# guard_or_false(a.shape[idx]==1)=False, a.stride()=(1, 2),
988+
# idx=1, a.shape=torch.Size([2, s3]), shape=[2, Max(s10, s3)],
989+
# original_idx=1
985990
torch._check(
986991
a.shape[original_idx] == shape[idx],
987992
lambda idx=idx, original_idx=original_idx: (
988993
f"non-broadcasting semantics require "
989994
f"{a.shape[original_idx]} == {shape[idx]}, "
990995
f"{guard_or_false(a.shape[idx] != 1)}, "
991-
f"guard_or_false(a.shape[idx] == 1)="
996+
f"guard_or_false(a.shape[idx]==1)="
992997
f"{guard_or_false(a.shape[idx] == 1)}, "
993-
f"a.stride()={a.stride()}, idx={idx}, "
994-
f"original_idx={original_idx}"
998+
f"a.stride()={a.stride()}, idx={idx}, a.shape={a.shape}, "
999+
f"shape={shape}, original_idx={original_idx}"
9951000
),
9961001
)
9971002
new_strides.append(a.stride()[original_idx])
@@ -1006,3 +1011,77 @@ def _greater_than_reduce(acc, x):
10061011
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
10071012

10081013
return a.as_strided(shape, new_strides, a.storage_offset())
1014+
1015+
1016+
def patched__broadcast_in_dim_meta_level_2(
1017+
a: torch._prims_common.TensorLikeType,
1018+
shape: torch._prims_common.ShapeType,
1019+
broadcast_dimensions: Sequence[int],
1020+
):
1021+
"""Patches ``torch._prims._broadcast_in_dim_meta``."""
1022+
from torch.fx.experimental.symbolic_shapes import (
1023+
guard_or_false,
1024+
guard_or_true,
1025+
sym_or,
1026+
)
1027+
1028+
# Type checks
1029+
assert isinstance(a, torch._prims_common.TensorLike)
1030+
assert isinstance(shape, Sequence)
1031+
assert isinstance(broadcast_dimensions, Sequence)
1032+
1033+
# every dimension must be accounted for
1034+
assert a.ndim == len(broadcast_dimensions)
1035+
1036+
# broadcast shape must have weakly more dimensions
1037+
assert len(shape) >= a.ndim
1038+
1039+
# broadcast_dimensions must be an ascending sequence
1040+
# (no relative reordering of dims) of integers and
1041+
# each dimension must be within the new shape
1042+
def _greater_than_reduce(acc, x):
1043+
assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
1044+
assert x > acc
1045+
assert x < len(shape)
1046+
1047+
return x
1048+
1049+
reduce(_greater_than_reduce, broadcast_dimensions, -1)
1050+
1051+
# shape must be broadcastable to
1052+
for idx, new_idx in enumerate(broadcast_dimensions):
1053+
torch._check(
1054+
sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
1055+
lambda idx=idx, new_idx=new_idx: (
1056+
f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
1057+
),
1058+
)
1059+
1060+
new_strides = []
1061+
original_idx = 0
1062+
for idx in range(len(shape)):
1063+
if idx in broadcast_dimensions:
1064+
# Assigns a stride of zero to dimensions
1065+
# which were actually broadcast
1066+
if guard_or_false(a.shape[original_idx] == 1):
1067+
if guard_or_false(a.shape[original_idx] == shape[idx]):
1068+
new_strides.append(a.stride()[original_idx])
1069+
else:
1070+
new_strides.append(0)
1071+
# PATCHED: disabled this check
1072+
elif guard_or_false(a.shape[original_idx] != 1):
1073+
new_strides.append(a.stride()[original_idx])
1074+
else:
1075+
# PATCHED: torch._check was removed
1076+
new_strides.append(a.stride()[original_idx])
1077+
original_idx = original_idx + 1
1078+
else:
1079+
if guard_or_true(shape[idx] != 1):
1080+
# consistent with previous use of guard_size_oblivious
1081+
new_strides.append(0)
1082+
elif original_idx == a.ndim:
1083+
new_strides.append(1)
1084+
else:
1085+
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
1086+
1087+
return a.as_strided(shape, new_strides, a.storage_offset())

0 commit comments

Comments
 (0)