From f91cc32c700996a6adb2982ad68ff62b4cb43d7a Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 14 Apr 2025 15:37:06 +0200 Subject: [PATCH 1/5] fix script --- _unittests/ut_export/test_dynamic_shapes.py | 67 +++++++++++++++++ onnx_diagnostic/_command_lines_parser.py | 3 +- onnx_diagnostic/export/dynamic_shapes.py | 82 ++++++++++++++++++++- onnx_diagnostic/torch_models/test_helper.py | 62 ++++++++++++---- 4 files changed, 195 insertions(+), 19 deletions(-) diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index fc36af57..ec9dcb95 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -679,6 +679,42 @@ def test_couple_input_ds_replace_string(self): ).replace_string_by(value="DYN"), ) + def test_couple_input_ds_replace_by_string(self): + T3x1 = torch.rand((3, 1)) + T3x4 = torch.rand((3, 4)) + T5x1 = torch.rand((5, 1)) + args = (T5x1,) + kwargs = {"A": T3x4, "B": (T3x1, T3x1)} + ds_batch = {0: "batch"} + ds_batch_seq = {0: "batch", 1: "seq"} + ds = {"X": ds_batch, "A": ds_batch, "B": (ds_batch, ds_batch)} + Cls = CoupleInputsDynamicShapes + res = Cls( + args, + kwargs, + ds, + args_names=["X"], + ).replace_by_string() + self.assertEqual(ds, res) + + ds_batch = {0: torch.export.Dim("batch")} + ds_batch_seq = {0: torch.export.Dim("batch"), 1: torch.export.Dim.DYNAMIC} + ds = {"X": ds_batch, "A": ds_batch_seq, "B": (ds_batch_seq, ds_batch_seq)} + res = Cls( + args, + kwargs, + ds, + args_names=["X"], + ).replace_by_string() + self.assertEqual( + { + "X": {0: "batch"}, + "A": {0: "batch", 1: "Dim1"}, + "B": ({0: "batch", 1: "Dim1"}, {0: "batch", 1: "Dim1"}), + }, + res, + ) + def test_couple_input_ds_change_dynamic_dimensions(self): T257 = torch.arange(2 * 5 * 7).reshape((2, 5, 7)) T29 = torch.arange(2 * 9).reshape((2, 9)) @@ -703,6 +739,37 @@ def test_couple_input_ds_change_dynamic_dimensions_fixed(self): self.assertEqual((1, 5, 8), new_input["A"].shape) self.assertEqual((1, 50), new_input["B"].shape) + def test_dynamic_cache_replace_by_string(self): + n_layers = 2 + bsize, nheads, slen, dim = 2, 4, 3, 7 + cache = make_dynamic_cache( + [ + (torch.randn(bsize, nheads, slen, dim), torch.randn(bsize, nheads, slen, dim)) + for i in range(n_layers) + ] + ) + + DYN = torch.export.Dim.DYNAMIC + ds = { + "cache": [ + [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}], + [{0: DYN, 1: DYN}, {0: DYN, 1: DYN}], + ] + } + inst = CoupleInputsDynamicShapes((), dict(cache=cache), ds) + as_string = inst.replace_by_string() + self.assertEqual( + { + "cache": [ + {0: "Dim0", 1: "Dim1"}, + {0: "Dim2", 1: "Dim3"}, + {0: "Dim4", 1: "Dim5"}, + {0: "Dim6", 1: "Dim7"}, + ] + }, + as_string, + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 7a62c4a1..6d2456d1 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -291,7 +291,8 @@ def get_parser_validate() -> ArgumentParser: "--ortfusiontype", required=False, help="applies onnxruntime fusion, this parameter should contain the " - "model type or multiple values separated by |", + "model type or multiple values separated by `|`. `ALL` can be used " + "to run them all", ) parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity") parser.add_argument("--dtype", help="changes dtype if necessary") diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index 7f6528a7..53143e56 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -1,5 +1,5 @@ import inspect -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union import numpy as np import torch from ..helpers import string_type @@ -8,6 +8,30 @@ DYNAMIC_SHAPES = Tuple[Tuple[Any, ...], Dict[str, Any]] +def flatten_dynamic_shapes(ds: Any) -> Any: + """Flattens the dynamic shapes.""" + if isinstance(ds, list): + return _flat_list([flatten_dynamic_shapes(t) for t in ds]) + if isinstance(ds, tuple): + return tuple(_flat_list([flatten_dynamic_shapes(t) for t in ds])) + if isinstance(ds, dict): + if all(isinstance(i, int) for i in ds): + # That's a dynamic shape + return ds + return _flat_list([flatten_dynamic_shapes(t) for t in ds.values()]) + raise AssertionError(f"Not implemented for {type(ds)}: {ds}") + + +def _flat_list(li: List[Any]) -> List[Dict[int, str]]: + res = [] + for t in li: + if isinstance(t, dict): + res.append(t) + else: + res.extend(t) + return res + + class CoupleInputsDynamicShapes: """ Pair inputs / dynamic shapes. @@ -76,7 +100,7 @@ def _replace_string_dim_tensor(cls, inputs, ds, value=None): assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}" assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), ( f"Unexpected types, inputs is a Tensor but ds is {ds}, " - f"a dictionary is expected to specify a dimension dimension" + f"a dictionary is expected to specify a dimension" ) if value is None: value = torch.export.Dim.DYNAMIC @@ -86,6 +110,57 @@ def _replace_string_dim_tensor(cls, inputs, ds, value=None): new_ds[i] = value return new_ds + def replace_by_string(self): + """ + Replaces dimensions by strings. + + Example: + + .. runpython:: + :showcode: + + import torch + from onnx_diagnostic.export.dynamic_shapes import CoupleInputsDynamicShapes + + Dim = torch.export.Dim + T3x1 = torch.rand((3, 1)) + T3x4 = torch.rand((3, 4)) + ds_batch = {0: Dim("batch")} + ds_batch_seq = {0: Dim("batch"), 1: Dim("seq")} + kwargs = {"A": T3x4, "B": (T3x1, T3x1)} + ds = {"A": ds_batch, "B": (ds_batch, ds_batch_seq)} + print(CoupleInputsDynamicShapes((), kwargs, ds).replace_by_string()) + """ + unique = set() + return self._generic_walker( + lambda inputs, ds, unique=unique: self._replace_dim_tensor_by_string( + inputs, ds, unique=unique + ) + ) + + @classmethod + def _replace_dim_tensor_by_string(cls, inputs, ds, unique: Set[str]): + assert isinstance(inputs, torch.Tensor), f"unexpected type for inputs {type(inputs)}" + assert isinstance(ds, dict) and all(isinstance(s, int) for s in ds), ( + f"Unexpected types, inputs is a Tensor but ds is {ds}, " + f"a dictionary is expected to specify a dimension" + ) + new_ds = ds.copy() + for i, v in ds.items(): + if isinstance(v, str): + assert v not in unique, f"Dimension {v!r} is already defined in {unique}" + unique.add(v) + new_ds[i] = v + elif v in (torch.export.Dim.DYNAMIC, torch.export.Dim.AUTO): + name = f"Dim{len(unique)}" + new_ds[i] = name + unique.add(name) + else: + name = v.__name__ + unique.add(name) + new_ds[i] = name + return new_ds + def invalid_dimensions_for_export(self): """ Tells if the inputs are valid based on the dynamic shapes definition. @@ -252,6 +327,9 @@ def _generic_walker_step(cls, processor: Callable, inputs, ds): f"map this class with the given dynamic shapes." ) flat, _spec = torch.utils._pytree.tree_flatten(inputs) + if all(isinstance(t, torch.Tensor) for t in flat): + # We need to flatten dynamic shapes as well + ds = flatten_dynamic_shapes(ds) return cls._generic_walker_step(processor, flat, ds) class ChangeDimensionProcessor: diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index a6256345..5f3d0172 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -5,6 +5,7 @@ import time import onnx import torch +from ..export import CoupleInputsDynamicShapes from ..helpers import max_diff, string_type, string_diff from ..helpers.helper import flatten_object from ..helpers.rt_helper import make_feeds @@ -506,7 +507,12 @@ def validate_model( ), f"Missing attribute num_attention_heads in configuration {config}" num_attention_heads = config.num_attention_heads - model_types = ortfusiontype.split("|") + if ortfusiontype == "ALL": + from onnxruntime.transformers.optimizer import MODEL_TYPES + + model_types = sorted(MODEL_TYPES) + else: + model_types = ortfusiontype.split("|") for model_type in model_types: flavour = f"ort{model_type}" summary[f"version_{flavour}_hidden_size"] = hidden_size @@ -517,13 +523,15 @@ def validate_model( print(f"[validate_model] run onnxruntime fusion for {model_type!r}") input_filename = data["onnx_filename"] output_path = f"{os.path.splitext(input_filename)[0]}.ort.{model_type}.onnx" - run_ort_fusion( + ort_sum, ort_data = run_ort_fusion( input_filename, output_path, model_type=model_type, num_attention_heads=num_attention_heads, hidden_size=hidden_size, ) + summary.update(ort_sum) + data.update(ort_data) data[f"onnx_filename_{flavour}"] = output_path duration = time.perf_counter() - begin summary[f"time_ortfusion_{flavour}"] = duration @@ -590,7 +598,7 @@ def call_exporter( optimization=optimization, ) return summary, data - if exporter.startswith("custom-"): + if exporter == "custom" or exporter.startswith("custom"): # torch export summary, data = call_torch_export_custom( exporter=exporter, @@ -758,6 +766,9 @@ def _mk(key): if input_data_key in data: source = data[input_data_key] + if not os.path.exists(source): + summary[_mk("ERR_onnx_missing")] = f"FileNotFoundError({source!r})" + return summary, data summary[input_data_key] = source summary[_mk("onnx_size")] = os.stat(source).st_size else: @@ -866,7 +877,7 @@ def call_torch_export_onnx( assert "model" in data, f"model is missing from data: {sorted(data)}" assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}" summary: Dict[str, Union[str, int, float]] = {} - dynamo = "nostrict" not in exporter + dynamo = "dynamo" in exporter args, kwargs = split_args_kwargs(data["inputs_export"]) ds = data.get("dynamic_shapes", None) if verbose: @@ -884,6 +895,15 @@ def call_torch_export_onnx( summary["export_args"] = string_type(args, with_shape=True) summary["export_kwargs"] = string_type(kwargs, with_shape=True) + export_export_kwargs = ( + dict(dynamo=True, dynamic_shapes=ds) + if dynamo + else dict( + dynamo=False, + dynamic_axes=CoupleInputsDynamicShapes(args, kwargs, ds).replace_by_string(), + ) + ) + begin = time.perf_counter() if quiet: try: @@ -891,8 +911,7 @@ def call_torch_export_onnx( data["model"], args, kwargs=kwargs, - dynamic_shapes=ds, - dynamo=dynamo, + **export_export_kwargs, ) except Exception as e: summary["ERR_export_export"] = str(e) @@ -904,8 +923,7 @@ def call_torch_export_onnx( data["model"], args, kwargs=kwargs, - dynamic_shapes=ds, - dynamo=dynamo, + **export_export_kwargs, ) summary["time_export_export"] = time.perf_counter() - begin @@ -966,6 +984,7 @@ def call_torch_export_custom( None, }, f"unexpected value for optimization={optimization}" assert exporter in { + "custom", "custom-strict", "custom-strict-dec", "custom-strict-all", @@ -1155,14 +1174,24 @@ def run_ort_fusion( f"[run_ort_fusion] starts optimization for " f"model_type={model_type!r} with {n_nodes} nodes" ) - new_onx = optimize_by_fusion( - onx, - model_type=model_type, - num_heads=num_attention_heads, - hidden_size=hidden_size, - optimization_options=opts, - ) - duration = {time.perf_counter() - begin} + try: + new_onx = optimize_by_fusion( + onx, + model_type=model_type, + num_heads=num_attention_heads, + hidden_size=hidden_size, + optimization_options=opts, + ) + except Exception as e: + duration = {time.perf_counter() - begin} + if verbose: + print(f"[run_ort_fusion] failed in {duration} for model_type={model_type!r}") + return { + f"ERR_opt_ort_{model_type}": str(e), + f"opt_ort_{model_type}_duration": duration, + }, {} + + duration = time.perf_counter() - begin delta = len(new_onx.model.graph.node) if verbose: print(f"[run_ort_fusion] done in {duration} with {delta} nodes") @@ -1175,6 +1204,7 @@ def run_ort_fusion( return { f"opt_ort_{model_type}_n_nodes1": n_nodes, f"opt_ort_{model_type}_n_nodes2": delta, + f"opt_ort_{model_type}_delta_node": delta - n_nodes, f"opt_ort_{model_type}_duration": duration, f"opt_ort_{model_type}_duration_save": d, }, {f"opt_ort_{model_type}": output_path} From 43b2c3931ba9581016d066f1a321bc6408e2bb70 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 14 Apr 2025 15:39:43 +0200 Subject: [PATCH 2/5] fix mypy --- onnx_diagnostic/torch_models/test_helper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index 5f3d0172..b6bdeaf5 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -754,7 +754,7 @@ def validate_onnx_model( def _mk(key): return f"{key}_{flavour}" if flavour else key - summary = {} + summary: Dict[str, Any] = {} flat_inputs = flatten_object(data["inputs"], drop_keys=True) d = flat_inputs[0].get_device() providers = ( From efa0750c7017b215eaa2075428a30490952c6e30 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 14 Apr 2025 16:11:32 +0200 Subject: [PATCH 3/5] ut --- _unittests/ut_export/test_dynamic_shapes.py | 2 +- onnx_diagnostic/export/dynamic_shapes.py | 1 - onnx_diagnostic/torch_models/test_helper.py | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index ec9dcb95..e6354697 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -710,7 +710,7 @@ def test_couple_input_ds_replace_by_string(self): { "X": {0: "batch"}, "A": {0: "batch", 1: "Dim1"}, - "B": ({0: "batch", 1: "Dim1"}, {0: "batch", 1: "Dim1"}), + "B": ({0: "batch", 1: "Dim2"}, {0: "batch", 1: "Dim3"}), }, res, ) diff --git a/onnx_diagnostic/export/dynamic_shapes.py b/onnx_diagnostic/export/dynamic_shapes.py index 53143e56..987346fc 100644 --- a/onnx_diagnostic/export/dynamic_shapes.py +++ b/onnx_diagnostic/export/dynamic_shapes.py @@ -148,7 +148,6 @@ def _replace_dim_tensor_by_string(cls, inputs, ds, unique: Set[str]): new_ds = ds.copy() for i, v in ds.items(): if isinstance(v, str): - assert v not in unique, f"Dimension {v!r} is already defined in {unique}" unique.add(v) new_ds[i] = v elif v in (torch.export.Dim.DYNAMIC, torch.export.Dim.AUTO): diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index b6bdeaf5..9fcdc559 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -1183,7 +1183,7 @@ def run_ort_fusion( optimization_options=opts, ) except Exception as e: - duration = {time.perf_counter() - begin} + duration = time.perf_counter() - begin if verbose: print(f"[run_ort_fusion] failed in {duration} for model_type={model_type!r}") return { From 71b5f8404dac6739d98b676cf7d78167185f213d Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 14 Apr 2025 16:30:49 +0200 Subject: [PATCH 4/5] fix annot --- _unittests/ut_export/test_dynamic_shapes.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/_unittests/ut_export/test_dynamic_shapes.py b/_unittests/ut_export/test_dynamic_shapes.py index e6354697..08da1426 100644 --- a/_unittests/ut_export/test_dynamic_shapes.py +++ b/_unittests/ut_export/test_dynamic_shapes.py @@ -1,6 +1,6 @@ import unittest import torch -from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.export import ModelInputs, CoupleInputsDynamicShapes @@ -739,6 +739,7 @@ def test_couple_input_ds_change_dynamic_dimensions_fixed(self): self.assertEqual((1, 5, 8), new_input["A"].shape) self.assertEqual((1, 50), new_input["B"].shape) + @requires_transformers("4.51") def test_dynamic_cache_replace_by_string(self): n_layers = 2 bsize, nheads, slen, dim = 2, 4, 3, 7 From 8730aa3b89e2048fed4080595952f7a048ce96a1 Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 14 Apr 2025 16:42:19 +0200 Subject: [PATCH 5/5] documentation --- _doc/cmds/validate.rst | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/_doc/cmds/validate.rst b/_doc/cmds/validate.rst index 81f685e5..fd19f758 100644 --- a/_doc/cmds/validate.rst +++ b/_doc/cmds/validate.rst @@ -100,3 +100,19 @@ Let's export with ONNX this time and checks for discrepancies. main("validate -m arnir0/Tiny-LLM --run -v 1 --export onnx-dynamo -o dump_models --patch --opt ir".split()) +Run onnxruntime fusions ++++++++++++++++++++++++ + +This option runs `transformers optimizations `_ +implemented in :epkg:`onnxruntime`. The list of supported ``model_type`` can be found in the documentation +of function :func:`onnx_diagnostic.torch_models.test_helper.run_ort_fusion`. + +.. code-block:: + + python -m onnx_diagnostic validate -m arnir0/Tiny-LLM --run -v 1 --export onnx-dynamo -o dump_models --patch --opt ir --ortfusiontype ALL + +.. runpython:: + + from onnx_diagnostic._command_lines_parser import main + + main("validate -m arnir0/Tiny-LLM --run -v 1 --export onnx-dynamo -o dump_models --patch --opt ir --ortfusiontype ALL".split())