From 18aead7071f4cf688a52e849fbcbd0bb252612bf Mon Sep 17 00:00:00 2001 From: xadupre Date: Sat, 12 Apr 2025 12:32:36 +0200 Subject: [PATCH 1/3] add ort fusion --- .../ut_torch_models/test_test_helpers.py | 6 + onnx_diagnostic/torch_models/test_helper.py | 108 +++++++++++++++--- 2 files changed, 101 insertions(+), 13 deletions(-) diff --git a/_unittests/ut_torch_models/test_test_helpers.py b/_unittests/ut_torch_models/test_test_helpers.py index a20f509a..6f846983 100644 --- a/_unittests/ut_torch_models/test_test_helpers.py +++ b/_unittests/ut_torch_models/test_test_helpers.py @@ -7,6 +7,7 @@ get_inputs_for_task, validate_model, filter_inputs, + run_ort_fusion, ) from onnx_diagnostic.torch_models.hghub.model_inputs import get_get_inputs_function_for_tasks @@ -69,6 +70,11 @@ def test_validate_model_onnx(self): self.assertIsInstance(summary, dict) self.assertIsInstance(data, dict) self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-4) + onnx_filename = data["onnx_filename"] + output_path = f"{onnx_filename}.ortopt.onnx" + run_ort_fusion( + onnx_filename, output_path, num_attention_heads=2, hidden_size=192, verbose=10 + ) def test_filter_inputs(self): inputs, ds = {"a": 1, "b": 2}, {"a": 20, "b": 30} diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index d79ff78e..d94e125f 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -420,6 +420,7 @@ def validate_model( ) summary.update(summary_export) + dump_stats = None if dump_folder: if "exported_program" in data: ep = data["exported_program"] @@ -435,22 +436,27 @@ def validate_model( epo = data["onnx_program"] if verbose: print(f"[validate_model] dumps onnx program in {dump_folder!r}...") - onnx_file_name = os.path.join(dump_folder, f"{folder_name}.onnx") + onnx_filename = os.path.join(dump_folder, f"{folder_name}.onnx") + begin = time.perf_counter() if isinstance(epo, onnx.model_container.ModelContainer): - epo.save(onnx_file_name, all_tensors_to_one_file=True) + epo.save(onnx_filename, all_tensors_to_one_file=True) else: - epo.save(onnx_file_name, external_data=True) + epo.save(onnx_filename, external_data=True) + duration = time.perf_counter() - begin if verbose: - print("[validate_model] done (dump onnx)") + print(f"[validate_model] done (dump onnx) in {duration}") + data["onnx_filename"] = onnx_filename + summary["time_onnx_save"] = duration if verbose: print(f"[validate_model] dumps statistics in {dump_folder!r}...") - with open(os.path.join(dump_folder, f"{folder_name}.stats"), "w") as f: + dump_stats = os.path.join(dump_folder, f"{folder_name}.stats") + with open(dump_stats, "w") as f: for k, v in sorted(summary.items()): f.write(f":{k}:{v};\n") if verbose: print("[validate_model] done (dump)") - if exporter and exporter.startswith("onnx-") and do_run: + if exporter and exporter.startswith(("onnx-", "custom-")) and do_run: summary_valid, data = validate_onnx_model( data=data, quiet=quiet, @@ -461,6 +467,10 @@ def validate_model( if verbose: print("[validate_model] -- done (final)") + if dump_stats: + with open(dump_stats, "w") as f: + for k, v in sorted(summary.items()): + f.write(f":{k}:{v};\n") return summary, data @@ -642,7 +652,7 @@ def validate_onnx_model( quiet: bool = False, verbose: int = 0, optimization: Optional[str] = None, -): +) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Verifies that an onnx model produces the same expected outputs. @@ -665,10 +675,10 @@ def validate_onnx_model( if d < 0 else ["CUDAExecutionProvider", "CPUExecutionProvider"] ) - if "onnx_file_name" in data: - source = data["onnx_file_name"] + if "onnx_filename" in data: + source = data["onnx_filename"] summary["onnx_filename"] = source - summary["onnx_size"] = os.stats(source).st_size + summary["onnx_size"] = os.stat(source).st_size else: assert ( "onnx_program" in data @@ -745,7 +755,7 @@ def call_torch_export_onnx( quiet: bool = False, verbose: int = 0, optimization: Optional[str] = None, -): +) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Exports a model into onnx. If a patch must be applied, it should be before this functions. @@ -818,7 +828,7 @@ def call_torch_export_onnx( if verbose: print("[call_torch_export_onnx] done (export)") data["onnx_program"] = epo - if verbose > 1: + if verbose > 5: print("[call_torch_export_onnx] -- ONNXProgram") print(epo) print("[call_torch_export_onnx] -- End of ONNXProgram") @@ -850,7 +860,7 @@ def call_torch_export_custom( quiet: bool = False, verbose: int = 0, optimization: Optional[str] = None, -): +) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Exports a model into onnx. If a patch must be applied, it should be before this functions. @@ -1011,3 +1021,75 @@ def call_torch_export_custom( print("[call_torch_export_custom] done (export)") data["onnx_program"] = epo return summary, data + + +def run_ort_fusion( + model_or_path: Union[str, onnx.ModelProto], + output_path: str, + num_attention_heads: int, + hidden_size: int, + model_type: str = "bert", + verbose: int = 0, +) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Runs :epkg:`onnxruntime` fusion optimizer. + + :param model_or_path: path to the ModelProto or the ModelProto itself + :param output_path: the model to save + :param num_attention_heads: number of heads, usually ``config.num_attention_heads`` + :param hidden_size: hidden size, usually ``config.hidden_size`` + :param model_type: type of optimization, see below + :param verbose: verbosity + :return: two dictionaries, summary and data + + Supported values for ``model_type``: + + .. runpython:: + :showcode: + + import pprint + from onnxruntime.transformers.optimizer import MODEL_TYPES + + pprint.pprint(sorted(MODEL_TYPES)) + """ + from onnxruntime.transformers.optimizer import optimize_by_fusion + from onnxruntime.transformers.fusion_options import FusionOptions + + opts = FusionOptions(model_type) + + if isinstance(model_or_path, str): + if verbose: + print(f"[run_ort_fusion] loads {model_or_path!r}") + onx = onnx.load(model_or_path) + else: + onx = model_or_path + begin = time.perf_counter() + n_nodes = len(onx.graph.node) + if verbose: + print( + f"[run_ort_fusion] starts optimization for " + f"model_type={model_type!r} with {n_nodes} nodes" + ) + new_onx = optimize_by_fusion( + onx, + model_type=model_type, + num_heads=num_attention_heads, + hidden_size=hidden_size, + optimization_options=opts, + ) + duration = {time.perf_counter() - begin} + delta = len(new_onx.model.graph.node) + if verbose: + print(f"[run_ort_fusion] done in {duration} with {delta} nodes") + print(f"[run_ort_fusion] save to {output_path!r}") + begin = time.perf_counter() + new_onx.save_model_to_file(output_path, use_external_data_format=True) + d = time.perf_counter() - begin + if verbose: + print(f"[run_ort_fusion] done in {d}") + return { + f"opt_ort_{model_type}_n_nodes1": n_nodes, + f"opt_ort_{model_type}_n_nodes2": delta, + f"opt_ort_{model_type}_duration": duration, + f"opt_ort_{model_type}_duration_save": d, + }, {f"opt_ort_{model_type}": output_path} From 31b633715e67b6695a582dce8b38a5a7d0cd140a Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 14 Apr 2025 11:55:52 +0200 Subject: [PATCH 2/3] ortfusion in command line --- onnx_diagnostic/_command_lines_parser.py | 7 + onnx_diagnostic/torch_models/test_helper.py | 139 ++++++++++++++++---- 2 files changed, 119 insertions(+), 27 deletions(-) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 1e0cb064..7a62c4a1 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -287,6 +287,12 @@ def get_parser_validate() -> ArgumentParser: help="drops the following inputs names, it should be a list " "with comma separated values", ) + parser.add_argument( + "--ortfusiontype", + required=False, + help="applies onnxruntime fusion, this parameter should contain the " + "model type or multiple values separated by |", + ) parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity") parser.add_argument("--dtype", help="changes dtype if necessary") parser.add_argument("--device", help="changes the device if necessary") @@ -338,6 +344,7 @@ def _cmd_validate(argv: List[Any]): exporter=args.export, dump_folder=args.dump_folder, drop_inputs=None if not args.drop else args.drop.split(","), + ortfusiontype=args.ortfusiontype, ) print("") print("-- summary --") diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index d94e125f..62c9bff8 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -197,6 +197,7 @@ def validate_model( stop_if_static: int = 1, dump_folder: Optional[str] = None, drop_inputs: Optional[List[str]] = None, + ortfusiontype: Optional[str] = None, ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]: """ Validates a model. @@ -222,11 +223,33 @@ def validate_model( see :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors` :param dump_folder: dumps everything in a subfolder of this one :param drop_inputs: drops this list of inputs (given their names) + :param ortfusiontype: runs ort fusion, the parameters defines the fusion type, + it accepts multiple values separated by ``|``, + see :func:`onnx_diagnostic.torch_models.test_helper.run_ort_fusion` :return: two dictionaries, one with some metrics, another one with whatever the function produces """ assert not trained, f"trained={trained} not supported yet" summary = version_summary() + + summary.update( + dict( + version_model_id=model_id, + version_do_run=str(do_run), + version_dtype=str(dtype or ""), + version_device=str(device or ""), + version_trained=str(trained), + version_optimization=optimization or "", + version_quiet=str(quiet), + version_patch=str(patch), + version_dump_folder=dump_folder or "", + version_drop_inputs=str(list(drop_inputs or "")), + version_ortfusiontype=ortfusiontype or "", + version_stop_if_static=str(stop_if_static), + version_exporter=exporter, + ) + ) + folder_name = None if dump_folder: folder_name = _make_folder_name( @@ -456,15 +479,66 @@ def validate_model( if verbose: print("[validate_model] done (dump)") - if exporter and exporter.startswith(("onnx-", "custom-")) and do_run: - summary_valid, data = validate_onnx_model( - data=data, - quiet=quiet, - verbose=verbose, - optimization=optimization, - ) + if not exporter or not exporter.startswith(("onnx-", "custom-")): + if verbose: + print("[validate_model] -- done (final)") + if dump_stats: + with open(dump_stats, "w") as f: + for k, v in sorted(summary.items()): + f.write(f":{k}:{v};\n") + return summary, data + + if do_run: + summary_valid, data = validate_onnx_model(data=data, quiet=quiet, verbose=verbose) summary.update(summary_valid) + if ortfusiontype and "onnx_filename" in data: + assert ( + "configuration" in data + ), f"missing configuration in data, cannot run ort fusion for model_id={model_id}" + config = data["configuration"] + assert hasattr( + config, "hidden_size" + ), f"Missing attribute hidden_size in configuration {config}" + hidden_size = config.hidden_size + assert hasattr( + config, "num_attention_heads" + ), f"Missing attribute num_attention_heads in configuration {config}" + num_attention_heads = config.num_attention_heads + + model_types = ortfusiontype.split("|") + for model_type in model_types: + flavour = f"ort{model_type}" + summary[f"version_{flavour}_hidden_size"] = hidden_size + summary[f"version_{flavour}_num_attention_heads"] = num_attention_heads + + begin = time.perf_counter() + if verbose: + print(f"[validate_model] run onnxruntime fusion for {model_type!r}") + input_filename = data["onnx_filename"] + output_path = f"{os.path.splitext(input_filename)[0]}.ort.{model_type}.onnx" + run_ort_fusion( + input_filename, + output_path, + model_type=model_type, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + ) + data[f"onnx_filename_{flavour}"] = output_path + duration = time.perf_counter() - begin + summary[f"time_ortfusion_{flavour}"] = duration + if verbose: + print( + f"[validate_model] done {model_type!r} in {duration}, " + f"saved into {output_path!r}" + ) + + if do_run: + summary_valid, data = validate_onnx_model( + data=data, quiet=quiet, verbose=verbose, flavour=flavour + ) + summary.update(summary_valid) + if verbose: print("[validate_model] -- done (final)") if dump_stats: @@ -651,22 +725,27 @@ def validate_onnx_model( data: Dict[str, Any], quiet: bool = False, verbose: int = 0, - optimization: Optional[str] = None, + flavour: Optional[str] = None, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Verifies that an onnx model produces the same - expected outputs. + expected outputs. It uses ``data["onnx_filename]`` as the input + onnx filename or ``data["onnx_filename_{flavour}]`` if *flavour* + is specified. :param data: dictionary with all the necessary inputs, the dictionary must contains keys ``model`` and ``inputs_export`` :param quiet: catch exception or not :param verbose: verbosity - :param optimization: optimization to do + :param flavour: use a different version of the inputs :return: two dictionaries, one with some metrics, another one with whatever the function produces """ import onnxruntime + def _mk(key): + return f"{key}_{flavour}" if flavour else key + summary = {} flat_inputs = flatten_object(data["inputs"], drop_keys=True) d = flat_inputs[0].get_device() @@ -675,36 +754,42 @@ def validate_onnx_model( if d < 0 else ["CUDAExecutionProvider", "CPUExecutionProvider"] ) - if "onnx_filename" in data: - source = data["onnx_filename"] - summary["onnx_filename"] = source - summary["onnx_size"] = os.stat(source).st_size + input_data_key = f"onnx_filename_{flavour}" if flavour else "onnx_filename" + + if input_data_key in data: + source = data[input_data_key] + summary[input_data_key] = source + summary[_mk("onnx_size")] = os.stat(source).st_size else: + assert not flavour, f"flavour={flavour!r}, the filename must be saved." assert ( "onnx_program" in data ), f"onnx_program is missing from data which has {sorted(data)}" source = data["onnx_program"].model_proto.SerializeToString() assert len(source) < 2**31, f"The model is highger than 2Gb: {len(source) / 2**30} Gb" - summary["onnx_size"] = len(source) + summary[_mk("onnx_size")] = len(source) if verbose: - print(f"[validate_onnx_model] verify onnx model with providers {providers}...") + print( + f"[validate_onnx_model] verify onnx model with providers " + f"{providers}..., flavour={flavour!r}" + ) begin = time.perf_counter() if quiet: try: sess = onnxruntime.InferenceSession(source, providers=providers) except Exception as e: - summary["ERR_onnx_ort_create"] = str(e) - data["ERR_onnx_ort_create"] = e - summary["time_onnx_ort_create"] = time.perf_counter() - begin + summary[_mk("ERR_onnx_ort_create")] = str(e) + data[_mk("ERR_onnx_ort_create")] = e + summary[_mk("time_onnx_ort_create")] = time.perf_counter() - begin return summary, data else: sess = onnxruntime.InferenceSession(source, providers=providers) - summary["time_onnx_ort_create"] = time.perf_counter() - begin - data["onnx_ort_sess"] = sess + summary[_mk("time_onnx_ort_create")] = time.perf_counter() - begin + data[_mk("onnx_ort_sess")] = sess if verbose: - print("[validate_onnx_model] done (ort_session)") + print(f"[validate_onnx_model] done (ort_session) flavour={flavour!r}") # make_feeds if verbose: @@ -718,7 +803,7 @@ def validate_onnx_model( ) if verbose: print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}") - summary["onnx_ort_inputs"] = string_type(feeds, with_shape=True) + summary[_mk("onnx_ort_inputs")] = string_type(feeds, with_shape=True) if verbose: print("[validate_onnx_model] done (make_feeds)") @@ -730,9 +815,9 @@ def validate_onnx_model( try: got = sess.run(None, feeds) except Exception as e: - summary["ERR_onnx_ort_run"] = str(e) - data["ERR_onnx_ort_run"] = e - summary["time_onnx_ort_run"] = time.perf_counter() - begin + summary[_mk("ERR_onnx_ort_run")] = str(e) + data[_mk("ERR_onnx_ort_run")] = e + summary[_mk("time_onnx_ort_run")] = time.perf_counter() - begin return summary, data else: got = sess.run(None, feeds) @@ -745,7 +830,7 @@ def validate_onnx_model( if verbose: print(f"[validate_onnx_model] discrepancies={string_diff(disc)}") for k, v in disc.items(): - summary[f"disc_onnx_ort_run_{k}"] = v + summary[_mk(f"disc_onnx_ort_run_{k}")] = v return summary, data From 84e5d72e22a05dfbd9566f3325e969edff17367a Mon Sep 17 00:00:00 2001 From: xadupre Date: Mon, 14 Apr 2025 12:27:16 +0200 Subject: [PATCH 3/3] mypy --- CHANGELOGS.rst | 1 + onnx_diagnostic/torch_models/test_helper.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 826cf660..2220e036 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.4.0 +++++ +* :pr:`50`: add support for onnxruntime fusion * :pr:`48`: add support for EncoderDecoderCache, test with openai/whisper-tiny * :pr:`45`: improve change_dynamic_dimension to fix some dimensions diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index 62c9bff8..a6256345 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -246,7 +246,7 @@ def validate_model( version_drop_inputs=str(list(drop_inputs or "")), version_ortfusiontype=ortfusiontype or "", version_stop_if_static=str(stop_if_static), - version_exporter=exporter, + version_exporter=exporter or "", ) )