Skip to content

Commit 37f70e0

Browse files
authored
Patches _maybe_broadcast to support a corner case (#249)
* Patches _maybe_broadcast to support a corner case * changes * fix * add 4.57.0 to ci * fix test * spell * disable one test * fix
1 parent 667aaf6 commit 37f70e0

File tree

9 files changed

+403
-14
lines changed

9 files changed

+403
-14
lines changed

.github/workflows/ci.yml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
matrix:
1818
os: [ubuntu-latest]
1919
python: ['3.10', '3.11', '3.12', '3.13']
20-
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.55.4', '4.56.2', 'main']
20+
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.55.4', '4.56.2', '4.57', 'main']
2121
torch: ['2.8', 'main']
2222
exclude:
2323
- python: '3.10'
@@ -30,6 +30,8 @@ jobs:
3030
transformers: '4.55.4'
3131
- python: '3.10'
3232
transformers: '4.56.2'
33+
- python: '3.10'
34+
transformers: '4.57.0'
3335
- python: '3.11'
3436
torch: 'main'
3537
- python: '3.11'
@@ -38,6 +40,8 @@ jobs:
3840
transformers: '4.55.4'
3941
- python: '3.11'
4042
transformers: '4.56.2'
43+
- python: '3.11'
44+
transformers: '4.57.0'
4145
- python: '3.13'
4246
torch: '2.8'
4347
- python: '3.13'

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.14
55
++++++
66

7+
* :pr:`249`: patches _maybe_broadcast to support a corner case
8+
79
0.7.13
810
++++++
911

_unittests/ut_tasks/test_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def test_falcon_mamba_dev(self):
270270
model(**inputs)
271271
model(**data["inputs2"])
272272
self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)])
273-
if not has_transformers("4.57"):
273+
if not has_transformers("4.57.99"):
274274
raise unittest.SkipTest("The model has control flow.")
275275
with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1):
276276
torch.export.export(

_unittests/ut_torch_export_patches/test_eval.py

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest
2-
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
2+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, long_test
33
from onnx_diagnostic.torch_export_patches.eval import discover, evaluation
44

55

@@ -9,14 +9,20 @@ def test_discover(self):
99
res = discover()
1010
self.assertNotEmpty(res)
1111
for mod in res.values():
12-
if mod.__name__ == "ControlFlowCondIdentity_153832":
13-
continue
1412
with self.subTest(name=mod.__name__):
13+
if mod.__name__ == "ControlFlowCondIdentity_153832":
14+
raise unittest.SkipTest(
15+
"ControlFlowCondIdentity_153832 needs missing clone."
16+
)
1517
m = mod()
1618
if isinstance(m._inputs, tuple):
1719
m(*m._inputs)
1820
else:
19-
m(*m._inputs[0])
21+
for v in m._inputs:
22+
m(*v)
23+
if hasattr(m, "_valid"):
24+
for v in m._valid:
25+
m(*v)
2026

2127
def test_eval(self):
2228
d = list(discover().items())[0] # noqa: RUF015
@@ -102,6 +108,100 @@ def test_run_exporter_dimension1(self):
102108
dynamic=True,
103109
)
104110

111+
@long_test()
112+
def test_documentation(self):
113+
import inspect
114+
import textwrap
115+
import pandas
116+
from onnx_diagnostic.helpers import string_type
117+
from onnx_diagnostic.torch_export_patches.eval import discover, run_exporter
118+
from onnx_diagnostic.ext_test_case import unit_test_going
119+
120+
cases = discover()
121+
print()
122+
print(":ref:`Summary <led-summary-exported-program>`")
123+
print()
124+
sorted_cases = sorted(cases.items())
125+
if unit_test_going():
126+
sorted_cases = sorted_cases[:3]
127+
for name, _cls_model in sorted_cases:
128+
print(f"* :ref:`{name} <led-model-case-export-{name}>`")
129+
print()
130+
print()
131+
132+
obs = []
133+
for name, cls_model in sorted(cases.items()):
134+
print()
135+
print(f".. _led-model-case-export-{name}:")
136+
print()
137+
print(name)
138+
print("=" * len(name))
139+
print()
140+
print("forward")
141+
print("+++++++")
142+
print()
143+
print(".. code-block:: python")
144+
print()
145+
src = inspect.getsource(cls_model.forward)
146+
if src:
147+
print(textwrap.indent(textwrap.dedent(src), " "))
148+
else:
149+
print(" # code is missing")
150+
print()
151+
print()
152+
for exporter in (
153+
"export-strict",
154+
"export-nostrict",
155+
"export-nostrict-oblivious",
156+
"export-nostrict-decall",
157+
"export-tracing",
158+
):
159+
expname = exporter.replace("export-", "")
160+
print()
161+
print(expname)
162+
print("+" * len(expname))
163+
print()
164+
res = run_exporter(exporter, cls_model, True, quiet=True)
165+
case_ref = f":ref:`{name} <led-model-case-export-{name}>`"
166+
expo = exporter.split("-", maxsplit=1)[-1]
167+
if "inputs" in res:
168+
print(f"* **inputs:** ``{string_type(res['inputs'], with_shape=True)}``")
169+
if "dynamic_shapes" in res:
170+
print(f"* **shapes:** ``{string_type(res['dynamic_shapes'])}``")
171+
print()
172+
print()
173+
if "exported" in res:
174+
print(".. code-block:: text")
175+
print()
176+
print(textwrap.indent(str(res["exported"].graph), " "))
177+
print()
178+
print()
179+
obs.append(dict(case=case_ref, error="", exporter=expo))
180+
else:
181+
print("**FAILED**")
182+
print()
183+
print(".. code-block:: text")
184+
print()
185+
err = str(res["error"])
186+
if err:
187+
print(textwrap.indent(err, " "))
188+
else:
189+
print(" # no error found for the failure")
190+
print()
191+
print()
192+
obs.append(dict(case=case_ref, error="FAIL", exporter=expo))
193+
194+
print()
195+
print(".. _led-summary-exported-program:")
196+
print()
197+
print("Summary")
198+
print("+++++++")
199+
print()
200+
df = pandas.DataFrame(obs)
201+
piv = df.pivot(index="case", columns="exporter", values="error")
202+
print(piv.to_markdown(tablefmt="rst"))
203+
print()
204+
105205

106206
if __name__ == "__main__":
107207
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
ExtTestCase,
77
requires_torch,
88
requires_transformers,
9+
has_transformers,
910
has_torch,
1011
)
12+
from onnx_diagnostic.helpers.cache_helper import CacheKeyValue, make_dynamic_cache
13+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
14+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
1115
from onnx_diagnostic.torch_export_patches import torch_export_patches
1216
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1317

@@ -317,6 +321,125 @@ def forward(self, x, ind1, ind2):
317321
got = ep.module()(*inputs)
318322
self.assertEqualArray(expected, got)
319323

324+
def test_patched__broadcast_in_dim_meta(self):
325+
class Model(torch.nn.Module):
326+
def forward(self, x, ind1, ind2):
327+
return x[ind1, ind2]
328+
329+
inputs = (
330+
torch.randn(2, 1024),
331+
torch.tensor([[0, 1]], dtype=torch.int64).T,
332+
torch.arange(1024, dtype=torch.int64),
333+
)
334+
model = Model()
335+
expected = model(*inputs)
336+
337+
with (
338+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
339+
torch_export_patches(),
340+
):
341+
ep = torch.export.export(
342+
model,
343+
inputs,
344+
dynamic_shapes=use_dyn_not_str(({0: "A", 1: "B"}, {0: "C", 1: "D"}, {0: "E"})),
345+
)
346+
self.assertEqualArray(expected, ep.module()(*inputs), atol=1e-2)
347+
348+
@requires_torch("2.7.9999")
349+
@requires_transformers("4.49.9999")
350+
def test_export_with_patch_tiny_llm_dim_meta(self):
351+
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
352+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
353+
order = ["input_ids", "attention_mask", "position_ids", "past_key_values"]
354+
self.assertEqual(list(inputs), order)
355+
expected = model(**torch_deepcopy(inputs))
356+
with self.subTest(input="no01", backed_size_oblivious=False):
357+
with torch_export_patches(patch_transformers=True):
358+
ep = torch.export.export(
359+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
360+
)
361+
got = ep.module()(**torch_deepcopy(inputs))
362+
self.assertEqualArrayAny(expected, got)
363+
364+
with self.subTest(input="no01", backed_size_oblivious=True):
365+
if not has_transformers("4.55"):
366+
raise unittest.SkipTest("test not working with transformers<4.55")
367+
with (
368+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
369+
torch_export_patches(patch_transformers=True),
370+
):
371+
ep = torch.export.export(
372+
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
373+
)
374+
got = ep.module()(**torch_deepcopy(inputs))
375+
self.assertEqualArrayAny(expected, got)
376+
377+
def _batch1(t):
378+
if t.__class__.__name__ == "DynamicCache":
379+
kv = CacheKeyValue(t)
380+
keys = [t[:1] for t in kv.key_cache]
381+
values = [t[:1] for t in kv.value_cache]
382+
return make_dynamic_cache(tuple(zip(keys, values)))
383+
if t.ndim > 1:
384+
return t[:1]
385+
return t
386+
387+
export_inputs = {k: _batch1(v) for k, v in inputs.items()}
388+
389+
# with self.subTest(input="batch1", backed_size_oblivious=False):
390+
# with torch_export_patches(patch_transformers=True):
391+
# ep = torch.export.export(
392+
# model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds)
393+
# )
394+
# got = ep.module()(**torch_deepcopy(inputs))
395+
# self.assertEqualArrayAny(expected, got)
396+
397+
with self.subTest(input="batch1", backed_size_oblivious=True):
398+
if not has_transformers("4.55"):
399+
raise unittest.SkipTest("test not working with transformers<4.55")
400+
with (
401+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
402+
torch_export_patches(patch_transformers=True),
403+
):
404+
ep = torch.export.export(
405+
model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds)
406+
)
407+
try:
408+
got = ep.module()(**torch_deepcopy(inputs))
409+
except AssertionError as e:
410+
got = None
411+
if "Guard failed: position_ids.size()[0] == 1" not in str(e):
412+
raise
413+
414+
if got is not None:
415+
self.assertEqualArrayAny(expected, got)
416+
417+
if "inputs_empty_cache" not in data:
418+
return
419+
420+
export_inputs = data["inputs_empty_cache"]
421+
422+
# with self.subTest(input="cache0", backed_size_oblivious=False):
423+
# with torch_export_patches(patch_transformers=True):
424+
# ep = torch.export.export(
425+
# model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds)
426+
# )
427+
# got = ep.module()(**torch_deepcopy(inputs))
428+
# self.assertEqualArrayAny(expected, got)
429+
430+
with self.subTest(input="cache0", backed_size_oblivious=True):
431+
if not has_transformers("4.55"):
432+
raise unittest.SkipTest("test not working with transformers<4.55")
433+
with (
434+
torch.fx.experimental._config.patch(backed_size_oblivious=True),
435+
torch_export_patches(patch_transformers=True),
436+
):
437+
ep = torch.export.export(
438+
model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds)
439+
)
440+
got = ep.module()(**torch_deepcopy(inputs))
441+
self.assertEqualArrayAny(expected, got)
442+
320443

321444
if __name__ == "__main__":
322445
unittest.main(verbosity=2)

onnx_diagnostic/torch_export_patches/eval/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,13 @@ def run_exporter(
676676

677677
if dynamic and len(inputs) > 1:
678678
for index, i in enumerate(inputs):
679-
expected = model(*_clone(i))
679+
if quiet:
680+
try:
681+
expected = model(*_clone(i))
682+
except Exception as e:
683+
return dict(error=str(e), success=0, error_step=f"run0.{index}")
684+
else:
685+
expected = model(*_clone(i))
680686
try:
681687
got = mod(*i)
682688
except Exception as e:

onnx_diagnostic/torch_export_patches/eval/model_cases.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,12 +353,9 @@ def else_branch(input_ids, image_features, vocab_size):
353353

354354

355355
class ControlFlowCondIdentity_153832(torch.nn.Module):
356-
"""
357-
`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_
358-
"""
356+
"""`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_"""
359357

360358
def forward(self, x, y):
361-
362359
def branch_cond_then_1(x):
363360
x = torch.abs(x) + 1
364361
return x

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,8 @@ def torch_export_patches(
347347
patched__constrain_user_specified_dimhint_range,
348348
_catch_produce_guards_and_solve_constraints,
349349
patch__check_input_constraints_for_graph,
350+
patched__broadcast_in_dim_meta,
351+
patched__maybe_broadcast,
350352
)
351353

352354
if verbose:
@@ -383,6 +385,16 @@ def torch_export_patches(
383385
patched__constrain_user_specified_dimhint_range
384386
)
385387

388+
# torch._prims._broadcast_in_dim_meta
389+
f_broadcast_in_dim = torch._prims.broadcast_in_dim
390+
f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
391+
torch._prims._broadcast_in_dim_meta = patched__broadcast_in_dim_meta
392+
torch._prims.broadcast_in_dim = patched__broadcast_in_dim_meta
393+
394+
# torch._refs._maybe_broadcast
395+
f__maybe_broadcast = torch._refs._maybe_broadcast
396+
torch._refs._maybe_broadcast = patched__maybe_broadcast
397+
386398
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
387399
if patch_torch and catch_constraints:
388400
if verbose:
@@ -584,6 +596,9 @@ def torch_export_patches(
584596
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
585597
f___constrain_user_specified_dimhint_range
586598
)
599+
torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta
600+
torch._prims.broadcast_in_dim = f_broadcast_in_dim
601+
torch._refs._maybe_broadcast = f__maybe_broadcast
587602

588603
if verbose:
589604
print("[torch_export_patches] restored pytorch functions")
@@ -723,9 +738,7 @@ def torch_export_patches(
723738

724739

725740
def replacement_before_exporting(args: Any) -> Any:
726-
"""
727-
Does replacements on the given inputs if needed.
728-
"""
741+
"""Does replacements on the given inputs if needed."""
729742
if args is None:
730743
return None
731744
if isinstance(args, (int, float)):

0 commit comments

Comments
 (0)