diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 967c5b0c..1b2f8f44 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,8 @@ Change Logs 0.7.10 ++++++ +* :pr:`218`: patches used sdpa_mask_recent_torch used from _vmap_for_bhqkv + 0.7.9 +++++ diff --git a/README.rst b/README.rst index c876fbb2..363c7eea 100644 --- a/README.rst +++ b/README.rst @@ -58,7 +58,7 @@ Getting started git clone https://github.com/sdpython/onnx-diagnostic.git cd onnx-diagnostic - pip install -e . + pip install -e . -v or diff --git a/_doc/examples/plot_dump_intermediate_results.py b/_doc/examples/plot_dump_intermediate_results.py index 9fe760a3..bc3090ef 100644 --- a/_doc/examples/plot_dump_intermediate_results.py +++ b/_doc/examples/plot_dump_intermediate_results.py @@ -129,7 +129,7 @@ # Let's create the ONNX model. ep = torch.export.export(model, inputs, dynamic_shapes=ds) -epo = torch.onnx.export(ep, dynamo=True) +epo = torch.onnx.export(ep) epo.optimize() epo.save("plot_dump_intermediate_results.onnx") diff --git a/_doc/examples/plot_export_tiny_phi2.py b/_doc/examples/plot_export_tiny_phi2.py index cc2d1df5..b4979334 100644 --- a/_doc/examples/plot_export_tiny_phi2.py +++ b/_doc/examples/plot_export_tiny_phi2.py @@ -126,7 +126,7 @@ with torch_export_patches(patch_transformers=True): epo = torch.onnx.export( - ep, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes, dynamo=True + ep, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=dynamic_shapes ) # %% diff --git a/_unittests/ut_export/test_jit.py b/_unittests/ut_export/test_jit.py index e4ec87f2..90302ddb 100644 --- a/_unittests/ut_export/test_jit.py +++ b/_unittests/ut_export/test_jit.py @@ -7,7 +7,6 @@ ignore_warnings, requires_onnxscript, ) -from onnx_diagnostic.reference import ExtendedReferenceEvaluator from onnx_diagnostic.helpers.torch_helper import is_torchdynamo_exporting try: @@ -62,7 +61,7 @@ def test_dummy_loop(self): @hide_stdout() @ignore_warnings(UserWarning) - @requires_onnxscript("0.5") + @requires_onnxscript("0.7") def test_export_loop_onnxscript(self): class Model(torch.nn.Module): def forward(self, images, position): @@ -75,19 +74,6 @@ def forward(self, images, position): y = torch.arange(5, dtype=torch.int64) + 1 expected = model(x, y) - name = self.get_dump_file("test_export_loop_onnxscript.onnx") - torch.onnx.export( - model, - (x, y), - name, - dynamic_axes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}}, - dynamo=False, - ) - ref = ExtendedReferenceEvaluator(name) - feeds = dict(images=x.numpy(), position=y.numpy()) - got = ref.run(None, feeds)[0] - self.assertEqualArray(expected, got) - DYN = torch.export.Dim.DYNAMIC ep = torch.export.export( model, @@ -103,7 +89,6 @@ def forward(self, images, position): (x, y), name2, dynamic_shapes={"images": {0: "batch", 1: "maxdim"}, "position": {0: "batch"}}, - dynamo=True, fallback=False, ) import onnxruntime diff --git a/_unittests/ut_helpers/test_cache_helper.py b/_unittests/ut_helpers/test_cache_helper.py index 82439106..23563875 100644 --- a/_unittests/ut_helpers/test_cache_helper.py +++ b/_unittests/ut_helpers/test_cache_helper.py @@ -1,4 +1,5 @@ import unittest +from typing import Callable import torch import transformers from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers @@ -19,6 +20,13 @@ ) from onnx_diagnostic.torch_export_patches import torch_export_patches +try: + from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( + patched__vmap_for_bhqkv, + ) +except ImportError: + patched__vmap_for_bhqkv = None + class TestCacheHelpers(ExtTestCase): def test_string_type(self): @@ -69,7 +77,7 @@ def test_replace_by(self): ) DYN = torch.export.Dim.DYNAMIC - nargs, nkwargs, nds = convert_dynamic_axes_into_dynamic_shapes( + _nargs, _nkwargs, nds = convert_dynamic_axes_into_dynamic_shapes( None, args=tuple(), kwargs=kwargs, dynamic_axes=dynamic_shapes ) self.assertEqual(dynamic_shapes, nds) @@ -254,6 +262,92 @@ def test_unflatten_flatten_hybrid_cache(self): self.string_type(unflat, with_shape=True), ) + @unittest.skipIf(patched__vmap_for_bhqkv is None, "transformers too old") + def test_cache_update_padding_mask_function_vmap(self): + def causal_mask_function( + batch_idx: int, head_idx: int, q_idx: int, kv_idx: int + ) -> bool: + return kv_idx <= q_idx + + def padding_mask_function(padding_mask: torch.Tensor) -> Callable: + def inner_mask(batch_idx: int, head_idx: int, q_idx: int, kv_idx: int) -> bool: + return padding_mask[batch_idx, kv_idx] + + return inner_mask + + def and_masks(*mask_functions: list[Callable]) -> Callable: + if not all(callable(arg) for arg in mask_functions): + raise RuntimeError( + f"All inputs should be callable mask_functions: {mask_functions}" + ) + + def and_mask(batch_idx, head_idx, q_idx, kv_idx): + result = q_idx.new_ones((), dtype=torch.bool) + for mask in mask_functions: + result = result & mask(batch_idx, head_idx, q_idx, kv_idx).to( + result.device + ) + return result + + return and_mask + + def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: + dimensions = [(None, None, None, 0), (None, None, 0, None)] + if bh_indices: + dimensions.extend([(None, 0, None, None), (0, None, None, None)]) + for dims in dimensions: + mask_function = torch.vmap(mask_function, in_dims=dims, out_dims=0) + return mask_function + + class Model(torch.nn.Module): + def forward(self, x, mask): + mask_function = and_masks(causal_mask_function, padding_mask_function(mask)) + batch_arange = torch.arange(x.shape[0]) + head_arange = torch.arange(x.shape[3]) + kv_arange = torch.arange(x.shape[1]) + cache_position = torch.arange(x.shape[2]) + f = patched__vmap_for_bhqkv(mask_function) + causal_mask = f(batch_arange, head_arange, cache_position, kv_arange) + return x + causal_mask.to(x.dtype) + + inputs = { + "x": torch.rand((4, 4, 4, 4), dtype=torch.float32), + "mask": torch.ones((4, 4), dtype=torch.int64), + } + model = Model() + expected = model(**inputs) + self.assertNotEmpty(expected) + DYN = torch.export.Dim.DYNAMIC + ep = torch.export.export( + model, + (), + kwargs=inputs, + dynamic_shapes={"x": {0: DYN, 1: DYN, 2: DYN, 3: DYN}, "mask": {0: DYN, 1: DYN}}, + ) + self.assertNotEmpty(ep) + + def test_simple_indices(self): + class Model(torch.nn.Module): + def forward(self, x, i, j): + return x[i, j] + + inputs = ( + torch.rand((4, 4), dtype=torch.float32), + torch.randint(0, 4, (4, 4, 4, 4), dtype=torch.int64), + torch.randint(0, 4, (4, 4, 4, 4), dtype=torch.int64), + ) + model = Model() + expected = model(*inputs) + self.assertEqual(expected.shape, (4, 4, 4, 4)) + DYN = torch.export.Dim.DYNAMIC + sh = {0: DYN, 1: DYN, 2: DYN, 3: DYN} + ep = torch.export.export( + model, + inputs, + dynamic_shapes=({0: DYN, 1: DYN}, sh, sh), + ) + self.assertNotEmpty(ep) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_helpers/test_ort_session_tinyllm.py b/_unittests/ut_helpers/test_ort_session_tinyllm.py index 82a52ce5..cc21c65a 100644 --- a/_unittests/ut_helpers/test_ort_session_tinyllm.py +++ b/_unittests/ut_helpers/test_ort_session_tinyllm.py @@ -87,7 +87,7 @@ def test_check_allruntimes_on_tiny_llm(self): proto = to_onnx(model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=ds) else: proto = torch.onnx.export( - model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=ds, dynamo=True + model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=ds ).model_proto self.dump_onnx("test_check_allruntimes_on_tiny_llm.onnx", proto) diff --git a/_unittests/ut_torch_export_patches/test_dynamic_class.py b/_unittests/ut_torch_export_patches/test_dynamic_class.py index 7e2a3319..da4cbd91 100644 --- a/_unittests/ut_torch_export_patches/test_dynamic_class.py +++ b/_unittests/ut_torch_export_patches/test_dynamic_class.py @@ -3,6 +3,11 @@ import unittest from typing import Any, Dict, List, Tuple import torch + +try: + import transformers.masking_utils as masking_utils +except ImportError: + masking_utils = None from onnx_diagnostic.ext_test_case import ( ExtTestCase, ignore_warnings, @@ -14,7 +19,9 @@ from onnx_diagnostic.torch_export_patches.onnx_export_errors import ( torch_export_patches, ) +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs +import onnx_diagnostic.torch_export_patches.patches.patch_transformers as patch_transformers class TestOnnxExportErrors(ExtTestCase): @@ -305,7 +312,7 @@ def test_phi2_export_module(self): model, (), kwargs=inputs, - dynamic_shapes=dyn_shapes, + dynamic_shapes=use_dyn_not_str(dyn_shapes), strict=False, # True works but then the it fails during the execution ) # ep = ep.run_decompositions() @@ -319,6 +326,7 @@ def test_phi2_export_module(self): @ignore_warnings(UserWarning) @requires_torch("2.9") + @hide_stdout() def test_phi2_export_interpreter(self): data = get_untrained_model_with_inputs("microsoft/phi-2") model, inputs, dyn_shapes = data["model"], data["inputs"], data["dynamic_shapes"] @@ -338,12 +346,17 @@ def test_phi2_export_interpreter(self): str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True) ) - with torch_export_patches(patch_transformers=True): + with torch_export_patches(patch_transformers=True, verbose=1): + if masking_utils is not None: + self.assertEqual( + masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"], + patch_transformers.patched_sdpa_mask_recent_torch, + ) ep = torch.export.export( model, (), kwargs=inputs, - dynamic_shapes=dyn_shapes, + dynamic_shapes=use_dyn_not_str(dyn_shapes), strict=False, # True works but then the it fails during the execution ) # ep = ep.run_decompositions() diff --git a/_unittests/ut_torch_export_patches/test_patch_module.py b/_unittests/ut_torch_export_patches/test_patch_module.py index 670d3ff0..99a165a1 100644 --- a/_unittests/ut_torch_export_patches/test_patch_module.py +++ b/_unittests/ut_torch_export_patches/test_patch_module.py @@ -604,8 +604,8 @@ def loop_body_1(z, iv, x, y): rewritten_expected2 = RewrittenModel2()(x, y) self.assertEqualArray(expected, rewritten_expected2) - if not has_torch("2.9"): - raise unittest.SkipTest("skipped export, torch must be >= 2.9") + if not has_torch("2.10"): + raise unittest.SkipTest("skipped export, torch must be >= 2.10") torch.export.export(RewrittenModel2(), (x, y), dynamic_shapes=ds, strict=False) ep = torch.export.export(Model(), (x, y), dynamic_shapes=ds, strict=False) diff --git a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py index 8c54f679..a54abbe9 100644 --- a/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_serialization_transformers.py @@ -164,7 +164,7 @@ def forward(self, cache): def test_base_model_output_unflatten_flatten(self): bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4))) with torch_export_patches(patch_transformers=True): - flat, _spec = torch.utils._pytree.tree_flatten(bo) + _flat, _spec = torch.utils._pytree.tree_flatten(bo) unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True) self.assertIsInstance(unflat, list) self.assertEqual("#1[T1r3]", self.string_type(unflat)) diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index 7ed791d1..25eb5a73 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -17,7 +17,7 @@ def test_vmap(self): got = patched_vmap(f)(x, y) self.assertEqualArray(expected, got) - @requires_torch("2.9") + @requires_torch("2.10") def test_export_vmap(self): class Model(torch.nn.Module): def forward(self, x, y): @@ -206,10 +206,11 @@ def _vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callabl class Model(torch.nn.Module): def forward(self, batch_arange, head_arange, cache_position, kv_arange): - with TransformGetItemToIndex(): - causal_mask2 = _vmap_for_bhqkv2(mask_function)( - batch_arange, head_arange, cache_position, kv_arange - ) + # with TransformGetItemToIndex(): + # This context as ignored in 2.8 and not any more in 2.9. + causal_mask2 = _vmap_for_bhqkv2(mask_function)( + batch_arange, head_arange, cache_position, kv_arange + ) return causal_mask2 inputs = batch_arange, head_arange, cache_position, kv_arange diff --git a/_unittests/ut_torch_models/test_llm_phi2.py b/_unittests/ut_torch_models/test_llm_phi2.py index c05007a6..77002654 100644 --- a/_unittests/ut_torch_models/test_llm_phi2.py +++ b/_unittests/ut_torch_models/test_llm_phi2.py @@ -8,11 +8,13 @@ ) from onnx_diagnostic.torch_models.llms import get_phi2 from onnx_diagnostic.helpers import string_type +from onnx_diagnostic.torch_export_patches import torch_export_patches +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str class TestLlmPhi(ExtTestCase): def test_get_phi2(self): - data = get_phi2(num_hidden_layers=2) + data = get_phi2(num_hidden_layers=2, batch_size=2) model, inputs = data["model"], data["inputs"] self.assertIn("DynamicCache", string_type(inputs)) model(**inputs) @@ -20,14 +22,37 @@ def test_get_phi2(self): @ignore_warnings(UserWarning) @requires_transformers("4.54") @requires_torch("2.9.99") - def test_export_phi2_1(self): + def test_export_phi2_1_batch_size_1(self): # exporting vmap does not work - data = get_phi2(num_hidden_layers=2) + data = get_phi2(num_hidden_layers=2, batch_size=1) model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + self.assertEqual(inputs["input_ids"].shape[0], 1) self.assertEqual( {"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs) ) - ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds) + with torch.fx.experimental._config.patch( + backed_size_oblivious=True + ), torch_export_patches(patch_transformers=True): + ep = torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) + assert ep + + @ignore_warnings(UserWarning) + @requires_transformers("4.54") + @requires_torch("2.9.99") + def test_export_phi2_1_batch_size_2(self): + # exporting vmap does not work + data = get_phi2(num_hidden_layers=2, batch_size=2) + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + self.assertEqual(inputs["input_ids"].shape[0], 2) + self.assertEqual( + {"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs) + ) + with torch_export_patches(patch_transformers=True): + ep = torch.export.export( + model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds) + ) assert ep diff --git a/_unittests/ut_torch_models/test_tiny_llms.py b/_unittests/ut_torch_models/test_tiny_llms.py index cadbdb5b..ac37a7b7 100644 --- a/_unittests/ut_torch_models/test_tiny_llms.py +++ b/_unittests/ut_torch_models/test_tiny_llms.py @@ -29,7 +29,7 @@ def test_tiny_llm_export_dynamic(self): self.assertEqual( {"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs) ) - with torch_export_patches(patch_transformers=True): + with torch_export_patches(patch_transformers=True, verbose=1): ep = torch.export.export( model, (), diff --git a/_unittests/ut_torch_models/test_tiny_llms_bypassed.py b/_unittests/ut_torch_models/test_tiny_llms_bypassed.py index 02506515..8c3b2bd5 100644 --- a/_unittests/ut_torch_models/test_tiny_llms_bypassed.py +++ b/_unittests/ut_torch_models/test_tiny_llms_bypassed.py @@ -32,7 +32,7 @@ def debug(): print("***", data["dynamic_shapes"]) import torch.export._draft_export - ep, report = torch.export._draft_export.draft_export( + _ep, report = torch.export._draft_export.draft_export( model, (), kwargs=inputs, @@ -56,12 +56,13 @@ def debug(): @ignore_warnings(UserWarning) def test_export_phi2_2_bypassed(self): - data = get_phi2(num_hidden_layers=2) + data = get_phi2(num_hidden_layers=2, batch_size=2) model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] self.assertEqual( {"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs) ) model(**torch_deepcopy(inputs)) + ds = use_dyn_not_str(ds) with torch_export_patches(patch_transformers=True, stop_if_static=1) as modificator: inputs = modificator(inputs) ep = torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False) diff --git a/_unittests/ut_torch_models/test_validate_whole_models.py b/_unittests/ut_torch_models/test_validate_whole_models.py index 801b0e9c..4fd63f02 100644 --- a/_unittests/ut_torch_models/test_validate_whole_models.py +++ b/_unittests/ut_torch_models/test_validate_whole_models.py @@ -96,7 +96,7 @@ def test_f_validate_model_onnx_dynamo_ir(self): ) @requires_torch("2.7") - @requires_onnxscript("0.5") + @requires_onnxscript("0.7") @hide_stdout() @ignore_warnings(FutureWarning) def test_g_validate_model_onnx_dynamo_os_ort(self): diff --git a/_unittests/ut_torch_onnx/test_sbs.py b/_unittests/ut_torch_onnx/test_sbs.py index 4c24ce5e..28f934a1 100644 --- a/_unittests/ut_torch_onnx/test_sbs.py +++ b/_unittests/ut_torch_onnx/test_sbs.py @@ -68,9 +68,7 @@ def forward(self, x): ep = torch.export.export( Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},) ) - epo = torch.onnx.export( - ep, (x,), dynamic_shapes=({0: torch.export.Dim("batch")},), dynamo=True - ) + epo = torch.onnx.export(ep, (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)) onx = epo.model_proto results = list( run_aligned( diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index 4cc3c4d3..85412ec7 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -47,7 +47,7 @@ def run_test(self, fold: str, name: str, verbose=0) -> int: cmds = [sys.executable, "-u", os.path.join(fold, name)] p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) res = p.communicate() - out, err = res + _out, err = res st = err.decode("ascii", errors="ignore") if st and "Traceback" in st: if '"dot" not found in path.' in st: @@ -116,6 +116,13 @@ def add_test_methods(cls): ): reason = "unstable, let's wait for the next version" + if ( + not reason + and name in {"plot_export_tiny_phi2.py"} + and not has_transformers("4.55") + ): + reason = "unstable, let's wait for the next version" + if reason: @unittest.skip(reason) diff --git a/_unittests/ut_xrun_doc/test_documentation_recipes.py b/_unittests/ut_xrun_doc/test_documentation_recipes.py index 59f1b682..b821a457 100644 --- a/_unittests/ut_xrun_doc/test_documentation_recipes.py +++ b/_unittests/ut_xrun_doc/test_documentation_recipes.py @@ -46,7 +46,7 @@ def run_test(self, fold: str, name: str, verbose=0) -> int: cmds = [sys.executable, "-u", os.path.join(fold, name)] p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) res = p.communicate() - out, err = res + _out, err = res st = err.decode("ascii", errors="ignore") if st and "Traceback" in st: if '"dot" not found in path.' in st: diff --git a/_unittests/ut_xrun_doc/test_documentation_technical.py b/_unittests/ut_xrun_doc/test_documentation_technical.py index 5dbfb661..9e450a94 100644 --- a/_unittests/ut_xrun_doc/test_documentation_technical.py +++ b/_unittests/ut_xrun_doc/test_documentation_technical.py @@ -41,7 +41,7 @@ def run_test(self, fold: str, name: str, verbose=0) -> int: cmds = [sys.executable, "-u", os.path.join(fold, name)] p = subprocess.Popen(cmds, stdout=subprocess.PIPE, stderr=subprocess.PIPE) res = p.communicate() - out, err = res + _out, err = res st = err.decode("ascii", errors="ignore") if st and "Traceback" in st: if '"dot" not found in path.' in st: diff --git a/onnx_diagnostic/helpers/cache_helper.py b/onnx_diagnostic/helpers/cache_helper.py index 9e993cc5..9d3a52e0 100644 --- a/onnx_diagnostic/helpers/cache_helper.py +++ b/onnx_diagnostic/helpers/cache_helper.py @@ -270,7 +270,7 @@ def __init__(self): self.num_attention_heads = key_value_pairs[0][0].shape[1] self.num_hidden_layers = len(key_value_pairs) - def get_text_config(self): + def get_text_config(self, *args, **kwargs): return self assert max_cache_len is not None, ( @@ -366,7 +366,7 @@ def __init__(self): self.num_hidden_layers = len(key_value_pairs) self.dtype = dtype - def get_text_config(self): + def get_text_config(self, *args, **kwargs): return self cache = MambaCache( @@ -409,7 +409,7 @@ def __init__(self): self.num_hidden_layers = len(key_value_pairs) self.sliding_window = key_value_pairs[0][0].shape[2] - def get_text_config(self): + def get_text_config(self, *args, **kwargs): return self cache = transformers.cache_utils.SlidingWindowCache( @@ -577,7 +577,7 @@ class _config: sliding_window = _sliding_window num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3 - def get_text_config(self): + def get_text_config(self, *args, **kwargs): return self if layer_types: diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 06ea7c8b..ee7eedc3 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1186,7 +1186,7 @@ def shadowing_names( shadow |= set(i.name for i in g.input) & shadow_context shadow |= set(i.name for i in g.initializer) & shadow_context shadow |= set(i.name for i in g.sparse_initializer) & shadow_context - s, ps, c = shadowing_names( + s, _ps, c = shadowing_names( g.node, verbose=verbose, existing=existing, shadow_context=existing ) shadow |= s diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index ec3af9c6..734aba76 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -543,7 +543,7 @@ def __init__(self, embedding_dim: int = 16, context_size: int = 256): ) def forward(self, x): - B, T, C = x.shape + _B, T, C = x.shape query = self.query(x) key = self.key(x) @@ -866,7 +866,7 @@ def torch_tensor_size(value: Any) -> Any: if value.__class__.__name__ == "MambaCache": return torch_tensor_size(value.conv_states) + torch_tensor_size(value.ssm_states) if value.__class__ in torch.utils._pytree.SUPPORTED_NODES: - args, spec = torch.utils._pytree.tree_flatten(value) + args, _spec = torch.utils._pytree.tree_flatten(value) return sum(torch_tensor_size(a) for a in args) # We should have a code using serialization, deserialization assuming a model diff --git a/onnx_diagnostic/reference/ops/op_scan.py b/onnx_diagnostic/reference/ops/op_scan.py index bcf80966..7c5080ec 100644 --- a/onnx_diagnostic/reference/ops/op_scan.py +++ b/onnx_diagnostic/reference/ops/op_scan.py @@ -26,11 +26,11 @@ def _run( ): ( num_loop_state_vars, - num_scan_outputs, - output_directions, - max_dir_out, - output_axes, - max_axe_out, + _num_scan_outputs, + _output_directions, + _max_dir_out, + _output_axes, + _max_axe_out, state_names_in, state_names_out, scan_names_in, diff --git a/onnx_diagnostic/reference/ort_evaluator.py b/onnx_diagnostic/reference/ort_evaluator.py index 2f9d749c..cba391ef 100644 --- a/onnx_diagnostic/reference/ort_evaluator.py +++ b/onnx_diagnostic/reference/ort_evaluator.py @@ -562,7 +562,7 @@ def _run_if( if key in self._cache: sess = self._cache[key][1] else: - self._cache[key] = onx, sess = self._get_sess_if(node, name, inputs, results) + self._cache[key] = _onx, sess = self._get_sess_if(node, name, inputs, results) assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" feeds = {name: results[name] for name in sess.input_names} @@ -616,7 +616,7 @@ def _run_scan( if key in self._cache: sess = self._cache[key][1] else: - self._cache[key] = onx, sess = self._get_sess_scan(node, name, inputs, results) + self._cache[key] = _onx, sess = self._get_sess_scan(node, name, inputs, results) assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}" feeds = {name: results[name] for name in sess.input_names} diff --git a/onnx_diagnostic/torch_export_patches/eval/model_cases.py b/onnx_diagnostic/torch_export_patches/eval/model_cases.py index caf4c405..6541d5ff 100644 --- a/onnx_diagnostic/torch_export_patches/eval/model_cases.py +++ b/onnx_diagnostic/torch_export_patches/eval/model_cases.py @@ -384,7 +384,7 @@ def add(carry: torch.Tensor, y: torch.Tensor): def forward(self, x): init = torch.zeros_like(x[0]) - carry, out = torch.ops.higher_order.scan( + carry, _out = torch.ops.higher_order.scan( ControlFlowScan.add, [init], [x], additional_inputs=[] ) return carry @@ -429,7 +429,7 @@ def dist(carry: torch.Tensor, x: torch.Tensor): return [carry.clone(), rd] def forward(self, x): - carry, out = torch.ops.higher_order.scan( + _carry, out = torch.ops.higher_order.scan( ControlFlowScanCDist.dist, [x], [x], @@ -483,7 +483,7 @@ def dist(y: torch.Tensor, scanned_x: torch.Tensor): return [y.clone(), rd] def forward(self, x, y): - carry, out = torch.ops.higher_order.scan( + _carry, out = torch.ops.higher_order.scan( ControlFlowScanCDistXY.dist, [y], [x], diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index bbabd431..f115718d 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -439,6 +439,28 @@ def torch_export_patches( f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv + if verbose: + print( + "[torch_export_patches] patches " + "transformers.masking_utils.sdpa_mask_recent_torch" + ) + f_transformers_sdpa_mask_recent_torch = masking_utils.sdpa_mask_recent_torch + masking_utils.sdpa_mask_recent_torch = ( + patch_transformers_list.patched_sdpa_mask_recent_torch + ) + if masking_utils.sdpa_mask == f_transformers_sdpa_mask_recent_torch: + if verbose: + print( + "[torch_export_patches] patches " + "transformers.masking_utils.sdpa_mask" + ) + f_transformers_sdpa_mask = masking_utils.sdpa_mask + masking_utils.sdpa_mask = ( + patch_transformers_list.patched_sdpa_mask_recent_torch + ) + else: + f_transformers_sdpa_mask = None + if ( masking_utils and patch_transformers_list.patch_masking_utils @@ -456,10 +478,37 @@ def torch_export_patches( and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] == f_transformers_eager_mask ): + if verbose: + print( + "[torch_export_patches] patches " + "transformers.masking_utils.eager_mask " + "in ALL_MASK_ATTENTION_FUNCTIONS" + ) masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = ( patch_transformers_list.patched_eager_mask ) + if ( + masking_utils + and patch_transformers_list.patch_masking_utils + and hasattr(masking_utils, "sdpa_mask") + and f_transformers_sdpa_mask is not None + ): + if verbose: + print( + "[torch_export_patches] patches " + "transformers.masking_utils.sdpa_mask " + "in ALL_MASK_ATTENTION_FUNCTIONS" + ) + if ( + "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS + and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] + == f_transformers_sdpa_mask + ): + masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = ( + patch_transformers_list.patched_sdpa_mask_recent_torch + ) + if custom_patches: if verbose: print("[torch_export_patches] applies custom patches") @@ -568,12 +617,31 @@ def torch_export_patches( and hasattr(masking_utils, "_vmap_for_bhqkv") ): masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv + if verbose: print( "[torch_export_patches] restored " "transformers.masking_utils._vmap_for_bhqkv" ) + masking_utils.sdpa_mask_recent_torch = ( + f_transformers_sdpa_mask_recent_torch + ) + + if verbose: + print( + "[torch_export_patches] restored " + "transformers.masking_utils.sdpa_mask_recent_torch" + ) + + if f_transformers_sdpa_mask is not None: + masking_utils.sdpa_mask = f_transformers_sdpa_mask + if verbose: + print( + "[torch_export_patches] restored " + "transformers.masking_utils.sdpa_mask" + ) + if ( masking_utils and patch_transformers_list.patch_masking_utils @@ -581,6 +649,11 @@ def torch_export_patches( ): f_transformers_eager_mask = masking_utils.eager_mask masking_utils.eager_mask = f_transformers_eager_mask + if verbose: + print( + "[torch_export_patches] restored " + "transformers.masking_utils.eager_mask" + ) if ( "eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] @@ -589,11 +662,32 @@ def torch_export_patches( masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = ( f_transformers_eager_mask ) - if verbose: - print( - "[torch_export_patches] restored " - "transformers.masking_utils.eager_mask" + if verbose: + print( + "[torch_export_patches] restored " + "transformers.masking_utils.eager_mask " + "in ALL_MASK_ATTENTION_FUNCTIONS" + ) + + if ( + masking_utils + and patch_transformers_list.patch_masking_utils + and hasattr(masking_utils, "sdpa_mask") + ): + if ( + "sdpa" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS + and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] + == patch_transformers_list.patched_sdpa_mask_recent_torch + ): + masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"] = ( + f_transformers_sdpa_mask ) + if verbose: + print( + "[torch_export_patches] restored " + "transformers.masking_utils.sdpa_mask " + "in ALL_MASK_ATTENTION_FUNCTIONS" + ) ######## # caches diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index be088fe0..9a96a2b7 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -37,7 +37,13 @@ if patch_masking_utils: # Introduced in 4.52 - from transformers.masking_utils import causal_mask_function, sdpa_mask + from transformers.masking_utils import ( + causal_mask_function, + padding_mask_function, + and_masks, + _ignore_causal_mask_sdpa, + prepare_padding_mask, + ) def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``.""" @@ -105,7 +111,7 @@ def patched_eager_mask( """manual patch for function ``transformers.masking_utils.eager_mask``.""" # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf _ = kwargs.pop("allow_is_causal_skip", None) - mask = sdpa_mask( + mask = patched_sdpa_mask_recent_torch( batch_size=batch_size, cache_position=cache_position, kv_length=kv_length, @@ -125,6 +131,35 @@ def patched_eager_mask( mask = (~mask).to(dtype) * min_dtype return mask + def patched_sdpa_mask_recent_torch( + batch_size: int, + cache_position: torch.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Callable = causal_mask_function, + attention_mask: Optional[torch.Tensor] = None, + local_size: Optional[int] = None, + allow_is_causal_skip: bool = True, + **kwargs, + ) -> Optional[torch.Tensor]: + """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``.""" + q_length = cache_position.shape[0] + padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False) + if allow_is_causal_skip and _ignore_causal_mask_sdpa( + padding_mask, q_length, kv_length, kv_offset, local_size + ): + return None + kv_arange = torch.arange(kv_length, device=cache_position.device) + kv_arange += kv_offset + if padding_mask is not None: + mask_function = and_masks(mask_function, padding_mask_function(padding_mask)) + batch_arange = torch.arange(batch_size, device=cache_position.device) + head_arange = torch.arange(1, device=cache_position.device) + causal_mask = patched__vmap_for_bhqkv(mask_function)( + batch_arange, head_arange, cache_position, kv_arange + ) + return causal_mask + if patch_parse_processor_args: diff --git a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py index ceb4b16e..0d19b120 100644 --- a/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +++ b/onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py @@ -218,7 +218,6 @@ def unflatten_sliding_window_cache( values: List[Any], context: torch.utils._pytree.Context, output_type=None ) -> SlidingWindowCache: """Restores a :class:`transformers.cache_utils.SlidingWindowCache` from python objects.""" - key_cache, value_cache = values return make_sliding_window_cache(list(zip(values[0], values[1]))) diff --git a/onnx_diagnostic/torch_models/untrained/llm_phi2.py b/onnx_diagnostic/torch_models/untrained/llm_phi2.py index afbd4448..0c7f73f0 100644 --- a/onnx_diagnostic/torch_models/untrained/llm_phi2.py +++ b/onnx_diagnostic/torch_models/untrained/llm_phi2.py @@ -9,6 +9,7 @@ def get_phi2( sequence_length: int = 30, sequence_length2: int = 3, dynamic_rope: bool = False, + use_dim_not_dynamic: bool = False, **kwargs, ) -> Dict[str, Any]: """ @@ -18,6 +19,8 @@ def get_phi2( :param sequence_length: sequence length :param sequence_length2: new sequence length :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`) + :param use_dim_not_dynamic: uses ``torch.export.Dim`` and not a string for the batch size, + the sequence length and the cache length :param kwargs: to overwrite the configuration, example ``num_hidden_layers=1`` :return: dictionary @@ -62,9 +65,14 @@ def get_phi2( n_layers = config["num_hidden_layers"] num_key_value_heads = config["num_key_value_heads"] - batch = torch.export.Dim("batch", min=1, max=1024) - seq_length = torch.export.Dim("seq_length", min=1, max=4096) - cache_length = torch.export.Dim("cache_length", min=1, max=4096) + if use_dim_not_dynamic: + batch = torch.export.Dim("batch", min=1, max=1024) + seq_length = torch.export.Dim("seq_length", min=1, max=4096) + cache_length = torch.export.Dim("cache_length", min=1, max=4096) + else: + batch = "batch" + seq_length = "seq_length" + cache_length = "cache_length" shapes = { "input_ids": {0: batch, 1: seq_length}, diff --git a/onnx_diagnostic/torch_models/validate.py b/onnx_diagnostic/torch_models/validate.py index f64673f3..42d60c1e 100644 --- a/onnx_diagnostic/torch_models/validate.py +++ b/onnx_diagnostic/torch_models/validate.py @@ -352,7 +352,7 @@ def validate_model( The following exporters are available: * ``export-nostrict``: run :func:`torch.export.export` (..., strict=False) - * ``onnx-dynamo``: run :func:`torch.onnx.export` (..., dynamo=True), + * ``onnx-dynamo``: run :func:`torch.onnx.export` (...), models can be optimized with ``optimization`` in ``("ir", "os_ort")`` * ``modelbuilder``: use :epkg:`ModelBuilder` to builds the onnx model * ``custom``: custom exporter (see :epkg:`experimental-experiment`), diff --git a/onnx_diagnostic/torch_onnx/sbs.py b/onnx_diagnostic/torch_onnx/sbs.py index a43fd05d..1c2451fc 100644 --- a/onnx_diagnostic/torch_onnx/sbs.py +++ b/onnx_diagnostic/torch_onnx/sbs.py @@ -205,7 +205,7 @@ def post_process(obs): Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},) ) onx = torch.onnx.export( - Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},), dynamo=True + Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},) ).model_proto results = list( map(