Skip to content

Commit b987ca3

Browse files
committed
Patches _maybe_broadcast to support a corner case
1 parent 667aaf6 commit b987ca3

File tree

6 files changed

+291
-9
lines changed

6 files changed

+291
-9
lines changed

_unittests/ut_torch_export_patches/test_eval.py

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest
2-
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
2+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, long_test
33
from onnx_diagnostic.torch_export_patches.eval import discover, evaluation
44

55

@@ -9,14 +9,20 @@ def test_discover(self):
99
res = discover()
1010
self.assertNotEmpty(res)
1111
for mod in res.values():
12-
if mod.__name__ == "ControlFlowCondIdentity_153832":
13-
continue
1412
with self.subTest(name=mod.__name__):
13+
if mod.__name__ == "ControlFlowCondIdentity_153832":
14+
raise unittest.SkipTest(
15+
"ControlFlowCondIdentity_153832 needs missing clone."
16+
)
1517
m = mod()
1618
if isinstance(m._inputs, tuple):
1719
m(*m._inputs)
1820
else:
19-
m(*m._inputs[0])
21+
for v in m._inputs:
22+
m(*v)
23+
if hasattr(m, "_valid"):
24+
for v in m._valid:
25+
m(*v)
2026

2127
def test_eval(self):
2228
d = list(discover().items())[0] # noqa: RUF015
@@ -102,6 +108,100 @@ def test_run_exporter_dimension1(self):
102108
dynamic=True,
103109
)
104110

111+
@long_test()
112+
def test_documentation(self):
113+
import inspect
114+
import textwrap
115+
import pandas
116+
from onnx_diagnostic.helpers import string_type
117+
from onnx_diagnostic.torch_export_patches.eval import discover, run_exporter
118+
from onnx_diagnostic.ext_test_case import unit_test_going
119+
120+
cases = discover()
121+
print()
122+
print(":ref:`Summary <led-summary-exported-program>`")
123+
print()
124+
sorted_cases = sorted(cases.items())
125+
if unit_test_going():
126+
sorted_cases = sorted_cases[:3]
127+
for name, _cls_model in sorted_cases:
128+
print(f"* :ref:`{name} <led-model-case-export-{name}>`")
129+
print()
130+
print()
131+
132+
obs = []
133+
for name, cls_model in sorted(cases.items()):
134+
print()
135+
print(f".. _led-model-case-export-{name}:")
136+
print()
137+
print(name)
138+
print("=" * len(name))
139+
print()
140+
print("forward")
141+
print("+++++++")
142+
print()
143+
print(".. code-block:: python")
144+
print()
145+
src = inspect.getsource(cls_model.forward)
146+
if src:
147+
print(textwrap.indent(textwrap.dedent(src), " "))
148+
else:
149+
print(" # code is missing")
150+
print()
151+
print()
152+
for exporter in (
153+
"export-strict",
154+
"export-nostrict",
155+
"export-nostrict-oblivious",
156+
"export-nostrict-decall",
157+
"export-tracing",
158+
):
159+
expname = exporter.replace("export-", "")
160+
print()
161+
print(expname)
162+
print("+" * len(expname))
163+
print()
164+
res = run_exporter(exporter, cls_model, True, quiet=True)
165+
case_ref = f":ref:`{name} <led-model-case-export-{name}>`"
166+
expo = exporter.split("-", maxsplit=1)[-1]
167+
if "inputs" in res:
168+
print(f"* **inputs:** ``{string_type(res['inputs'], with_shape=True)}``")
169+
if "dynamic_shapes" in res:
170+
print(f"* **shapes:** ``{string_type(res['dynamic_shapes'])}``")
171+
print()
172+
print()
173+
if "exported" in res:
174+
print(".. code-block:: text")
175+
print()
176+
print(textwrap.indent(str(res["exported"].graph), " "))
177+
print()
178+
print()
179+
obs.append(dict(case=case_ref, error="", exporter=expo))
180+
else:
181+
print("**FAILED**")
182+
print()
183+
print(".. code-block:: text")
184+
print()
185+
err = str(res["error"])
186+
if err:
187+
print(textwrap.indent(err, " "))
188+
else:
189+
print(" # no error found for the failure")
190+
print()
191+
print()
192+
obs.append(dict(case=case_ref, error="FAIL", exporter=expo))
193+
194+
print()
195+
print(".. _led-summary-exported-program:")
196+
print()
197+
print("Summary")
198+
print("+++++++")
199+
print()
200+
df = pandas.DataFrame(obs)
201+
piv = df.pivot(index="case", columns="exporter", values="error")
202+
print(piv.to_markdown(tablefmt="rst"))
203+
print()
204+
105205

106206
if __name__ == "__main__":
107207
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,30 @@ def forward(self, x, ind1, ind2):
317317
got = ep.module()(*inputs)
318318
self.assertEqualArray(expected, got)
319319

320+
def test_patched__broadcast_in_dim_meta(self):
321+
class Model(torch.nn.Module):
322+
def forward(self, x, ind1, ind2):
323+
return x[ind1, ind2]
324+
325+
inputs = (
326+
torch.randn(2, 1024),
327+
torch.tensor([[0, 1]], dtype=torch.int64).T,
328+
torch.arange(1024, dtype=torch.int64),
329+
)
330+
model = Model()
331+
expected = model(*inputs)
332+
333+
with (
334+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
335+
torch_export_patches(),
336+
):
337+
ep = torch.export.export(
338+
model,
339+
inputs,
340+
dynamic_shapes=use_dyn_not_str(({0: "A", 1: "B"}, {0: "C", 1: "D"}, {0: "E"})),
341+
)
342+
self.assertEqualArray(expected, ep.module()(*inputs), atol=1e-2)
343+
320344

321345
if __name__ == "__main__":
322346
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/eval/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,13 @@ def run_exporter(
676676

677677
if dynamic and len(inputs) > 1:
678678
for index, i in enumerate(inputs):
679-
expected = model(*_clone(i))
679+
if quiet:
680+
try:
681+
expected = model(*_clone(i))
682+
except Exception as e:
683+
return dict(error=str(e), success=0, error_step=f"run0.{index}")
684+
else:
685+
expected = model(*_clone(i))
680686
try:
681687
got = mod(*i)
682688
except Exception as e:

onnx_diagnostic/torch_export_patches/eval/model_cases.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,12 +353,9 @@ def else_branch(input_ids, image_features, vocab_size):
353353

354354

355355
class ControlFlowCondIdentity_153832(torch.nn.Module):
356-
"""
357-
`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_
358-
"""
356+
"""`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_"""
359357

360358
def forward(self, x, y):
361-
362359
def branch_cond_then_1(x):
363360
x = torch.abs(x) + 1
364361
return x

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,8 @@ def torch_export_patches(
347347
patched__constrain_user_specified_dimhint_range,
348348
_catch_produce_guards_and_solve_constraints,
349349
patch__check_input_constraints_for_graph,
350+
patched__broadcast_in_dim_meta,
351+
patched__maybe_broadcast,
350352
)
351353

352354
if verbose:
@@ -383,6 +385,14 @@ def torch_export_patches(
383385
patched__constrain_user_specified_dimhint_range
384386
)
385387

388+
# torch._prims._broadcast_in_dim_meta
389+
f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
390+
torch._prims._broadcast_in_dim_meta = patched__broadcast_in_dim_meta
391+
392+
# torch._refs._maybe_broadcast
393+
f__maybe_broadcast = torch._refs._maybe_broadcast
394+
torch._refs._maybe_broadcast = patched__maybe_broadcast
395+
386396
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
387397
if patch_torch and catch_constraints:
388398
if verbose:
@@ -584,6 +594,8 @@ def torch_export_patches(
584594
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
585595
f___constrain_user_specified_dimhint_range
586596
)
597+
torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta
598+
torch._refs._maybe_broadcast = f__maybe_broadcast
587599

588600
if verbose:
589601
print("[torch_export_patches] restored pytorch functions")

onnx_diagnostic/torch_export_patches/patches/patch_torch.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import inspect
22
import os
33
import traceback
4+
from functools import reduce
45
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
56
import torch
67
from torch._subclasses.fake_tensor import FakeTensorMode
@@ -570,3 +571,145 @@ def patched__constrain_user_specified_dimhint_range(
570571
return msg
571572

572573
return None
574+
575+
576+
def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True):
577+
"""Patches ``torch._refs._maybe_broadcast``."""
578+
from torch._prims_common import ShapeType, TensorLike, Number
579+
580+
# Computes common shape
581+
common_shape = patched__broadcast_shapes(
582+
*(t.shape if isinstance(t, TensorLike) else None for t in args)
583+
)
584+
585+
def should_expand(a: ShapeType, b: ShapeType) -> bool:
586+
from torch.fx.experimental.symbolic_shapes import (
587+
guard_or_false,
588+
sym_and,
589+
sym_or,
590+
)
591+
592+
if len(a) != len(b):
593+
return True
594+
595+
for x, y in zip(a, b):
596+
if guard_or_false(x != y):
597+
# We know they are not the same.
598+
return True
599+
600+
# They are the same or we do not know if they are the same or not.
601+
# 1==1 no-broadcast
602+
# u0==1 and 1==u0 cases. We broadcast!
603+
if guard_or_false(sym_and(x == 1, y == 1)):
604+
pass
605+
elif guard_or_false(sym_or(x == 1, y == 1)):
606+
# assume broadcasting.
607+
return True
608+
609+
# u0==u1 assume the same, no broadcasting!
610+
# PATCHED: avoid errors
611+
return x != y
612+
# torch._check(
613+
# x == y,
614+
# lambda x=x, y=y: (
615+
# f"sizes assumed to be the same due to unbacked "
616+
# f"broadcasting semantics x={x!r}, y={y!r}"
617+
# ),
618+
# )
619+
620+
return False
621+
622+
def __maybe_broadcast(x, shape):
623+
if x is None:
624+
return None
625+
elif isinstance(x, Number):
626+
return x
627+
elif isinstance(x, TensorLike):
628+
if preserve_cpu_scalar_tensors and torch._prims_common.is_cpu_scalar_tensor(x):
629+
return x
630+
631+
if should_expand(x.shape, common_shape):
632+
return x.expand(common_shape)
633+
634+
return x
635+
else:
636+
raise RuntimeError(f"Unexpected type when broadcasting: {str(type(x))}!")
637+
638+
return tuple(__maybe_broadcast(x, common_shape) for x in args)
639+
640+
641+
def patched__broadcast_in_dim_meta(
642+
a: torch._prims_common.TensorLikeType,
643+
shape: torch._prims_common.ShapeType,
644+
broadcast_dimensions: Sequence[int],
645+
):
646+
"""Patches ``torch._prims._broadcast_in_dim_meta``."""
647+
from torch.fx.experimental.symbolic_shapes import (
648+
guard_or_false,
649+
guard_or_true,
650+
sym_or,
651+
)
652+
653+
# Type checks
654+
assert isinstance(a, torch._prims_common.TensorLike)
655+
assert isinstance(shape, Sequence)
656+
assert isinstance(broadcast_dimensions, Sequence)
657+
658+
# every dimension must be accounted for
659+
assert a.ndim == len(broadcast_dimensions)
660+
661+
# broadcast shape must have weakly more dimensions
662+
assert len(shape) >= a.ndim
663+
664+
# broadcast_dimensions must be an ascending sequence
665+
# (no relative reordering of dims) of integers and
666+
# each dimension must be within the new shape
667+
def _greater_than_reduce(acc, x):
668+
assert isinstance(x, torch.export.Dim)
669+
assert x > acc
670+
assert x < len(shape)
671+
672+
return x
673+
674+
reduce(_greater_than_reduce, broadcast_dimensions, -1)
675+
676+
# shape must be broadcastable to
677+
for idx, new_idx in enumerate(broadcast_dimensions):
678+
torch._check(
679+
sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
680+
lambda idx=idx, new_idx=new_idx: (
681+
f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
682+
),
683+
)
684+
685+
new_strides = []
686+
original_idx = 0
687+
for idx in range(len(shape)):
688+
if idx in broadcast_dimensions:
689+
# Assigns a stride of zero to dimensions
690+
# which were actually broadcast
691+
if guard_or_false(a.shape[original_idx] == 1):
692+
if guard_or_false(a.shape[original_idx] == shape[idx]):
693+
new_strides.append(a.stride()[original_idx])
694+
else:
695+
new_strides.append(0)
696+
else:
697+
torch._check(
698+
a.shape[original_idx] == shape[idx],
699+
lambda idx=idx, original_idx=original_idx: (
700+
f"non-broadcasting semantics require "
701+
f"{a.shape[original_idx]} == {shape[idx]}"
702+
),
703+
)
704+
new_strides.append(a.stride()[original_idx])
705+
original_idx = original_idx + 1
706+
else:
707+
if guard_or_true(shape[idx] != 1):
708+
# consistent with previous use of guard_size_oblivious
709+
new_strides.append(0)
710+
elif original_idx == a.ndim:
711+
new_strides.append(1)
712+
else:
713+
new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
714+
715+
return a.as_strided(shape, new_strides, a.storage_offset())

0 commit comments

Comments
 (0)