diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7b8ac9df..4a871d7a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,20 +15,28 @@ jobs: strategy: matrix: os: [ubuntu-latest] - python: ['3.10', '3.11', '3.12'] - transformers: ['4.48.3', '4.51.3', 'main'] - torch: ['2.6', '2.7', 'main'] + python: ['3.10', '3.11', '3.12', '3.13'] + transformers: ['4.48.3', '4.51.3', '4.52.1', 'main'] + torch: ['2.7', 'main'] exclude: - python: '3.10' - transformers: 'main' + torch: 'main' + - python: '3.11' + torch: 'main' - python: '3.10' - torch: '2.7' + transformers: '4.52.1' + - python: '3.10' + transformers: 'main' - python: '3.11' - transformers: '4.51.3' + transformers: '4.52.1' - python: '3.11' + transformers: 'main' + - python: '3.13' torch: '2.7' - - python: '3.12' - torch: '2.6' + - python: '3.13' + transformers: '4.48.3' + - python: '3.13' + transformers: '4.51.3' steps: - uses: actions/checkout@v3 diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 554d3fce..895ea020 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -1,6 +1,9 @@ Change Logs =========== +0.5.1 ++++++ + 0.5.0 +++++ diff --git a/_doc/api/torch_export_patches/eval/index.rst b/_doc/api/torch_export_patches/eval/index.rst new file mode 100644 index 00000000..b710e5d9 --- /dev/null +++ b/_doc/api/torch_export_patches/eval/index.rst @@ -0,0 +1,12 @@ +onnx_diagnostic.torch_export_patches.eval +========================================= + +.. toctree:: + :maxdepth: 1 + :caption: modules + + model_cases + +.. automodule:: onnx_diagnostic.torch_export_patches.eval + :members: + :no-undoc-members: diff --git a/_doc/api/torch_export_patches/eval/model_cases.rst b/_doc/api/torch_export_patches/eval/model_cases.rst new file mode 100644 index 00000000..82fa0b80 --- /dev/null +++ b/_doc/api/torch_export_patches/eval/model_cases.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.torch_export_patches.eval.model_cases +===================================================== + +.. automodule:: onnx_diagnostic.torch_export_patches.eval.model_cases + :members: + :undoc-members: diff --git a/_doc/api/torch_export_patches/index.rst b/_doc/api/torch_export_patches/index.rst index 7227ed31..4493ffec 100644 --- a/_doc/api/torch_export_patches/index.rst +++ b/_doc/api/torch_export_patches/index.rst @@ -5,6 +5,7 @@ onnx_diagnostic.torch_export_patches :maxdepth: 1 :caption: submodules + eval/index onnx_export_errors onnx_export_serialization patches/index diff --git a/_doc/conf.py b/_doc/conf.py index 8a3d1faa..9d0b5265 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -92,6 +92,8 @@ def linkcode_resolve(domain, info): "matplotlib": ("https://matplotlib.org/stable/", None), "numpy": ("https://numpy.org/doc/stable", None), "onnx": ("https://onnx.ai/onnx/", None), + "onnxruntime": ("https://onnxruntime.ai/docs/api/python/", None), + "onnxscript": ("https://microsoft.github.io/onnxscript/", None), "onnx_array_api": ("https://sdpython.github.io/doc/onnx-array-api/dev/", None), "onnx_diagnostic": ("https://sdpython.github.io/doc/onnx-diagnostic/dev/", None), "onnx_extended": ("https://sdpython.github.io/doc/onnx-extended/dev/", None), diff --git a/_doc/index.rst b/_doc/index.rst index b422e024..1a479225 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -36,6 +36,7 @@ It also implements tools to investigate, validate exported models (ExportedProgr :caption: Contents patches + status/index api/index cmds/index auto_examples/index @@ -195,14 +196,18 @@ See :meth:`onnx_diagnostic.export.ModelInputs.guess_dynamic_shapes`. >>> (({0: 'dim_0I0', 1: 'dim_0I1'}, {1: 'dim_1I1'}), {}) -use_dyn_for_str +use_dyn_not_str +++++++++++++++ +See :meth:`onnx_diagnostic.torch_export_patches.patch_inputs.use_dyn_not_str`. +The function replaces dynamic dimensions defined as strings by +``torch.export.Dim.DYNAMIC``. Older versions ++++++++++++++ +* `0.5.1 <../v0.5.1/index.html>`_ * `0.5.0 <../v0.5.0/index.html>`_ * `0.4.4 <../v0.4.4/index.html>`_ * `0.4.3 <../v0.4.3/index.html>`_ diff --git a/_doc/status/exported_program_dynamic.rst b/_doc/status/exported_program_dynamic.rst new file mode 100644 index 00000000..0311e577 --- /dev/null +++ b/_doc/status/exported_program_dynamic.rst @@ -0,0 +1,90 @@ +===================================== +Exported Programs with Dynamic Shapes +===================================== + +The following script shows the exported program for many short cases +and various l-plot-export-with-dynamic-shape to retrieve an ONNX model equivalent +to the original model. + +.. runpython:: + :showcode: + :rst: + :toggle: code + :warningout: UserWarning + + import inspect + import textwrap + import pandas + from onnx_diagnostic.helpers import string_type + from onnx_diagnostic.torch_export_patches.eval import discover, run_exporter + from onnx_diagnostic.ext_test_case import unit_test_going + + cases = discover() + print() + print(":ref:`Summary `") + print() + sorted_cases = sorted(cases.items()) + if unit_test_going(): + sorted_cases = sorted_cases[:3] + for name, cls_model in sorted_cases: + print(f"* :ref:`{name} `") + print() + + obs = [] + for name, cls_model in sorted(cases.items()): + print() + print(f".. _led-model-case-export-{name}:") + print() + print(name) + print("=" * len(name)) + print() + print("forward") + print("+++++++") + print() + print("::") + print() + print(textwrap.indent(textwrap.dedent(inspect.getsource(cls_model.forward)), " ")) + print() + for exporter in ( + "export-strict", + "export-nostrict", + "export-nostrict-decall", + ): + expname = exporter.replace("export-", "") + print() + print(expname) + print("+" * len(expname)) + print() + res = run_exporter(exporter, cls_model, True, quiet=True) + case_ref = f":ref:`{name} `" + expo = exporter.split("-", maxsplit=1)[-1] + if "inputs" in res: + print(f"* **inputs:** ``{string_type(res['inputs'], with_shape=True)}``") + if "dynamic_shapes" in res: + print(f"* **shapes:** ``{string_type(res['dynamic_shapes'])}``") + print() + if "exported" in res: + print("::") + print() + print(textwrap.indent(str(res["exported"].graph), " ")) + print() + obs.append(dict(case=case_ref, error="", exporter=expo)) + else: + print("**FAILED**") + print() + print("::") + print() + print(textwrap.indent(str(res["error"]), " ")) + print() + obs.append(dict(case=case_ref, error="FAIL", exporter=expo)) + + print() + print(".. _led-summary-exported-program:") + print() + print("Summary") + print("+++++++") + print() + df = pandas.DataFrame(obs) + piv = df.pivot(index="case", columns="exporter", values="error") + print(piv.to_markdown(tablefmt="rst")) + print() diff --git a/_doc/status/index.rst b/_doc/status/index.rst new file mode 100644 index 00000000..265f3e98 --- /dev/null +++ b/_doc/status/index.rst @@ -0,0 +1,12 @@ +=============== +Exporter Status +=============== + +Following sections tries to capture what patches are put in place to export, +what works and what does not with :func:`torch.export.export`. + +.. toctree:: + :maxdepth: 1 + + exported_program_dynamic + patches_coverage diff --git a/_doc/status/patches_coverage.rst b/_doc/status/patches_coverage.rst new file mode 100644 index 00000000..bb523b68 --- /dev/null +++ b/_doc/status/patches_coverage.rst @@ -0,0 +1,45 @@ +======================= +Coverage of the Patches +======================= + +Serialized Classes +================== + +The following code shows the list of serialized classes in transformers. + +.. runpython:: + :showcode: + + import onnx_diagnostic.torch_export_patches.onnx_export_serialization as p + + print('\n'.join(sorted(p.serialization_functions()))) + +Patched Classes +=============== + +The following script shows the list of methods patched +for transformers. + +.. runpython:: + :showcode: + + import onnx_diagnostic.torch_export_patches.patches.patch_transformers as p + + for name, cls in p.__dict__.items(): + if name.startswith("patched_"): + print(f"{cls._PATCHED_CLASS_.__name__}: {', '.join(cls._PATCHES_)}") + +Half Automated Rewrites for Control Flows +========================================= + +The following script shows the list of methods automatically rewritten +due to control flows. + +.. runpython:: + :showcode: + + import onnx_diagnostic.torch_export_patches.patch_module_helper as p + + for name, f in p.__dict__.items(): + if name.startswith("_rewrite_"): + print(f.__doc__) diff --git a/_unittests/ut_tasks/test_tasks_image_classification.py b/_unittests/ut_tasks/test_tasks_image_classification.py index 2b723ee5..e9856d10 100644 --- a/_unittests/ut_tasks/test_tasks_image_classification.py +++ b/_unittests/ut_tasks/test_tasks_image_classification.py @@ -16,7 +16,7 @@ def test_image_classification(self): model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] model(**inputs) model(**data["inputs2"]) - if not has_transformers("4.51.999"): + if not has_transformers("4.52.999"): raise unittest.SkipTest("Requires transformers>=4.52") with torch_export_patches(patch_transformers=True, verbose=10): torch.export.export( diff --git a/_unittests/ut_torch_export_patches/test_eval.py b/_unittests/ut_torch_export_patches/test_eval.py new file mode 100644 index 00000000..6c898b62 --- /dev/null +++ b/_unittests/ut_torch_export_patches/test_eval.py @@ -0,0 +1,51 @@ +import unittest +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch +from onnx_diagnostic.torch_export_patches.eval import discover, evaluation + + +class TestEval(ExtTestCase): + @requires_torch("2.7", "scan") + def test_discover(self): + res = discover() + self.assertNotEmpty(res) + for mod in res.values(): + if mod.__name__ == "ControlFlowCondIdentity_153832": + continue + with self.subTest(name=mod.__name__): + m = mod() + if isinstance(m._inputs, tuple): + m(*m._inputs) + else: + m(*m._inputs[0]) + + def test_eval(self): + d = list(discover().items())[0] # noqa: RUF015 + ev = evaluation( + quiet=False, + cases={d[0]: d[1]}, + exporters=( + "export-strict", + "export-nostrict", + "custom", + "dynamo", + "dynamo-ir", + "export-tracing", + ), + ) + self.assertIsInstance(ev, list) + self.assertIsInstance(ev[0], dict) + + def test_run_exporter(self): + evaluation( + cases="SignatureListFixedLength", + exporters="custom-strict", + quiet=False, + dynamic=False, + ) + + def test_run_exporter_regex(self): + evaluation(cases=".*Aten.*", exporters="custom-strict", quiet=False, dynamic=False) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_models/test_llm_phi2.py b/_unittests/ut_torch_models/test_llm_phi2.py index af4f616a..090c9b6e 100644 --- a/_unittests/ut_torch_models/test_llm_phi2.py +++ b/_unittests/ut_torch_models/test_llm_phi2.py @@ -13,7 +13,7 @@ def test_get_phi2(self): model(**inputs) @ignore_warnings(UserWarning) - @requires_transformers("4.52") + @requires_transformers("4.53") def test_export_phi2_1(self): data = get_phi2(num_hidden_layers=2) model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] diff --git a/_unittests/ut_torch_models/test_tiny_llms.py b/_unittests/ut_torch_models/test_tiny_llms.py index 909bdc05..ae4f5682 100644 --- a/_unittests/ut_torch_models/test_tiny_llms.py +++ b/_unittests/ut_torch_models/test_tiny_llms.py @@ -14,7 +14,7 @@ def test_get_tiny_llm(self): model(**inputs) @ignore_warnings(UserWarning) - @requires_transformers("4.52") + @requires_transformers("4.53") def test_export_tiny_llm_1(self): data = get_tiny_llm() model, inputs = data["model"], data["inputs"] diff --git a/onnx_diagnostic/__init__.py b/onnx_diagnostic/__init__.py index 02f6377d..b01c76c0 100644 --- a/onnx_diagnostic/__init__.py +++ b/onnx_diagnostic/__init__.py @@ -1,7 +1,7 @@ """ -Investigates onnx models. +Patches, Investigates onnx models. Functions, classes to dig into a model when this one is right, slow, wrong... """ -__version__ = "0.5.0" +__version__ = "0.5.1" __author__ = "Xavier Dupré" diff --git a/onnx_diagnostic/torch_export_patches/eval/__init__.py b/onnx_diagnostic/torch_export_patches/eval/__init__.py new file mode 100644 index 00000000..c59c089b --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/eval/__init__.py @@ -0,0 +1,621 @@ +import contextlib +import io +import itertools +import re +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +import onnx + + +def discover(): + """ + Discovers all model cases used to evaluate an exporter. + + .. runpython:: + :showcode: + + import pprint + from onnx_diagnostic.torch_export_patches.eval import discover + + pprint.pprint(discover()) + """ + from . import model_cases + + res = {} + for m in model_cases.__dict__.values(): + if m is None or isinstance(m, str): + continue + if not hasattr(m, "forward"): + continue + assert m.__name__ not in res, f"Case {m.__name__!r} is duplicated." + assert hasattr(m, "_inputs"), f"Attribute '_inputs' is missing from class {m}" + assert hasattr(m, "_dynamic"), f"Attribute '_dynamic' is missing from class {m}" + res[m.__name__] = m + return res + + +def evaluation( + exporters: Tuple[str] = ( + "export-strict", + "export-nostrict", + "export-nostrict-decall", + ), + dynamic: Tuple[bool] = (False, True), + cases: Optional[Union[str, Dict[str, type]]] = None, + verbose: int = 0, + quiet: bool = True, +) -> List[Dict[str, Any]]: + """ + Evaluates exporter for a list of cases. + + :param exporters: exporters to evaluate + :param dynamic: evaluate static shape and dynamic shapes + :param cases: model cases to evaluate + :param verbose: verbosity + :param quiet: catch exception + :return: results, list of dictionaries + """ + if isinstance(exporters, str): + exporters = (exporters,) + if isinstance(dynamic, (bool, int)): + dynamic = (dynamic,) + + if cases is None: + cases = discover() + elif cases in ("three", ["three"]): + all_cases = discover() + cases = dict(list(all_cases.items())[:3]) + elif isinstance(cases, str): + cases = (cases,) + + if isinstance(cases, (list, tuple)): + all_cases = discover() + new_cases = [] # type: ignore[var-annotated] + for c in cases: + if "*" in c or "?" in c: + # regex + reg = re.compile(c) + new_cases.extend(k for k in all_cases if reg.match(k)) + else: + new_cases.append(c) + cases = {k: v for k, v in all_cases.items() if k in set(new_cases)} + + sorted_cases = sorted(cases.items()) + loop = list(itertools.product(sorted_cases, dynamic, exporters)) + if verbose: + try: + import tqdm + + loop = tqdm.tqdm(loop) + except ImportError: + + def _loop(): + for _ in loop: + print(f"[evaluation] {_}") + yield _ + + assert len(loop) > 0, f"No case to test for cases={cases!r}." + obs = [] + for case, dyn, exporter in loop: + name, cls_model = case + res = run_exporter(exporter, cls_model, dyn, quiet=quiet, verbose=max(0, verbose - 1)) + res.update(dict(name=name, dynamic=int(dyn), exporter=exporter)) + obs.append(res) + return obs + + +def _flatten_inputs(x: Any) -> List["torch.Tensor"]: # noqa: F821 + """ + Flatten inputs. + """ + if x is None: + return x + import torch + + if isinstance(x, (list, tuple)): + res = [] + for i in x: + if i is None or isinstance( + i, + ( + torch.Tensor, + torch.SymInt, + torch.SymFloat, + int, + float, + ), + ): + res.append(i) + else: + res.extend(_flatten_inputs(i)) + return tuple(res) if isinstance(x, tuple) else res + raise AssertionError(f"Unexpected type {type(x)} for x") + + +def _to_numpy(x): + if hasattr(x, "numpy"): + return x.numpy() + if isinstance(x, int): + # onnxruntime does not like scalar + return np.array([x], dtype=np.int64) + if isinstance(x, float): + # onnxruntime does not like scalar + return np.array([x], dtype=np.float32) + if isinstance(x, list): + return [_to_numpy(_) for _ in x] + if isinstance(x, tuple): + return tuple(_to_numpy(_) for _ in x) + raise TypeError(f"Unable to convert type {type(x)}, x={x} into numpy") + + +def _make_feeds(names, args): + if len(names) == len(args): + return {k: _to_numpy(v) for k, v in zip(names, args)} + if len(names) > len(args): + flats = _flatten_inputs(args) + return {k: _to_numpy(v) for k, v in zip(names, flats)} + from ...helpers import string_type + + raise RuntimeError( + f"Unable to handle names={names!r} and args={string_type(args, limit=20)}" + ) + + +def _clone(x): + if hasattr(x, "clone"): + return x.clone() + if isinstance(x, (int, float)): + return x + if isinstance(x, list): + return [_clone(_) for _ in x] + if isinstance(x, tuple): + return tuple(_clone(_) for _ in x) + raise TypeError(f"Unable to clone type {type(x)}, x={x} into numpy") + + +def _make_exporter_export( + exporter: str, + model: "torch.nn.Module", # noqa: F821 + inputs: Tuple[Any, ...], + dynamic_shapes: Optional[Any] = None, + verbose: int = 0, + quiet: bool = True, +) -> Union[Dict, Callable]: + import torch + + if exporter == "export-strict": + try: + exported = torch.export.export( + model, inputs, dynamic_shapes=dynamic_shapes, strict=True + ) + except Exception as e: + if not quiet: + raise + return dict(error=str(e), success=0, error_step="export") + if verbose >= 9: + print("-- graph") + print(exported.graph) + return exported.module() + if exporter in ("export-strict-dec", "export-strict-decall"): + try: + exported = torch.export.export( + model, inputs, dynamic_shapes=dynamic_shapes, strict=True + ) + if verbose >= 9: + print("-- graph before decomposition") + print(exported.graph) + exported = ( + exported.run_decompositions() + if "decall" in exporter + else exported.run_decompositions({}) + ) + except Exception as e: + if not quiet: + raise + return dict(error=str(e), success=0, error_step="export") + if verbose >= 9: + print("-- graph after decomposition") + print(exported.graph) + return exported.module() + if exporter == "export-nostrict": + try: + exported = torch.export.export( + model, inputs, dynamic_shapes=dynamic_shapes, strict=False + ) + except Exception as e: + if not quiet: + raise + return dict(error=str(e), success=0, error_step="export") + if verbose >= 9: + print("-- graph") + print(exported.graph) + return exported.module() + if exporter in ("export-nostrict-dec", "export-nostrict-decall"): + try: + exported = torch.export.export( + model, inputs, dynamic_shapes=dynamic_shapes, strict=False + ) + if verbose >= 9: + print("-- graph before decomposition") + print(exported.graph) + exported = ( + exported.run_decompositions() + if "decall" in exporter + else exported.run_decompositions({}) + ) + except Exception as e: + if not quiet: + raise + return dict(error=str(e), success=0, error_step="export") + if verbose >= 9: + print("-- graph after decomposition") + print(exported.graph) + return exported.module() + if exporter == "export-tracing": + from experimental_experiment.torch_interpreter.tracing import CustomTracer + + try: + graph = CustomTracer().trace(model) + mod = torch.fx.GraphModule(model, graph) + except Exception as e: + if not quiet: + raise + return dict(error=str(e), success=0, error_step="export") + if verbose >= 9: + print("-- graph") + print(graph) + return mod + raise AssertionError(f"Unexpected exporter={exporter!r}") + + +def _make_exporter_onnx( + exporter: str, + model: "torch.nn.Module", # noqa: F821 + inputs: Tuple[Any, ...], + dynamic_shapes: Optional[Any] = None, + verbose: int = 0, + quiet: bool = True, +) -> Union[Dict, Tuple[onnx.ModelProto, Any]]: + from ...helpers import string_type + + if exporter.startswith("custom"): + from experimental_experiment.torch_interpreter import to_onnx, ExportOptions + + opts = {} + opts["strict"] = "-nostrict" not in exporter + opts["fallback"] = "-fallback" in exporter + opts["tracing"] = "-tracing" in exporter + opts["jit"] = "-jit" in exporter + if "-dec" in exporter: + opts["decomposition_table"] = "all" if "-decall" in exporter else "default" + try: + onx, builder = to_onnx( + model, + inputs, + dynamic_shapes=dynamic_shapes, + export_options=ExportOptions(**opts), + return_builder=True, + ) + except Exception as e: + if not quiet: + raise RuntimeError( + f"Unable to convert model={model.__class__.__name__}, " + f"input={string_type(inputs[0], with_shape=True)}, " + f"dynamic_shapes={dynamic_shapes}, " + f"exporter={exporter!r}" + ) from e + return dict(error=str(e), success=0, error_step="export") + return onx, builder + if exporter == "dynamo": + import torch + + try: + if verbose >= 2: + onx = torch.onnx.export( + model, + inputs, + dynamic_shapes=dynamic_shapes, + dynamo=True, + report=True, + ).model_proto + else: + with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( + io.StringIO() + ): + onx = torch.onnx.export( + model, + inputs, + dynamic_shapes=dynamic_shapes, + dynamo=True, + ).model_proto + except Exception as e: + if not quiet: + raise RuntimeError( + f"Unable to convert model={model.__class__.__name__}, " + f"input={string_type(inputs[0], with_shape=True)}, " + f"dynamic_shapes={dynamic_shapes}, " + f"exporter={exporter!r}" + ) from e + return dict(error=str(e), success=0, error_step="export") + return onx, None + if exporter == "dynamo-ir": + import torch + + try: + if verbose >= 2: + ep = torch.onnx.export( + model, + inputs, + dynamic_shapes=dynamic_shapes, + dynamo=True, + report=True, + ) + ep.optimize() + onx = ep.model_proto + else: + with contextlib.redirect_stdout(io.StringIO()), contextlib.redirect_stderr( + io.StringIO() + ): + ep = torch.onnx.export( + model, + inputs, + dynamic_shapes=dynamic_shapes, + dynamo=True, + ) + ep.optimize() + onx = ep.model_proto + except Exception as e: + if not quiet: + raise RuntimeError( + f"Unable to convert model={model.__class__.__name__}, " + f"input={string_type(inputs[0], with_shape=True)}, " + f"dynamic_shapes={dynamic_shapes}, " + f"exporter={exporter!r}" + ) from e + return dict(error=str(e), success=0, error_step="export") + return onx, None + raise AssertionError(f"Unexpected exporter={exporter!r}") + + +def run_exporter( + exporter: str, + cls_model: type, + dynamic: bool = False, + quiet: bool = False, + verbose: int = 0, +) -> Dict[str, Any]: + """ + Runs an exporter and returns whether it fails or not. + + :param exporter: exporter + :param cls_model: model class to create + :param inputs: list of inputs to try + :param dynamic: use dynamic shape or not + :param quiet: raise exception or not + :param verbose: verbosity + :return: results + """ + from onnx_diagnostic.helpers import max_diff, string_type + from onnx_diagnostic.helpers.onnx_helper import pretty_onnx + + assert hasattr( + cls_model, "_inputs" + ), f"Attribute '_inputs' is missing from class {cls_model}" + + model = cls_model() + inputs = cls_model._inputs + if isinstance(inputs, tuple): + inputs = [inputs] + if dynamic: + assert hasattr( + cls_model, "_dynamic" + ), f"Attribute '_inputs' is missing from class {cls_model}" + dynamic_shapes = cls_model._dynamic + else: + dynamic_shapes = None + + base = dict(inputs=inputs, model=model, dynamic_shapes=dynamic_shapes) + + if verbose > 0: + print( + f"[run_exporter] exporter={exporter}, model={cls_model.__name__}, " + f"dynamic={dynamic}, inputs={string_type(inputs, with_shape=True)}" + ) + + builder = None + onx = None + + if exporter.startswith("export-"): + mod = _make_exporter_export( + exporter, + model, + inputs[0], + dynamic_shapes=dynamic_shapes, + verbose=verbose, + quiet=quiet, + ) + if isinstance(mod, dict): + # something went wrong + return mod + else: + res = _make_exporter_onnx( + exporter, + model, + inputs[0], + dynamic_shapes=dynamic_shapes, + verbose=verbose, + quiet=quiet, + ) + if isinstance(res, dict): + # something went wrong + return res + + onx, builder = res + if verbose >= 9: + print("[run_exporter] onnx model") + print( + builder.pretty_text(add_fx_graph=True) + if builder is not None + else pretty_onnx(onx) + ) + if verbose >= 2: + onnx.save(onx, f"evaluation-{model.__class__.__name__}-{dynamic}-{exporter}.onnx") + + names = [i.name for i in onx.graph.input] + flats = _flatten_inputs(inputs[0]) if len(names) > len(inputs[0]) else inputs[0] + + assert quiet or len(names) == len(flats), ( + f"Input mismatch, inputs[0]={string_type(inputs[0])} " + f"inputs but names={names!r}, " + f"model={cls_model.__name__}, export={exporter!r}" + ) + if len(names) != len(flats): + res = dict( + error=f"Input mismatch, inputs[0]={string_type(inputs[0])} " + f"but names={names!r}, model={cls_model.__name__}, export={exporter!r}", + success=0, + error_step="inputs", + ) + res.update(base) + return res + + import onnxruntime + + try: + sess = onnxruntime.InferenceSession( + onx.SerializeToString(), providers=["CPUExecutionProvider"] + ) + except Exception as e: + if not quiet: + raise + res = dict(error=str(e), success=0, error_step="ort-init") + res.update(base) + return res + + mod = lambda *args, names=names: sess.run(None, _make_feeds(names, args)) # noqa: E731 + + # we need to clone for models modifying the inputs + try: + expected = model(*_clone(inputs[0])) + except Exception as e: + if not quiet: + raise RuntimeError( + f"eager mode failed=\n{string_type(inputs[0], with_shape=True)} " + f"\nmodel=\n{type(model)}" + ) from e + res = dict(error=str(e), success=0, error_step="eager") + res.update(base) + return res + try: + got = mod(*inputs[0]) + except Exception as e: + if not quiet: + raise RuntimeError( + f"onnxruntime failed, feeds=\n{string_type(inputs[0], with_shape=True)} " + f"\nmodel=\n{pretty_onnx(onx)}" + ) from e + res = dict(error=str(e), success=0, error_step="run.0") + res.update(base) + return res + + base["expected"] = expected + base["obtained"] = got + + try: + disc = max_diff(expected, got) + except Exception as e: + if not quiet: + raise + res = dict(error=str(e), success=0, error_step="discrepancy") + res.update(base) + return res + + if verbose >= 5 and np.isinf(disc["abs"]): + print("[run_exporter] comparison issues with") + print(f"-- inputs={string_type(inputs[0], with_shape=True, limit=20)}") + print(f"-- expected={string_type(expected, with_shape=True, limit=20)}") + print(f"-- got={string_type(got, with_shape=True, limit=20)}") + elif verbose >= 9: + print("[run_exporter] inputs and outputs") + print( + f"-- inputs=" + f"{string_type(inputs[0], with_shape=True, with_min_max=True, limit=20)}" + ) + print( + f"-- expected=" + f"{string_type(expected, with_shape=True, with_min_max=True, limit=20)}" + ) + print(f"-- got={string_type(got, with_shape=True, with_min_max=True, limit=20)}") + del disc["n"] + del disc["sum"] + disc.update( + dict( + success=1 if disc["abs"] < 0.1 else 0, + model_cls=model.__class__, + exported=mod, # type: ignore[dict-item] + onnx=onx, # type: ignore[dict-item] + ) + ) + if disc["abs"] >= 0.1: + disc["error"] = "diff.0" + disc["error_step"] = "diff.0" + if verbose >= 9: + max_diff(expected, got, verbose=verbose) + else: + disc["success"] = 1 + + if dynamic and onx is not None: + ds = [] + for i in onx.graph.input: + if i.type.tensor_type: + for di, dim in enumerate(i.type.tensor_type.shape.dim): + if dim.dim_param: + ds.append((i.name, di, dim.dim_param)) + if verbose >= 2: + print(f"[run_exporter] dynamic dimension={ds}") + if not ds: + return dict(error="no dynamic shape", success=0, error_step="dynamic") + + if dynamic and len(inputs) > 1: + for index, i in enumerate(inputs): + expected = model(*_clone(i)) + try: + got = mod(*i) + except Exception as e: + if not quiet: + raise RuntimeError( + f"onnxruntime failed,\n-- feeds=\n{string_type(i, with_shape=True)} " + f"exporter={exporter!r}, dynamic_shapes={dynamic_shapes}" + f"\n-- model=\n{pretty_onnx(onx) if onx is not None else type(model)}" + ) from e + return dict(error=str(e), success=0, error_step=f"run.{index}") + + try: + d = max_diff(expected, got) + except Exception as e: + if not quiet: + raise + return dict(error=str(e), success=0, error_step=f"discrepancy.{index}") + + if verbose >= 5 and np.isinf(d["abs"]): + print(f"[run_exporter] comparison issues iteration {index}") + print(f"-- inputs={string_type(i, with_shape=True)}") + print(f"-- expected={string_type(expected, with_shape=True)}") + print(f"-- got={string_type(got, with_shape=True)}") + elif verbose >= 9: + print(f"[run_exporter] inputs and outputs iteration {index}") + print(f"-- inputs={string_type(i, with_shape=True, with_min_max=True)}") + print( + f"-- expected={string_type(expected, with_shape=True, with_min_max=True)}" + ) + print(f"-- got={string_type(got, with_shape=True, with_min_max=True)}") + del d["n"] + del d["sum"] + if d["abs"] >= 0.1: + d["error"] = f"diff.{index}" + d["error_step"] = f"diff.{index}" + d["success"] = 0 + disc.update(d) + + disc.update(base) + return disc diff --git a/onnx_diagnostic/torch_export_patches/eval/model_cases.py b/onnx_diagnostic/torch_export_patches/eval/model_cases.py new file mode 100644 index 00000000..35c2586e --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/eval/model_cases.py @@ -0,0 +1,877 @@ +import numpy as np +import torch + +DIM = torch.export.Dim +DYN = torch.export.Dim.DYNAMIC + + +class AtenRollRelu(torch.nn.Module): + def forward(self, x): + return torch.relu(torch.roll(x, -1, -1)) + + _inputs = ((torch.arange(8 * 3) + 10).reshape((2, -1, 4)).to(torch.float32),) + _dynamic = {"x": {0: DIM("batch")}} + + +class AtenRollPos(torch.nn.Module): + def forward(self, x): + return torch.roll(x, 1, -1) + + _inputs = ((torch.arange(8 * 3) + 10).reshape((2, -1, 4)).to(torch.float32),) + _dynamic = {"x": {0: DIM("batch")}} + + +class InplaceAdd(torch.nn.Module): + + def __init__(self): + super().__init__() + self.bias = torch.ones((1, 4), dtype=torch.float32) + + def forward(self, x): + x += self.bias + return x + + _inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)] + _dynamic = {"x": {0: DIM("batch")}} + + +class InplaceAdd2(torch.nn.Module): + + def __init__(self): + super().__init__() + self.bias = torch.ones((1, 4), dtype=torch.float32) + + def forward(self, x): + x.add_(self.bias) + return x + + _inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)] + _dynamic = {"x": {0: DIM("batch")}} + + +class InplaceAdd_Mul(torch.nn.Module): + + def __init__(self): + super().__init__() + self.bias = torch.ones((1, 4), dtype=torch.float32) + + def forward(self, x): + x.add_(self.bias) + return x * 2 + + _inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)] + _dynamic = {"x": {0: DIM("batch")}} + + +class InplaceCloneAdd_(torch.nn.Module): + + def __init__(self): + super().__init__() + self.bias = torch.ones((1, 4), dtype=torch.float32) + + def forward(self, x): + x = x.clone() + x.add_(self.bias) + return x + + _inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)] + _dynamic = {"x": {0: DIM("batch")}} + + +class InplaceSetItemSquare(torch.nn.Module): + + def forward(self, x): + x[:2, :3] = 1 + return x + + _inputs = [(torch.rand(5, 5),), (torch.rand(7, 5),)] + _dynamic = {"x": {0: DIM("batch")}} + + +class InplaceSetItemSquareAdd(torch.nn.Module): + + def forward(self, x): + x[:2, :3] = 1 + return x + 2 + + _inputs = [(torch.rand(5, 5),), (torch.rand(7, 5),)] + _dynamic = {"x": {0: DIM("batch")}} + + +class InplaceSetItemSquareAdd2(torch.nn.Module): + + def forward(self, x): + x[:2, :3] = 1 + return x + 2, x + 3 + + _inputs = [(torch.rand(5, 5),), (torch.rand(7, 5),)] + _dynamic = {"x": {0: DIM("batch")}} + + +class InplaceSetItemEllipsis_1(torch.nn.Module): + + def __init__(self): + super().__init__() + self.params = torch.zeros((1, 8192, 4), dtype=torch.float32) + + def forward(self, index, update): + copy = self.params.clone() + copy[..., index] = update + return copy + + _inputs = ( + (torch.from_numpy(np.array([0, 3, 2, 1])).to(torch.int64)), + (torch.arange(4 * 8192) + 10).reshape((-1, 4)).to(torch.float32), + ) + _dynamic = {"x": {0: DIM("batch")}} + + +class InplaceSetItemEllipsis_2(torch.nn.Module): + + def __init__(self): + super().__init__() + self.params = torch.zeros((1, 8192, 6), dtype=torch.float32) + + def forward(self, index, update): + copy = self.params.clone() + copy[..., index] = update + return copy + + _inputs = ( + torch.from_numpy(np.array([0, 3, 2, 5])).to(torch.int64), + (torch.arange(4 * 8192) + 10).reshape((-1, 4)).to(torch.float32), + ) + _dynamic = {"x": {0: DIM("batch")}} + + +class InplaceSetItemMask(torch.nn.Module): + def forward(self, x): + mask = x.to(bool) + x[mask] = 2 + return x + + _inputs = [(torch.randn((2, 3, 3)),), (torch.randn((3, 3, 3)),)] + _dynamic = {"x": {0: DIM("batch")}} + + +class AtenInterpolate(torch.nn.Module): + + def forward(self, x): + y = torch.nn.functional.interpolate( + x, + scale_factor=2.0, + mode="bilinear", + recompute_scale_factor=False, + ) + return y + + _inputs = (torch.randn(2, 2, 3, 4, requires_grad=False),) + _dynamic = {"x": {0: DIM("batch")}} + + +class AtenNonZero(torch.nn.Module): + + def forward(self, x): + y = torch.nonzero(x) + return y + + _inputs = (torch.randn(3, 4, requires_grad=False),) + _dynamic = {"x": {0: DIM("batch")}} + + +class AtenNonZeroTuple(torch.nn.Module): + + def forward(self, x): + y = torch.nonzero(x, as_tuple=True) + return y[0], y[1] + + _inputs = (torch.randn(3, 4, requires_grad=False),) + _dynamic = {"x": {0: DIM("batch")}} + + +class AtenAsStrided(torch.nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, x): + y = torch.as_strided(x, (2, 2, 8, 4), (128, 8, 16, 1)) + return y + + _inputs = (torch.randn((2, 2, 8, 8), requires_grad=False),) + _dynamic = {"x": {0: DIM("batch")}} + + +class ComplexPolar(torch.nn.Module): + def forward(self, x, angle): + return torch.polar(x, angle) + + _inputs = (torch.rand(4, 4), torch.rand(4, 4)) + _dynamic = {"x": {0: DIM("batch")}, "angle": {0: DIM("batch")}} + + +class ControlFlowCond(torch.nn.Module): + def forward(self, x): + def true_fn(x): + return torch.sin(x) + + def false_fn(x): + return torch.cos(x) + + return torch.cond(x.sum() > 0, true_fn, false_fn, [x]) + + _inputs = (torch.rand(5, 3),) + _dynamic = {"x": {0: DIM("batch")}} + + +class ControlFlowCond2Outputs(torch.nn.Module): + def forward(self, x): + def true_fn(x): + return torch.sin(x), torch.cos(x) + + def false_fn(x): + return torch.cos(x), torch.sin(x) + + return torch.cond(x.sum() > 0, true_fn, false_fn, [x]) + + _inputs = (torch.rand(5, 3),) + _dynamic = {"x": {0: DIM("batch")}} + + +class ControlFlowCond2Inputs(torch.nn.Module): + def forward(self, x, y): + def true_fn(x, y): + return torch.sin(x), torch.cos(x) + y + + def false_fn(x, y): + return torch.cos(x), torch.sin(x) + y + + return torch.cond(x.sum() > 0, true_fn, false_fn, [x, y]) + + _inputs = torch.rand(5, 3), torch.rand(5, 3) + _dynamic = {"x": {0: DIM("batch")}, "y": {0: DIM("batch")}} + + +class ControlFlowNestCond(torch.nn.Module): + def forward(self, x): + def true_fn2(x): + def true_fn1(x): + return torch.sin(x) + + def false_fn1(x): + return torch.cos(x) + + return torch.cond(x.sum() < 0, true_fn1, false_fn1, [x]) + + def false_fn2(x): + return -x + + return torch.cond(x.sum() > 0, true_fn2, false_fn2, [x]) + + _inputs = (torch.rand(5, 3),) + _dynamic = {"x": {0: DIM("batch")}} + + +class ControlFlowCondConstant(torch.nn.Module): + def forward(self, x): + def true_fn(x): + return torch.sin(x) - torch.ones(x.shape, dtype=x.dtype) + + def false_fn(x): + return torch.cos(x) + torch.ones((1, 1024), dtype=x.dtype) + + return torch.cond(x.sum() > 0, true_fn, false_fn, [x]) + + _inputs = (torch.rand(1024, 1024),) + _dynamic = {"x": {0: DIM("batch")}} + + +class ControlFlowCondNestedModule(torch.nn.Module): + + class Submodule(torch.nn.Module): + def __init__(self): + super().__init__() + # Nested weight + self.weight = torch.nn.Parameter(torch.tensor([100.0])) + + def forward(self, x): + def true_fn(x): + return x * self.weight + + def false_fn(x): + return x / self.weight + + y = torch.cond(torch.abs(x).sum() > 100, true_fn, false_fn, [x]) + return y + + def __init__(self): + super().__init__() + self.submodule = ControlFlowCondNestedModule.Submodule() + self.weight = torch.nn.Parameter(torch.tensor([42.0])) + + def forward(self, x): + def true_fn(x): + return self.submodule(x) + + def false_fn(x): + return x - self.weight + + y = torch.cond(x.sum() > 0, true_fn, false_fn, [x]) + return y + + _inputs = (torch.tensor([-1, 2]),) + _dynamic = {"x": {0: DIM("batch")}} + + +class ControlFlowCondNonZero(torch.nn.Module): + def forward(self, input_ids, image_features, vocab_size): + def then_branch(input_ids, image_features, vocab_size): + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + condition = (input_ids < 0) & (input_ids > -int(1e9)) + positions = torch.nonzero(condition, as_tuple=True) + input_ids = input_ids.clamp_min(0).clamp_max(vocab_size) + return (input_ids, positions[0], positions[1]) + + def else_branch(input_ids, image_features, vocab_size): + r = torch.where(torch.zeros((1, 1), dtype=torch.bool)) + return (input_ids, r[0], r[1]) + + a, b, c = torch.cond( + image_features.numel() > 0, + then_branch, + else_branch, + [input_ids, image_features, vocab_size], + ) + return a, b, c + + _inputs = [ + ( + (torch.arange(24) - 8).reshape((2, -1)).to(torch.int64), + torch.arange(32).reshape((2, -1)).to(torch.float32), + 1025, + ), + ( + (torch.arange(24) - 8).reshape((2, -1)).to(torch.int64), + torch.tensor([[], []], dtype=torch.float32), + 1025, + ), + ] + _dynamic = ( + {0: DIM("batch")}, + {0: DIM("batch"), 1: DIM("seq_length")}, + None, + ) + + +class ControlFlowCondIdentity_153832(torch.nn.Module): + """`#153832 `_""" + + def forward(self, x, y): + + def branch_cond_then_1(x): + x = torch.abs(x) + 1 + return x + + def branch_cond_else_1(x): + return x # fails but succeeds with x.clone() + + x = torch.cond(x.sum() > 0, branch_cond_then_1, branch_cond_else_1, [x]) + return x + y + + _inputs = [ + (torch.rand((3, 4)), torch.rand((3, 4))), + (torch.rand((4, 5)), torch.rand((4, 5))), + ] + _dynamic = {"x": {0: DYN, 1: DYN}, "y": {0: DYN, 1: DYN}} + + +class ControlFlowScan(torch.nn.Module): + + @staticmethod + def add(carry: torch.Tensor, y: torch.Tensor): + next_carry = carry + y + return [next_carry, next_carry] + + def forward(self, x): + init = torch.zeros_like(x[0]) + carry, out = torch.ops.higher_order.scan( + ControlFlowScan.add, [init], [x], additional_inputs=[] + ) + return carry + + _inputs = (torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32),) + _dynamic = {"x": {0: DIM("batch")}} + + +class ControlFlowScan2Carried(torch.nn.Module): + @staticmethod + def add(carry1: torch.Tensor, carry2: torch.Tensor, y1: torch.Tensor, y2: torch.Tensor): + next_carry1 = carry1 + y1 + next_carry2 = carry2 * y2 + return [next_carry1, next_carry2, next_carry1, next_carry2] + + def forward(self, x): + init1 = torch.zeros_like(x[0]) + init2 = torch.ones_like(x[0]) + carry1, carry2, out1, out2 = torch.ops.higher_order.scan( + ControlFlowScan2Carried.add, + [init1, init2], + [x, x * 2], + # dim=0, # 01/31/2025, not supported anymore + additional_inputs=[], + ) + return carry1, carry2, out1, out2 + + _inputs = ( + torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32), + ) + _dynamic = {"x": {0: DIM("batch")}} + + +class ControlFlowScanCDist(torch.nn.Module): + @staticmethod + def dist(carry: torch.Tensor, x: torch.Tensor): + sub = carry - x.reshape((1, -1)) + sq = sub * sub + rd = sq.sum(axis=1) ** 0.5 + # clone --> UnsupportedAliasMutationException: + # Combine_fn might be aliasing the input! + return [carry.clone(), rd] + + def forward(self, x): + carry, out = torch.ops.higher_order.scan( + ControlFlowScanCDist.dist, + [x], + [x], + # dim=0, # 01/31/2025, not supported anymore + additional_inputs=[], + ) + return out + + _inputs = ( + torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32), + ) + _dynamic = {"x": {0: DIM("batch")}} + + +class ControlFlowScanCDist2(torch.nn.Module): + @staticmethod + def dist(unused: torch.Tensor, x: torch.Tensor, samex: torch.Tensor): + sub = samex - x.reshape((1, -1)) + sq = sub * sub + rd = torch.sqrt(sq.sum(axis=1)) + # clone --> UnsupportedAliasMutationException: + # Combine_fn might be aliasing the input! + return [unused.clone(), rd] + + def forward(self, x): + z = torch.tensor([0], dtype=torch.float32) + y = x.clone() + out = torch.ops.higher_order.scan( + ControlFlowScanCDist2.dist, + [z], + [x], + # dim=0, # 01/31/2025, not supported anymore + additional_inputs=[y], + ) + return out[1] + + _inputs = ( + torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32), + ) + _dynamic = {"x": {0: DIM("batch")}} + + +class ControlFlowScanCDistXY(torch.nn.Module): + + @staticmethod + def dist(y: torch.Tensor, scanned_x: torch.Tensor): + sub = y - scanned_x.reshape((1, -1)) + sq = sub * sub + rd = torch.sqrt(sq.sum(axis=1)) + # clone --> UnsupportedAliasMutationException: + # Combine_fn might be aliasing the input! + return [y.clone(), rd] + + def forward(self, x, y): + carry, out = torch.ops.higher_order.scan( + ControlFlowScanCDistXY.dist, + [y], + [x], + # dim=0, # 01/31/2025, not supported anymore + additional_inputs=[], + ) + return out + + _inputs = [ + (torch.randn(3, 4), torch.randn(5, 4)), + (torch.randn(13, 14), torch.randn(15, 14)), + ] + _dynamic = { + "x": {0: DIM("x_rows"), 1: DIM("dim")}, + "y": {0: DIM("y_rows"), 1: DIM("dim")}, + } + + +class ControlFlowScanInplace_153705(torch.nn.Module): + """`#153705 `_""" + + def forward(self, x, y): + def loop_body_1(z, iv, x, y): + z = z.clone() + i = iv.item() + z[i, :] = ((x[i, :] - y) ** 2).sum(dim=-1) + return [z, iv] + + z = torch.empty((x.shape[0], y.shape[0])) + r = torch.ops.higher_order.scan( + loop_body_1, [z], [torch.arange(x.shape[0], dtype=torch.int64)], [x, y] + ) + return r[0] + + _inputs = [ + (torch.rand((3, 4)), torch.rand((5, 4))), + (torch.rand((4, 5)), torch.rand((6, 5))), + ] + _dynamic = {"x": {0: DYN, 1: DYN}, "y": {0: DYN, 1: DYN}} + + +class ControlFlowScanDecomposition_151564(torch.nn.Module): + """`#151564 `_""" + + @classmethod + def dummy_loop(cls, padded: torch.Tensor, pos: torch.Tensor): + copy = torch.zeros(padded.shape) + for i in range(pos.shape[0]): + p = pos[i] + copy[i, :p] = padded[i, :p] + return copy + + @classmethod + def dummy_loop_with_scan(cls, padded: torch.Tensor, pos: torch.Tensor): + def pad_row(padded, p): + row = torch.zeros((padded.shape[0],)) + torch._check(p.item() > 0) + torch._check(p.item() < padded.shape[0]) + # this check is not always true, we add it anyway to make this dimension >= 2 + # and avoid raising an exception about dynamic dimension in {0, 1} + if torch.compiler.is_exporting(): + torch._check(p.item() > 1) + row[: p.item()] = padded[: p.item()] + return (row,) + + return torch.ops.higher_order.scan( + pad_row, + [], + [padded, pos], + [], + ) + + @classmethod + def select_when_exporting(cls, f, f_scan): + return f_scan if torch.compiler.is_exporting() else f + + def forward(self, images, position): + return self.select_when_exporting(self.dummy_loop, self.dummy_loop_with_scan)( + images, position + ) + + _inputs = [(torch.randn((5, 6)), torch.arange(5, dtype=torch.int64) + 1)] + _dynamic = {"images": {0: DYN, 1: DYN}, "position": {0: DYN}} + + +class SignatureInt1(torch.nn.Module): + def __init__(self, n_dims: int = 3, n_targets: int = 1): + super().__init__() + self.linear = torch.nn.Linear(n_dims, n_targets) + self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets)) + + def forward(self, x, i: int = 2): + return torch.sigmoid(self.linear(x)) - self.buff + x[:, i : i + 1] + + _inputs = [ + ((torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), 1), + ((torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32), 2), + ] + _dynamic = ({0: DIM("batch", min=1, max=1024)}, None) + + +class SignatureFloat1(torch.nn.Module): + def __init__(self, n_dims: int = 3, n_targets: int = 1): + super().__init__() + self.linear = torch.nn.Linear(n_dims, n_targets) + self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets)) + + def forward(self, x, alpha: float = 2.0): + return torch.sigmoid(self.linear(x)) - self.buff * alpha + + _inputs = [ + ((torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), 1.5), + ((torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32), 2.5), + ] + _dynamic = ({0: DIM("batch", min=1, max=1024)}, None) + + +class SignatureInt2(torch.nn.Module): + def __init__(self, n_dims: int = 3, n_targets: int = 1): + super().__init__() + self.linear = torch.nn.Linear(n_dims, n_targets) + self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets)) + + def forward(self, x, i: int = 2): + return torch.sigmoid(self.linear(x)) - self.buff + x[:, i] + + _inputs = ((torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), 1) + _dynamic = { + "x": {0: DIM("batch")}, + "i": None, # DIM("ii", min=0, max=3)} + } + + +class SignatureListFixedLength(torch.nn.Module): + def __init__(self, n_dims: int = 3, n_targets: int = 1): + super().__init__() + self.linear = torch.nn.Linear(n_dims, n_targets) + self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets)) + + def forward(self, x, lx: list): + return ( + torch.sigmoid(self.linear(x)) - self.buff + lx[0] * lx[1].sum(axis=1, keepdim=True) + ) + + _inputs = [ + ( + (torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), + [ + (torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32), + (torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32), + ], + ), + ( + (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32), + [ + (torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32), + (torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32), + ], + ), + ] + _dynamic = { + "x": {0: DIM("batch")}, + "lx": [{0: DIM("batch")}, {0: DIM("batch")}], + } + + +class SignatureListVariableLength(torch.nn.Module): + def __init__(self, n_dims: int = 3, n_targets: int = 1): + super().__init__() + self.linear = torch.nn.Linear(n_dims, n_targets) + self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets)) + + def forward(self, x, lx: list): + t = torch.cat(lx, dim=1).sum(axis=1, keepdim=True) + return torch.sigmoid(self.linear(x)) - self.buff + t + + _inputs = [ + ( + (torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), + [ + (torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32), + (torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32), + ], + ), + ( + (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32), + [ + (torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32), + (torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32), + (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32), + ], + ), + ] + _dynamic = { + "x": {0: DIM("batch")}, + "lx": [{0: DIM("batch")}, {0: DIM("batch")}], + } + + +class BuildInLen(torch.nn.Module): + def __init__(self, n_dims: int = 3, n_targets: int = 1): + super().__init__() + self.linear = torch.nn.Linear(n_dims, n_targets) + self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets)) + + def forward(self, x, lx: list): + t = lx[0] * lx[1].sum(axis=1, keepdim=True) + if len(lx) > 2: + t = t + lx[2].sum(axis=1, keepdim=True) + return torch.sigmoid(self.linear(x)) - self.buff + t + + _inputs = [ + ( + (torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), + [ + (torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32), + (torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32), + ], + ), + ( + (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32), + [ + (torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32), + (torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32), + (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32), + ], + ), + ] + _dynamic = { + "x": {0: DIM("batch")}, + "lx": [{0: DIM("batch")}, {0: DIM("batch")}], + } + + +class BuildInIsInstance(torch.nn.Module): + def __init__(self, n_dims: int = 3, n_targets: int = 1): + super().__init__() + self.linear = torch.nn.Linear(n_dims, n_targets) + self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets)) + + def forward(self, x, lx: list | torch.Tensor): + if isinstance(lx, list): + t = lx[0] * lx[1].sum(axis=1, keepdim=True) + return torch.sigmoid(self.linear(x)) - self.buff + t + return torch.sigmoid(self.linear(x)) - self.buff + lx + + _inputs = [ + ( + (torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), + [ + (torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32), + (torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32), + ], + ), + ( + (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32), + [ + (torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32), + (torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32), + ], + ), + ] + _dynamic = { + "x": {0: DIM("batch")}, + "lx": [{0: DIM("batch")}, {0: DIM("batch")}], + } + + +class SignatureShapeAsIndex(torch.nn.Module): + def __init__(self, n_dims: int = 3, n_targets: int = 1): + super().__init__() + self.linear = torch.nn.Linear(n_dims, n_targets) + self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets)) + + def forward(self, x, y): + t = torch.sigmoid(self.linear(x)) + x + return t[:, : y.shape[1]] + + _inputs = ( + (torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), + (torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32), + ) + _dynamic = { + "x": {0: DIM("batch", min=0, max=1024)}, + "y": { + 0: DIM("batch", min=0, max=1024), + 1: DIM("length", min=0, max=2), + }, + } + + +class TypeBFloat16(torch.nn.Module): + + def forward(self, x): + xb = x.to(torch.bfloat16) + return (xb + xb).to(torch.float32) + + _inputs = (torch.rand(4, 4).to(torch.float32),) + _dynamic = {"x": {0: DIM("batch")}} + + +class CropLastDimensionWithTensorShape(torch.nn.Module): + + def forward(self, x, y): + return x[..., : y.shape[0]] + + _inputs = [ + ( + torch.rand(3, 4, 4).to(torch.float32), + torch.rand( + 2, + ).to(torch.float32), + ), + ( + torch.rand(6, 4, 4).to(torch.float32), + torch.rand( + 3, + ).to(torch.float32), + ), + ] + _dynamic = { + "x": {0: DIM("batch")}, + "y": {0: DIM("crop", min=1, max=3)}, + } + + +class CropLastDimensionWithTensorContent(torch.nn.Module): + + def forward(self, x, shape): + return x[..., : shape[0]] + + _inputs = [ + (torch.rand(3, 4, 4).to(torch.float32), torch.tensor([2], dtype=torch.int64)), + (torch.rand(6, 4, 4).to(torch.float32), torch.tensor([3], dtype=torch.int64)), + ] + _dynamic = {"x": {0: DIM("batch")}} + + +class SignatureListFixedWithNone(torch.nn.Module): + + def forward(self, lx): + x = lx[0] + if lx[1] is not None: + x += lx[1] + if lx[2] is not None: + x += lx[2] + return x + + _inputs = [ + ([torch.rand((4, 4)), torch.rand((4, 4)), None],), + ([torch.rand((4, 4)), torch.rand((4, 4)), torch.rand((4, 4))],), + ] + _dynamic = { + "lx": [{0: DIM("batch")}, {0: DIM("batch")}], + } + + +class CreateFromShape(torch.nn.Module): + def forward(self, x): + y = torch.ones((x.shape[0], x.shape[1] + 1)) + return y + + _inputs = [(torch.rand((4, 4)),), (torch.rand((5, 5)),)] + _dynamic = {"x": {0: DIM("dx"), 1: DIM("dy")}} + + +class CreateFromShapeThroughFunction(torch.nn.Module): + @staticmethod + def add_one(dim): + return dim + 1 + + def forward(self, x): + dy1 = CreateFromShapeThroughFunction.add_one(x.shape[1]) + y = torch.ones((x.shape[0], dy1)) + return y + + _inputs = [(torch.rand((4, 4)),), (torch.rand((5, 5)),)] + _dynamic = {"x": {0: DIM("dx"), 1: DIM("dy")}} diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index 0d6777df..10589ff9 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -1,5 +1,5 @@ import pprint -from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import packaging.version as pv import optree import torch @@ -133,6 +133,11 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: # To avoid doing it multiple times. PATCH_OF_PATCHES.add(BaseModelOutput) + return serialization_functions(verbose=verbose) + + +def serialization_functions(verbose: int = 0) -> Dict[str, Union[Callable, int]]: + """Returns the list of serialization functions.""" return dict( DynamicCache=register_class_serialization( DynamicCache, diff --git a/onnx_diagnostic/torch_export_patches/patch_module_helper.py b/onnx_diagnostic/torch_export_patches/patch_module_helper.py index fbedb2a2..dd178d4f 100644 --- a/onnx_diagnostic/torch_export_patches/patch_module_helper.py +++ b/onnx_diagnostic/torch_export_patches/patch_module_helper.py @@ -19,6 +19,28 @@ def ast_or_into_bitor(node: "ast.Node") -> "ast.Node": return new_node +def _rewrite_bart_encoder_layer(): + "BartEncoderLayer, PLBartEncoderLayer" + import transformers + + bd = dict( + filter_node=( + lambda node: isinstance(node, ast.If) and not isinstance(node.test, ast.Name) + ), + pre_rewriter=ast_or_into_bitor, + ) + + def _add(f): + g = bd.copy() + g["function"] = f + return g + + return [ + _add(transformers.models.bart.modeling_bart.BartEncoderLayer.forward), + _add(transformers.models.plbart.modeling_plbart.PLBartEncoderLayer.forward), + ] + + def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]: """ Returns a known list of methods or functions to rewrite because of control flow @@ -43,22 +65,5 @@ def code_needing_rewriting(cls_name: str) -> Optional[List[Any]]: "PLBartEncoderLayer", "PLBartForConditionalGeneration", }: - import transformers - - bd = dict( - filter_node=( - lambda node: isinstance(node, ast.If) and not isinstance(node.test, ast.Name) - ), - pre_rewriter=ast_or_into_bitor, - ) - - def _add(f): - g = bd.copy() - g["function"] = f - return g - - return [ - _add(transformers.models.bart.modeling_bart.BartEncoderLayer.forward), - _add(transformers.models.plbart.modeling_plbart.PLBartEncoderLayer.forward), - ] + return _rewrite_bart_encoder_layer() return None