Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
matrix:
os: [ubuntu-latest]
python: ['3.10', '3.11', '3.12', '3.13']
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.55.4', '4.56.2', 'main']
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.55.4', '4.56.2', '4.57', 'main']
torch: ['2.8', 'main']
exclude:
- python: '3.10'
Expand All @@ -30,6 +30,8 @@ jobs:
transformers: '4.55.4'
- python: '3.10'
transformers: '4.56.2'
- python: '3.10'
transformers: '4.57.0'
- python: '3.11'
torch: 'main'
- python: '3.11'
Expand All @@ -38,6 +40,8 @@ jobs:
transformers: '4.55.4'
- python: '3.11'
transformers: '4.56.2'
- python: '3.11'
transformers: '4.57.0'
- python: '3.13'
torch: '2.8'
- python: '3.13'
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Change Logs
0.7.14
++++++

* :pr:`249`: patches _maybe_broadcast to support a corner case

0.7.13
++++++

Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_tasks/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def test_falcon_mamba_dev(self):
model(**inputs)
model(**data["inputs2"])
self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)])
if not has_transformers("4.57"):
if not has_transformers("4.57.99"):
raise unittest.SkipTest("The model has control flow.")
with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1):
torch.export.export(
Expand Down
108 changes: 104 additions & 4 deletions _unittests/ut_torch_export_patches/test_eval.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import unittest
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, long_test
from onnx_diagnostic.torch_export_patches.eval import discover, evaluation


Expand All @@ -9,14 +9,20 @@ def test_discover(self):
res = discover()
self.assertNotEmpty(res)
for mod in res.values():
if mod.__name__ == "ControlFlowCondIdentity_153832":
continue
with self.subTest(name=mod.__name__):
if mod.__name__ == "ControlFlowCondIdentity_153832":
raise unittest.SkipTest(
"ControlFlowCondIdentity_153832 needs missing clone."
)
m = mod()
if isinstance(m._inputs, tuple):
m(*m._inputs)
else:
m(*m._inputs[0])
for v in m._inputs:
m(*v)
if hasattr(m, "_valid"):
for v in m._valid:
m(*v)

def test_eval(self):
d = list(discover().items())[0] # noqa: RUF015
Expand Down Expand Up @@ -102,6 +108,100 @@ def test_run_exporter_dimension1(self):
dynamic=True,
)

@long_test()
def test_documentation(self):
import inspect
import textwrap
import pandas
from onnx_diagnostic.helpers import string_type
from onnx_diagnostic.torch_export_patches.eval import discover, run_exporter
from onnx_diagnostic.ext_test_case import unit_test_going

cases = discover()
print()
print(":ref:`Summary <led-summary-exported-program>`")
print()
sorted_cases = sorted(cases.items())
if unit_test_going():
sorted_cases = sorted_cases[:3]
for name, _cls_model in sorted_cases:
print(f"* :ref:`{name} <led-model-case-export-{name}>`")
print()
print()

obs = []
for name, cls_model in sorted(cases.items()):
print()
print(f".. _led-model-case-export-{name}:")
print()
print(name)
print("=" * len(name))
print()
print("forward")
print("+++++++")
print()
print(".. code-block:: python")
print()
src = inspect.getsource(cls_model.forward)
if src:
print(textwrap.indent(textwrap.dedent(src), " "))
else:
print(" # code is missing")
print()
print()
for exporter in (
"export-strict",
"export-nostrict",
"export-nostrict-oblivious",
"export-nostrict-decall",
"export-tracing",
):
expname = exporter.replace("export-", "")
print()
print(expname)
print("+" * len(expname))
print()
res = run_exporter(exporter, cls_model, True, quiet=True)
case_ref = f":ref:`{name} <led-model-case-export-{name}>`"
expo = exporter.split("-", maxsplit=1)[-1]
if "inputs" in res:
print(f"* **inputs:** ``{string_type(res['inputs'], with_shape=True)}``")
if "dynamic_shapes" in res:
print(f"* **shapes:** ``{string_type(res['dynamic_shapes'])}``")
print()
print()
if "exported" in res:
print(".. code-block:: text")
print()
print(textwrap.indent(str(res["exported"].graph), " "))
print()
print()
obs.append(dict(case=case_ref, error="", exporter=expo))
else:
print("**FAILED**")
print()
print(".. code-block:: text")
print()
err = str(res["error"])
if err:
print(textwrap.indent(err, " "))
else:
print(" # no error found for the failure")
print()
print()
obs.append(dict(case=case_ref, error="FAIL", exporter=expo))

print()
print(".. _led-summary-exported-program:")
print()
print("Summary")
print("+++++++")
print()
df = pandas.DataFrame(obs)
piv = df.pivot(index="case", columns="exporter", values="error")
print(piv.to_markdown(tablefmt="rst"))
print()


if __name__ == "__main__":
unittest.main(verbosity=2)
123 changes: 123 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
ExtTestCase,
requires_torch,
requires_transformers,
has_transformers,
has_torch,
)
from onnx_diagnostic.helpers.cache_helper import CacheKeyValue, make_dynamic_cache
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str

Expand Down Expand Up @@ -317,6 +321,125 @@ def forward(self, x, ind1, ind2):
got = ep.module()(*inputs)
self.assertEqualArray(expected, got)

def test_patched__broadcast_in_dim_meta(self):
class Model(torch.nn.Module):
def forward(self, x, ind1, ind2):
return x[ind1, ind2]

inputs = (
torch.randn(2, 1024),
torch.tensor([[0, 1]], dtype=torch.int64).T,
torch.arange(1024, dtype=torch.int64),
)
model = Model()
expected = model(*inputs)

with (
torch.fx.experimental._config.patch(backed_size_oblivious=True),
torch_export_patches(),
):
ep = torch.export.export(
model,
inputs,
dynamic_shapes=use_dyn_not_str(({0: "A", 1: "B"}, {0: "C", 1: "D"}, {0: "E"})),
)
self.assertEqualArray(expected, ep.module()(*inputs), atol=1e-2)

@requires_torch("2.7.9999")
@requires_transformers("4.49.9999")
def test_export_with_patch_tiny_llm_dim_meta(self):
data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0)
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
order = ["input_ids", "attention_mask", "position_ids", "past_key_values"]
self.assertEqual(list(inputs), order)
expected = model(**torch_deepcopy(inputs))
with self.subTest(input="no01", backed_size_oblivious=False):
with torch_export_patches(patch_transformers=True):
ep = torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
)
got = ep.module()(**torch_deepcopy(inputs))
self.assertEqualArrayAny(expected, got)

with self.subTest(input="no01", backed_size_oblivious=True):
if not has_transformers("4.55"):
raise unittest.SkipTest("test not working with transformers<4.55")
with (
torch.fx.experimental._config.patch(backed_size_oblivious=True),
torch_export_patches(patch_transformers=True),
):
ep = torch.export.export(
model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
)
got = ep.module()(**torch_deepcopy(inputs))
self.assertEqualArrayAny(expected, got)

def _batch1(t):
if t.__class__.__name__ == "DynamicCache":
kv = CacheKeyValue(t)
keys = [t[:1] for t in kv.key_cache]
values = [t[:1] for t in kv.value_cache]
return make_dynamic_cache(tuple(zip(keys, values)))
if t.ndim > 1:
return t[:1]
return t

export_inputs = {k: _batch1(v) for k, v in inputs.items()}

# with self.subTest(input="batch1", backed_size_oblivious=False):
# with torch_export_patches(patch_transformers=True):
# ep = torch.export.export(
# model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds)
# )
# got = ep.module()(**torch_deepcopy(inputs))
# self.assertEqualArrayAny(expected, got)

with self.subTest(input="batch1", backed_size_oblivious=True):
if not has_transformers("4.55"):
raise unittest.SkipTest("test not working with transformers<4.55")
with (
torch.fx.experimental._config.patch(backed_size_oblivious=True),
torch_export_patches(patch_transformers=True),
):
ep = torch.export.export(
model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds)
)
try:
got = ep.module()(**torch_deepcopy(inputs))
except AssertionError as e:
got = None
if "Guard failed: position_ids.size()[0] == 1" not in str(e):
raise

if got is not None:
self.assertEqualArrayAny(expected, got)

if "inputs_empty_cache" not in data:
return

export_inputs = data["inputs_empty_cache"]

# with self.subTest(input="cache0", backed_size_oblivious=False):
# with torch_export_patches(patch_transformers=True):
# ep = torch.export.export(
# model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds)
# )
# got = ep.module()(**torch_deepcopy(inputs))
# self.assertEqualArrayAny(expected, got)

with self.subTest(input="cache0", backed_size_oblivious=True):
if not has_transformers("4.55"):
raise unittest.SkipTest("test not working with transformers<4.55")
with (
torch.fx.experimental._config.patch(backed_size_oblivious=True),
torch_export_patches(patch_transformers=True),
):
ep = torch.export.export(
model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds)
)
got = ep.module()(**torch_deepcopy(inputs))
self.assertEqualArrayAny(expected, got)


if __name__ == "__main__":
unittest.main(verbosity=2)
8 changes: 7 additions & 1 deletion onnx_diagnostic/torch_export_patches/eval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,13 @@ def run_exporter(

if dynamic and len(inputs) > 1:
for index, i in enumerate(inputs):
expected = model(*_clone(i))
if quiet:
try:
expected = model(*_clone(i))
except Exception as e:
return dict(error=str(e), success=0, error_step=f"run0.{index}")
else:
expected = model(*_clone(i))
try:
got = mod(*i)
except Exception as e:
Expand Down
5 changes: 1 addition & 4 deletions onnx_diagnostic/torch_export_patches/eval/model_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,9 @@ def else_branch(input_ids, image_features, vocab_size):


class ControlFlowCondIdentity_153832(torch.nn.Module):
"""
`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_
"""
"""`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_"""

def forward(self, x, y):

def branch_cond_then_1(x):
x = torch.abs(x) + 1
return x
Expand Down
19 changes: 16 additions & 3 deletions onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ def torch_export_patches(
patched__constrain_user_specified_dimhint_range,
_catch_produce_guards_and_solve_constraints,
patch__check_input_constraints_for_graph,
patched__broadcast_in_dim_meta,
patched__maybe_broadcast,
)

if verbose:
Expand Down Expand Up @@ -383,6 +385,16 @@ def torch_export_patches(
patched__constrain_user_specified_dimhint_range
)

# torch._prims._broadcast_in_dim_meta
f_broadcast_in_dim = torch._prims.broadcast_in_dim
f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
torch._prims._broadcast_in_dim_meta = patched__broadcast_in_dim_meta
torch._prims.broadcast_in_dim = patched__broadcast_in_dim_meta

# torch._refs._maybe_broadcast
f__maybe_broadcast = torch._refs._maybe_broadcast
torch._refs._maybe_broadcast = patched__maybe_broadcast

# torch._export.non_strict_utils.produce_guards_and_solve_constraints
if patch_torch and catch_constraints:
if verbose:
Expand Down Expand Up @@ -584,6 +596,9 @@ def torch_export_patches(
torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
f___constrain_user_specified_dimhint_range
)
torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta
torch._prims.broadcast_in_dim = f_broadcast_in_dim
torch._refs._maybe_broadcast = f__maybe_broadcast

if verbose:
print("[torch_export_patches] restored pytorch functions")
Expand Down Expand Up @@ -723,9 +738,7 @@ def torch_export_patches(


def replacement_before_exporting(args: Any) -> Any:
"""
Does replacements on the given inputs if needed.
"""
"""Does replacements on the given inputs if needed."""
if args is None:
return None
if isinstance(args, (int, float)):
Expand Down
Loading
Loading