Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ Change Logs
0.7.4
+++++

* :pr:`174`: changes for the next version of onnx, fixes all_dynamic_shape_from_inputs
* :pr:`178`: add a patch for eager_mask to handle ``assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs``
* :pr:`177`: changes for the next version of onnx, fixes all_dynamic_shape_from_inputs

0.7.3
+++++
Expand Down
21 changes: 11 additions & 10 deletions _scripts/test_backend_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ def run(self, inputs, **kwargs):
if isinstance(inputs, numpy.ndarray):
inputs = [inputs]
if isinstance(inputs, list):
if len(inputs) == len(self._session.input_names):
feeds = dict(zip(self._session.input_names, inputs))
if len(inputs) == len(self._session.get_inputs()):
feeds = dict(zip([i.name for i in self._session.get_inputs()], inputs))
else:
input_names = [i.name for i in self._session.get_inputs()]
feeds = {}
pos_inputs = 0
for inp, tshape in zip(self._session.input_names, self._session.input_types):
for inp, tshape in zip(input_names, self._session.input_types):
shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim)
if shape == inputs[pos_inputs].shape:
feeds[inp] = inputs[pos_inputs]
Expand All @@ -54,20 +55,20 @@ def is_compatible(cls, model) -> bool:
@classmethod
def supports_device(cls, device: str) -> bool:
d = Device(device)
if d == DeviceType.CPU:
if d.type == DeviceType.CPU:
return True
if d == DeviceType.CUDA:
import torch

return torch.cuda.is_available()
# if d.type == DeviceType.CUDA:
# import torch
#
# return torch.cuda.is_available()
return False

@classmethod
def create_inference_session(cls, model, device):
d = Device(device)
if d == DeviceType.CUDA:
if d.type == DeviceType.CUDA:
providers = ["CUDAExecutionProvider"]
elif d == DeviceType.CPU:
elif d.type == DeviceType.CPU:
providers = ["CPUExecutionProvider"]
else:
raise ValueError(f"Unrecognized device {device!r} or {d!r}")
Expand Down
8 changes: 4 additions & 4 deletions _unittests/ut_reference/test_backend_onnxruntime_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def is_compatible(cls, model) -> bool:
@classmethod
def supports_device(cls, device: str) -> bool:
d = Device(device)
if d == DeviceType.CPU:
if d.type == DeviceType.CPU:
return True
if d == DeviceType.CUDA:
if d.type == DeviceType.CUDA:
import torch

return torch.cuda.is_available()
Expand All @@ -61,9 +61,9 @@ def supports_device(cls, device: str) -> bool:
@classmethod
def create_inference_session(cls, model, device):
d = Device(device)
if d == DeviceType.CUDA:
if d.type == DeviceType.CUDA:
providers = ["CUDAExecutionProvider"]
elif d == DeviceType.CPU:
elif d.type == DeviceType.CPU:
providers = ["CPUExecutionProvider"]
else:
raise ValueError(f"Unrecognized device {device!r} or {d!r}")
Expand Down
23 changes: 23 additions & 0 deletions _unittests/ut_torch_export_patches/test_patch_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import unittest
import torch
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_transformers
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
from onnx_diagnostic.torch_export_patches import torch_export_patches
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str


class TestHuggingFaceHubModel(ExtTestCase):
@hide_stdout()
@requires_transformers("4.51")
def test_patch_eager_mask_open_whisper_tiny(self):
mid = "openai/whisper-tiny"
data = get_untrained_model_with_inputs(mid, verbose=1)
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
model(**torch_deepcopy(inputs))
with torch_export_patches(patch_transformers=True, verbose=1):
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds))


if __name__ == "__main__":
unittest.main(verbosity=2)
29 changes: 24 additions & 5 deletions _unittests/ut_torch_export_patches/test_patch_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
from typing import Callable
import torch
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
patched__vmap_for_bhqkv as _vmap_for_bhqkv2,
)
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, requires_transformers


class TestPatchPatchTorch(ExtTestCase):
@requires_transformers("4.52")
def test_vmap(self):
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap

f = lambda x, y: x * y + 1 # noqa: E731
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([0.1, 0.2, 0.3])
Expand All @@ -32,7 +31,10 @@ def forward(self, x, y):
self.assertEqualArray(Model()(x, y), ep.module()(x, y))

@requires_torch("2.8")
@requires_transformers("4.52")
def test_export_patched_vmap(self):
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap

class Model(torch.nn.Module):
def forward(self, x, y):
f = lambda x, y: x * y + 1 # noqa: E731
Expand All @@ -43,14 +45,20 @@ def forward(self, x, y):
ep = torch.export.export(Model(), (x, y))
self.assertEqualArray(Model()(x, y), ep.module()(x, y))

@requires_transformers("4.52")
def test_vmap_outdim(self):
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap

f = lambda x: x**2 # noqa: E731
x = torch.randn(2, 5)
expected = torch.vmap(f, out_dims=1)(x)
got = patched_vmap(f, out_dims=1)(x)
self.assertEqualArray(expected, got)

@requires_transformers("4.52")
def test_vmap_dict(self):
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap

f = lambda d: torch.dot(d["x"], d["y"]) # noqa: E731
x, y = torch.randn(2, 5), torch.randn(5)
input = {"x": x, "y": y}
Expand All @@ -60,13 +68,19 @@ def test_vmap_dict(self):
)
# self.assertEqualArray(_expected, got)

@requires_transformers("4.52")
def test_vmap_tuple(self):
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap

x, y = torch.randn(2, 5), torch.randn(5)
expected = torch.vmap(torch.dot, in_dims=(0, None))(x, y)
got = patched_vmap(torch.dot, in_dims=(0, None))(x, y)
self.assertEqualArray(expected, got, atol=1e-5)

@requires_transformers("4.52")
def test_vmap_transformers_scenario_vmap(self):
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap

def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
def inner_mask(batch_idx, head_idx, q_idx, kv_idx):
return padding_mask[batch_idx, kv_idx]
Expand Down Expand Up @@ -140,7 +154,12 @@ def forward(self, batch_arange, head_arange, cache_position, kv_arange):
self.assertEqualArray(causal_mask, ep.moule(*inputs))

@requires_torch("2.8")
@requires_transformers("4.53")
def test_vmap_transformers_scenario_novmap(self):
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
patched__vmap_for_bhqkv as _vmap_for_bhqkv2,
)

def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
def inner_mask(batch_idx, head_idx, q_idx, kv_idx):
return padding_mask[batch_idx, kv_idx]
Expand Down
62 changes: 56 additions & 6 deletions onnx_diagnostic/torch_export_patches/onnx_export_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,11 @@ def torch_export_patches(
patch_transformers_list, verbose=verbose
)

if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"):
if (
masking_utils
and patch_transformers_list.patch_masking_utils
and hasattr(masking_utils, "_vmap_for_bhqkv")
):
if verbose:
print(
"[torch_export_patches] patches "
Expand All @@ -429,6 +433,27 @@ def torch_export_patches(
f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv

if (
masking_utils
and patch_transformers_list.patch_masking_utils
and hasattr(masking_utils, "eager_mask")
):
if verbose:
print(
"[torch_export_patches] patches "
"transformers.masking_utils.eager_mask"
)
f_transformers_eager_mask = masking_utils.eager_mask
masking_utils.eager_mask = patch_transformers_list.patched_eager_mask
if (
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
== f_transformers_eager_mask
):
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
patch_transformers_list.patched_eager_mask
)

if custom_patches:
if verbose:
print("[torch_export_patches] applies custom patches")
Expand Down Expand Up @@ -511,7 +536,7 @@ def torch_export_patches(

if custom_patches:
if verbose:
print("[torch_export_patches] unpatch custom patches")
print("[torch_export_patches] unpatches custom patches")
unpatch_module_or_classes(
custom_patches, revert_custom_patches_info, verbose=verbose
)
Expand All @@ -526,18 +551,43 @@ def torch_export_patches(
except ImportError:
masking_utils = None
if verbose:
print("[torch_export_patches] unpatch transformers")
print("[torch_export_patches] unpatches transformers")
unpatch_module_or_classes(
patch_transformers_list, revert_patches_info, verbose=verbose
)

if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"):
if (
masking_utils
and patch_transformers_list.patch_masking_utils
and hasattr(masking_utils, "_vmap_for_bhqkv")
):
masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
if verbose:
print(
"[torch_export_patches] unpatch "
"[torch_export_patches] restored "
"transformers.masking_utils._vmap_for_bhqkv"
)
masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv

if (
masking_utils
and patch_transformers_list.patch_masking_utils
and hasattr(masking_utils, "eager_mask")
):
f_transformers_eager_mask = masking_utils.eager_mask
masking_utils.eager_mask = f_transformers_eager_mask
if (
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
== patch_transformers_list.patched_eager_mask
):
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
f_transformers_eager_mask
)
if verbose:
print(
"[torch_export_patches] restored "
"transformers.masking_utils.eager_mask"
)

########
# caches
Expand Down
Loading
Loading