Skip to content

Commit 20d83db

Browse files
committed
fix
1 parent 5c081ba commit 20d83db

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def forward(self, x, ind1, ind2):
321321
got = ep.module()(*inputs)
322322
self.assertEqualArray(expected, got)
323323

324+
@requires_torch("2.11", "until we know more")
324325
def test_patched__broadcast_in_dim_meta(self):
325326
class Model(torch.nn.Module):
326327
def forward(self, x, ind1, ind2):
@@ -336,7 +337,7 @@ def forward(self, x, ind1, ind2):
336337

337338
with (
338339
torch.fx.experimental._config.patch(backed_size_oblivious=True),
339-
torch_export_patches(),
340+
torch_export_patches(patch_torch=True),
340341
):
341342
ep = torch.export.export(
342343
model,

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,12 @@ def _greater_than_reduce(acc, x):
986986
a.shape[original_idx] == shape[idx],
987987
lambda idx=idx, original_idx=original_idx: (
988988
f"non-broadcasting semantics require "
989-
f"{a.shape[original_idx]} == {shape[idx]}"
989+
f"{a.shape[original_idx]} == {shape[idx]}, "
990+
f"{guard_or_false(a.shape[idx] != 1)}, "
991+
f"guard_or_false(a.shape[idx] == 1)="
992+
f"{guard_or_false(a.shape[idx] == 1)}, "
993+
f"a.stride()={a.stride()}, idx={idx}, "
994+
f"original_idx={original_idx}"
990995
),
991996
)
992997
new_strides.append(a.stride()[original_idx])

0 commit comments

Comments
 (0)