diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index c3778de4..3ead0e8a 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,7 +4,8 @@ Change Logs 0.7.4 +++++ -* :pr:`174`: changes for the next version of onnx, fixes all_dynamic_shape_from_inputs +* :pr:`178`: add a patch for eager_mask to handle ``assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs`` +* :pr:`177`: changes for the next version of onnx, fixes all_dynamic_shape_from_inputs 0.7.3 +++++ diff --git a/_scripts/test_backend_onnxruntime.py b/_scripts/test_backend_onnxruntime.py index 48bb1777..9ae959f2 100644 --- a/_scripts/test_backend_onnxruntime.py +++ b/_scripts/test_backend_onnxruntime.py @@ -26,12 +26,13 @@ def run(self, inputs, **kwargs): if isinstance(inputs, numpy.ndarray): inputs = [inputs] if isinstance(inputs, list): - if len(inputs) == len(self._session.input_names): - feeds = dict(zip(self._session.input_names, inputs)) + if len(inputs) == len(self._session.get_inputs()): + feeds = dict(zip([i.name for i in self._session.get_inputs()], inputs)) else: + input_names = [i.name for i in self._session.get_inputs()] feeds = {} pos_inputs = 0 - for inp, tshape in zip(self._session.input_names, self._session.input_types): + for inp, tshape in zip(input_names, self._session.input_types): shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim) if shape == inputs[pos_inputs].shape: feeds[inp] = inputs[pos_inputs] @@ -54,20 +55,20 @@ def is_compatible(cls, model) -> bool: @classmethod def supports_device(cls, device: str) -> bool: d = Device(device) - if d == DeviceType.CPU: + if d.type == DeviceType.CPU: return True - if d == DeviceType.CUDA: - import torch - - return torch.cuda.is_available() + # if d.type == DeviceType.CUDA: + # import torch + # + # return torch.cuda.is_available() return False @classmethod def create_inference_session(cls, model, device): d = Device(device) - if d == DeviceType.CUDA: + if d.type == DeviceType.CUDA: providers = ["CUDAExecutionProvider"] - elif d == DeviceType.CPU: + elif d.type == DeviceType.CPU: providers = ["CPUExecutionProvider"] else: raise ValueError(f"Unrecognized device {device!r} or {d!r}") diff --git a/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py b/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py index 9f596c4f..521aab6f 100644 --- a/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py +++ b/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py @@ -50,9 +50,9 @@ def is_compatible(cls, model) -> bool: @classmethod def supports_device(cls, device: str) -> bool: d = Device(device) - if d == DeviceType.CPU: + if d.type == DeviceType.CPU: return True - if d == DeviceType.CUDA: + if d.type == DeviceType.CUDA: import torch return torch.cuda.is_available() @@ -61,9 +61,9 @@ def supports_device(cls, device: str) -> bool: @classmethod def create_inference_session(cls, model, device): d = Device(device) - if d == DeviceType.CUDA: + if d.type == DeviceType.CUDA: providers = ["CUDAExecutionProvider"] - elif d == DeviceType.CPU: + elif d.type == DeviceType.CPU: providers = ["CPUExecutionProvider"] else: raise ValueError(f"Unrecognized device {device!r} or {d!r}") diff --git a/_unittests/ut_torch_export_patches/test_patch_models.py b/_unittests/ut_torch_export_patches/test_patch_models.py new file mode 100644 index 00000000..8b09e249 --- /dev/null +++ b/_unittests/ut_torch_export_patches/test_patch_models.py @@ -0,0 +1,23 @@ +import unittest +import torch +from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_transformers +from onnx_diagnostic.helpers.torch_helper import torch_deepcopy +from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs +from onnx_diagnostic.torch_export_patches import torch_export_patches +from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str + + +class TestHuggingFaceHubModel(ExtTestCase): + @hide_stdout() + @requires_transformers("4.51") + def test_patch_eager_mask_open_whisper_tiny(self): + mid = "openai/whisper-tiny" + data = get_untrained_model_with_inputs(mid, verbose=1) + model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"] + model(**torch_deepcopy(inputs)) + with torch_export_patches(patch_transformers=True, verbose=1): + torch.export.export(model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_torch_export_patches/test_patch_torch.py b/_unittests/ut_torch_export_patches/test_patch_torch.py index 99c2465a..7ed791d1 100644 --- a/_unittests/ut_torch_export_patches/test_patch_torch.py +++ b/_unittests/ut_torch_export_patches/test_patch_torch.py @@ -2,15 +2,14 @@ from typing import Callable import torch from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex -from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch -from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap -from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( - patched__vmap_for_bhqkv as _vmap_for_bhqkv2, -) +from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, requires_transformers class TestPatchPatchTorch(ExtTestCase): + @requires_transformers("4.52") def test_vmap(self): + from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap + f = lambda x, y: x * y + 1 # noqa: E731 x = torch.tensor([1.0, 2.0, 3.0]) y = torch.tensor([0.1, 0.2, 0.3]) @@ -32,7 +31,10 @@ def forward(self, x, y): self.assertEqualArray(Model()(x, y), ep.module()(x, y)) @requires_torch("2.8") + @requires_transformers("4.52") def test_export_patched_vmap(self): + from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap + class Model(torch.nn.Module): def forward(self, x, y): f = lambda x, y: x * y + 1 # noqa: E731 @@ -43,14 +45,20 @@ def forward(self, x, y): ep = torch.export.export(Model(), (x, y)) self.assertEqualArray(Model()(x, y), ep.module()(x, y)) + @requires_transformers("4.52") def test_vmap_outdim(self): + from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap + f = lambda x: x**2 # noqa: E731 x = torch.randn(2, 5) expected = torch.vmap(f, out_dims=1)(x) got = patched_vmap(f, out_dims=1)(x) self.assertEqualArray(expected, got) + @requires_transformers("4.52") def test_vmap_dict(self): + from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap + f = lambda d: torch.dot(d["x"], d["y"]) # noqa: E731 x, y = torch.randn(2, 5), torch.randn(5) input = {"x": x, "y": y} @@ -60,13 +68,19 @@ def test_vmap_dict(self): ) # self.assertEqualArray(_expected, got) + @requires_transformers("4.52") def test_vmap_tuple(self): + from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap + x, y = torch.randn(2, 5), torch.randn(5) expected = torch.vmap(torch.dot, in_dims=(0, None))(x, y) got = patched_vmap(torch.dot, in_dims=(0, None))(x, y) self.assertEqualArray(expected, got, atol=1e-5) + @requires_transformers("4.52") def test_vmap_transformers_scenario_vmap(self): + from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap + def padding_mask_function(padding_mask: torch.Tensor) -> Callable: def inner_mask(batch_idx, head_idx, q_idx, kv_idx): return padding_mask[batch_idx, kv_idx] @@ -140,7 +154,12 @@ def forward(self, batch_arange, head_arange, cache_position, kv_arange): self.assertEqualArray(causal_mask, ep.moule(*inputs)) @requires_torch("2.8") + @requires_transformers("4.53") def test_vmap_transformers_scenario_novmap(self): + from onnx_diagnostic.torch_export_patches.patches.patch_transformers import ( + patched__vmap_for_bhqkv as _vmap_for_bhqkv2, + ) + def padding_mask_function(padding_mask: torch.Tensor) -> Callable: def inner_mask(batch_idx, head_idx, q_idx, kv_idx): return padding_mask[batch_idx, kv_idx] diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 058041e5..ed8ece0b 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -420,7 +420,11 @@ def torch_export_patches( patch_transformers_list, verbose=verbose ) - if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"): + if ( + masking_utils + and patch_transformers_list.patch_masking_utils + and hasattr(masking_utils, "_vmap_for_bhqkv") + ): if verbose: print( "[torch_export_patches] patches " @@ -429,6 +433,27 @@ def torch_export_patches( f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv + if ( + masking_utils + and patch_transformers_list.patch_masking_utils + and hasattr(masking_utils, "eager_mask") + ): + if verbose: + print( + "[torch_export_patches] patches " + "transformers.masking_utils.eager_mask" + ) + f_transformers_eager_mask = masking_utils.eager_mask + masking_utils.eager_mask = patch_transformers_list.patched_eager_mask + if ( + "eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS + and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] + == f_transformers_eager_mask + ): + masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = ( + patch_transformers_list.patched_eager_mask + ) + if custom_patches: if verbose: print("[torch_export_patches] applies custom patches") @@ -511,7 +536,7 @@ def torch_export_patches( if custom_patches: if verbose: - print("[torch_export_patches] unpatch custom patches") + print("[torch_export_patches] unpatches custom patches") unpatch_module_or_classes( custom_patches, revert_custom_patches_info, verbose=verbose ) @@ -526,18 +551,43 @@ def torch_export_patches( except ImportError: masking_utils = None if verbose: - print("[torch_export_patches] unpatch transformers") + print("[torch_export_patches] unpatches transformers") unpatch_module_or_classes( patch_transformers_list, revert_patches_info, verbose=verbose ) - if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"): + if ( + masking_utils + and patch_transformers_list.patch_masking_utils + and hasattr(masking_utils, "_vmap_for_bhqkv") + ): + masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv if verbose: print( - "[torch_export_patches] unpatch " + "[torch_export_patches] restored " "transformers.masking_utils._vmap_for_bhqkv" ) - masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv + + if ( + masking_utils + and patch_transformers_list.patch_masking_utils + and hasattr(masking_utils, "eager_mask") + ): + f_transformers_eager_mask = masking_utils.eager_mask + masking_utils.eager_mask = f_transformers_eager_mask + if ( + "eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS + and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] + == patch_transformers_list.patched_eager_mask + ): + masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = ( + f_transformers_eager_mask + ) + if verbose: + print( + "[torch_export_patches] restored " + "transformers.masking_utils.eager_mask" + ) ######## # caches diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index d06939ac..3ebc9e8f 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -7,59 +7,107 @@ import transformers from transformers.modeling_attn_mask_utils import AttentionMaskConverter from transformers.cache_utils import StaticCache, Cache, DynamicCache + +try: + import transformers.masking_utils + + patch_masking_utils = True +except ImportError: + patch_masking_utils = False + from ...ext_test_case import has_transformers from ...helpers.torch_helper import is_torchdynamo_exporting -def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: - """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``.""" - from ...helpers import string_type - - dimensions: List[Tuple[Optional[int], ...]] = [ - (None, None, None, 0), - (None, None, 0, None), - ] - if bh_indices: - dimensions.extend([(None, 0, None, None), (0, None, None, None)]) - # reshape - dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions] - dimensions = tuple(reversed(dimensions)) - indices = tuple(shape.index(-1) for shape in dimensions) - - # unsqueeze - udimensions = [tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions] - - def vector_mask_function( - *args, mask_function=mask_function, dimensions=dimensions, indices=indices - ): - assert len(args) == len(dimensions) == len(udimensions), ( - f"Mismatch between args={string_type(args)} and dimensions={dimensions} " - f"and udimensions={udimensions}." - ) - assert len(indices) == len(args), ( - f"Mismatch between args={string_type(args)} and indices={indices}, " - f"they should have the same length." +if patch_masking_utils: + # Introduced in 4.52 + from transformers.masking_utils import causal_mask_function, sdpa_mask + + def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) -> Callable: + """manual patch for function ``transformers.masking_utils._vmap_for_bhqkv``.""" + from ...helpers import string_type + + dimensions: List[Tuple[Optional[int], ...]] = [ + (None, None, None, 0), + (None, None, 0, None), + ] + if bh_indices: + dimensions.extend([(None, 0, None, None), (0, None, None, None)]) + # reshape + dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions] + dimensions = tuple(reversed(dimensions)) + indices = tuple(shape.index(-1) for shape in dimensions) + + # unsqueeze + udimensions = [ + tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions + ] + + def vector_mask_function( + *args, mask_function=mask_function, dimensions=dimensions, indices=indices + ): + assert len(args) == len(dimensions) == len(udimensions), ( + f"Mismatch between args={string_type(args)} and dimensions={dimensions} " + f"and udimensions={udimensions}." + ) + assert len(indices) == len(args), ( + f"Mismatch between args={string_type(args)} and indices={indices}, " + f"they should have the same length." + ) + for a in args: + assert ( + a.ndim == 1 + ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}" + torch._check(a.shape[0] > 0) + + new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)] + # new_args = [ + # a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2]) + # for a, dims in zip(args, udimensions) + # ] + max_shape = tuple(args[i].shape[0] for i in indices) + # if is_torchdynamo_exporting(): + # for a in args: + # # The exporter should export with a dimension > 1 + # # to make sure it is dynamic. + # torch._check(a.shape[0] > 1) + expanded_args = [a.expand(max_shape) for a in new_args] + return mask_function(*expanded_args) + + return vector_mask_function + + def patched_eager_mask( + batch_size: int, + cache_position: torch.Tensor, + kv_length: int, + kv_offset: int = 0, + mask_function: Callable = causal_mask_function, + attention_mask: Optional[torch.Tensor] = None, + dtype: torch.dtype = torch.float32, + **kwargs, + ) -> torch.Tensor: + """manual patch for function ``transformers.masking_utils.eager_mask``.""" + # The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf + _ = kwargs.pop("allow_is_causal_skip", None) + mask = sdpa_mask( + batch_size=batch_size, + cache_position=cache_position, + kv_length=kv_length, + kv_offset=kv_offset, + mask_function=mask_function, + attention_mask=attention_mask, + allow_is_causal_skip=False, + allow_torch_fix=False, + **kwargs, ) - for a in args: - assert ( - a.ndim == 1 - ), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}" - torch._check(a.shape[0] > 0) - - new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)] - # new_args = [ - # a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2]) - # for a, dims in zip(args, udimensions) - # ] - max_shape = tuple(args[i].shape[0] for i in indices) - # if is_torchdynamo_exporting(): - # for a in args: - # # The exporter should export with a dimension > 1 to make sure it is dynamic. - # torch._check(a.shape[0] > 1) - expanded_args = [a.expand(max_shape) for a in new_args] - return mask_function(*expanded_args) - - return vector_mask_function + min_dtype = torch.finfo(dtype).min + # The patched line. + # we need 0s where the tokens should be taken into account, + # and -inf otherwise (mask is already of boolean type) + # mask = + # torch.where(mask, torch.tensor(0.0, device=mask.device, dtype=dtype), min_dtype) + mask = (~mask).to(dtype) * min_dtype + return mask def _patch_make_causal_mask(