diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e9aa96db..4b03de8f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,7 @@ jobs: matrix: os: [ubuntu-latest] python: ['3.11', '3.12'] - transformers: ['4.48.3', '4.51.2', 'main'] + transformers: ['4.48.3', '4.51.3', 'main'] torch: ['2.6', 'main'] steps: diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 896c3541..738708e9 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.4.0 +++++ +* :pr:`65`: support SlidingWindowCache * :pr:`63`: support option ``--trained`` * :pr:`61`: improves dynamic shapes for EncoderDecoderCache * :pr:`58`: add function use_dyn_not_str to replace string by ``torch.export.Dim.DYNAMIC``, diff --git a/_doc/conf.py b/_doc/conf.py index 0545e77e..e82b88db 100644 --- a/_doc/conf.py +++ b/_doc/conf.py @@ -123,6 +123,7 @@ ("py:class", "transformers.cache_utils.DynamicCache"), ("py:class", "transformers.cache_utils.EncoderDecoderCache"), ("py:class", "transformers.cache_utils.MambaCache"), + ("py:class", "transformers.cache_utils.SlidingWindowCache"), ("py:class", "transformers.configuration_utils.PretrainedConfig"), ("py:func", "torch.export._draft_export.draft_export"), ("py:func", "torch._export.tools.report_exportability"), @@ -187,6 +188,7 @@ "ExecuTorch": "https://pytorch.org/executorch/stable/intro-overview.html", "ExecuTorch Runtime Python API Reference": "https://pytorch.org/executorch/stable/runtime-python-api-reference.html", "ExecuTorch Tutorial": "https://pytorch.org/executorch/stable/tutorials/export-to-executorch-tutorial.html", + "experimental-experiment": "https://sdpython.github.io/doc/experimental-experiment/dev/", "JIT": "https://en.wikipedia.org/wiki/Just-in-time_compilation", "FunctionProto": "https://onnx.ai/onnx/api/classes.html#functionproto", "graph break": "https://pytorch.org/docs/stable/torch.compiler_faq.html#graph-breaks", diff --git a/_unittests/ut_helpers/test_cache_helper.py b/_unittests/ut_helpers/test_cache_helper.py index 6e2ac748..0b752a0d 100644 --- a/_unittests/ut_helpers/test_cache_helper.py +++ b/_unittests/ut_helpers/test_cache_helper.py @@ -1,12 +1,14 @@ import unittest import torch import transformers -from onnx_diagnostic.ext_test_case import ExtTestCase +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import ( + flatten_unflatten_for_dynamic_shapes, make_dynamic_cache, make_encoder_decoder_cache, - flatten_unflatten_for_dynamic_shapes, + make_mamba_cache, + make_sliding_window_cache, ) from onnx_diagnostic.export import CoupleInputsDynamicShapes from onnx_diagnostic.torch_export_patches.patch_inputs import ( @@ -132,6 +134,37 @@ def test_unflatten_flatten_encoder_decoder_cache(self): self.string_type(c2, with_shape=True), ) + @requires_transformers("4.51") # the structure changes + def test_make_mamba_cache(self): + cache = make_mamba_cache( + [ + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + (torch.rand((4, 4, 4)), torch.rand((4, 4, 4))), + ] + ) + text = self.string_type(cache, with_shape=True) + self.assertEqual( + "MambaCache(conv_states=#3[T10s4x4x4,T10s4x4x4,T10s4x4x4], " + "ssm_states=#3[T10s4x4x4,T10s4x4x4,T10s4x4x4])", + text, + ) + + def test_make_sliding_window_cache(self): + cache = make_sliding_window_cache( + [ + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))), + ] + ) + text = self.string_type(cache, with_shape=True) + self.assertEqual( + "SlidingWindowCache(key_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7], " + "value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])", + text, + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization.py b/_unittests/ut_torch_export_patches/test_patch_serialization.py index 5cbe258d..e4d207a3 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization.py @@ -1,10 +1,11 @@ import unittest import torch from transformers.modeling_outputs import BaseModelOutput -from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_torch from onnx_diagnostic.helpers.cache_helper import ( make_encoder_decoder_cache, make_dynamic_cache, + make_sliding_window_cache, flatten_unflatten_for_dynamic_shapes, ) from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( @@ -164,6 +165,53 @@ def test_base_model_output_unflatten_flatten(self): self.assertIsInstance(unflat, dict) self.assertEqual(list(unflat), ["last_hidden_state"]) + @ignore_warnings(UserWarning) + def test_base_sliding_window_cache_unflatten_flatten(self): + cache = make_sliding_window_cache( + [(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))] + ) + with bypass_export_some_errors(): + cache2 = torch_deepcopy([cache]) + self.assertEqualAny([cache], cache2) + + @ignore_warnings(UserWarning) + @requires_torch("2.7") + def test_sliding_window_cache_export(self): + class Model(torch.nn.Module): + def forward(self, cache): + return cache.key_cache[0] + + cache = make_sliding_window_cache( + [ + (torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))), + (torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))), + ] + ) + model = Model() + model(cache) + DYN = torch.export.Dim.DYNAMIC + ds = [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]] + + with bypass_export_some_errors(patch_transformers=True): + torch.export.export(model, (cache,), dynamic_shapes=(ds,)) + + @ignore_warnings(UserWarning) + def test_sliding_window_cache_flatten(self): + cache = make_sliding_window_cache( + [(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))] + ) + with bypass_export_some_errors(): + flat, _spec = torch.utils._pytree.tree_flatten(cache) + self.assertEqual( + "#2[T1s4x4x4x4,T1s4x4x4x4]", + self.string_type(flat, with_shape=True), + ) + cache2 = torch.utils._pytree.tree_unflatten(flat, _spec) + self.assertEqual( + self.string_type(cache, with_shape=True, with_min_max=True), + self.string_type(cache2, with_shape=True, with_min_max=True), + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index d1edec06..631fde3f 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1,3 +1,4 @@ +import argparse import json import sys import textwrap @@ -227,6 +228,21 @@ def _cmd_config(argv: List[Any]): print(f"task: {task_from_id(args.mid)}") +class _ParseDict(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + d = getattr(namespace, self.dest) or {} + + if values: + for item in values: + split_items = item.split("=", 1) + key = split_items[0].strip() # we remove blanks around keys, as is logical + value = split_items[1] + + d[key] = value + + setattr(namespace, self.dest, d) + + def get_parser_validate() -> ArgumentParser: parser = ArgumentParser( prog="test", @@ -297,6 +313,14 @@ def get_parser_validate() -> ArgumentParser: parser.add_argument("-v", "--verbose", default=0, type=int, help="verbosity") parser.add_argument("--dtype", help="changes dtype if necessary") parser.add_argument("--device", help="changes the device if necessary") + parser.add_argument( + "--iop", + metavar="KEY=VALUE", + nargs="*", + help="Additional input options, use to change the default " + "inputs use to export, example: --iop cls_cache=SlidingWindowCache", + action=_ParseDict, + ) return parser @@ -346,6 +370,7 @@ def _cmd_validate(argv: List[Any]): dump_folder=args.dump_folder, drop_inputs=None if not args.drop else args.drop.split(","), ortfusiontype=args.ortfusiontype, + input_options=args.iop, ) print("") print("-- summary --") diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index fb786e38..2fd9092e 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -920,7 +920,7 @@ def assertEqualAny( else: for e, g in zip(expected, value): self.assertEqualAny(e, g, msg=msg, atol=atol, rtol=rtol) - elif expected.__class__.__name__ == "DynamicCache": + elif expected.__class__.__name__ in ("DynamicCache", "SlidingWindowCache"): self.assertEqual(type(expected), type(value), msg=msg) atts = ["key_cache", "value_cache"] self.assertEqualAny( diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index c14594d0..2040e6ff 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -26,12 +26,8 @@ def flatten_unflatten_for_dynamic_shapes(obj: Any, use_dict: bool = False) -> An subtrees = [] for subspec in spec.children_specs: end += subspec.num_leaves - if use_dict and (subspec.type is dict or subspec.context): - value = subspec.unflatten(flat[start:end]) - value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict) - else: - value = subspec.unflatten(flat[start:end]) - value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict) + value = subspec.unflatten(flat[start:end]) + value = flatten_unflatten_for_dynamic_shapes(value, use_dict=use_dict) subtrees.append(value) start = end if use_dict and (spec.type is dict or spec.context): @@ -185,3 +181,36 @@ def __init__(self): ) cache.ssm_states[i][:, :, :] = key_value_pairs[i][1] return cache + + +def make_sliding_window_cache( + key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]], +) -> transformers.cache_utils.MambaCache: + "Creates a :class:`transformers.cache_utils.SlidingWindowCache`." + + class _config: + def __init__(self): + self.head_dim = key_value_pairs[0][0].shape[-1] + self.num_attention_heads = key_value_pairs[0][0].shape[1] + self.num_hidden_layers = len(key_value_pairs) + self.sliding_window = key_value_pairs[0][0].shape[2] + + cache = transformers.cache_utils.SlidingWindowCache( + _config(), + max_batch_size=key_value_pairs[0][0].shape[0], + max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window + device=key_value_pairs[0][0].device, + dtype=key_value_pairs[0][0].dtype, + ) + for i in range(len(key_value_pairs)): + assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, ( + f"Shape mismatch, expected {cache.key_cache[i].shape}, " + f"got {key_value_pairs[i][0].shape}" + ) + cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0] + assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, ( + f"Shape mismatch, expected {cache.value_cache[i].shape}, " + f"got {key_value_pairs[i][1].shape}" + ) + cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1] + return cache diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index f84064f0..ae6379f1 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -534,7 +534,7 @@ def string_type( print(f"[string_type] CACHE1:{type(obj)}") return f"MambaCache(conv_states={c}, ssm_states={d})" - if obj.__class__.__name__ == "DynamicCache": + if obj.__class__.__name__ in ("DynamicCache", "SlidingWindowCache"): kc = string_type( obj.key_cache, with_shape=with_shape, diff --git a/onnx_diagnostic/helpers/torch_test_helper.py b/onnx_diagnostic/helpers/torch_test_helper.py index d7bec618..e8ede458 100644 --- a/onnx_diagnostic/helpers/torch_test_helper.py +++ b/onnx_diagnostic/helpers/torch_test_helper.py @@ -4,7 +4,11 @@ import numpy as np import torch from .helper import string_type -from .cache_helper import make_dynamic_cache, make_encoder_decoder_cache +from .cache_helper import ( + make_dynamic_cache, + make_encoder_decoder_cache, + make_sliding_window_cache, +) def _forward_(*args, _f=None, _context=None, **kwargs): @@ -363,6 +367,10 @@ def torch_deepcopy(value: Any) -> Any: return make_dynamic_cache( torch_deepcopy(list(zip(value.key_cache, value.value_cache))) ) + if value.__class__.__name__ == "SlidingWindowCache": + return make_sliding_window_cache( + torch_deepcopy(list(zip(value.key_cache, value.value_cache))) + ) if value.__class__.__name__ == "EncoderDecoderCache": return make_encoder_decoder_cache( torch_deepcopy(value.self_attention_cache), diff --git a/onnx_diagnostic/tasks/text_generation.py b/onnx_diagnostic/tasks/text_generation.py index 7ce039ed..551a3985 100644 --- a/onnx_diagnostic/tasks/text_generation.py +++ b/onnx_diagnostic/tasks/text_generation.py @@ -1,6 +1,11 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union import torch -from ..helpers.cache_helper import make_dynamic_cache, make_mamba_cache +import transformers +from ..helpers.cache_helper import ( + make_dynamic_cache, + make_mamba_cache, + make_sliding_window_cache, +) from ..helpers.config_helper import update_config, check_hasattr, _pick __TASK__ = "text-generation" @@ -88,6 +93,10 @@ def get_inputs( cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096) if config is not None and config.__class__.__name__ == "FalconMambaConfig": + assert cls_cache in ( + "MambaCache", + transformers.cache_utils.MambaCache, + ), f"Unexpected value for cls_cache={cls_cache} and config={config}" seq_length_multiple = 8 sequence_length = ( (sequence_length + seq_length_multiple) @@ -156,6 +165,13 @@ def get_inputs( [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)], ], } + + make_cache = ( + make_sliding_window_cache + if cls_cache in ("SlidingWindowCache", transformers.cache_utils.SlidingWindowCache) + else make_dynamic_cache + ) + inputs = dict( input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length2)).to( torch.int64 @@ -166,7 +182,7 @@ def get_inputs( position_ids=torch.arange(sequence_length, sequence_length + sequence_length2) .to(torch.int64) .expand((batch_size, -1)), - past_key_values=make_dynamic_cache( + past_key_values=make_cache( [ ( torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim), diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index dcb52b7d..93109037 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -4,7 +4,12 @@ import optree import torch import transformers -from transformers.cache_utils import DynamicCache, MambaCache, EncoderDecoderCache +from transformers.cache_utils import ( + DynamicCache, + MambaCache, + EncoderDecoderCache, + SlidingWindowCache, +) from transformers.modeling_outputs import BaseModelOutput from ..helpers import string_type @@ -112,41 +117,43 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: # To avoid doing it multiple times. PATCH_OF_PATCHES.add(BaseModelOutput) - unregistered_dynamic_cache = _register_class_serialization( - DynamicCache, - flatten_dynamic_cache, - unflatten_dynamic_cache, - flatten_with_keys_dynamic_cache, - # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), - verbose=verbose, - ) - unregistered_base_model_output = _register_class_serialization( - BaseModelOutput, - flatten_base_model_output, - unflatten_base_model_output, - flatten_with_keys_base_model_output, - verbose=verbose, - ) - unregistered_encode_decode_cache = _register_class_serialization( - EncoderDecoderCache, - flatten_encoder_decoder_cache, - unflatten_encoder_decoder_cache, - flatten_with_keys_encoder_decoder_cache, - verbose=verbose, - ) - unregistered_mamba_cache = _register_class_serialization( - MambaCache, - flatten_mamba_cache, - unflatten_mamba_cache, - flatten_with_keys_mamba_cache, - verbose=verbose, - ) - return dict( - DynamicCache=unregistered_dynamic_cache, - MambaCache=unregistered_mamba_cache, - EncoderDecoderCache=unregistered_encode_decode_cache, - BaseModelOutput=unregistered_base_model_output, + DynamicCache=_register_class_serialization( + DynamicCache, + flatten_dynamic_cache, + unflatten_dynamic_cache, + flatten_with_keys_dynamic_cache, + # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), + verbose=verbose, + ), + MambaCache=_register_class_serialization( + MambaCache, + flatten_mamba_cache, + unflatten_mamba_cache, + flatten_with_keys_mamba_cache, + verbose=verbose, + ), + EncoderDecoderCache=_register_class_serialization( + EncoderDecoderCache, + flatten_encoder_decoder_cache, + unflatten_encoder_decoder_cache, + flatten_with_keys_encoder_decoder_cache, + verbose=verbose, + ), + BaseModelOutput=_register_class_serialization( + BaseModelOutput, + flatten_base_model_output, + unflatten_base_model_output, + flatten_with_keys_base_model_output, + verbose=verbose, + ), + SlidingWindowCache=_register_class_serialization( + SlidingWindowCache, + flatten_sliding_window_cache, + unflatten_sliding_window_cache, + flatten_with_keys_sliding_window_cache, + verbose=verbose, + ), ) @@ -279,6 +286,60 @@ def unflatten_dynamic_cache( return cache +#################### +# SlidingWindowCache +#################### + + +def flatten_sliding_window_cache( + cache: SlidingWindowCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """ + Serializes a :class:`transformers.cache_utils.SlidingWindowCache` + with python objects. + """ + flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)] + return [f[1] for f in flat], [f[0] for f in flat] + + +def flatten_with_keys_sliding_window_cache( + cache: SlidingWindowCache, +) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """ + Serializes a :class:`transformers.cache_utils.SlidingWindowCache` + with python objects. + """ + values, context = flatten_sliding_window_cache(cache) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +def unflatten_sliding_window_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> SlidingWindowCache: + """Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects.""" + key_cache, value_cache = values + + class _config: + def __init__(self): + self.head_dim = key_cache[0].shape[-1] + self.num_attention_heads = key_cache[0].shape[1] + self.num_hidden_layers = len(key_cache) + self.sliding_window = key_cache[0].shape[2] + + cache = SlidingWindowCache( + _config(), + max_batch_size=key_cache[0].shape[0], + max_cache_len=key_cache[0].shape[2], # sligding window + device=key_cache[0].device, + dtype=key_cache[0].dtype, + ) + + values = dict(zip(context, values)) + for k, v in values.items(): + setattr(cache, k, v) + return cache + + ##################### # EncoderDecoderCache ##################### diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index 400a7385..a226d488 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -97,6 +97,8 @@ def get_untrained_model_with_inputs( # input kwargs kwargs, fct = random_input_kwargs(config, task) + if verbose: + print(f"[get_untrained_model_with_inputs] use fct={fct}") if inputs_kwargs: kwargs.update(inputs_kwargs) diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index 45fe4e94..e3b5e53e 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -1,7 +1,7 @@ import datetime import inspect import os -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import time import onnx import torch @@ -180,6 +180,29 @@ def version_summary() -> Dict[str, Union[int, float, str]]: return summary +def _quiet_or_not_quiet( + quiet: bool, + suffix: str, + summary: Dict[str, Any], + data: Optional[Dict[str, Any]], + fct: Callable, +) -> Any: + begin = time.perf_counter() + if quiet: + try: + return fct() + except Exception as e: + summary[f"ERR_{suffix}"] = str(e) + summary[f"time_{suffix}"] = time.perf_counter() - begin + if data is None: + return {f"ERR_{suffix}": e} + data[f"ERR_{suffix}"] = e + return None + res = fct() + summary[f"time_{suffix}"] = time.perf_counter() - begin + return res + + def validate_model( model_id: str, task: Optional[str] = None, @@ -197,6 +220,7 @@ def validate_model( dump_folder: Optional[str] = None, drop_inputs: Optional[List[str]] = None, ortfusiontype: Optional[str] = None, + input_options: Optional[Dict[str, Any]] = None, ) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]: """ Validates a model. @@ -225,6 +249,8 @@ def validate_model( :param ortfusiontype: runs ort fusion, the parameters defines the fusion type, it accepts multiple values separated by ``|``, see :func:`onnx_diagnostic.torch_models.test_helper.run_ort_fusion` + :param input_options: additional options to define the dummy inputs + used to export :return: two dictionaries, one with some metrics, another one with whatever the function produces """ @@ -263,24 +289,26 @@ def validate_model( if verbose: print(f"[validate_model] validate model id {model_id!r}") - print("[validate_model] get dummy inputs...") + print(f"[validate_model] get dummy inputs with input_options={input_options}...") summary["model_id"] = model_id - begin = time.perf_counter() - if quiet: - try: - data = get_untrained_model_with_inputs( - model_id, verbose=verbose, task=task, same_as_pretrained=trained + iop = input_options or {} + data = _quiet_or_not_quiet( + quiet, + "create", + summary, + None, + ( + lambda mid=model_id, v=verbose, task=task, tr=trained, iop=iop: ( + get_untrained_model_with_inputs( + mid, verbose=v, task=task, same_as_pretrained=tr, inputs_kwargs=iop + ) ) - except Exception as e: - summary["ERR_create"] = str(e) - data["ERR_create"] = e - summary["time_create"] = time.perf_counter() - begin - return summary, {} - else: - data = get_untrained_model_with_inputs( - model_id, verbose=verbose, task=task, same_as_pretrained=trained - ) + ), + ) + data["input_options"] = input_options + if "ERR_create" in summary: + return summary, data if drop_inputs: if verbose: @@ -316,14 +344,14 @@ def validate_model( data["inputs"] = to_any(data["inputs"], device) # type: ignore summary["model_device"] = str(device) - summary["time_create"] = time.perf_counter() - begin for k in ["task", "size", "n_weights"]: summary[f"model_{k.replace('_','')}"] = data[k] - summary["model_inputs"] = string_type(data["inputs"], with_shape=True) - summary["model_shapes"] = string_type(str(data["dynamic_shapes"])) - summary["model_class"] = data["model"].__class__.__name__ - summary["model_config_class"] = data["configuration"].__class__.__name__ - summary["model_config"] = str(data["configuration"].to_dict()).replace(" ", "") + summary["model_inputs_opionts"] = str(input_options or "") + summary["model_inputs"] = string_type(data["inputs"], with_shape=True) + summary["model_shapes"] = string_type(str(data["dynamic_shapes"])) + summary["model_class"] = data["model"].__class__.__name__ + summary["model_config_class"] = data["configuration"].__class__.__name__ + summary["model_config"] = str(data["configuration"].to_dict()).replace(" ", "") summary["model_id"] = model_id if verbose: @@ -344,18 +372,14 @@ def validate_model( # We make a copy of the input just in case the model modifies them inplace hash_inputs = string_type(data["inputs"], with_shape=True) inputs = torch_deepcopy(data["inputs"]) - begin = time.perf_counter() - if quiet: - try: - expected = data["model"](**inputs) - except Exception as e: - summary["ERR_run"] = str(e) - data["ERR_run"] = e - summary["time_run"] = time.perf_counter() - begin - return summary, data - else: - expected = data["model"](**inputs) - summary["time_run"] = time.perf_counter() - begin + model = data["model"] + + expected = _quiet_or_not_quiet( + quiet, "run", summary, data, (lambda m=model, inp=inputs: m(**inp)) + ) + if "ERR_run" in summary: + return summary, data + summary["model_expected"] = string_type(expected, with_shape=True) if verbose: print("[validate_model] done (run)") @@ -397,18 +421,18 @@ def validate_model( # We make a copy of the input just in case the model modifies them inplace inputs = torch_deepcopy(data["inputs_export"]) - begin = time.perf_counter() - if quiet: - try: - expected = data["model"](**inputs) - except Exception as e: - summary["ERR_run_patched"] = str(e) - data["ERR_run_patched"] = e - summary["time_run_patched"] = time.perf_counter() - begin - return summary, data - else: - expected = data["model"](**inputs) - summary["time_run_patched"] = time.perf_counter() - begin + model = data["model"] + + expected = _quiet_or_not_quiet( + quiet, + "run_patched", + summary, + data, + (lambda m=model, inp=inputs: m(**inp)), + ) + if "ERR_run_patched" in summary: + return summary, data + disc = max_diff(data["expected"], expected) for k, v in disc.items(): summary[f"disc_patched_{k}"] = v @@ -578,7 +602,7 @@ def call_exporter( :return: two dictionaries, one with some metrics, another one with whatever the function produces """ - if exporter.startswith("export-"): + if exporter == "export" or exporter.startswith("export-"): # torch export summary, data = call_torch_export_export( exporter=exporter, @@ -673,23 +697,21 @@ def call_torch_export_export( print(f"[call_torch_export_export] dynamic_shapes_export_export={string_type(dse)}") print("[call_torch_export_export] export...") - begin = time.perf_counter() - if quiet: - try: - ep = torch.export.export( - data["model"], args, kwargs=kwargs, dynamic_shapes=dse, strict=strict + model = data["model"] + ep = _quiet_or_not_quiet( + quiet, + "export_export", + summary, + data, + ( + lambda m=model, args=args, kws=kwargs, dse=dse, s=strict: ( + torch.export.export(m, args, kwargs=kws, dynamic_shapes=dse, strict=s) ) - except Exception as e: - summary["ERR_export_export"] = str(e) - data["ERR_export_export"] = e - summary["time_export_export"] = time.perf_counter() - begin - return summary, data - else: - ep = torch.export.export( - data["model"], args, kwargs=kwargs, dynamic_shapes=dse, strict=strict - ) + ), + ) + if "ERR_export_export" in summary: + return summary, data - summary["time_export_export"] = time.perf_counter() - begin summary["export_graph_nodes"] = len(ep.graph.nodes) if verbose: print( @@ -715,18 +737,17 @@ def call_torch_export_export( # We make a copy of the input just in case the model modifies them inplace inputs = torch_deepcopy(data["inputs_export"]) model = ep.module() - begin = time.perf_counter() - if quiet: - try: - expected = model(**inputs) - except Exception as e: - summary["ERR_run_exported"] = str(e) - data["ERR_run_exported"] = e - summary["time_run_exported"] = time.perf_counter() - begin - return summary, data - else: - expected = model(**inputs) - summary["time_run_exported"] = time.perf_counter() - begin + + expected = _quiet_or_not_quiet( + quiet, + "run_exported", + summary, + data, + (lambda m=model, inputs=inputs: (model(**inputs))), + ) + if "ERR_export_export" in summary: + return summary, data + disc = max_diff(data["expected"], expected) for k, v in disc.items(): summary[f"disc_exported_{k}"] = v @@ -797,19 +818,20 @@ def _mk(key): f"{providers}..., flavour={flavour!r}" ) - begin = time.perf_counter() - if quiet: - try: - sess = onnxruntime.InferenceSession(source, providers=providers) - except Exception as e: - summary[_mk("ERR_onnx_ort_create")] = str(e) - data[_mk("ERR_onnx_ort_create")] = e - summary[_mk("time_onnx_ort_create")] = time.perf_counter() - begin - return summary, data - else: - sess = onnxruntime.InferenceSession(source, providers=providers) + sess = _quiet_or_not_quiet( + quiet, + _mk("time_onnx_ort_create"), + summary, + data, + ( + lambda source=source, providers=providers: onnxruntime.InferenceSession( + source, providers=providers + ) + ), + ) + if f"ERR_{_mk('time_onnx_ort_create')}" in summary: + return summary, data - summary[_mk("time_onnx_ort_create")] = time.perf_counter() - begin data[_mk("onnx_ort_sess")] = sess if verbose: print(f"[validate_onnx_model] done (ort_session) flavour={flavour!r}") @@ -833,17 +855,17 @@ def _mk(key): # run ort if verbose: print("[validate_onnx_model] run session...") - begin = time.perf_counter() - if quiet: - try: - got = sess.run(None, feeds) - except Exception as e: - summary[_mk("ERR_onnx_ort_run")] = str(e) - data[_mk("ERR_onnx_ort_run")] = e - summary[_mk("time_onnx_ort_run")] = time.perf_counter() - begin - return summary, data - else: - got = sess.run(None, feeds) + + got = _quiet_or_not_quiet( + quiet, + _mk("time_onnx_ort_run"), + summary, + data, + (lambda sess=sess, feeds=feeds: sess.run(None, feeds)), + ) + if f"ERR_{_mk('time_onnx_ort_run')}" in summary: + return summary, data + if verbose: print("[validate_onnx_model] done (run)") print(f"[validate_onnx_model] got={string_type(got, with_shape=True)}") @@ -931,29 +953,27 @@ def call_torch_export_onnx( f"[call_torch_export_onnx] export_export_kwargs=" f"{string_type(export_export_kwargs, with_shape=True)}" ) - begin = time.perf_counter() - if quiet: - try: - epo = torch.onnx.export( - data["model"], - args, - kwargs=kwargs, - **export_export_kwargs, + model = data["model"] + + epo = _quiet_or_not_quiet( + quiet, + "export_onnx", + summary, + data, + ( + lambda m=model, args=args, kws=kwargs, ekws=export_export_kwargs: ( + torch.onnx.export( + m, + args, + kwargs=kws, + **ekws, + ) ) - except Exception as e: - summary["ERR_export_export"] = str(e) - data["ERR_export_export"] = e - summary["time_export_export"] = time.perf_counter() - begin - return summary, data - else: - epo = torch.onnx.export( - data["model"], - args, - kwargs=kwargs, - **export_export_kwargs, - ) + ), + ) + if "ERR_export_onnx" in summary: + return summary, data - summary["time_export_export"] = time.perf_counter() - begin assert epo is not None, "no onnx export was found" if verbose: print("[call_torch_export_onnx] done (export)") @@ -963,21 +983,18 @@ def call_torch_export_onnx( print(epo) print("[call_torch_export_onnx] -- End of ONNXProgram") - begin = time.perf_counter() if optimization == "ir": if verbose: print(f"[call_torch_export_onnx] starts optimization={optimization!r}...") - if quiet: - try: - epo.optimize() - except Exception as e: - summary["ERR_export_optimize_ir"] = str(e) - data["ERR_export_optimize_ir"] = e - summary["time_export_optimize_ir"] = time.perf_counter() - begin - return summary, data - else: - epo.optimize() - summary["time_export_optimize_ir"] = time.perf_counter() - begin + _quiet_or_not_quiet( + quiet, + "export_onnx_opt_ir", + summary, + data, + (lambda epo=epo: epo.optimize()), + ) + if "ERR_export_onnx_opt_ir" in summary: + return summary, data if verbose: print("[call_torch_export_onnx] done (optimization)") @@ -1050,40 +1067,35 @@ def call_torch_export_custom( ), ) options = OptimizationOptions(patterns=optimization) if optimization else None + model = data["model"] + kws = dict( + dynamic_shapes=ds, + export_options=export_options, + options=options, + optimize=bool(optimization), + large_model=True, + return_optimize_report=True, + verbose=max(verbose - 2, 0), + ) - begin = time.perf_counter() - if quiet: - try: - epo, opt_stats = to_onnx( - data["model"], - args, - kwargs=kwargs, - dynamic_shapes=ds, - export_options=export_options, - options=options, - optimize=bool(optimization), - large_model=True, - return_optimize_report=True, - verbose=max(verbose - 2, 0), + epo, opt_stats = _quiet_or_not_quiet( + quiet, + "export_export_onnx_c", + summary, + data, + ( + lambda m=model, args=args, kwargs=kwargs, kws=kws: ( + to_onnx( + model, + args, + kwargs=kwargs, + **kws, + ) ) - except Exception as e: - summary["ERR_export_export"] = str(e) - data["ERR_export_export"] = e - summary["time_export_export"] = time.perf_counter() - begin - return summary, data - else: - epo, opt_stats = to_onnx( - data["model"], - args, - kwargs=kwargs, - dynamic_shapes=ds, - export_options=export_options, - options=options, - optimize=bool(optimization), - large_model=True, - return_optimize_report=True, - verbose=max(verbose - 2, 0), - ) + ), + ) + if "ERR_export_onnx_c" in summary: + return summary, data new_stat = {} if "optimization" in opt_stats: @@ -1147,7 +1159,6 @@ def call_torch_export_custom( ) ) - summary["time_export_export"] = time.perf_counter() - begin summary.update(new_stat) assert epo is not None, "no onnx export was found" if verbose: