Skip to content

Commit 38720ad

Browse files
committed
fix
1 parent 32f8457 commit 38720ad

File tree

3 files changed

+30
-12
lines changed

3 files changed

+30
-12
lines changed

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
requires_transformers,
99
has_torch,
1010
)
11+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
12+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
1113
from onnx_diagnostic.torch_export_patches import torch_export_patches
1214
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1315

@@ -341,6 +343,19 @@ def forward(self, x, ind1, ind2):
341343
)
342344
self.assertEqualArray(expected, ep.module()(*inputs), atol=1e-2)
343345

346+
@requires_torch("2.7.9999")
347+
@requires_transformers("4.49.9999")
348+
def test_export_tiny_llm_dim_meta(self):
349+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
350+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
351+
expected = model(**torch_deepcopy(inputs))
352+
with torch_export_patches(patch_transformers=True):
353+
ep = torch.export.export(
354+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
355+
)
356+
got = ep.module()(**inputs)
357+
self.assertEqualArrayAny(expected, got)
358+
344359

345360
if __name__ == "__main__":
346361
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,10 @@ def torch_export_patches(
386386
)
387387

388388
# torch._prims._broadcast_in_dim_meta
389+
f_broadcast_in_dim = torch._prims.broadcast_in_dim
389390
f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
390391
torch._prims._broadcast_in_dim_meta = patched__broadcast_in_dim_meta
392+
torch._prims.broadcast_in_dim = patched__broadcast_in_dim_meta
391393

392394
# torch._refs._maybe_broadcast
393395
f__maybe_broadcast = torch._refs._maybe_broadcast
@@ -595,6 +597,7 @@ def torch_export_patches(
595597
f___constrain_user_specified_dimhint_range
596598
)
597599
torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta
600+
torch._prims.broadcast_in_dim = f_broadcast_in_dim
598601
torch._refs._maybe_broadcast = f__maybe_broadcast
599602

600603
if verbose:
@@ -735,9 +738,7 @@ def torch_export_patches(
735738

736739

737740
def replacement_before_exporting(args: Any) -> Any:
738-
"""
739-
Does replacements on the given inputs if needed.
740-
"""
741+
"""Does replacements on the given inputs if needed."""
741742
if args is None:
742743
return None
743744
if isinstance(args, (int, float)):

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def should_expand(a: ShapeType, b: ShapeType) -> bool:
608608

609609
# u0==u1 assume the same, no broadcasting!
610610
# PATCHED: avoid errors
611-
return x != y
611+
return True # guard_or_true(x != y)
612612
# torch._check(
613613
# x == y,
614614
# lambda x=x, y=y: (
@@ -665,12 +665,13 @@ def patched__broadcast_in_dim_meta(
665665
# (no relative reordering of dims) of integers and
666666
# each dimension must be within the new shape
667667
def _greater_than_reduce(acc, x):
668-
assert isinstance(x, torch.export.Dim)
668+
assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
669669
assert x > acc
670670
assert x < len(shape)
671671

672672
return x
673673

674+
print("****", broadcast_dimensions)
674675
reduce(_greater_than_reduce, broadcast_dimensions, -1)
675676

676677
# shape must be broadcastable to
@@ -694,13 +695,14 @@ def _greater_than_reduce(acc, x):
694695
else:
695696
new_strides.append(0)
696697
else:
697-
torch._check(
698-
a.shape[original_idx] == shape[idx],
699-
lambda idx=idx, original_idx=original_idx: (
700-
f"non-broadcasting semantics require "
701-
f"{a.shape[original_idx]} == {shape[idx]}"
702-
),
703-
)
698+
# PATCHED: disabled this check
699+
# torch._check(
700+
# a.shape[original_idx] == shape[idx],
701+
# lambda idx=idx, original_idx=original_idx: (
702+
# f"non-broadcasting semantics require "
703+
# f"{a.shape[original_idx]} == {shape[idx]}"
704+
# ),
705+
# )
704706
new_strides.append(a.stride()[original_idx])
705707
original_idx = original_idx + 1
706708
else:

0 commit comments

Comments
 (0)