diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b89177a9..8d5c5626 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: matrix: os: [ubuntu-latest] python: ['3.10', '3.11', '3.12', '3.13'] - transformers: ['4.48.3', '4.51.3', '4.52.4', '4.55.4', '4.56.2', 'main'] + transformers: ['4.48.3', '4.51.3', '4.52.4', '4.55.4', '4.56.2', '4.57', 'main'] torch: ['2.8', 'main'] exclude: - python: '3.10' @@ -30,6 +30,8 @@ jobs: transformers: '4.55.4' - python: '3.10' transformers: '4.56.2' + - python: '3.10' + transformers: '4.57.0' - python: '3.11' torch: 'main' - python: '3.11' @@ -38,6 +40,8 @@ jobs: transformers: '4.55.4' - python: '3.11' transformers: '4.56.2' + - python: '3.11' + transformers: '4.57.0' - python: '3.13' torch: '2.8' - python: '3.13' diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index f61e7960..f39c13a2 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,8 @@ Change Logs 0.7.14 ++++++ +* :pr:`249`: patches _maybe_broadcast to support a corner case + 0.7.13 ++++++ diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index 975372ec..6815e8dc 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -270,7 +270,7 @@ def test_falcon_mamba_dev(self): model(**inputs) model(**data["inputs2"]) self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)]) - if not has_transformers("4.57"): + if not has_transformers("4.57.99"): raise unittest.SkipTest("The model has control flow.") with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1): torch.export.export( diff --git a/_unittests/ut_torch_export_patches/test_eval.py b/_unittests/ut_torch_export_patches/test_eval.py index 902df00a..1aca744e 100644 --- a/_unittests/ut_torch_export_patches/test_eval.py +++ b/_unittests/ut_torch_export_patches/test_eval.py @@ -1,5 +1,5 @@ import unittest -from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, long_test from onnx_diagnostic.torch_export_patches.eval import discover, evaluation @@ -9,14 +9,20 @@ def test_discover(self): res = discover() self.assertNotEmpty(res) for mod in res.values(): - if mod.__name__ == "ControlFlowCondIdentity_153832": - continue with self.subTest(name=mod.__name__): + if mod.__name__ == "ControlFlowCondIdentity_153832": + raise unittest.SkipTest( + "ControlFlowCondIdentity_153832 needs missing clone." + ) m = mod() if isinstance(m._inputs, tuple): m(*m._inputs) else: - m(*m._inputs[0]) + for v in m._inputs: + m(*v) + if hasattr(m, "_valid"): + for v in m._valid: + m(*v) def test_eval(self): d = list(discover().items())[0] # noqa: RUF015 @@ -102,6 +108,100 @@ def test_run_exporter_dimension1(self): dynamic=True, ) + @long_test() + def test_documentation(self): + import inspect + import textwrap + import pandas + from onnx_diagnostic.helpers import string_type + from onnx_diagnostic.torch_export_patches.eval import discover, run_exporter + from onnx_diagnostic.ext_test_case import unit_test_going + + cases = discover() + print() + print(":ref:`Summary `") + print() + sorted_cases = sorted(cases.items()) + if unit_test_going(): + sorted_cases = sorted_cases[:3] + for name, _cls_model in sorted_cases: + print(f"* :ref:`{name} `") + print() + print() + + obs = [] + for name, cls_model in sorted(cases.items()): + print() + print(f".. _led-model-case-export-{name}:") + print() + print(name) + print("=" * len(name)) + print() + print("forward") + print("+++++++") + print() + print(".. code-block:: python") + print() + src = inspect.getsource(cls_model.forward) + if src: + print(textwrap.indent(textwrap.dedent(src), " ")) + else: + print(" # code is missing") + print() + print() + for exporter in ( + "export-strict", + "export-nostrict", + "export-nostrict-oblivious", + "export-nostrict-decall", + "export-tracing", + ): + expname = exporter.replace("export-", "") + print() + print(expname) + print("+" * len(expname)) + print() + res = run_exporter(exporter, cls_model, True, quiet=True) + case_ref = f":ref:`{name} `" + expo = exporter.split("-", maxsplit=1)[-1] + if "inputs" in res: + print(f"* **inputs:** ``{string_type(res['inputs'], with_shape=True)}``") + if "dynamic_shapes" in res: + print(f"* **shapes:** ``{string_type(res['dynamic_shapes'])}``") + print() + print() + if "exported" in res: + print(".. code-block:: text") + print() + print(textwrap.indent(str(res["exported"].graph), " ")) + print() + print() + obs.append(dict(case=case_ref, error="", exporter=expo)) + else: + print("**FAILED**") + print() + print(".. code-block:: text") + print() + err = str(res["error"]) + if err: + print(textwrap.indent(err, " ")) + else: + print(" # no error found for the failure") + print() + print() + obs.append(dict(case=case_ref, error="FAIL", exporter=expo)) + + print() + print(".. _led-summary-exported-program:") + print() + print("Summary") + print("+++++++") + print() + df = pandas.DataFrame(obs) + piv = df.pivot(index="case", columns="exporter", values="error") + print(piv.to_markdown(tablefmt="rst")) + print() + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index c27a536d..98bb720a 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -6,8 +6,12 @@ ExtTestCase, requires_torch, requires_transformers, + has_transformers, has_torch, ) +from onnx_diagnostic.helpers.cache_helper import CacheKeyValue, make_dynamic_cache +from onnx_diagnostic.helpers.torch_helper import torch_deepcopy +from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str @@ -317,6 +321,125 @@ def forward(self, x, ind1, ind2): got = ep.module()(*inputs) self.assertEqualArray(expected, got) + def test_patched__broadcast_in_dim_meta(self): + class Model(torch.nn.Module): + def forward(self, x, ind1, ind2): + return x[ind1, ind2] + + inputs = ( + torch.randn(2, 1024), + torch.tensor([[0, 1]], dtype=torch.int64).T, + torch.arange(1024, dtype=torch.int64), + ) + model = Model() + expected = model(*inputs) + + with ( + torch.fx.experimental._config.patch(backed_size_oblivious=True), + torch_export_patches(), + ): + ep = torch.export.export( + model, + inputs, + dynamic_shapes=use_dyn_not_str(({0: "A", 1: "B"}, {0: "C", 1: "D"}, {0: "E"})), + ) + self.assertEqualArray(expected, ep.module()(*inputs), atol=1e-2) + + @requires_torch("2.7.9999") + @requires_transformers("4.49.9999") + def test_export_with_patch_tiny_llm_dim_meta(self): + data = get_untrained_model_with_inputs("arnir0/Tiny-LLM", verbose=0) + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + order = ["input_ids", "attention_mask", "position_ids", "past_key_values"] + self.assertEqual(list(inputs), order) + expected = model(**torch_deepcopy(inputs)) + with self.subTest(input="no01", backed_size_oblivious=False): + with torch_export_patches(patch_transformers=True): + ep = torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + got = ep.module()(**torch_deepcopy(inputs)) + self.assertEqualArrayAny(expected, got) + + with self.subTest(input="no01", backed_size_oblivious=True): + if not has_transformers("4.55"): + raise unittest.SkipTest("test not working with transformers<4.55") + with ( + torch.fx.experimental._config.patch(backed_size_oblivious=True), + torch_export_patches(patch_transformers=True), + ): + ep = torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + got = ep.module()(**torch_deepcopy(inputs)) + self.assertEqualArrayAny(expected, got) + + def _batch1(t): + if t.__class__.__name__ == "DynamicCache": + kv = CacheKeyValue(t) + keys = [t[:1] for t in kv.key_cache] + values = [t[:1] for t in kv.value_cache] + return make_dynamic_cache(tuple(zip(keys, values))) + if t.ndim > 1: + return t[:1] + return t + + export_inputs = {k: _batch1(v) for k, v in inputs.items()} + + # with self.subTest(input="batch1", backed_size_oblivious=False): + # with torch_export_patches(patch_transformers=True): + # ep = torch.export.export( + # model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds) + # ) + # got = ep.module()(**torch_deepcopy(inputs)) + # self.assertEqualArrayAny(expected, got) + + with self.subTest(input="batch1", backed_size_oblivious=True): + if not has_transformers("4.55"): + raise unittest.SkipTest("test not working with transformers<4.55") + with ( + torch.fx.experimental._config.patch(backed_size_oblivious=True), + torch_export_patches(patch_transformers=True), + ): + ep = torch.export.export( + model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + try: + got = ep.module()(**torch_deepcopy(inputs)) + except AssertionError as e: + got = None + if "Guard failed: position_ids.size()[0] == 1" not in str(e): + raise + + if got is not None: + self.assertEqualArrayAny(expected, got) + + if "inputs_empty_cache" not in data: + return + + export_inputs = data["inputs_empty_cache"] + + # with self.subTest(input="cache0", backed_size_oblivious=False): + # with torch_export_patches(patch_transformers=True): + # ep = torch.export.export( + # model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds) + # ) + # got = ep.module()(**torch_deepcopy(inputs)) + # self.assertEqualArrayAny(expected, got) + + with self.subTest(input="cache0", backed_size_oblivious=True): + if not has_transformers("4.55"): + raise unittest.SkipTest("test not working with transformers<4.55") + with ( + torch.fx.experimental._config.patch(backed_size_oblivious=True), + torch_export_patches(patch_transformers=True), + ): + ep = torch.export.export( + model, (), kwargs=export_inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + got = ep.module()(**torch_deepcopy(inputs)) + self.assertEqualArrayAny(expected, got) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/torch_export_patches/eval/__init__.py b/onnx_diagnostic/torch_export_patches/eval/__init__.py index 9cfe8d64..da74585e 100644 --- a/onnx_diagnostic/torch_export_patches/eval/__init__.py +++ b/onnx_diagnostic/torch_export_patches/eval/__init__.py @@ -676,7 +676,13 @@ def run_exporter( if dynamic and len(inputs) > 1: for index, i in enumerate(inputs): - expected = model(*_clone(i)) + if quiet: + try: + expected = model(*_clone(i)) + except Exception as e: + return dict(error=str(e), success=0, error_step=f"run0.{index}") + else: + expected = model(*_clone(i)) try: got = mod(*i) except Exception as e: diff --git a/onnx_diagnostic/torch_export_patches/eval/model_cases.py b/onnx_diagnostic/torch_export_patches/eval/model_cases.py index f67c39b9..586bf027 100644 --- a/onnx_diagnostic/torch_export_patches/eval/model_cases.py +++ b/onnx_diagnostic/torch_export_patches/eval/model_cases.py @@ -353,12 +353,9 @@ def else_branch(input_ids, image_features, vocab_size): class ControlFlowCondIdentity_153832(torch.nn.Module): - """ - `#153832 `_ - """ + """`#153832 `_""" def forward(self, x, y): - def branch_cond_then_1(x): x = torch.abs(x) + 1 return x diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 0ed5692c..972c728c 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -347,6 +347,8 @@ def torch_export_patches( patched__constrain_user_specified_dimhint_range, _catch_produce_guards_and_solve_constraints, patch__check_input_constraints_for_graph, + patched__broadcast_in_dim_meta, + patched__maybe_broadcast, ) if verbose: @@ -383,6 +385,16 @@ def torch_export_patches( patched__constrain_user_specified_dimhint_range ) + # torch._prims._broadcast_in_dim_meta + f_broadcast_in_dim = torch._prims.broadcast_in_dim + f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta + torch._prims._broadcast_in_dim_meta = patched__broadcast_in_dim_meta + torch._prims.broadcast_in_dim = patched__broadcast_in_dim_meta + + # torch._refs._maybe_broadcast + f__maybe_broadcast = torch._refs._maybe_broadcast + torch._refs._maybe_broadcast = patched__maybe_broadcast + # torch._export.non_strict_utils.produce_guards_and_solve_constraints if patch_torch and catch_constraints: if verbose: @@ -584,6 +596,9 @@ def torch_export_patches( torch._export.non_strict_utils._constrain_user_specified_dimhint_range = ( f___constrain_user_specified_dimhint_range ) + torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta + torch._prims.broadcast_in_dim = f_broadcast_in_dim + torch._refs._maybe_broadcast = f__maybe_broadcast if verbose: print("[torch_export_patches] restored pytorch functions") @@ -723,9 +738,7 @@ def torch_export_patches( def replacement_before_exporting(args: Any) -> Any: - """ - Does replacements on the given inputs if needed. - """ + """Does replacements on the given inputs if needed.""" if args is None: return None if isinstance(args, (int, float)): diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py index 6e24835b..bee4ba19 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_torch.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_torch.py @@ -1,6 +1,7 @@ import inspect import os import traceback +from functools import reduce from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch from torch._subclasses.fake_tensor import FakeTensorMode @@ -570,3 +571,146 @@ def patched__constrain_user_specified_dimhint_range( return msg return None + + +def patched__maybe_broadcast(*args, preserve_cpu_scalar_tensors=True): + """Patches ``torch._refs._maybe_broadcast``.""" + from torch._prims_common import ShapeType, TensorLike, Number + + # Computes common shape + common_shape = patched__broadcast_shapes( + *(t.shape if isinstance(t, TensorLike) else None for t in args) + ) + + def should_expand(a: ShapeType, b: ShapeType) -> bool: + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + sym_and, + sym_or, + ) + + if len(a) != len(b): + return True + + for x, y in zip(a, b): + if guard_or_false(x != y): + # We know they are not the same. + return True + + # They are the same or we do not know if they are the same or not. + # 1==1 no-broadcast + # u0==1 and 1==u0 cases. We broadcast! + if guard_or_false(sym_and(x == 1, y == 1)): + pass + elif guard_or_false(sym_or(x == 1, y == 1)): + # assume broadcasting. + return True + + # u0==u1 assume the same, no broadcasting! + # PATCHED: avoid errors + return True # guard_or_true(x != y) + # torch._check( + # x == y, + # lambda x=x, y=y: ( + # f"sizes assumed to be the same due to unbacked " + # f"broadcasting semantics x={x!r}, y={y!r}" + # ), + # ) + + return False + + def __maybe_broadcast(x, shape): + if x is None: + return None + elif isinstance(x, Number): + return x + elif isinstance(x, TensorLike): + if preserve_cpu_scalar_tensors and torch._prims_common.is_cpu_scalar_tensor(x): + return x + + if should_expand(x.shape, common_shape): + return x.expand(common_shape) + + return x + else: + raise RuntimeError(f"Unexpected type when broadcasting: {str(type(x))}!") + + return tuple(__maybe_broadcast(x, common_shape) for x in args) + + +def patched__broadcast_in_dim_meta( + a: torch._prims_common.TensorLikeType, + shape: torch._prims_common.ShapeType, + broadcast_dimensions: Sequence[int], +): + """Patches ``torch._prims._broadcast_in_dim_meta``.""" + from torch.fx.experimental.symbolic_shapes import ( + guard_or_false, + guard_or_true, + sym_or, + ) + + # Type checks + assert isinstance(a, torch._prims_common.TensorLike) + assert isinstance(shape, Sequence) + assert isinstance(broadcast_dimensions, Sequence) + + # every dimension must be accounted for + assert a.ndim == len(broadcast_dimensions) + + # broadcast shape must have weakly more dimensions + assert len(shape) >= a.ndim + + # broadcast_dimensions must be an ascending sequence + # (no relative reordering of dims) of integers and + # each dimension must be within the new shape + def _greater_than_reduce(acc, x): + assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x" + assert x > acc + assert x < len(shape) + + return x + + reduce(_greater_than_reduce, broadcast_dimensions, -1) + + # shape must be broadcastable to + for idx, new_idx in enumerate(broadcast_dimensions): + torch._check( + sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]), + lambda idx=idx, new_idx=new_idx: ( + f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}" + ), + ) + + new_strides = [] + original_idx = 0 + for idx in range(len(shape)): + if idx in broadcast_dimensions: + # Assigns a stride of zero to dimensions + # which were actually broadcast + if guard_or_false(a.shape[original_idx] == 1): + if guard_or_false(a.shape[original_idx] == shape[idx]): + new_strides.append(a.stride()[original_idx]) + else: + new_strides.append(0) + else: + # PATCHED: disabled this check + # torch._check( + # a.shape[original_idx] == shape[idx], + # lambda idx=idx, original_idx=original_idx: ( + # f"non-broadcasting semantics require " + # f"{a.shape[original_idx]} == {shape[idx]}" + # ), + # ) + new_strides.append(a.stride()[original_idx]) + original_idx = original_idx + 1 + else: + if guard_or_true(shape[idx] != 1): + # consistent with previous use of guard_size_oblivious + new_strides.append(0) + elif original_idx == a.ndim: + new_strides.append(1) + else: + new_strides.append(a.stride()[original_idx] * a.size()[original_idx]) + + return a.as_strided(shape, new_strides, a.storage_offset())