diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 460d6881..ff6a7916 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -16,7 +16,7 @@ jobs: matrix: os: [ubuntu-latest] python: ['3.11', '3.12'] - transformers: ['4.48.3', '4.51.1', 'main'] + transformers: ['4.48.3', '4.51.2', 'main'] torch: ['2.6', 'main'] steps: diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index ec1fc68b..826cf660 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.4.0 +++++ +* :pr:`48`: add support for EncoderDecoderCache, test with openai/whisper-tiny * :pr:`45`: improve change_dynamic_dimension to fix some dimensions 0.3.0 diff --git a/_doc/api/helpers/index.rst b/_doc/api/helpers/index.rst index 8974298c..79703adf 100644 --- a/_doc/api/helpers/index.rst +++ b/_doc/api/helpers/index.rst @@ -13,6 +13,7 @@ onnx_diagnostic.helpers memory_peak onnx_helper ort_session + rt_helper torch_test_helper .. autofunction:: onnx_diagnostic.helpers.max_diff diff --git a/_doc/api/helpers/rt_helper.rst b/_doc/api/helpers/rt_helper.rst new file mode 100644 index 00000000..2b7c6383 --- /dev/null +++ b/_doc/api/helpers/rt_helper.rst @@ -0,0 +1,7 @@ + +onnx_diagnostic.helpers.rt_helper +================================= + +.. automodule:: onnx_diagnostic.helpers.rt_helper + :members: + :no-undoc-members: diff --git a/_doc/examples/plot_export_tiny_phi2.py b/_doc/examples/plot_export_tiny_phi2.py index c15dd025..7022c981 100644 --- a/_doc/examples/plot_export_tiny_phi2.py +++ b/_doc/examples/plot_export_tiny_phi2.py @@ -25,7 +25,7 @@ from onnx_diagnostic import doc from onnx_diagnostic.helpers import max_diff, string_diff, string_type from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered -from onnx_diagnostic.helpers.ort_session import make_feeds +from onnx_diagnostic.helpers.rt_helper import make_feeds from onnx_diagnostic.torch_export_patches import bypass_export_some_errors from onnx_diagnostic.torch_models.hghub import ( get_untrained_model_with_inputs, diff --git a/_unittests/ut_helpers/test_ort_session_tinyllm.py b/_unittests/ut_helpers/test_ort_session_tinyllm.py index 18b0065f..febe05fe 100644 --- a/_unittests/ut_helpers/test_ort_session_tinyllm.py +++ b/_unittests/ut_helpers/test_ort_session_tinyllm.py @@ -7,10 +7,10 @@ from onnxruntime.capi import _pybind_state as ORTC from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings from onnx_diagnostic.helpers import max_diff +from onnx_diagnostic.helpers.rt_helper import make_feeds from onnx_diagnostic.helpers.ort_session import ( InferenceSessionForNumpy, InferenceSessionForTorch, - make_feeds, ) from onnx_diagnostic.torch_export_patches import bypass_export_some_errors from onnx_diagnostic.torch_models.llms import get_tiny_llm diff --git a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py index 435f328e..699ccdd7 100644 --- a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py +++ b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py @@ -121,6 +121,49 @@ def forward(self, x: torch.Tensor, cache: MambaCache): dynamic_shapes=({0: DYN}, [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]]), ) + @ignore_warnings(UserWarning) + def test_exportable_dynamic_shapes_constraints(self): + import torch + + class CustomCache: + def __init__(self, shape=None): + self.cache = [torch.zeros((shape)), torch.zeros((shape))] if shape else [] + + def flatten_cache(cache): + return [cache.cache], ["cache"] + + def unflatten_cache(values, context, output_type=None): + cache = CustomCache() + cache.cache = values[0] + return cache + + def flatten_with_keys_cache(d): + values, context = flatten_cache(d) + return [ + (torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values) + ], context + + torch.utils._pytree.register_pytree_node( + CustomCache, + flatten_cache, + unflatten_cache, + serialized_type_name=f"{CustomCache.__module__}.{CustomCache.__name__}", + flatten_with_keys_fn=flatten_with_keys_cache, + ) + + class Model(torch.nn.Module): + def forward(self, x, cache): + return cache.cache[0][0, :] + x + + model = Model() + model.eval() + x, cache = torch.rand((2, 4)), CustomCache((2, 4)) + model(x, cache) + DYN = torch.export.Dim.DYNAMIC + torch.export.export( + model, (x, cache), dynamic_shapes=({0: DYN}, [[{0: DYN}, {0: DYN}]]) + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization.py b/_unittests/ut_torch_export_patches/test_patch_serialization.py new file mode 100644 index 00000000..6c9ad057 --- /dev/null +++ b/_unittests/ut_torch_export_patches/test_patch_serialization.py @@ -0,0 +1,156 @@ +import unittest +import torch +from transformers.modeling_outputs import BaseModelOutput +from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings +from onnx_diagnostic.helpers.cache_helper import make_encoder_decoder_cache, make_dynamic_cache +from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( + bypass_export_some_errors, +) +from onnx_diagnostic.helpers.torch_test_helper import torch_deepcopy + + +class TestPatchSerialization(ExtTestCase): + @ignore_warnings(UserWarning) + def test_encoder_decoder_cache_flatten(self): + cache = make_encoder_decoder_cache( + make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), + make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]), + ) + with bypass_export_some_errors(): + flat, _spec = torch.utils._pytree.tree_flatten(cache) + self.assertEqual( + "#4[T1s4x4x4,T1s4x4x4,T1s5x5x5,T1s5x5x5]", + self.string_type(flat, with_shape=True), + ) + cache2 = torch.utils._pytree.tree_unflatten(flat, _spec) + self.assertEqual( + self.string_type(cache, with_shape=True, with_min_max=True), + self.string_type(cache2, with_shape=True, with_min_max=True), + ) + + @ignore_warnings(UserWarning) + def test_encoder_decoder_cache_deepcopy(self): + cache = make_encoder_decoder_cache( + make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), + make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]), + ) + with bypass_export_some_errors(): + cache2 = torch_deepcopy([cache]) + self.assertEqualAny([cache], cache2) + + @ignore_warnings(UserWarning) + def test_encoder_decoder_cache_export(self): + class Model(torch.nn.Module): + def forward(self, cache): + return cache.self_attention_cache.key_cache[0] + + cache1 = make_dynamic_cache( + [(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)] + ) + cache2 = make_dynamic_cache( + [(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)] + ) + + cache = make_encoder_decoder_cache(cache1, cache2) + model = Model() + model(cache) + DYN = torch.export.Dim.DYNAMIC + ds = [ + [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]], + [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]], + ] + + with bypass_export_some_errors(patch_transformers=True): + torch.export.export(model, (cache,), dynamic_shapes=(ds,)) + + @ignore_warnings(UserWarning) + def test_dynamic_cache_flatten(self): + cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) + with bypass_export_some_errors(): + flat, _spec = torch.utils._pytree.tree_flatten(cache) + self.assertEqual( + "#2[T1s4x4x4,T1s4x4x4]", + self.string_type(flat, with_shape=True), + ) + cache2 = torch.utils._pytree.tree_unflatten(flat, _spec) + self.assertEqual( + self.string_type(cache, with_shape=True, with_min_max=True), + self.string_type(cache2, with_shape=True, with_min_max=True), + ) + + @ignore_warnings(UserWarning) + def test_dynamic_cache_export(self): + class Model(torch.nn.Module): + def forward(self, cache): + return cache.key_cache[0] + + cache = make_dynamic_cache( + [(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)] + ) + model = Model() + model(cache) + DYN = torch.export.Dim.DYNAMIC + ds = [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]] + + with bypass_export_some_errors(): + torch.export.export(model, (cache,), dynamic_shapes=(ds,)) + + @ignore_warnings(UserWarning) + def test_dynamic_cache_deepcopy(self): + cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) + with bypass_export_some_errors(): + cache2 = torch_deepcopy([cache]) + self.assertEqualAny([cache], cache2) + + @ignore_warnings(UserWarning) + def test_base_model_output_deepcopy(self): + bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) + self.assertEqual(bo.__class__.__name__, "BaseModelOutput") + with bypass_export_some_errors(): + bo2 = torch_deepcopy([bo]) + self.assertIsInstance(bo2, list) + self.assertEqual(bo2[0].__class__.__name__, "BaseModelOutput") + self.assertEqualAny([bo], bo2) + + @ignore_warnings(UserWarning) + def test_base_model_output_string_type(self): + bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) + with bypass_export_some_errors(): + self.assertEqual( + "BaseModelOutput(last_hidden_state:T1s4x4x4)", + self.string_type(bo, with_shape=True), + ) + + @ignore_warnings(UserWarning) + def test_base_model_output_flatten(self): + bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) + with bypass_export_some_errors(): + flat, _spec = torch.utils._pytree.tree_flatten(bo) + self.assertEqual( + "#1[T1s4x4x4]", + self.string_type(flat, with_shape=True), + ) + bo2 = torch.utils._pytree.tree_unflatten(flat, _spec) + self.assertEqual( + self.string_type(bo, with_shape=True, with_min_max=True), + self.string_type(bo2, with_shape=True, with_min_max=True), + ) + + @ignore_warnings(UserWarning) + def test_base_model_output_export(self): + class Model(torch.nn.Module): + def forward(self, cache): + return cache.last_hidden_state[0] + + bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) + model = Model() + model(bo) + DYN = torch.export.Dim.DYNAMIC + ds = [{0: DYN}] + + with bypass_export_some_errors(): + torch.export.export(model, (bo,), dynamic_shapes=(ds,)) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_models/test_hghub_model.py b/_unittests/ut_torch_models/test_hghub_model.py index 1314b77d..e0b40310 100644 --- a/_unittests/ut_torch_models/test_hghub_model.py +++ b/_unittests/ut_torch_models/test_hghub_model.py @@ -1,5 +1,6 @@ import pprint import unittest +import torch import transformers from onnx_diagnostic.ext_test_case import ( ExtTestCase, @@ -14,6 +15,7 @@ ) from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config from onnx_diagnostic.torch_models.hghub.hub_data import load_models_testing +from onnx_diagnostic.torch_export_patches import bypass_export_some_errors class TestHuggingFaceHubModel(ExtTestCase): @@ -104,6 +106,72 @@ def test_get_untrained_model_with_inputs_text2text_generation(self): raise unittest.SkipTest(f"not working for {mid!r}") model(**inputs) + @hide_stdout() + def test_get_untrained_model_with_inputs_automatic_speech_recognition(self): + mid = "openai/whisper-tiny" + data = get_untrained_model_with_inputs(mid, verbose=1) + self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)]) + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + Dim = torch.export.Dim + self.maxDiff = None + self.assertIn("{0:Dim(batch),1:Dim(seq_length)}", self.string_type(ds)) + self.assertEqualAny( + { + "decoder_input_ids": { + 0: Dim("batch", min=1, max=1024), + 1: Dim("seq_length", min=1, max=4096), + }, + "cache_position": {0: Dim("seq_length", min=1, max=4096)}, + "encoder_outputs": [{0: Dim("batch", min=1, max=1024)}], + "past_key_values": [ + [ + [ + {0: Dim("batch", min=1, max=1024)}, + {0: Dim("batch", min=1, max=1024)}, + ], + [ + {0: Dim("batch", min=1, max=1024)}, + {0: Dim("batch", min=1, max=1024)}, + ], + ], + [ + [ + {0: Dim("batch", min=1, max=1024)}, + {0: Dim("batch", min=1, max=1024)}, + ], + [ + {0: Dim("batch", min=1, max=1024)}, + {0: Dim("batch", min=1, max=1024)}, + ], + ], + ], + }, + ds, + ) + model(**inputs) + self.assertEqual( + "#1[T1r3]", + self.string_type(torch.utils._pytree.tree_flatten(inputs["encoder_outputs"])[0]), + ) + with bypass_export_some_errors(patch_transformers=True, verbose=10): + flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0] + self.assertIsInstance(flat, list) + self.assertIsInstance(flat[0], torch.Tensor) + self.assertEqual( + "#8[T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4]", + self.string_type(flat), + ) + torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False) + with bypass_export_some_errors(patch_transformers=True, verbose=10): + flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0] + self.assertIsInstance(flat, list) + self.assertIsInstance(flat[0], torch.Tensor) + self.assertEqual( + "#8[T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4]", + self.string_type(flat), + ) + torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False) + @hide_stdout() def test_get_untrained_model_with_inputs_imagetext2text_generation(self): mid = "HuggingFaceM4/tiny-random-idefics" @@ -131,6 +199,7 @@ def _diff(c1, c2): for mid in load_models_testing(): with self.subTest(mid=mid): if mid in { + "hf-internal-testing/tiny-random-BeitForImageClassification", "hf-internal-testing/tiny-random-MaskFormerForInstanceSegmentation", "hf-internal-testing/tiny-random-MoonshineForConditionalGeneration", "fxmarty/pix2struct-tiny-random", diff --git a/_unittests/ut_torch_models/try_tasks.py b/_unittests/ut_torch_models/try_tasks.py index 06bc7470..05cdb04c 100644 --- a/_unittests/ut_torch_models/try_tasks.py +++ b/_unittests/ut_torch_models/try_tasks.py @@ -82,6 +82,72 @@ def test_imagetext2text_generation(self): print(generated_text[0]) + @never_test() + def test_automatic_speech_recognition(self): + # clear&&NEVERTEST=1 python _unittests/ut_torch_models/try_tasks.py -k automatic_speech + # https://huggingface.co/openai/whisper-tiny + + from transformers import WhisperProcessor, WhisperForConditionalGeneration + from datasets import load_dataset + + """ + kwargs=dict( + cache_position:T7s4, + past_key_values:EncoderDecoderCache( + self_attention_cache=DynamicCache[serialized](#2[#0[],#0[]]), + cross_attention_cache=DynamicCache[serialized](#2[#0[],#0[]]) + ), + decoder_input_ids:T7s1x4, + encoder_outputs:dict(last_hidden_state:T1s1x1500x384), + use_cache:bool,return_dict:bool + ) + kwargs=dict( + cache_position:T7s1, + past_key_values:EncoderDecoderCache( + self_attention_cache=DynamicCache[serialized](#2[ + #4[T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64], + #4[T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64] + ]), + cross_attention_cache=DynamicCache[serialized](#2[ + #4[T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64], + #4[T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64] + ]), + ), + decoder_input_ids:T7s1x1, + encoder_outputs:dict(last_hidden_state:T1s1x1500x384), + use_cache:bool,return_dict:bool + ) + """ + + # load model and processor + processor = WhisperProcessor.from_pretrained("openai/whisper-tiny") + model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") + forced_decoder_ids = processor.get_decoder_prompt_ids( + language="english", task="transcribe" + ) + + # load streaming dataset and read first audio sample + ds = load_dataset( + "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation" + ) + sample = ds[0]["audio"] + input_features = processor( + sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt" + ).input_features + + # generate token ids + print() + with steel_forward(model): + predicted_ids = model.generate( + input_features, forced_decoder_ids=forced_decoder_ids + ) + + # decode token ids to text + transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False) + print("--", transcription) + transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True) + print("--", transcription) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index cb9bc16d..1e0cb064 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -315,6 +315,14 @@ def _cmd_validate(argv: List[Any]): for k, v in data["dynamic_shapes"].items(): print(f" + {k.ljust(max_length)}: {_ds_clean(v)}") else: + # Let's skip any invalid combination if known to be unsupported + if ( + "onnx" not in (args.export or "") + and "custom" not in (args.export or "") + and (args.opt or "") + ): + print(f"validate - unsupported args: export={args.export!r}, opt={args.opt!r}") + return summary, _data = validate_model( model_id=args.mid, task=args.task, diff --git a/onnx_diagnostic/export/validate.py b/onnx_diagnostic/export/validate.py index 505406ed..f98b13f1 100644 --- a/onnx_diagnostic/export/validate.py +++ b/onnx_diagnostic/export/validate.py @@ -75,8 +75,8 @@ def _get(a): begin = time.perf_counter() print( f"[compare_modules] check ep with " - f"args={string_type(args, with_shape=True)}, " - f"kwargs={string_type(kwargs, with_shape=True)}..." + f"args={string_type(args, with_shape=True, with_device=True)}, " + f"kwargs={string_type(kwargs, with_shape=True, with_device=True)}..." ) got = modep(*_get(args), **_get(kwargs)) if verbose: diff --git a/onnx_diagnostic/ext_test_case.py b/onnx_diagnostic/ext_test_case.py index 9b70ae2c..4e2d1c60 100644 --- a/onnx_diagnostic/ext_test_case.py +++ b/onnx_diagnostic/ext_test_case.py @@ -887,7 +887,18 @@ def assertEqual(self, expected: Any, value: Any, msg: str = ""): def assertEqualAny( self, expected: Any, value: Any, atol: float = 0, rtol: float = 0, msg: str = "" ): - if isinstance(expected, (tuple, list, dict)): + if expected.__class__.__name__ == "BaseModelOutput": + self.assertEqual(type(expected), type(value), msg=msg) + self.assertEqual(len(expected), len(value), msg=msg) + self.assertEqual(list(expected), list(value), msg=msg) # checks the order + self.assertEqualAny( + {k: v for k, v in expected.items()}, # noqa: C416 + {k: v for k, v in value.items()}, # noqa: C416 + atol=atol, + rtol=rtol, + msg=msg, + ) + elif isinstance(expected, (tuple, list, dict)): self.assertIsInstance(value, type(expected), msg=msg) self.assertEqual(len(expected), len(value), msg=msg) if isinstance(expected, dict): @@ -898,6 +909,7 @@ def assertEqualAny( for e, g in zip(expected, value): self.assertEqualAny(e, g, msg=msg, atol=atol, rtol=rtol) elif expected.__class__.__name__ == "DynamicCache": + self.assertEqual(type(expected), type(value), msg=msg) atts = ["key_cache", "value_cache"] self.assertEqualAny( {k: expected.__dict__.get(k, None) for k in atts}, @@ -905,11 +917,25 @@ def assertEqualAny( atol=atol, rtol=rtol, ) + elif expected.__class__.__name__ == "EncoderDecoderCache": + self.assertEqual(type(expected), type(value), msg=msg) + atts = ["self_attention_cache", "cross_attention_cache"] + self.assertEqualAny( + {k: expected.__dict__.get(k, None) for k in atts}, + {k: value.__dict__.get(k, None) for k in atts}, + atol=atol, + rtol=rtol, + ) elif isinstance(expected, (int, float, str)): self.assertEqual(expected, value, msg=msg) elif hasattr(expected, "shape"): self.assertEqual(type(expected), type(value), msg=msg) self.assertEqualArray(expected, value, msg=msg, atol=atol, rtol=rtol) + elif expected.__class__.__name__ in ("Dim", "_Dim", "_DimHintType"): + self.assertEqual(type(expected), type(value), msg=msg) + self.assertEqual(expected.__name__, value.__name__, msg=msg) + elif expected is None: + self.assertEqual(expected, value, msg=msg) else: raise AssertionError( f"Comparison not implemented for types {type(expected)} and {type(value)}" @@ -1081,7 +1107,8 @@ def assert_onnx_disc( :class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch` """ from .helpers import string_type, string_diff, max_diff - from .helpers.ort_session import InferenceSessionForTorch, make_feeds + from .helpers.rt_helper import make_feeds + from .helpers.ort_session import InferenceSessionForTorch kws = dict(with_shape=True, with_min_max=verbose > 1) if verbose: @@ -1137,6 +1164,11 @@ def _debug(self): "Tells if DEBUG=1 is set up." return os.environ.get("DEBUG") in BOOLEAN_VALUES + def string_type(self, *args, **kwargs): + from .helpers import string_type + + return string_type(*args, **kwargs) + def subloop(self, *args, verbose: int = 0): "Loops over elements and calls :meth:`unittests.TestCase.subTest`." if len(args) == 1: diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 6c8a04fa..3c0bf9b7 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -132,9 +132,7 @@ def make_encoder_decoder_cache( self_attention_cache: transformers.cache_utils.DynamicCache, cross_attention_cache: transformers.cache_utils.DynamicCache, ) -> transformers.cache_utils.EncoderDecoderCache: - """ - Creates an EncoderDecoderCache. - """ + """Creates an EncoderDecoderCache.""" return transformers.cache_utils.EncoderDecoderCache( self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache ) diff --git a/onnx_diagnostic/helpers/helper.py b/onnx_diagnostic/helpers/helper.py index 81a95ec9..5c3a511c 100644 --- a/onnx_diagnostic/helpers/helper.py +++ b/onnx_diagnostic/helpers/helper.py @@ -1,6 +1,7 @@ import ast import enum import inspect +from dataclasses import is_dataclass, fields from typing import Any, Callable, Dict, List, Optional, Set import numpy as np @@ -140,6 +141,19 @@ def string_type( """ if obj is None: return "None" + if is_dataclass(obj): + values = {f.name: getattr(obj, f.name, None) for f in fields(obj)} + values = {k: v for k, v in values.items() if v is not None} + s = string_type( + values, + with_shape=with_shape, + with_min_max=with_min_max, + with_device=with_device, + ignore=ignore, + limit=limit, + ) + return f"{obj.__class__.__name__}{s[4:]}" + # tuple if isinstance(obj, tuple): if len(obj) == 1: @@ -235,6 +249,8 @@ def string_type( limit=limit, ) s = ",".join(f"{kv[0]}:{string_type(kv[1],**kws)}" for kv in obj.items()) + if all(isinstance(k, int) for k in obj): + return f"{{{s}}}" return f"dict({s})" # arrat if isinstance(obj, np.ndarray): @@ -265,7 +281,7 @@ def string_type( if isinstance(obj, torch.export.dynamic_shapes._DerivedDim): return "DerivedDim" if isinstance(obj, torch.export.dynamic_shapes._Dim): - return "Dim" + return f"Dim({obj.__name__})" if isinstance(obj, torch.SymInt): return "SymInt" if isinstance(obj, torch.SymFloat): @@ -341,6 +357,11 @@ def string_type( if isinstance(obj, slice): return "slice" + if obj == torch.export.Dim.DYNAMIC: + return "DYNAMIC" + if obj == torch.export.Dim.AUTO: + return "AUTO" + # others classes if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES: @@ -388,7 +409,7 @@ def string_type( f"dtype={obj.dtype}, shape={obj.shape})" ) - if obj.__class__.__name__ == "_DimHint": + if obj.__class__.__name__ in ("_DimHint", "_DimHintType"): return str(obj) if isinstance(obj, torch.nn.Module): diff --git a/onnx_diagnostic/helpers/ort_session.py b/onnx_diagnostic/helpers/ort_session.py index 90a70f17..39a0d45d 100644 --- a/onnx_diagnostic/helpers/ort_session.py +++ b/onnx_diagnostic/helpers/ort_session.py @@ -6,8 +6,7 @@ from torch._C import _from_dlpack import onnxruntime from onnxruntime.capi import _pybind_state as ORTC -from .cache_helper import is_cache_dynamic_registered -from .helper import size_type, string_type, flatten_object +from .helper import size_type from .onnx_helper import ( torch_dtype_to_onnx_dtype, onnx_dtype_to_np_dtype, @@ -18,43 +17,6 @@ DEVICES = {-1: ORTC.OrtDevice(ORTC.OrtDevice.cpu(), ORTC.OrtDevice.default_memory(), 0)} -def make_feeds( - proto: Union[onnx.ModelProto, List[str]], - inputs: Any, - use_numpy: bool = False, - copy: bool = False, -) -> Dict[str, Union[torch.Tensor, np.ndarray]]: - """ - Serializes the inputs to produce feeds expected - by :class:`onnxruntime.InferenceSession`. - - :param proto: onnx model or list of names - :param inputs: any kind of inputs - :param use_numpy: if True, converts torch tensors into numpy arrays - :param copy: a copy is made, this should be the case if the inputs is ingested - by ``OrtValue`` - :return: feeds dictionary - """ - flat = flatten_object(inputs, drop_keys=True) - assert ( - not all(isinstance(obj, torch.Tensor) for obj in flat) - or not is_cache_dynamic_registered(fast=True) - or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0]) - ), ( - f"Unexpected number of flattened objects, " - f"{string_type(flat, with_shape=True, limit=20)} != " - f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True,limit=20)}" - ) - if use_numpy: - flat = [t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t for t in flat] - names = ( - [i.name for i in proto.graph.input] if isinstance(proto, onnx.ModelProto) else proto - ) - if copy: - flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat] - return dict(zip(names, flat)) - - class _InferenceSession: @classmethod diff --git a/onnx_diagnostic/helpers/rt_helper.py b/onnx_diagnostic/helpers/rt_helper.py new file mode 100644 index 00000000..8d9eca07 --- /dev/null +++ b/onnx_diagnostic/helpers/rt_helper.py @@ -0,0 +1,47 @@ +from typing import Any, Dict, List, Union +import numpy as np +import onnx +import torch +from .helper import string_type, flatten_object +from .cache_helper import is_cache_dynamic_registered + + +def make_feeds( + proto: Union[onnx.ModelProto, List[str]], + inputs: Any, + use_numpy: bool = False, + copy: bool = False, + check_flatten: bool = True, +) -> Dict[str, Union[torch.Tensor, np.ndarray]]: + """ + Serializes the inputs to produce feeds expected + by :class:`onnxruntime.InferenceSession`. + + :param proto: onnx model or list of names + :param inputs: any kind of inputs + :param use_numpy: if True, converts torch tensors into numpy arrays + :param copy: a copy is made, this should be the case if the inputs is ingested + by ``OrtValue`` + :param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten`` + returns the same number of outputs + :return: feeds dictionary + """ + flat = flatten_object(inputs, drop_keys=True) + assert ( + not check_flatten + or not all(isinstance(obj, torch.Tensor) for obj in flat) + or not is_cache_dynamic_registered(fast=True) + or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0]) + ), ( + f"Unexpected number of flattened objects, " + f"{string_type(flat, with_shape=True)} != " + f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}" + ) + if use_numpy: + flat = [t.detach().cpu().numpy() if isinstance(t, torch.Tensor) else t for t in flat] + names = ( + [i.name for i in proto.graph.input] if isinstance(proto, onnx.ModelProto) else proto + ) + if copy: + flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat] + return dict(zip(names, flat)) diff --git a/onnx_diagnostic/helpers/torch_test_helper.py b/onnx_diagnostic/helpers/torch_test_helper.py index 6aa935e8..3c482bdd 100644 --- a/onnx_diagnostic/helpers/torch_test_helper.py +++ b/onnx_diagnostic/helpers/torch_test_helper.py @@ -351,7 +351,10 @@ def torch_deepcopy(value: Any) -> Any: if isinstance(value, set): return {torch_deepcopy(v) for v in value} if isinstance(value, dict): - return {k: torch_deepcopy(v) for k, v in value.items()} + if type(value) is dict: + return {k: torch_deepcopy(v) for k, v in value.items()} + # for BaseModelOutput + return value.__class__(**{k: torch_deepcopy(v) for k, v in value.items()}) if isinstance(value, np.ndarray): return value.copy() if hasattr(value, "clone"): diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 2b53f5e2..54b3e875 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -1,13 +1,8 @@ import contextlib -import pprint -from typing import Any, Callable, Dict, List, Optional, Set +from typing import Any, Callable, Dict, List, Optional from .onnx_export_serialization import ( - flatten_with_keys_dynamic_cache, - flatten_dynamic_cache, - unflatten_dynamic_cache, - flatten_mamba_cache, - flatten_with_keys_mamba_cache, - unflatten_mamba_cache, + _register_cache_serialization, + _unregister_cache_serialization, ) from .patches import patch_transformers as patch_transformers_list @@ -41,7 +36,7 @@ def patch_module_or_classes(mod, verbose: int = 0) -> Dict[type, Dict[type, Call original = cls._PATCHED_CLASS_ methods = cls._PATCHES_ if verbose: - print(f"[patch_module_or_classes] {name} - {cls.__name__}: {', '.join(methods)}") + print(f"[patch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}") keep = {n: getattr(original, n, None) for n in methods} for n in methods: @@ -74,7 +69,7 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo for cls, methods in info.items(): assert cls in set_patch, f"No patch registered for {cls} in {mod} (found {set_patch})" if verbose: - print(f"[unpatch_module_or_classes] {name} - {cls.__name__}: {', '.join(methods)}") + print(f"[unpatch_module_or_classes] {name}.{cls.__name__}: {', '.join(methods)}") original = cls._PATCHED_CLASS_ for n, v in methods.items(): if v is None: @@ -84,156 +79,6 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo setattr(original, n, v) -PATCH_OF_PATCHES: Set[Any] = set() - - -def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: - # Cache serialization: to be moved into appropriate packages - import torch - import transformers - import packaging.version as pv - - try: - from transformers.cache_utils import DynamicCache - except ImportError: - DynamicCache = None - - try: - from transformers.cache_utils import MambaCache - except ImportError: - MambaCache = None - - # MambaCache - unregistered_mamba_cache = True - if MambaCache is not None and MambaCache in torch.utils._pytree.SUPPORTED_NODES: - if verbose > 1: - print(f"[_register_cache_serialization] {MambaCache} already registered") - # It is already registered because bypass_export_some_errors was called - # within a section already calling bypass_export_some_errors or transformers - # has updated its code to do it. - # No need to register and unregister then. - unregistered_mamba_cache = False - else: - if verbose: - print("[_register_cache_serialization] register MambaCache") - torch.utils._pytree.register_pytree_node( - MambaCache, - flatten_mamba_cache, - unflatten_mamba_cache, - serialized_type_name=f"{MambaCache.__module__}.{MambaCache.__name__}", - flatten_with_keys_fn=flatten_with_keys_mamba_cache, - ) - - # DynamicCache serialization is different in transformers and does not - # play way with torch.export.export. - # see test test_export_dynamic_cache_cat with NOBYPASS=1 - # :: NOBYBASS=1 python _unittests/ut_torch_export_patches/test_dynamic_class.py -k e_c - # This is caused by this line: - # torch.fx._pytree.register_pytree_flatten_spec( - # DynamicCache, _flatten_dynamic_cache_for_fx) - # so we remove it anyway - if ( - DynamicCache in torch.fx._pytree.SUPPORTED_NODES - and not PATCH_OF_PATCHES - # and pv.Version(torch.__version__) < pv.Version("2.7") - and pv.Version(transformers.__version__) >= pv.Version("4.50") - ): - if verbose: - print( - "[_register_cache_serialization] DynamicCache " - "is unregistered and registered first." - ) - _unregister(DynamicCache) - torch.utils._pytree.register_pytree_node( - DynamicCache, - flatten_dynamic_cache, - unflatten_dynamic_cache, - serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", - flatten_with_keys_fn=flatten_with_keys_dynamic_cache, - ) - if pv.Version(torch.__version__) < pv.Version("2.7"): - torch.fx._pytree.register_pytree_flatten_spec( - DynamicCache, lambda x, _: [x.key_cache, x.value_cache] - ) - # To avoid doing it multiple times. - PATCH_OF_PATCHES.add(DynamicCache) - - unregistered_dynamic_cache = True - if DynamicCache is not None and DynamicCache in torch.utils._pytree.SUPPORTED_NODES: - if verbose > 1: - print(f"[_register_cache_serialization] {DynamicCache} already registered") - unregistered_dynamic_cache = False - else: - if verbose: - print("[_register_cache_serialization] register DynamicCache") - torch.utils._pytree.register_pytree_node( - DynamicCache, - flatten_dynamic_cache, - unflatten_dynamic_cache, - serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", - flatten_with_keys_fn=flatten_with_keys_dynamic_cache, - ) - if pv.Version(torch.__version__) < pv.Version("2.7"): - torch.fx._pytree.register_pytree_flatten_spec( - DynamicCache, lambda x, _: [x.key_cache, x.value_cache] - ) - - # check - from ..helpers.cache_helper import make_dynamic_cache - - cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) - values, spec = torch.utils._pytree.tree_flatten(cache) - cache2 = torch.utils._pytree.tree_unflatten(values, spec) - # torch.fx._pytree.tree_flatten(cache) - assert len(cache2.key_cache) == 1 - - return dict(DynamicCache=unregistered_dynamic_cache, MambaCache=unregistered_mamba_cache) - - -def _unregister(cls: type, verbose: int = 0): - import optree - import torch - - # torch.fx._pytree._deregister_pytree_flatten_spec(cls) - if cls in torch.fx._pytree.SUPPORTED_NODES: - del torch.fx._pytree.SUPPORTED_NODES[cls] - if cls in torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH: - del torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH[cls] - if hasattr(torch.utils._pytree, "_deregister_pytree_node"): - # torch >= 2.7 - torch.utils._pytree._deregister_pytree_node(cls) - optree.unregister_pytree_node(cls, namespace="torch") - if cls in torch.utils._pytree.SUPPORTED_NODES: - import packaging.version as pv - - if pv.Version(torch.__version__) < pv.Version("2.7.0"): - del torch.utils._pytree.SUPPORTED_NODES[cls] - assert cls not in torch.utils._pytree.SUPPORTED_NODES, ( - f"{cls} was not successful unregistered " - f"from torch.utils._pytree.SUPPORTED_NODES=" - f"{pprint.pformat(list(torch.utils._pytree.SUPPORTED_NODES))}" - ) - if verbose: - print(f"[_unregister_cache_serialization] unregistered {cls.__name__}") - - -def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0): - - if undo.get("MambaCache", False): - from transformers.cache_utils import MambaCache - - _unregister(MambaCache, verbose) - elif verbose > 1: - print("[_unregister_cache_serialization] skip unregister MambaCache") - - if undo.get("DynamicCache", False): - from transformers.cache_utils import DynamicCache - - _unregister(DynamicCache, verbose) - elif verbose > 1: - print("[_unregister_cache_serialization] skip unregister DynamicCache") - - @contextlib.contextmanager def register_additional_serialization_functions( patch_transformers: bool = False, verbose: int = 0 @@ -579,6 +424,9 @@ def replacement_before_exporting(args: Any) -> Any: return None if isinstance(args, (int, float)): return args + if type(args) not in {dict, tuple, list}: + # BaseModelOutput is a dict + return args if isinstance(args, dict): return {k: replacement_before_exporting(v) for k, v in args.items()} if isinstance(args, tuple): diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py index 73c56cbb..d42d2534 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_serialization.py @@ -1,52 +1,208 @@ -from typing import Any, Dict, List, Tuple +import pprint +from typing import Any, Callable, Dict, List, Optional, Set, Tuple +import packaging.version as pv +import optree import torch import transformers +from transformers.cache_utils import DynamicCache, MambaCache, EncoderDecoderCache +from transformers.modeling_outputs import BaseModelOutput +from ..helpers import string_type + + +PATCH_OF_PATCHES: Set[Any] = set() + + +def _register_class_serialization( + cls, + f_flatten: Callable, + f_unflatten: Callable, + f_flatten_with_keys: Callable, + f_check: Optional[Callable] = None, + verbose: int = 0, +) -> bool: + if cls is not None and cls in torch.utils._pytree.SUPPORTED_NODES: + return False + + if verbose: + print(f"[_register_cache_serialization] register {cls}") + torch.utils._pytree.register_pytree_node( + cls, + f_flatten, + f_unflatten, + serialized_type_name=f"{cls.__module__}.{cls.__name__}", + flatten_with_keys_fn=f_flatten_with_keys, + ) + if pv.Version(torch.__version__) < pv.Version("2.7"): + if verbose: + print( + f"[_register_cache_serialization] " + f"register {cls} for torch=={torch.__version__}" + ) + torch.fx._pytree.register_pytree_flatten_spec(cls, lambda x, _: f_flatten(x)[0]) + + # check + if f_check: + inst = f_check() + values, spec = torch.utils._pytree.tree_flatten(inst) + restored = torch.utils._pytree.tree_unflatten(values, spec) + assert string_type(inst, with_shape=True) == string_type(restored, with_shape=True), ( + f"Issue with registration of class {cls} " + f"inst={string_type(inst, with_shape=True)}, " + f"restored={string_type(restored, with_shape=True)}" + ) + return True + + +def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: + # DynamicCache serialization is different in transformers and does not + # play way with torch.export.export. + # see test test_export_dynamic_cache_cat with NOBYPASS=1 + # :: NOBYBASS=1 python _unittests/ut_torch_export_patches/test_dynamic_class.py -k e_c + # This is caused by this line: + # torch.fx._pytree.register_pytree_flatten_spec( + # DynamicCache, _flatten_dynamic_cache_for_fx) + # so we remove it anyway + if ( + DynamicCache in torch.utils._pytree.SUPPORTED_NODES + and DynamicCache not in PATCH_OF_PATCHES + # and pv.Version(torch.__version__) < pv.Version("2.7") + and pv.Version(transformers.__version__) >= pv.Version("4.50") + ): + if verbose: + print( + f"[_fix_registration] DynamicCache is unregistered and " + f"registered first for transformers=={transformers.__version__}" + ) + _unregister(DynamicCache, verbose=verbose) + _register_class_serialization( + DynamicCache, + flatten_dynamic_cache, + unflatten_dynamic_cache, + flatten_with_keys_dynamic_cache, + # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), + verbose=verbose, + ) + if verbose: + print("[_fix_registration] DynamicCache done.") + # To avoid doing it multiple times. + PATCH_OF_PATCHES.add(DynamicCache) + + # BaseModelOutput serialization is incomplete. + # It does not include dynamic shapes mapping. + if ( + BaseModelOutput in torch.utils._pytree.SUPPORTED_NODES + and BaseModelOutput not in PATCH_OF_PATCHES + ): + if verbose: + print( + f"[_fix_registration] BaseModelOutput is unregistered and " + f"registered first for transformers=={transformers.__version__}" + ) + _unregister(BaseModelOutput, verbose=verbose) + _register_class_serialization( + BaseModelOutput, + flatten_base_model_output, + unflatten_base_model_output, + flatten_with_keys_base_model_output, + verbose=verbose, + ) + if verbose: + print("[_fix_registration] BaseModelOutput done.") + + # To avoid doing it multiple times. + PATCH_OF_PATCHES.add(BaseModelOutput) + + unregistered_dynamic_cache = _register_class_serialization( + DynamicCache, + flatten_dynamic_cache, + unflatten_dynamic_cache, + flatten_with_keys_dynamic_cache, + # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), + verbose=verbose, + ) + unregistered_base_model_output = _register_class_serialization( + BaseModelOutput, + flatten_base_model_output, + unflatten_base_model_output, + flatten_with_keys_base_model_output, + verbose=verbose, + ) + unregistered_encode_decode_cache = _register_class_serialization( + EncoderDecoderCache, + flatten_encoder_decoder_cache, + unflatten_encoder_decoder_cache, + flatten_with_keys_encoder_decoder_cache, + verbose=verbose, + ) + unregistered_mamba_cache = _register_class_serialization( + MambaCache, + flatten_mamba_cache, + unflatten_mamba_cache, + flatten_with_keys_mamba_cache, + verbose=verbose, + ) + + return dict( + DynamicCache=unregistered_dynamic_cache, + MambaCache=unregistered_mamba_cache, + EncoderDecoderCache=unregistered_encode_decode_cache, + BaseModelOutput=unregistered_base_model_output, + ) + + +def _unregister(cls: type, verbose: int = 0): + # torch.utils._pytree._deregister_pytree_flatten_spec(cls) + if cls in torch.fx._pytree.SUPPORTED_NODES: + del torch.fx._pytree.SUPPORTED_NODES[cls] + if cls in torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH: + del torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH[cls] + if hasattr(torch.utils._pytree, "_deregister_pytree_node"): + # torch >= 2.7 + torch.utils._pytree._deregister_pytree_node(cls) + else: + if cls in torch.utils._pytree.SUPPORTED_NODES: + del torch.utils._pytree.SUPPORTED_NODES[cls] + optree.unregister_pytree_node(cls, namespace="torch") + if cls in torch.utils._pytree.SUPPORTED_NODES: + import packaging.version as pv + + if pv.Version(torch.__version__) < pv.Version("2.7.0"): + del torch.utils._pytree.SUPPORTED_NODES[cls] + assert cls not in torch.utils._pytree.SUPPORTED_NODES, ( + f"{cls} was not successful unregistered " + f"from torch.utils._pytree.SUPPORTED_NODES=" + f"{pprint.pformat(list(torch.utils._pytree.SUPPORTED_NODES))}" + ) + if verbose: + print(f"[_unregister_cache_serialization] unregistered {cls.__name__}") + + +def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0): + for cls in [MambaCache, DynamicCache, EncoderDecoderCache, BaseModelOutput]: + if undo.get(cls.__name__, False): + _unregister(cls, verbose) + ############ # MambaCache ############ -# self.conv_states: torch.Tensor = torch.zeros( -# config.num_hidden_layers, -# self.max_batch_size, -# self.intermediate_size, -# self.conv_kernel_size, -# device=device, -# dtype=dtype, -# ) -# self.ssm_states: torch.Tensor = torch.zeros( -# config.num_hidden_layers, -# self.max_batch_size, -# self.intermediate_size, -# self.ssm_state_size, -# device=device, -# dtype=dtype, -# ) def flatten_mamba_cache( - mamba_cache: transformers.cache_utils.MambaCache, + mamba_cache: MambaCache, ) -> Tuple[List[Any], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" flat = [ (k, getattr(mamba_cache, k)) - for k in [ - # "max_batch_size", # new in transformers==4.47 - # "intermediate_size", - # "ssm_state_size", - # "conv_kernel_size", - "conv_states", - "ssm_states", - ] + for k in ["conv_states", "ssm_states"] if hasattr(mamba_cache, k) ] return [f[1] for f in flat], [f[0] for f in flat] def unflatten_mamba_cache( - values: List[Any], - context: torch.utils._pytree.Context, - output_type=None, -) -> transformers.cache_utils.MambaCache: + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> MambaCache: """Restores a :class:`transformers.cache_utils.MambaCache` from python objects.""" conv_states, ssm_states = values @@ -63,8 +219,6 @@ def __init__(self): self.conv_kernel = conv_states.shape[3] self.num_hidden_layers = conv_states.shape[0] - from transformers.cache_utils import MambaCache - cache = MambaCache( _config(), max_batch_size=1, @@ -77,14 +231,12 @@ def __init__(self): return cache -def flatten_with_keys_mamba_cache(d: Dict[Any, Any]) -> Tuple[ +def flatten_with_keys_mamba_cache(cache: MambaCache) -> Tuple[ List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context, ]: """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects.""" - import torch - - values, context = flatten_mamba_cache(d) + values, context = flatten_mamba_cache(cache) return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context @@ -94,11 +246,9 @@ def flatten_with_keys_mamba_cache(d: Dict[Any, Any]) -> Tuple[ def flatten_dynamic_cache( - dynamic_cache: transformers.cache_utils.DynamicCache, + dynamic_cache: DynamicCache, ) -> Tuple[List[Any], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" - import transformers.cache_utils - if hasattr(transformers.cache_utils, "_flatten_dynamic_cache"): return transformers.cache_utils._flatten_dynamic_cache(dynamic_cache) flat = [ @@ -109,28 +259,20 @@ def flatten_dynamic_cache( return [f[1] for f in flat], [f[0] for f in flat] -def flatten_with_keys_dynamic_cache(d: Dict[Any, Any]) -> Tuple[ - List[Tuple[torch.utils._pytree.KeyEntry, Any]], - torch.utils._pytree.Context, -]: +def flatten_with_keys_dynamic_cache( + dynamic_cache: DynamicCache, +) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects.""" - import torch - import transformers.cache_utils - if hasattr(transformers.cache_utils, "_flatten_with_keys_dynamic_cache"): - return transformers.cache_utils._flatten_with_keys_dynamic_cache(d) - values, context = flatten_dynamic_cache(d) + return transformers.cache_utils._flatten_with_keys_dynamic_cache(dynamic_cache) + values, context = flatten_dynamic_cache(dynamic_cache) return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context def unflatten_dynamic_cache( - values: List[Any], - context: torch.utils._pytree.Context, - output_type=None, -) -> transformers.cache_utils.DynamicCache: + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> DynamicCache: """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects.""" - import transformers.cache_utils - if hasattr(transformers.cache_utils, "_unflatten_dynamic_cache"): assert output_type is None, f"output_type={output_type} not supported" return transformers.cache_utils._unflatten_dynamic_cache(values, context) @@ -140,3 +282,83 @@ def unflatten_dynamic_cache( for k, v in values.items(): setattr(cache, k, v) return cache + + +##################### +# EncoderDecoderCache +##################### + + +def flatten_encoder_decoder_cache( + ec_cache: EncoderDecoderCache, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """ + Serializes a :class:`transformers.cache_utils.EncoderDecoderCache` + with python objects. + """ + dictionary = { + "self_attention_cache": ec_cache.self_attention_cache, + "cross_attention_cache": ec_cache.cross_attention_cache, + } + return torch.utils._pytree._dict_flatten(dictionary) + + +def flatten_with_keys_encoder_decoder_cache(ec_cache: EncoderDecoderCache) -> Tuple[ + List[Tuple[torch.utils._pytree.KeyEntry, Any]], + torch.utils._pytree.Context, +]: + """ + Serializes a :class:`transformers.cache_utils.EncoderDecoderCache` + with python objects. + """ + dictionary = { + "self_attention_cache": ec_cache.self_attention_cache, + "cross_attention_cache": ec_cache.cross_attention_cache, + } + return torch.utils._pytree._dict_flatten_with_keys(dictionary) + + +def unflatten_encoder_decoder_cache( + values: List[Any], context: torch.utils._pytree.Context, output_type=None +) -> EncoderDecoderCache: + """Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects.""" + dictionary = torch.utils._pytree._dict_unflatten(values, context) + return EncoderDecoderCache(**dictionary) + + +################# +# BaseModelOutput +################# + + +def flatten_base_model_output( + bo: BaseModelOutput, +) -> Tuple[List[Any], torch.utils._pytree.Context]: + """ + Serializes a :class:`transformers.modeling_outputs.BaseModelOutput` + with python objects. + """ + return list(bo.values()), list(bo.keys()) + + +def flatten_with_keys_base_model_output( + bo: BaseModelOutput, +) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]: + """ + Serializes a :class:`transformers.modeling_outputs.BaseModelOutput` + with python objects. + """ + values, context = flatten_base_model_output(bo) + return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context + + +def unflatten_base_model_output( + values: List[Any], + context: torch.utils._pytree.Context, + output_type=None, +) -> BaseModelOutput: + """ + Restores a :class:`transformers.modeling_outputs.BaseModelOutput` + from python objects. + """ + return BaseModelOutput(**dict(zip(context, values))) diff --git a/onnx_diagnostic/torch_models/hghub/hub_api.py b/onnx_diagnostic/torch_models/hghub/hub_api.py index 4a17a2d7..9f7de9c6 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_api.py +++ b/onnx_diagnostic/torch_models/hghub/hub_api.py @@ -1,4 +1,5 @@ import functools +import os from typing import Any, Dict, List, Optional, Union import transformers from huggingface_hub import HfApi, model_info @@ -33,11 +34,13 @@ def get_cached_configuration(name: str) -> Optional[transformers.PretrainedConfi assert cached, "no cached configuration, which is weird" if name in cached: return cached[name]() + if os.environ.get("NOHTTP", ""): + raise AssertionError(f"Unable to find {name!r} in {sorted(cached)}") return None def get_pretrained_config( - model_id: str, trust_remote_code: bool = True, use_cached: bool = True + model_id: str, trust_remote_code: bool = True, use_preinstalled: bool = True ) -> Any: """ Returns the config for a model_id. @@ -45,13 +48,13 @@ def get_pretrained_config( :param model_id: model id :param trust_remote_code: trust_remote_code, see :meth:`transformers.AutoConfig.from_pretrained` - :param used_cached: if cached, uses this version to avoid + :param use_preinstalled: if use_preinstalled, uses this version to avoid accessing the network, if available, it is returned by :func:`get_cached_configuration`, the cached list is mostly for unit tests :return: a configuration """ - if use_cached: + if use_preinstalled: conf = get_cached_configuration(model_id) if conf is not None: return conf diff --git a/onnx_diagnostic/torch_models/hghub/hub_data.py b/onnx_diagnostic/torch_models/hghub/hub_data.py index 9cd96557..af141b6f 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data.py @@ -119,7 +119,7 @@ VitsModel,text-to-audio Wav2Vec2ConformerForCTC,automatic-speech-recognition Wav2Vec2Model,feature-extraction - WhisperForConditionalGeneration,no-pipeline-tag + WhisperForConditionalGeneration,automatic-speech-recognition XLMModel,feature-extraction XLMRobertaForCausalLM,text-generation YolosForObjectDetection,object-detection diff --git a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py index 750b1793..85d6cb56 100644 --- a/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +++ b/onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py @@ -174,7 +174,7 @@ def _cached_hf_internal_testing_tiny_random_beitforimageclassification(): ) -def _cached_hf_internal_testing_tiny_random_convnext(): +def _ccached_hf_internal_testing_tiny_random_convnext(): "hf-internal-testing/tiny-random-convnext" t64 = textwrap.dedent( """ @@ -1334,10 +1334,10 @@ def _cached_hf_internal_testing_tiny_random_convnext(): return transformers.ConvNextConfig(**kwargs) -def _cached_fxmarty_tiny_random_gemmaforcausallm(): +def _ccached_fxmarty_tiny_random_gemmaforcausallm(): "fxmarty/tiny-random-GemmaForCausalLM" return transformers.GemmaConfig( - { + **{ "architectures": ["GemmaForCausalLM"], "attention_bias": false, "attention_dropout": 0.0, @@ -1366,7 +1366,7 @@ def _cached_fxmarty_tiny_random_gemmaforcausallm(): ) -def _cached_hf_internal_testing_tiny_random_gptneoxforcausallm(): +def _ccached_hf_internal_testing_tiny_random_gptneoxforcausallm(): "hf-internal-testing/tiny-random-GPTNeoXForCausalLM" return transformers.GPTNeoXConfig( **{ @@ -1405,7 +1405,7 @@ def _cached_hf_internal_testing_tiny_random_gptneoxforcausallm(): ) -def _cached_hf_internal_testing_tiny_random_graniteforcausallm(): +def _ccached_hf_internal_testing_tiny_random_graniteforcausallm(): "hf-internal-testing/tiny-random-GraniteForCausalLM" return transformers.GraniteConfig( **{ @@ -1441,7 +1441,7 @@ def _cached_hf_internal_testing_tiny_random_graniteforcausallm(): ) -def _cached_hf_internal_testing_tiny_random_hieraforimageclassification(): +def _ccached_hf_internal_testing_tiny_random_hieraforimageclassification(): "hf-internal-testing/tiny-random-HieraForImageClassification" return transformers.HieraConfig( **{ @@ -1482,7 +1482,7 @@ def _cached_hf_internal_testing_tiny_random_hieraforimageclassification(): ) -def _cached_fxmarty_tiny_llama_fast_tokenizer(): +def _ccached_fxmarty_tiny_llama_fast_tokenizer(): "fxmarty/tiny-llama-fast-tokenizer" return transformers.LlamaConfig( **{ @@ -1516,7 +1516,7 @@ def _cached_fxmarty_tiny_llama_fast_tokenizer(): ) -def _cached_sshleifer_tiny_marian_en_de(): +def _ccached_sshleifer_tiny_marian_en_de(): "sshleifer/tiny-marian-en-de" return transformers.MarianConfig( **{ @@ -1567,7 +1567,7 @@ def _cached_sshleifer_tiny_marian_en_de(): ) -def _cached_hf_internal_testing_tiny_random_maskformerforinstancesegmentation(): +def _ccached_hf_internal_testing_tiny_random_maskformerforinstancesegmentation(): "hf-internal-testing/tiny-random-MaskFormerForInstanceSegmentation" t64 = textwrap.dedent( """ @@ -1625,7 +1625,7 @@ def _cached_hf_internal_testing_tiny_random_maskformerforinstancesegmentation(): return transformers.MaskFormerConfig(**kwargs) -def _cached_echarlaix_tiny_random_mistral(): +def _ccached_echarlaix_tiny_random_mistral(): "echarlaix/tiny-random-mistral" return transformers.MistralConfig( **{ @@ -1660,7 +1660,7 @@ def _cached_echarlaix_tiny_random_mistral(): ) -def _cached_hf_internal_testing_tiny_random_mobilevit(): +def _ccached_hf_internal_testing_tiny_random_mobilevit(): "hf-internal-testing/tiny-random-mobilevit" t64 = textwrap.dedent( """ @@ -2823,7 +2823,7 @@ def _cached_hf_internal_testing_tiny_random_mobilevit(): return transformers.MobileViTConfig(**kwargs) -def _cached_hf_internal_testing_tiny_random_moonshineforconditionalgeneration(): +def _ccached_hf_internal_testing_tiny_random_moonshineforconditionalgeneration(): "hf-internal-testing/tiny-random-MoonshineForConditionalGeneration" return transformers.MoonshineConfig( **{ @@ -2859,7 +2859,7 @@ def _cached_hf_internal_testing_tiny_random_moonshineforconditionalgeneration(): ) -def _cached_hf_internal_testing_tiny_random_olmoforcausallm(): +def _ccached_hf_internal_testing_tiny_random_olmoforcausallm(): "hf-internal-testing/tiny-random-OlmoForCausalLM" return transformers.OlmoConfig( **{ @@ -2889,7 +2889,7 @@ def _cached_hf_internal_testing_tiny_random_olmoforcausallm(): ) -def _cached_hf_internal_testing_tiny_random_olmo2forcausallm(): +def _ccached_hf_internal_testing_tiny_random_olmo2forcausallm(): "hf-internal-testing/tiny-random-Olmo2ForCausalLM" return transformers.Olmo2Config( **{ @@ -2919,7 +2919,7 @@ def _cached_hf_internal_testing_tiny_random_olmo2forcausallm(): ) -def _cached_echarlaix_tiny_random_phiforcausallm(): +def _ccached_echarlaix_tiny_random_phiforcausallm(): "echarlaix/tiny-random-PhiForCausalLM" return transformers.PhiConfig( **{ @@ -2957,7 +2957,7 @@ def _cached_echarlaix_tiny_random_phiforcausallm(): ) -def _cached_xenova_tiny_random_phi3forcausallm(): +def _ccached_xenova_tiny_random_phi3forcausallm(): "Xenova/tiny-random-Phi3ForCausalLM" return transformers.Phi3Config( **{ @@ -2992,7 +2992,7 @@ def _cached_xenova_tiny_random_phi3forcausallm(): ) -def _cached_fxmarty_pix2struct_tiny_random(): +def _ccached_fxmarty_pix2struct_tiny_random(): "fxmarty/pix2struct-tiny-random" return transformers.Pix2StructConfig( **{ @@ -3059,7 +3059,7 @@ def _cached_fxmarty_pix2struct_tiny_random(): ) -def _cached_fxmarty_tiny_dummy_qwen2(): +def _ccached_fxmarty_tiny_dummy_qwen2(): "fxmarty/tiny-dummy-qwen2" return transformers.Qwen2Config( **{ @@ -3091,7 +3091,7 @@ def _cached_fxmarty_tiny_dummy_qwen2(): ) -def _cached_hf_internal_testing_tiny_random_vitmsnforimageclassification(): +def _ccached_hf_internal_testing_tiny_random_vitmsnforimageclassification(): "hf-internal-testing/tiny-random-ViTMSNForImageClassification" return transformers.ViTMSNConfig( **{ @@ -3116,7 +3116,7 @@ def _cached_hf_internal_testing_tiny_random_vitmsnforimageclassification(): ) -def _cached_hf_internal_testing_tiny_random_yolosmodel(): +def _ccached_hf_internal_testing_tiny_random_yolosmodel(): "hf-internal-testing/tiny-random-YolosModel" return transformers.YolosConfig( **{ @@ -3152,7 +3152,7 @@ def _cached_hf_internal_testing_tiny_random_yolosmodel(): ) -def _cached_hf_internal_testing_tiny_xlm_roberta(): +def _ccached_hf_internal_testing_tiny_xlm_roberta(): "hf-internal-testing/tiny-xlm-roberta" return transformers.XLMRobertaConfig( **{ @@ -3191,7 +3191,7 @@ def _cached_hf_internal_testing_tiny_xlm_roberta(): ) -def _cached_hf_m4_tiny_random_idefics(): +def _ccached_hf_m4_tiny_random_idefics(): "HuggingFaceM4/tiny-random-idefics" return transformers.IdeficsConfig( **{ @@ -3257,3 +3257,135 @@ def _cached_hf_m4_tiny_random_idefics(): "word_embed_proj_dim": 16, } ) + + +def _ccached_openai_whisper_tiny(): + "openai/whisper-tiny" + return transformers.WhisperConfig( + **{ + "_name_or_path": "openai/whisper-tiny", + "activation_dropout": 0.0, + "activation_function": "gelu", + "architectures": ["WhisperForConditionalGeneration"], + "attention_dropout": 0.0, + "begin_suppress_tokens": [220, 50257], + "bos_token_id": 50257, + "d_model": 384, + "decoder_attention_heads": 6, + "decoder_ffn_dim": 1536, + "decoder_layerdrop": 0.0, + "decoder_layers": 4, + "decoder_start_token_id": 50258, + "dropout": 0.0, + "encoder_attention_heads": 6, + "encoder_ffn_dim": 1536, + "encoder_layerdrop": 0.0, + "encoder_layers": 4, + "eos_token_id": 50257, + "forced_decoder_ids": [[1, 50259], [2, 50359], [3, 50363]], + "init_std": 0.02, + "is_encoder_decoder": true, + "max_length": 448, + "max_source_positions": 1500, + "max_target_positions": 448, + "model_type": "whisper", + "num_hidden_layers": 4, + "num_mel_bins": 80, + "pad_token_id": 50257, + "scale_embedding": false, + "suppress_tokens": [ + 1, + 2, + 7, + 8, + 9, + 10, + 14, + 25, + 26, + 27, + 28, + 29, + 31, + 58, + 59, + 60, + 61, + 62, + 63, + 90, + 91, + 92, + 93, + 359, + 503, + 522, + 542, + 873, + 893, + 902, + 918, + 922, + 931, + 1350, + 1853, + 1982, + 2460, + 2627, + 3246, + 3253, + 3268, + 3536, + 3846, + 3961, + 4183, + 4667, + 6585, + 6647, + 7273, + 9061, + 9383, + 10428, + 10929, + 11938, + 12033, + 12331, + 12562, + 13793, + 14157, + 14635, + 15265, + 15618, + 16553, + 16604, + 18362, + 18956, + 20075, + 21675, + 22520, + 26130, + 26161, + 26435, + 28279, + 29464, + 31650, + 32302, + 32470, + 36865, + 42863, + 47425, + 49870, + 50254, + 50258, + 50358, + 50359, + 50360, + 50361, + 50362, + ], + "torch_dtype": "float32", + "transformers_version": "4.27.0.dev0", + "use_cache": true, + "vocab_size": 51865, + } + ) diff --git a/onnx_diagnostic/torch_models/hghub/model_inputs.py b/onnx_diagnostic/torch_models/hghub/model_inputs.py index bfc5eebf..4aa143d8 100644 --- a/onnx_diagnostic/torch_models/hghub/model_inputs.py +++ b/onnx_diagnostic/torch_models/hghub/model_inputs.py @@ -103,6 +103,14 @@ def reduce_model_config(config: Any, task: str) -> Dict[str, Any]: config.vision_config.num_hidden_layers = min( config.vision_config.num_hidden_layers, 2 ) + elif task == "automatic-speech-recognition": + kwargs = {} + if hasattr(config, "num_decoder_layers"): + config.num_decoder_layers = min(config.num_decoder_layers, 2) + if hasattr(config, "decoder_layers"): + config.decoder_layers = min(config.decoder_layers, 2) + if hasattr(config, "num_hidden_layers"): + config.num_hidden_layers = min(config.num_hidden_layers, 2) else: raise NotImplementedError(f"Input generation for task {task!r} not implemented yet.") @@ -274,6 +282,35 @@ def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callabl num_channels=3 if config is None else config.vision_config.num_channels, ) fct = get_inputs_for_image_text_to_text # type: ignore + elif task == "automatic-speech-recognition": + if config is not None: + check_hasattr( + config, + "d_model", + "decoder_attention_heads", + "decoder_layers", + "encoder_attention_heads", + "encoder_layers", + "max_source_positions", + "num_hidden_layers", + "vocab_size", + ) + kwargs = dict( + batch_size=2, + sequence_length=30, + dummy_max_token_id=31000 if config is None else config.vocab_size, + max_source_positions=1500 if config is None else config.max_source_positions, + d_model=384 if config is None else config.d_model, + num_hidden_layers=4 if config is None else config.num_hidden_layers, + encoder_attention_heads=6 if config is None else config.encoder_attention_heads, + encoder_layers=4 if config is None else config.encoder_layers, + decoder_attention_heads=6 if config is None else config.decoder_attention_heads, + decoder_layers=4 if config is None else config.decoder_layers, + head_dim=( + 64 if config is None else (config.d_model // config.encoder_attention_heads) + ), + ) + fct = get_inputs_for_speech_automatic_recognition # type: ignore else: raise NotImplementedError(f"Input generation for task {task!r} not implemented yet.") @@ -289,6 +326,7 @@ def get_untrained_model_with_inputs( verbose: int = 0, dynamic_rope: Optional[bool] = None, same_as_pretrained: bool = False, + use_preinstalled: bool = True, ) -> Dict[str, Any]: """ Gets a non initialized model similar to the original model @@ -305,6 +343,7 @@ def get_untrained_model_with_inputs( :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`) :param same_as_pretrained: if True, do not change the default values to get a smaller model + :param use_preinstalled: use preinstalled configurations :return: dictionary with a model, inputs, dynamic shapes, and the configuration Example: @@ -326,8 +365,10 @@ def get_untrained_model_with_inputs( """ if verbose: print(f"[get_untrained_model_with_inputs] model_id={model_id!r}") + if use_preinstalled: + print(f"[get_untrained_model_with_inputs] use preinstalled {model_id!r}") if config is None: - config = get_pretrained_config(model_id) + config = get_pretrained_config(model_id, use_preinstalled=use_preinstalled) archs = config.architectures # type: ignore assert archs is not None and len(archs) == 1, ( f"Unable to determine the architecture for model {model_id!r}, " @@ -477,6 +518,9 @@ def get_inputs_for_text_generation( :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1`` :return: dictionary """ + if head_dim is None: + assert config, "head_dim is None, the value cannot be set without a configuration" + head_dim = config.hidden_size // config.num_attention_heads batch = torch.export.Dim("batch", min=1, max=1024) seq_length = torch.export.Dim("seq_length", min=1, max=4096) cache_length = torch.export.Dim("cache_length", min=1, max=4096) @@ -717,11 +761,128 @@ def get_inputs_for_text2text_generation( return dict(inputs=inputs, dynamic_shapes=shapes) +def get_inputs_for_speech_automatic_recognition( + model: torch.nn.Module, + config: Optional[Any], + dummy_max_token_id: int, + max_source_positions: int, + d_model: int, + num_hidden_layers: int, + encoder_attention_heads: int, + encoder_layers: int, + decoder_layers: int, + head_dim: int, + batch_size: int = 2, + sequence_length: int = 30, + **kwargs, +): + """ + Generates input for task ``text2text-generation``. + + :param model: model to get the missing information + :param config: configuration used to generate the model + :param batch_size: batch size + :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1`` + :return: dictionary + + Stolen inputs for one model. + + :: + + dict( + cache_position:T7s4, + past_key_values:EncoderDecoderCache( + self_attention_cache=DynamicCache[serialized](#2[#0[],#0[]]), + cross_attention_cache=DynamicCache[serialized](#2[#0[],#0[]]) + ), + decoder_input_ids:T7s1x4, + encoder_outputs:BaseModelOutput(last_hidden_state:T1s1x1500x384), + use_cache:bool,return_dict:bool + ) + dict( + cache_position:T7s1, + past_key_values:EncoderDecoderCache( + self_attention_cache=DynamicCache[serialized](#2[ + #4[T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64], + #4[T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64,T1s1x6x4x64] + ]), + cross_attention_cache=DynamicCache[serialized](#2[ + #4[T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64], + #4[T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64,T1s1x6x1500x64] + ]), + ), + decoder_input_ids:T7s1x1, + encoder_outputs:BaseModelOutput(last_hidden_state:T1s1x1500x384), + use_cache:bool,return_dict:bool + ) + """ + batch = torch.export.Dim("batch", min=1, max=1024) + seq_length = torch.export.Dim("seq_length", min=1, max=4096) + + shapes = { + "decoder_input_ids": {0: batch, 1: seq_length}, + "cache_position": {0: seq_length}, + "encoder_outputs": [{0: batch}], + "past_key_values": [ + [ + [{0: batch} for _ in range(num_hidden_layers)], + [{0: batch} for _ in range(num_hidden_layers)], + ], + [ + [{0: batch} for _ in range(num_hidden_layers)], + [{0: batch} for _ in range(num_hidden_layers)], + ], + ], + } + inputs = dict( + decoder_input_ids=torch.randint( + 0, dummy_max_token_id, (batch_size, sequence_length) + ).to(torch.int64), + cache_position=(torch.arange(sequence_length) + 5).to(torch.int64), + encoder_outputs=transformers.modeling_outputs.BaseModelOutput( + last_hidden_state=torch.randn(batch_size, max_source_positions, d_model) + ), + past_key_values=make_encoder_decoder_cache( + make_dynamic_cache( + [ + ( + torch.randn( + batch_size, encoder_attention_heads, encoder_layers, head_dim + ), + torch.randn( + batch_size, encoder_attention_heads, encoder_layers, head_dim + ), + ) + for i in range(num_hidden_layers) + ] + ), + make_dynamic_cache( + [ + ( + torch.randn( + batch_size, encoder_attention_heads, max_source_positions, head_dim + ), + torch.randn( + batch_size, encoder_attention_heads, max_source_positions, head_dim + ), + ) + for i in range(num_hidden_layers) + ] + ), + ), + # one these is selected based on the forward method signature + # encoder_last_hidden_state=torch.randn(batch_size, sequence_length2, encoder_dim), + # encoder_outputs=torch.randn(batch_size, sequence_length2, encoder_dim), + ) + return dict(inputs=inputs, dynamic_shapes=shapes) + + def get_get_inputs_function_for_tasks() -> Dict[str, Callable]: """Returns all the function producing dummy inputs for every task.""" return { + "automatic-speech-recognition": get_inputs_for_speech_automatic_recognition, "image-classification": get_inputs_for_image_classification, + "image-text-to-text": get_inputs_for_image_text_to_text, "text-generation": get_inputs_for_text_generation, "text2text-generation": get_inputs_for_text2text_generation, - "image-text-to-text": get_inputs_for_image_text_to_text, } diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index 9ce0a69f..d79ff78e 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -3,10 +3,11 @@ import os from typing import Any, Dict, List, Optional, Tuple, Union import time +import onnx import torch from ..helpers import max_diff, string_type, string_diff from ..helpers.helper import flatten_object -from ..helpers.ort_session import make_feeds +from ..helpers.rt_helper import make_feeds from ..helpers.torch_test_helper import to_any, torch_deepcopy from ..torch_export_patches import bypass_export_some_errors from .hghub import get_untrained_model_with_inputs @@ -23,16 +24,7 @@ def empty(value: Any) -> bool: def _ds_clean(v): - return ( - str(v) - .replace(",min=None", "") - .replace(",max=None", "") - .replace(",_factory=True", "") - .replace("", "") - .replace("_DimHint(type=<_DimHintType.DYNAMIC: 3>)", "DYNAMIC") - .replace("_DimHint(type=<_DimHintType.AUTO: 3>)", "AUTO") - ) + return string_type(v) def get_inputs_for_task(task: str, config: Optional[Any] = None) -> Dict[str, Any]: @@ -161,25 +153,25 @@ def version_summary() -> Dict[str, Union[int, float, str]]: try: import transformers - summary["version_transformers"] = transformers.__version__ + summary["version_transformers"] = getattr(transformers, "__version__", "?") except ImportError: pass try: import onnx - summary["version_onnx"] = onnx.__version__ + summary["version_onnx"] = getattr(onnx, "__version__", "?") except ImportError: pass try: import onnxscript - summary["version_onnxscript"] = onnxscript.__version__ + summary["version_onnxscript"] = getattr(onnxscript, "__version__", "?") except ImportError: pass try: import onnxruntime - summary["version_onnxruntime"] = onnxruntime.__version__ + summary["version_onnxruntime"] = getattr(onnxruntime, "__version__", "?") except ImportError: pass import onnx_diagnostic @@ -235,6 +227,7 @@ def validate_model( """ assert not trained, f"trained={trained} not supported yet" summary = version_summary() + folder_name = None if dump_folder: folder_name = _make_folder_name( model_id, exporter, optimization, dtype=dtype, device=device @@ -246,12 +239,12 @@ def validate_model( summary["dump_folder_name"] = folder_name if verbose: print(f"[validate_model] dump into {folder_name!r}") - else: - folder_name = None + if verbose: print(f"[validate_model] validate model id {model_id!r}") print("[validate_model] get dummy inputs...") summary["model_id"] = model_id + begin = time.perf_counter() if quiet: try: @@ -266,7 +259,7 @@ def validate_model( if drop_inputs: if verbose: - print(f"[validate_model] drop inputs {drop_inputs!r}") + print(f"[validate_model] -- drop inputs {drop_inputs!r}") print(f"[validate_model] current inputs: {string_type(data['inputs'])}") print( f"[validate_model] current dynnamic_shapes: " @@ -309,6 +302,7 @@ def validate_model( summary["model_id"] = model_id if verbose: + print("[validate_model] --") print(f"[validate_model] task={data['task']}") print(f"[validate_model] size={data['size'] / 2**20} Mb") print(f"[validate_model] n_weights={data['n_weights'] / 1e6} millions parameters") @@ -316,10 +310,11 @@ def validate_model( print(f"[validate_model] +INPUT {k}={string_type(v, with_shape=True)}") for k, v in data["dynamic_shapes"].items(): print(f"[validate_model] +SHAPE {k}={_ds_clean(v)}") + print("[validate_model] --") if do_run: if verbose: - print("[validate_model] run the model...") + print("[validate_model] -- run the model...") print(f"[validate_model] inputs={string_type(data['inputs'], with_shape=True)}") # We make a copy of the input just in case the model modifies them inplace hash_inputs = string_type(data["inputs"], with_shape=True) @@ -348,12 +343,15 @@ def validate_model( if exporter: print( - f"[validate_model] export the model with {exporter!r}, " + f"[validate_model] -- export the model with {exporter!r}, " f"optimization={optimization!r}" ) if patch: if verbose: - print("[validate_model] applies patches before exporting") + print( + f"[validate_model] applies patches before exporting " + f"stop_if_static={stop_if_static}" + ) with bypass_export_some_errors( # type: ignore patch_transformers=True, stop_if_static=stop_if_static, @@ -426,7 +424,7 @@ def validate_model( if "exported_program" in data: ep = data["exported_program"] if verbose: - print(f"[validate_model] dumps exported program in {dump_folder!r}...") + print(f"[validate_model] -- dumps exported program in {dump_folder!r}...") with open(os.path.join(dump_folder, f"{folder_name}.ep"), "w") as f: f.write(str(ep)) with open(os.path.join(dump_folder, f"{folder_name}.graph"), "w") as f: @@ -438,7 +436,10 @@ def validate_model( if verbose: print(f"[validate_model] dumps onnx program in {dump_folder!r}...") onnx_file_name = os.path.join(dump_folder, f"{folder_name}.onnx") - epo.save(onnx_file_name, external_data=True) + if isinstance(epo, onnx.model_container.ModelContainer): + epo.save(onnx_file_name, all_tensors_to_one_file=True) + else: + epo.save(onnx_file_name, external_data=True) if verbose: print("[validate_model] done (dump onnx)") if verbose: @@ -459,7 +460,7 @@ def validate_model( summary.update(summary_valid) if verbose: - print("[validate_model] done (final)") + print("[validate_model] -- done (final)") return summary, data @@ -505,6 +506,16 @@ def call_exporter( optimization=optimization, ) return summary, data + if exporter.startswith("custom-"): + # torch export + summary, data = call_torch_export_custom( + exporter=exporter, + data=data, + quiet=quiet, + verbose=verbose, + optimization=optimization, + ) + return summary, data raise NotImplementedError( f"export with {exporter!r} and optimization={optimization!r} not implemented yet" ) @@ -536,6 +547,7 @@ def call_torch_export_export( "export-strict", "export-nostrict", }, f"Unexpected value for exporter={exporter!r}" + assert not optimization, f"No optimization is implemented for exporter={exporter!r}" assert "model" in data, f"model is missing from data: {sorted(data)}" assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}" summary: Dict[str, Union[str, int, float]] = {} @@ -625,6 +637,108 @@ def call_torch_export_export( return summary, data +def validate_onnx_model( + data: Dict[str, Any], + quiet: bool = False, + verbose: int = 0, + optimization: Optional[str] = None, +): + """ + Verifies that an onnx model produces the same + expected outputs. + + :param data: dictionary with all the necessary inputs, the dictionary must + contains keys ``model`` and ``inputs_export`` + :param quiet: catch exception or not + :param verbose: verbosity + :param optimization: optimization to do + :return: two dictionaries, one with some metrics, + another one with whatever the function produces + """ + import onnxruntime + + summary = {} + flat_inputs = flatten_object(data["inputs"], drop_keys=True) + d = flat_inputs[0].get_device() + providers = ( + ["CPUExecutionProvider"] + if d < 0 + else ["CUDAExecutionProvider", "CPUExecutionProvider"] + ) + if "onnx_file_name" in data: + source = data["onnx_file_name"] + summary["onnx_filename"] = source + summary["onnx_size"] = os.stats(source).st_size + else: + assert ( + "onnx_program" in data + ), f"onnx_program is missing from data which has {sorted(data)}" + source = data["onnx_program"].model_proto.SerializeToString() + assert len(source) < 2**31, f"The model is highger than 2Gb: {len(source) / 2**30} Gb" + summary["onnx_size"] = len(source) + if verbose: + print(f"[validate_onnx_model] verify onnx model with providers {providers}...") + + begin = time.perf_counter() + if quiet: + try: + sess = onnxruntime.InferenceSession(source, providers=providers) + except Exception as e: + summary["ERR_onnx_ort_create"] = str(e) + data["ERR_onnx_ort_create"] = e + summary["time_onnx_ort_create"] = time.perf_counter() - begin + return summary, data + else: + sess = onnxruntime.InferenceSession(source, providers=providers) + + summary["time_onnx_ort_create"] = time.perf_counter() - begin + data["onnx_ort_sess"] = sess + if verbose: + print("[validate_onnx_model] done (ort_session)") + + # make_feeds + if verbose: + print("[validate_onnx_model] -- make_feeds...") + print(f"[validate_onnx_model] inputs={string_type(data['inputs'], with_shape=True)}") + feeds = make_feeds( + [i.name for i in sess.get_inputs()], + data["inputs"], + use_numpy=True, + check_flatten=False, + ) + if verbose: + print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}") + summary["onnx_ort_inputs"] = string_type(feeds, with_shape=True) + if verbose: + print("[validate_onnx_model] done (make_feeds)") + + # run ort + if verbose: + print("[validate_onnx_model] run session...") + begin = time.perf_counter() + if quiet: + try: + got = sess.run(None, feeds) + except Exception as e: + summary["ERR_onnx_ort_run"] = str(e) + data["ERR_onnx_ort_run"] = e + summary["time_onnx_ort_run"] = time.perf_counter() - begin + return summary, data + else: + got = sess.run(None, feeds) + if verbose: + print("[validate_onnx_model] done (run)") + print(f"[validate_onnx_model] got={string_type(got, with_shape=True)}") + + # compute discrepancies + disc = max_diff(data["expected"], got, flatten=True) + if verbose: + print(f"[validate_onnx_model] discrepancies={string_diff(disc)}") + for k, v in disc.items(): + summary[f"disc_onnx_ort_run_{k}"] = v + return summary, data + + def call_torch_export_onnx( data: Dict[str, Any], exporter: str, @@ -730,98 +844,170 @@ def call_torch_export_onnx( return summary, data -def validate_onnx_model( +def call_torch_export_custom( data: Dict[str, Any], + exporter: str, quiet: bool = False, verbose: int = 0, optimization: Optional[str] = None, ): """ - Verifies that an onnx model produces the same - expected outputs. + Exports a model into onnx. + If a patch must be applied, it should be before this functions. :param data: dictionary with all the necessary inputs, the dictionary must contains keys ``model`` and ``inputs_export`` + :param exporter: exporter to call :param quiet: catch exception or not :param verbose: verbosity :param optimization: optimization to do :return: two dictionaries, one with some metrics, another one with whatever the function produces """ - import onnxruntime - - summary = {} - flat_inputs = flatten_object(data["inputs"], drop_keys=True) - d = flat_inputs[0].get_device() - providers = ( - ["CPUExecutionProvider"] - if d < 0 - else ["CUDAExecutionProvider", "CPUExecutionProvider"] - ) - if "onnx_file_name" in data: - source = data["onnx_file_name"] - summary["onnx_filename"] = source - summary["onnx_size"] = os.stats(source).st_size - else: - assert ( - "onnx_program" in data - ), f"onnx_program is missing from data which has {sorted(data)}" - source = data["onnx_program"].model_proto.SerializeToString() - assert len(source) < 2**31, f"The model is highger than 2Gb: {len(source) / 2**30} Gb" - summary["onnx_size"] = len(source) + assert optimization in { + "", + "default", + "default+onnxruntime", + None, + }, f"unexpected value for optimization={optimization}" + assert exporter in { + "custom-strict", + "custom-strict-dec", + "custom-strict-all", + "custom-nostrict", + "custom-nostrict-dec", + "custom-nostrict-all", + }, f"Unexpected value for exporter={exporter!r}" + assert "model" in data, f"model is missing from data: {sorted(data)}" + assert "inputs_export" in data, f"inputs_export is missing from data: {sorted(data)}" + summary: Dict[str, Union[str, int, float]] = {} + dynamo = "nostrict" not in exporter + args, kwargs = split_args_kwargs(data["inputs_export"]) + ds = data.get("dynamic_shapes", None) if verbose: - print(f"[validate_onnx_model] verify onnx model with providers {providers}...") - - begin = time.perf_counter() - if quiet: - try: - sess = onnxruntime.InferenceSession(source, providers=providers) - except Exception as e: - summary["ERR_onnx_ort_create"] = str(e) - data["ERR_onnx_ort_create"] = e - summary["time_onnx_ort_create"] = time.perf_counter() - begin - return summary, data - else: - sess = onnxruntime.InferenceSession(source, providers=providers) + print( + f"[call_torch_export_custom] exporter={exporter!r}, " + f"optimization={optimization!r}" + ) + print(f"[call_torch_export_custom] args={string_type(args, with_shape=True)}") + print(f"[call_torch_export_custom] kwargs={string_type(kwargs, with_shape=True)}") + print(f"[call_torch_export_custom] dynamic_shapes={_ds_clean(ds)}") + print("[call_torch_export_custom] export...") + summary["export_exporter"] = exporter + summary["export_optimization"] = optimization or "" + summary["export_dynamo"] = dynamo + summary["export_args"] = string_type(args, with_shape=True) + summary["export_kwargs"] = string_type(kwargs, with_shape=True) - summary["time_onnx_ort_create"] = time.perf_counter() - begin - data["onnx_ort_sess"] = sess - if verbose: - print("[validate_onnx_model] done (ort_session)") + from experimental_experiment.torch_interpreter import to_onnx, ExportOptions + from experimental_experiment.xbuilder import OptimizationOptions - # make_feeds - if verbose: - print("[validate_onnx_model] make_feeds...") - print(f"[validate_onnx_model] inputs={string_type(data['inputs'], with_shape=True)}") - feeds = make_feeds([i.name for i in sess.get_inputs()], data["inputs"], use_numpy=True) - if verbose: - print(f"[validate_onnx_model] ort inputs={string_type(feeds, with_shape=True)}") - summary["onnx_ort_inputs"] = string_type(feeds, with_shape=True) - if verbose: - print("[validate_onnx_model] done (make_feeds)") + export_options = ExportOptions( + strict="nostrict" not in exporter, + decomposition_table=( + "dec" if "-dec" in exporter else ("all" if "-all" in exporter else None) + ), + ) + options = OptimizationOptions(patterns=optimization) if optimization else None - # run ort - if verbose: - print("[validate_onnx_model] run session...") begin = time.perf_counter() if quiet: try: - got = sess.run(None, feeds) + epo, opt_stats = to_onnx( + data["model"], + args, + kwargs=kwargs, + dynamic_shapes=ds, + export_options=export_options, + options=options, + optimize=bool(optimization), + large_model=True, + return_optimize_report=True, + ) except Exception as e: - summary["ERR_onnx_ort_run"] = str(e) - data["ERR_onnx_ort_run"] = e - summary["time_onnx_ort_run"] = time.perf_counter() - begin + summary["ERR_export_export"] = str(e) + data["ERR_export_export"] = e + summary["time_export_export"] = time.perf_counter() - begin return summary, data else: - got = sess.run(None, feeds) - if verbose: - print("[validate_onnx_model] done (run)") - print(f"[validate_onnx_model] got={string_type(got, with_shape=True)}") + epo, opt_stats = to_onnx( + data["model"], + args, + kwargs=kwargs, + dynamic_shapes=ds, + export_options=export_options, + options=options, + optimize=bool(optimization), + large_model=True, + return_optimize_report=True, + ) - # compute discrepancies - disc = max_diff(data["expected"], got, flatten=True) + new_stat = {} + if "optimization" in opt_stats: + added, removed, time_in = 0, 0, 0.0 + max_iter = 0 + applied = {} + matched = set() + n_applied = 0 + by_pattern = {} + by_pattern_n = {} + by_iter = {} + cst_added, cst_removed, cst_time_in = 0, 0, 0.0 + + for obs in opt_stats["optimization"]: + pattern = obs["pattern"] + if pattern == "constant_folding": + cst_added += obs.get("added", 0) + cst_removed += obs.get("removed", 0) + cst_time_in += obs.get("time_in", 0) + if pattern not in by_pattern: + by_pattern[pattern] = 0 + by_pattern_n[pattern] = 0 + by_iter[pattern] = 0 + time_in += obs.get("time_in", 0) + added += obs.get("added", 0) + removed += obs.get("removed", 0) + max_iter = max(max_iter, obs.get("iteration", 0)) + by_pattern[pattern] += obs.get("time_in", 0) + by_pattern_n[pattern] += obs.get("added", 0) - obs.get("removed", 0) + if not pattern.startswith("match"): + by_iter[pattern] = max(by_iter[pattern], obs.get("iteration", 0)) + p = obs["pattern"] + if p.startswith("match_"): + matched.add(p) + elif p.startswith("apply_"): + key = f"op_opt_{p}" + key2 = f"op_opt_maxiter_{p}" + if key not in applied: + applied[key] = 1 + applied[key2] = obs["iteration"] + else: + applied[key] += 1 + applied[key2] = max(obs["iteration"], applied[key2]) + n_applied += 1 + + new_stat.update( + dict( + onnx_opt_optimized=1, + op_opt_all_time_in=time_in, + op_opt_all_added=added, + op_opt_all_removed=removed, + op_opt_max_iter=max_iter, + op_opt_unique_matched=len(matched), + op_opt_unique_applied=len(applied), + op_opt_n_applied=n_applied, + time_export_optimization=time_in, + op_opt_export_optimization=time_in, + op_opt_cst_time_in=cst_time_in, + op_opt_cst_added=cst_added, + op_opt_cst_removed=cst_removed, + ) + ) + + summary["time_export_export"] = time.perf_counter() - begin + summary.update(new_stat) + assert epo is not None, "no onnx export was found" if verbose: - print(f"[validate_onnx_model] discrepancies={string_diff(disc)}") - for k, v in disc.items(): - summary[f"disc_onnx_ort_run_{k}"] = v + print("[call_torch_export_custom] done (export)") + data["onnx_program"] = epo return summary, data