diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index b6d36fa3..26ab8af8 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,9 @@ Change Logs 0.4.3 +++++ +* :pr:`75`: renames bypass_export_some_patches into torch_export_patches, keeps the old name +* :pr:`74`: increases the list of class/architectures + 0.4.2 +++++ diff --git a/README.rst b/README.rst index cf240828..e89ded82 100644 --- a/README.rst +++ b/README.rst @@ -30,13 +30,13 @@ it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using .. code-block:: python - with bypass_export_some_errors(patch_transformers=True) as f: + with torch_export_patches(patch_transformers=True) as f: ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes) # ... It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...). See `documentation of onnx-diagnostic `_ and -`bypass_export_some_errors `_. +`torch_export_patches `_. Getting started +++++++++++++++ diff --git a/_doc/api/torch_export_patches/index.rst b/_doc/api/torch_export_patches/index.rst index a536d9dc..b35e8c62 100644 --- a/_doc/api/torch_export_patches/index.rst +++ b/_doc/api/torch_export_patches/index.rst @@ -13,6 +13,6 @@ onnx_diagnostic.torch_export_patches :members: :no-undoc-members: -.. autofunction:: onnx_diagnostic.torch_export_patches.bypass_export_some_errors +.. autofunction:: onnx_diagnostic.torch_export_patches.torch_export_patches .. autofunction:: onnx_diagnostic.torch_export_patches.register_additional_serialization_functions diff --git a/_doc/examples/plot_export_hub_codellama.py b/_doc/examples/plot_export_hub_codellama.py index 533ad0dd..bf932afc 100644 --- a/_doc/examples/plot_export_hub_codellama.py +++ b/_doc/examples/plot_export_hub_codellama.py @@ -29,7 +29,7 @@ get_pretrained_config, task_from_id, ) -from onnx_diagnostic.torch_export_patches import bypass_export_some_errors +from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str model_id = "codellama/CodeLlama-7b-Python-hf" @@ -90,9 +90,9 @@ # # The model uses :class:`transformers.cache_utils.DynamicCache`. # It still requires patches to be exportable (control flow). -# See :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors` +# See :func:`onnx_diagnostic.torch_export_patches.torch_export_patches` -with bypass_export_some_errors(patch_transformers=True) as f: +with torch_export_patches(patch_transformers=True) as f: ep = torch.export.export( model, (), diff --git a/_doc/examples/plot_export_locate_issue.py b/_doc/examples/plot_export_locate_issue.py index 3cfbdad1..3c1006aa 100644 --- a/_doc/examples/plot_export_locate_issue.py +++ b/_doc/examples/plot_export_locate_issue.py @@ -26,7 +26,7 @@ import traceback import torch from onnx_diagnostic import doc -from onnx_diagnostic.torch_export_patches import bypass_export_some_errors +from onnx_diagnostic.torch_export_patches import torch_export_patches class ModelWithIssue(torch.nn.Module): @@ -80,12 +80,12 @@ def forward(self, x: torch.Tensor, ys: list[torch.Tensor]): # Stop when a dynamic dimension turns static # ========================================== # -# We use :func:`bypass_export_some_errors -# ` +# We use :func:`torch_export_patches +# ` # to replace torch implementation by a new one raising the exception # mentioned in previous section. -with bypass_export_some_errors(stop_if_static=1, verbose=1): +with torch_export_patches(stop_if_static=1, verbose=1): try: torch.export.export(model, inputs, dynamic_shapes=dyn_shapes) except (AssertionError, torch._dynamo.exc.TorchRuntimeError) as e: diff --git a/_doc/examples/plot_export_tiny_llm_patched.py b/_doc/examples/plot_export_tiny_llm_patched.py index f1576baf..60a20d15 100644 --- a/_doc/examples/plot_export_tiny_llm_patched.py +++ b/_doc/examples/plot_export_tiny_llm_patched.py @@ -69,7 +69,7 @@ from onnx_diagnostic import doc from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered from onnx_diagnostic.helpers import string_type -from onnx_diagnostic.torch_export_patches import bypass_export_some_errors +from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_models.llms import get_tiny_llm @@ -101,10 +101,10 @@ # %% # If they are not registered, function -# func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors` +# func:`onnx_diagnostic.torch_export_patches.torch_export_patches` # should take care of it. Then we export. -with bypass_export_some_errors(patch_transformers=True, verbose=10) as modificator: +with torch_export_patches(patch_transformers=True, verbose=10) as modificator: assert is_cache_dynamic_registered() # it must be true here ep = torch.export.export( untrained_model, @@ -126,7 +126,7 @@ cloned_inputs = copy.deepcopy(inputs) -with bypass_export_some_errors(patch_transformers=True, verbose=10) as modificator: +with torch_export_patches(patch_transformers=True, verbose=10) as modificator: ep = torch.export.export( model, (), diff --git a/_doc/examples/plot_export_tiny_phi2.py b/_doc/examples/plot_export_tiny_phi2.py index ddcf1480..b91bdcb4 100644 --- a/_doc/examples/plot_export_tiny_phi2.py +++ b/_doc/examples/plot_export_tiny_phi2.py @@ -26,7 +26,7 @@ from onnx_diagnostic.helpers import max_diff, string_diff, string_type from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered from onnx_diagnostic.helpers.rt_helper import make_feeds -from onnx_diagnostic.torch_export_patches import bypass_export_some_errors +from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str from onnx_diagnostic.torch_models.hghub import ( get_untrained_model_with_inputs, @@ -76,7 +76,7 @@ # ++++++ -with bypass_export_some_errors(patch_transformers=True) as modificator: +with torch_export_patches(patch_transformers=True) as modificator: # Unnecessary steps but useful in case of an error # We check the cache is registered. @@ -110,7 +110,7 @@ # applies :meth:`torch.export.ExportedProgram.run_decompositions` # may export local pieces of the model again. -with bypass_export_some_errors(patch_transformers=True): +with torch_export_patches(patch_transformers=True): epo = torch.onnx.export( ep, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes, dynamo=True ) diff --git a/_doc/examples/plot_export_with_dynamic_cache.py b/_doc/examples/plot_export_with_dynamic_cache.py index ed7d9ae7..5dfa9c3f 100644 --- a/_doc/examples/plot_export_with_dynamic_cache.py +++ b/_doc/examples/plot_export_with_dynamic_cache.py @@ -30,7 +30,7 @@ make_dynamic_cache, ) from onnx_diagnostic.export import ModelInputs -from onnx_diagnostic.torch_export_patches import bypass_export_some_errors +from onnx_diagnostic.torch_export_patches import torch_export_patches class Model(torch.nn.Module): @@ -99,14 +99,14 @@ def forward(self, cache, z): # And finally the export. # The export is simple if ``transformers>=4.50``, otherwise, # transformers needs to be patched. -# :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors` +# :func:`onnx_diagnostic.torch_export_patches.torch_export_patches` # registers functions to serialize ``DynamicCache``. This one is modified to make # the shape inference implemented in :epkg:`torch` happy. if has_transformers("4.50"): ep = torch.export.export(model, inputs[0], dynamic_shapes=ds[0], strict=False) else: - with bypass_export_some_errors(patch_transformers=True) as modificator: + with torch_export_patches(patch_transformers=True) as modificator: ep = torch.export.export( model, modificator(inputs[0]), dynamic_shapes=ds[0], strict=False ) diff --git a/_doc/index.rst b/_doc/index.rst index 917df6b1..6d4792dd 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -24,12 +24,12 @@ Sources available at `github/onnx-diagnostic =4.52") - with bypass_export_some_errors(patch_transformers=True, verbose=10): + with torch_export_patches(patch_transformers=True, verbose=10): torch.export.export( model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False ) diff --git a/_unittests/ut_tasks/test_tasks_image_text_to_text.py b/_unittests/ut_tasks/test_tasks_image_text_to_text.py index a99db3f4..731371ef 100644 --- a/_unittests/ut_tasks/test_tasks_image_text_to_text.py +++ b/_unittests/ut_tasks/test_tasks_image_text_to_text.py @@ -2,7 +2,7 @@ import torch from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, has_transformers, has_torch from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs -from onnx_diagnostic.torch_export_patches import bypass_export_some_errors +from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str @@ -20,7 +20,7 @@ def test_image_text_to_text(self): raise unittest.SkipTest("The model has control flow.") if not has_torch("2.7.99"): raise unittest.SkipTest("sym_max does not work with dynamic dimension") - with bypass_export_some_errors(patch_transformers=True, verbose=10): + with torch_export_patches(patch_transformers=True, verbose=10): torch.export.export( model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False ) diff --git a/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py b/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py index b441646a..bbfb34ab 100644 --- a/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py +++ b/_unittests/ut_tasks/test_tasks_zero_shot_image_classification.py @@ -2,7 +2,7 @@ import torch from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_torch from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs -from onnx_diagnostic.torch_export_patches import bypass_export_some_errors +from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str @@ -17,7 +17,7 @@ def test_zero_shot_image_classification(self): model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] model(**inputs) model(**data["inputs2"]) - with bypass_export_some_errors(patch_transformers=True, verbose=10): + with torch_export_patches(patch_transformers=True, verbose=10): torch.export.export( model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds), strict=False ) diff --git a/_unittests/ut_torch_export_patches/test_dynamic_class.py b/_unittests/ut_torch_export_patches/test_dynamic_class.py index 32831134..1cdb19eb 100644 --- a/_unittests/ut_torch_export_patches/test_dynamic_class.py +++ b/_unittests/ut_torch_export_patches/test_dynamic_class.py @@ -13,7 +13,7 @@ from onnx_diagnostic.helpers import string_type from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( - bypass_export_some_errors, + torch_export_patches, ) from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs @@ -56,7 +56,7 @@ def forward(self, x, cache): DYN = torch.export.Dim.DYNAMIC # patching - with bypass_export_some_errors(patch_transformers=True): + with torch_export_patches(patch_transformers=True): got = model(*inputs) self.assertEqualArray(expected, got) ep = torch.export.export( @@ -260,7 +260,7 @@ def forward(self, x, dc): self.assertEqualArray(expected, got) return - with bypass_export_some_errors(patch_transformers=True): + with torch_export_patches(patch_transformers=True): ep = torch.export.export(model, (), kwargs=inputs) args, _spec = torch.utils._pytree.tree_flatten(inputs) @@ -292,7 +292,7 @@ def test_phi2_export_module(self): str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True) ) - with bypass_export_some_errors(patch_transformers=True): + with torch_export_patches(patch_transformers=True): ep = torch.export.export( model, (), @@ -330,7 +330,7 @@ def test_phi2_export_interpreter(self): str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True) ) - with bypass_export_some_errors(patch_transformers=True): + with torch_export_patches(patch_transformers=True): ep = torch.export.export( model, (), diff --git a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py index 699ccdd7..b9adbc7d 100644 --- a/_unittests/ut_torch_export_patches/test_onnx_export_errors.py +++ b/_unittests/ut_torch_export_patches/test_onnx_export_errors.py @@ -10,7 +10,7 @@ ) from onnx_diagnostic.helpers import string_type from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( - bypass_export_some_errors, + torch_export_patches, ) @@ -34,7 +34,7 @@ def __init__(self): cache = MambaCache(_config(), max_batch_size=1, device="cpu") - with bypass_export_some_errors(verbose=1): + with torch_export_patches(verbose=1): values, spec = py_pytree.tree_flatten(cache) cache2 = py_pytree.tree_unflatten(values, spec) self.assertEqual(cache.max_batch_size, cache2.max_batch_size) @@ -78,7 +78,7 @@ def forward(self, x: torch.Tensor, cache: MambaCache): model = Model() model(x, cache) - with bypass_export_some_errors(verbose=1): + with torch_export_patches(verbose=1): cache = MambaCache(_config(), max_batch_size=1, device="cpu") torch.export.export(Model(), (x, cache)) @@ -113,7 +113,7 @@ def forward(self, x: torch.Tensor, cache: MambaCache): model(x, cache) DYN = torch.export.Dim.DYNAMIC - with bypass_export_some_errors(): + with torch_export_patches(): cache = MambaCache(_config(), max_batch_size=2, device="cpu") torch.export.export( Model(), diff --git a/_unittests/ut_torch_export_patches/test_patch_base_class.py b/_unittests/ut_torch_export_patches/test_patch_base_class.py index 7dc4c708..6e10e91f 100644 --- a/_unittests/ut_torch_export_patches/test_patch_base_class.py +++ b/_unittests/ut_torch_export_patches/test_patch_base_class.py @@ -1,7 +1,10 @@ import unittest import torch from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout -from onnx_diagnostic.torch_export_patches import bypass_export_some_errors +from onnx_diagnostic.torch_export_patches import ( + torch_export_patches, + bypass_export_some_errors, +) class TestPatchBaseClass(ExtTestCase): @@ -73,6 +76,28 @@ def m1(self, x): model = Model() x = torch.arange(4) self.assertEqualArray(x * x, model(x)) + with torch_export_patches(custom_patches=[patched_Model], verbose=10): + self.assertEqualArray(x**3, model(x)) + + @hide_stdout() + def test_bypass_export_some_errors(self): + class Model2(torch.nn.Module): + def m2(self, x): + return x * x + + def forward(self, x): + return self.m2(x) + + class patched_Model: + _PATCHED_CLASS_ = Model2 + _PATCHES_ = ["m2"] + + def m2(self, x): + return x**3 + + model = Model2() + x = torch.arange(4) + self.assertEqualArray(x * x, model(x)) with bypass_export_some_errors(custom_patches=[patched_Model], verbose=10): self.assertEqualArray(x**3, model(x)) diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization.py b/_unittests/ut_torch_export_patches/test_patch_serialization.py index 619bd76d..1e0bcb12 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization.py @@ -9,7 +9,7 @@ flatten_unflatten_for_dynamic_shapes, ) from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( - bypass_export_some_errors, + torch_export_patches, ) from onnx_diagnostic.helpers.torch_test_helper import torch_deepcopy @@ -21,7 +21,7 @@ def test_encoder_decoder_cache_flatten(self): make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]), ) - with bypass_export_some_errors(): + with torch_export_patches(): flat, _spec = torch.utils._pytree.tree_flatten(cache) self.assertEqual( "#4[T1s4x4x4,T1s4x4x4,T1s5x5x5,T1s5x5x5]", @@ -39,7 +39,7 @@ def test_encoder_decoder_cache_deepcopy(self): make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]), make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]), ) - with bypass_export_some_errors(): + with torch_export_patches(): cache2 = torch_deepcopy([cache]) self.assertEqualAny([cache], cache2) @@ -65,13 +65,13 @@ def forward(self, cache): [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]], ] - with bypass_export_some_errors(patch_transformers=True): + with torch_export_patches(patch_transformers=True): torch.export.export(model, (cache,), dynamic_shapes=(ds,)) @ignore_warnings(UserWarning) def test_dynamic_cache_flatten(self): cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) - with bypass_export_some_errors(): + with torch_export_patches(): flat, _spec = torch.utils._pytree.tree_flatten(cache) self.assertEqual( "#2[T1s4x4x4,T1s4x4x4]", @@ -97,13 +97,13 @@ def forward(self, cache): DYN = torch.export.Dim.DYNAMIC ds = [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]] - with bypass_export_some_errors(): + with torch_export_patches(): torch.export.export(model, (cache,), dynamic_shapes=(ds,)) @ignore_warnings(UserWarning) def test_dynamic_cache_deepcopy(self): cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) - with bypass_export_some_errors(): + with torch_export_patches(): cache2 = torch_deepcopy([cache]) self.assertEqualAny([cache], cache2) @@ -111,7 +111,7 @@ def test_dynamic_cache_deepcopy(self): def test_base_model_output_deepcopy(self): bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) self.assertEqual(bo.__class__.__name__, "BaseModelOutput") - with bypass_export_some_errors(): + with torch_export_patches(): bo2 = torch_deepcopy([bo]) self.assertIsInstance(bo2, list) self.assertEqual(bo2[0].__class__.__name__, "BaseModelOutput") @@ -120,7 +120,7 @@ def test_base_model_output_deepcopy(self): @ignore_warnings(UserWarning) def test_base_model_output_string_type(self): bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) - with bypass_export_some_errors(): + with torch_export_patches(): self.assertEqual( "BaseModelOutput(last_hidden_state:T1s4x4x4)", self.string_type(bo, with_shape=True), @@ -129,7 +129,7 @@ def test_base_model_output_string_type(self): @ignore_warnings(UserWarning) def test_base_model_output_flatten(self): bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) - with bypass_export_some_errors(): + with torch_export_patches(): flat, _spec = torch.utils._pytree.tree_flatten(bo) self.assertEqual( "#1[T1s4x4x4]", @@ -153,13 +153,13 @@ def forward(self, cache): DYN = torch.export.Dim.DYNAMIC ds = [{0: DYN}] - with bypass_export_some_errors(): + with torch_export_patches(): torch.export.export(model, (bo,), dynamic_shapes=(ds,)) @ignore_warnings(UserWarning) def test_base_model_output_unflatten_flatten(self): bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) - with bypass_export_some_errors(patch_transformers=True): + with torch_export_patches(patch_transformers=True): flat, _spec = torch.utils._pytree.tree_flatten(bo) unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True) self.assertIsInstance(unflat, dict) @@ -170,7 +170,7 @@ def test_base_sliding_window_cache_unflatten_flatten(self): cache = make_sliding_window_cache( [(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))] ) - with bypass_export_some_errors(): + with torch_export_patches(): cache2 = torch_deepcopy([cache]) self.assertEqualAny([cache], cache2) @@ -192,7 +192,7 @@ def forward(self, cache): DYN = torch.export.Dim.DYNAMIC ds = [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]] - with bypass_export_some_errors(patch_transformers=True): + with torch_export_patches(patch_transformers=True): torch.export.export(model, (cache,), dynamic_shapes=(ds,)) @ignore_warnings(UserWarning) @@ -200,7 +200,7 @@ def test_sliding_window_cache_flatten(self): cache = make_sliding_window_cache( [(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))] ) - with bypass_export_some_errors(): + with torch_export_patches(): flat, _spec = torch.utils._pytree.tree_flatten(cache) self.assertEqual( "#2[T1s4x4x4x4,T1s4x4x4x4]", diff --git a/_unittests/ut_torch_models/test_hghub_model.py b/_unittests/ut_torch_models/test_hghub_model.py index 6083812f..7b2b49e1 100644 --- a/_unittests/ut_torch_models/test_hghub_model.py +++ b/_unittests/ut_torch_models/test_hghub_model.py @@ -10,7 +10,7 @@ from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config from onnx_diagnostic.torch_models.hghub.hub_data import load_models_testing -from onnx_diagnostic.torch_export_patches import bypass_export_some_errors +from onnx_diagnostic.torch_export_patches import torch_export_patches class TestHuggingFaceHubModel(ExtTestCase): @@ -91,7 +91,7 @@ def test_get_untrained_model_with_inputs_clip_vit(self): mid = "openai/clip-vit-base-patch16" data = get_untrained_model_with_inputs(mid, verbose=1) model, inputs = data["model"], data["inputs"] - with bypass_export_some_errors(patch_transformers=True): + with torch_export_patches(patch_transformers=True): model(**inputs) # different expected value for different version of transformers self.assertIn((data["size"], data["n_weights"]), [(188872708, 47218177)]) diff --git a/_unittests/ut_torch_models/test_tiny_llms_bypassed.py b/_unittests/ut_torch_models/test_tiny_llms_bypassed.py index 4ec3e943..35b3c224 100644 --- a/_unittests/ut_torch_models/test_tiny_llms_bypassed.py +++ b/_unittests/ut_torch_models/test_tiny_llms_bypassed.py @@ -6,7 +6,7 @@ from onnx_diagnostic.torch_models.llms import get_tiny_llm from onnx_diagnostic.torch_models.llms import get_phi2 from onnx_diagnostic.helpers import string_type -from onnx_diagnostic.torch_export_patches import bypass_export_some_errors +from onnx_diagnostic.torch_export_patches import torch_export_patches from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( patched_DynamicCache, ) @@ -22,7 +22,7 @@ def test_export_tiny_llm_2_bypassed(self): {"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs) ) - with bypass_export_some_errors( + with torch_export_patches( patch_torch=False, patch_transformers=True, catch_constraints=False, verbose=10 ) as modificator: @@ -61,11 +61,11 @@ def test_export_phi2_2_bypassed(self): self.assertEqual( {"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs) ) - with bypass_export_some_errors(patch_transformers=True) as modificator: + with torch_export_patches(patch_transformers=True) as modificator: inputs = modificator(inputs) ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False) assert ep - with bypass_export_some_errors(patch_transformers=True) as modificator: + with torch_export_patches(patch_transformers=True) as modificator: inputs = modificator(inputs) ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False) assert ep diff --git a/_unittests/ut_torch_models/test_tiny_llms_onnx.py b/_unittests/ut_torch_models/test_tiny_llms_onnx.py index eda2d3de..48c13f06 100644 --- a/_unittests/ut_torch_models/test_tiny_llms_onnx.py +++ b/_unittests/ut_torch_models/test_tiny_llms_onnx.py @@ -10,7 +10,7 @@ requires_transformers, ) from onnx_diagnostic.torch_models.llms import get_tiny_llm -from onnx_diagnostic.torch_export_patches import bypass_export_some_errors +from onnx_diagnostic.torch_export_patches import torch_export_patches try: from experimental_experiment.torch_interpreter import to_onnx, ExportOptions @@ -58,7 +58,7 @@ def test_onnx_export_tiny_llm_xdbg(self): self.assertEqual( {"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs) ) - with bypass_export_some_errors(patch_transformers=True): + with torch_export_patches(patch_transformers=True): onx = to_onnx( model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"], verbose=1 ) @@ -74,7 +74,7 @@ def test_bypass_onnx_export_tiny_llm_official_nopositionids(self): del inputs["position_ids"] del ds["position_ids"] self.assertEqual({"attention_mask", "past_key_values", "input_ids"}, set(inputs)) - with bypass_export_some_errors(patch_transformers=True, verbose=1) as modificator: + with torch_export_patches(patch_transformers=True, verbose=1) as modificator: new_inputs = modificator(copy.deepcopy(inputs)) ep = torch.onnx.export( model, @@ -101,7 +101,7 @@ def test_bypass_onnx_export_tiny_llm_official_full(self): self.assertEqual( {"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs) ) - with bypass_export_some_errors(patch_transformers=True, verbose=1) as modificator: + with torch_export_patches(patch_transformers=True, verbose=1) as modificator: new_inputs = modificator(copy.deepcopy(inputs)) ep = torch.onnx.export( model, @@ -134,7 +134,7 @@ def test_bypass_onnx_export_tiny_llm_xdbg(self): self.assertEqual( {"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs) ) - with bypass_export_some_errors(patch_transformers=True, verbose=2) as modificator: + with torch_export_patches(patch_transformers=True, verbose=2) as modificator: new_inputs = modificator(inputs) onx = to_onnx( model, diff --git a/onnx_diagnostic/torch_export_patches/__init__.py b/onnx_diagnostic/torch_export_patches/__init__.py index ff978ee3..0b241ba0 100644 --- a/onnx_diagnostic/torch_export_patches/__init__.py +++ b/onnx_diagnostic/torch_export_patches/__init__.py @@ -1,4 +1,8 @@ from .onnx_export_errors import ( - bypass_export_some_errors, + torch_export_patches, register_additional_serialization_functions, ) + + +# bypass_export_some_errors is the first name given to the patches. +bypass_export_some_errors = torch_export_patches # type: ignore diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 54b3e875..27cbfe14 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -93,7 +93,7 @@ def register_additional_serialization_functions( @contextlib.contextmanager -def bypass_export_some_errors( +def torch_export_patches( patch_sympy: bool = True, patch_torch: bool = True, patch_transformers: bool = False, @@ -145,13 +145,13 @@ def bypass_export_some_errors( :: - with bypass_export_some_errors(patch_transformers=True) as modificator: + with torch_export_patches(patch_transformers=True) as modificator: inputs = modificator(inputs) onx = to_onnx(..., inputs, ...) :: - with bypass_export_some_errors(patch_transformers=True) as modificator: + with torch_export_patches(patch_transformers=True) as modificator: inputs = modificator(inputs) onx = torch.onnx.export(..., inputs, ...) @@ -159,7 +159,7 @@ def bypass_export_some_errors( :: - with bypass_export_some_errors(patch_transformers=True) as modificator: + with torch_export_patches(patch_transformers=True) as modificator: inputs = modificator(inputs) ep = torch.export.export(..., inputs, ...) @@ -190,7 +190,7 @@ def bypass_export_some_errors( if verbose: print( - "[bypass_export_some_errors] replace torch.jit.isinstance, " + "[torch_export_patches] replace torch.jit.isinstance, " "torch._dynamo.mark_static_address" ) @@ -210,8 +210,8 @@ def bypass_export_some_errors( f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None) if verbose: - print(f"[bypass_export_some_errors] sympy.__version__={sympy.__version__!r}") - print("[bypass_export_some_errors] patch sympy") + print(f"[torch_export_patches] sympy.__version__={sympy.__version__!r}") + print("[torch_export_patches] patch sympy") sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}" @@ -228,9 +228,9 @@ def bypass_export_some_errors( ) if verbose: - print(f"[bypass_export_some_errors] torch.__version__={torch.__version__!r}") - print(f"[bypass_export_some_errors] stop_if_static={stop_if_static!r}") - print("[bypass_export_some_errors] patch pytorch") + print(f"[torch_export_patches] torch.__version__={torch.__version__!r}") + print(f"[torch_export_patches] stop_if_static={stop_if_static!r}") + print("[torch_export_patches] patch pytorch") # torch.jit.isinstance f_jit_isinstance = torch.jit.isinstance @@ -252,7 +252,7 @@ def bypass_export_some_errors( # torch._export.non_strict_utils.produce_guards_and_solve_constraints if catch_constraints: if verbose: - print("[bypass_export_some_errors] modifies shape constraints") + print("[torch_export_patches] modifies shape constraints") f_produce_guards_and_solve_constraints = ( torch._export.non_strict_utils.produce_guards_and_solve_constraints ) @@ -277,22 +277,20 @@ def bypass_export_some_errors( ShapeEnv._log_guard_remember = ShapeEnv._log_guard if verbose: - print( - "[bypass_export_some_errors] assert when a dynamic dimension turns static" - ) - print("[bypass_export_some_errors] replaces ShapeEnv._set_replacement") + print("[torch_export_patches] assert when a dynamic dimension turns static") + print("[torch_export_patches] replaces ShapeEnv._set_replacement") f_shape_env__set_replacement = ShapeEnv._set_replacement ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement if verbose: - print("[bypass_export_some_errors] replaces ShapeEnv._log_guard") + print("[torch_export_patches] replaces ShapeEnv._log_guard") f_shape_env__log_guard = ShapeEnv._log_guard ShapeEnv._log_guard = patched_ShapeEnv._log_guard if stop_if_static > 1: if verbose: - print("[bypass_export_some_errors] replaces ShapeEnv._check_frozen") + print("[torch_export_patches] replaces ShapeEnv._check_frozen") f_shape_env__check_frozen = ShapeEnv._check_frozen ShapeEnv._check_frozen = patched_ShapeEnv._check_frozen @@ -305,7 +303,7 @@ def bypass_export_some_errors( import transformers print( - f"[bypass_export_some_errors] transformers.__version__=" + f"[torch_export_patches] transformers.__version__=" f"{transformers.__version__!r}" ) revert_patches_info = patch_module_or_classes( @@ -314,7 +312,7 @@ def bypass_export_some_errors( if custom_patches: if verbose: - print("[bypass_export_some_errors] applies custom patches") + print("[torch_export_patches] applies custom patches") revert_custom_patches_info = patch_module_or_classes( custom_patches, verbose=verbose ) @@ -326,7 +324,7 @@ def bypass_export_some_errors( fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x) if verbose: - print("[bypass_export_some_errors] done patching") + print("[torch_export_patches] done patching") try: yield fct_callable @@ -336,7 +334,7 @@ def bypass_export_some_errors( ####### if verbose: - print("[bypass_export_some_errors] remove patches") + print("[torch_export_patches] remove patches") if patch_sympy: # tracked by https://github.com/pytorch/pytorch/issues/143494 @@ -346,7 +344,7 @@ def bypass_export_some_errors( delattr(sympy.core.numbers.IntegerConstant, "name") if verbose: - print("[bypass_export_some_errors] restored sympy functions") + print("[torch_export_patches] restored sympy functions") ####### # torch @@ -362,22 +360,22 @@ def bypass_export_some_errors( torch._meta_registrations._broadcast_shapes = f__broadcast_shapes if verbose: - print("[bypass_export_some_errors] restored pytorch functions") + print("[torch_export_patches] restored pytorch functions") if stop_if_static: if verbose: - print("[bypass_export_some_errors] restored ShapeEnv._set_replacement") + print("[torch_export_patches] restored ShapeEnv._set_replacement") ShapeEnv._set_replacement = f_shape_env__set_replacement if verbose: - print("[bypass_export_some_errors] restored ShapeEnv._log_guard") + print("[torch_export_patches] restored ShapeEnv._log_guard") ShapeEnv._log_guard = f_shape_env__log_guard if stop_if_static > 1: if verbose: - print("[bypass_export_some_errors] restored ShapeEnv._check_frozen") + print("[torch_export_patches] restored ShapeEnv._check_frozen") ShapeEnv._check_frozen = f_shape_env__check_frozen if catch_constraints: @@ -389,11 +387,11 @@ def bypass_export_some_errors( f__check_input_constraints_for_graph ) if verbose: - print("[bypass_export_some_errors] restored shape constraints") + print("[torch_export_patches] restored shape constraints") if custom_patches: if verbose: - print("[bypass_export_some_errors] unpatch custom patches") + print("[torch_export_patches] unpatch custom patches") unpatch_module_or_classes( custom_patches, revert_custom_patches_info, verbose=verbose ) @@ -404,7 +402,7 @@ def bypass_export_some_errors( if patch_transformers: if verbose: - print("[bypass_export_some_errors] unpatch transformers") + print("[torch_export_patches] unpatch transformers") unpatch_module_or_classes( patch_transformers_list, revert_patches_info, verbose=verbose ) diff --git a/onnx_diagnostic/torch_models/test_helper.py b/onnx_diagnostic/torch_models/test_helper.py index a39266c2..aa93ece7 100644 --- a/onnx_diagnostic/torch_models/test_helper.py +++ b/onnx_diagnostic/torch_models/test_helper.py @@ -12,7 +12,7 @@ from ..helpers.torch_test_helper import to_any, torch_deepcopy from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes from ..tasks import random_input_kwargs -from ..torch_export_patches import bypass_export_some_errors +from ..torch_export_patches import torch_export_patches from ..torch_export_patches.patch_inputs import use_dyn_not_str from .hghub import get_untrained_model_with_inputs @@ -242,9 +242,9 @@ def validate_model( depend on the the exporter :param quiet: if quiet, catches exception if any issue :param patch: applies patches (``patch_transformers=True``) before exporting, - see :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors` + see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches` :param stop_if_static: stops if a dynamic dimension becomes static, - see :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors` + see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches` :param dump_folder: dumps everything in a subfolder of this one :param drop_inputs: drops this list of inputs (given their names) :param ortfusiontype: runs ort fusion, the parameters defines the fusion type, @@ -417,7 +417,7 @@ def validate_model( f"[validate_model] applies patches before exporting " f"stop_if_static={stop_if_static}" ) - with bypass_export_some_errors( # type: ignore + with torch_export_patches( # type: ignore patch_transformers=True, stop_if_static=stop_if_static, verbose=max(0, verbose - 1),