Skip to content

Commit 3e6108d

Browse files
authored
Patches eager_mode for whisper-tiny (#178)
* Patches eager_mode for whisper-tiny * fix * fix ut * ut * fix
1 parent 2ae830c commit 3e6108d

File tree

7 files changed

+216
-74
lines changed

7 files changed

+216
-74
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ Change Logs
44
0.7.4
55
+++++
66

7-
* :pr:`174`: changes for the next version of onnx, fixes all_dynamic_shape_from_inputs
7+
* :pr:`178`: add a patch for eager_mask to handle ``assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs``
8+
* :pr:`177`: changes for the next version of onnx, fixes all_dynamic_shape_from_inputs
89

910
0.7.3
1011
+++++

_scripts/test_backend_onnxruntime.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ def run(self, inputs, **kwargs):
2626
if isinstance(inputs, numpy.ndarray):
2727
inputs = [inputs]
2828
if isinstance(inputs, list):
29-
if len(inputs) == len(self._session.input_names):
30-
feeds = dict(zip(self._session.input_names, inputs))
29+
if len(inputs) == len(self._session.get_inputs()):
30+
feeds = dict(zip([i.name for i in self._session.get_inputs()], inputs))
3131
else:
32+
input_names = [i.name for i in self._session.get_inputs()]
3233
feeds = {}
3334
pos_inputs = 0
34-
for inp, tshape in zip(self._session.input_names, self._session.input_types):
35+
for inp, tshape in zip(input_names, self._session.input_types):
3536
shape = tuple(d.dim_value for d in tshape.tensor_type.shape.dim)
3637
if shape == inputs[pos_inputs].shape:
3738
feeds[inp] = inputs[pos_inputs]
@@ -54,20 +55,20 @@ def is_compatible(cls, model) -> bool:
5455
@classmethod
5556
def supports_device(cls, device: str) -> bool:
5657
d = Device(device)
57-
if d == DeviceType.CPU:
58+
if d.type == DeviceType.CPU:
5859
return True
59-
if d == DeviceType.CUDA:
60-
import torch
61-
62-
return torch.cuda.is_available()
60+
# if d.type == DeviceType.CUDA:
61+
# import torch
62+
#
63+
# return torch.cuda.is_available()
6364
return False
6465

6566
@classmethod
6667
def create_inference_session(cls, model, device):
6768
d = Device(device)
68-
if d == DeviceType.CUDA:
69+
if d.type == DeviceType.CUDA:
6970
providers = ["CUDAExecutionProvider"]
70-
elif d == DeviceType.CPU:
71+
elif d.type == DeviceType.CPU:
7172
providers = ["CPUExecutionProvider"]
7273
else:
7374
raise ValueError(f"Unrecognized device {device!r} or {d!r}")

_unittests/ut_reference/test_backend_onnxruntime_evaluator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ def is_compatible(cls, model) -> bool:
5050
@classmethod
5151
def supports_device(cls, device: str) -> bool:
5252
d = Device(device)
53-
if d == DeviceType.CPU:
53+
if d.type == DeviceType.CPU:
5454
return True
55-
if d == DeviceType.CUDA:
55+
if d.type == DeviceType.CUDA:
5656
import torch
5757

5858
return torch.cuda.is_available()
@@ -61,9 +61,9 @@ def supports_device(cls, device: str) -> bool:
6161
@classmethod
6262
def create_inference_session(cls, model, device):
6363
d = Device(device)
64-
if d == DeviceType.CUDA:
64+
if d.type == DeviceType.CUDA:
6565
providers = ["CUDAExecutionProvider"]
66-
elif d == DeviceType.CPU:
66+
elif d.type == DeviceType.CPU:
6767
providers = ["CPUExecutionProvider"]
6868
else:
6969
raise ValueError(f"Unrecognized device {device!r} or {d!r}")
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, requires_transformers
4+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
5+
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
6+
from onnx_diagnostic.torch_export_patches import torch_export_patches
7+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
8+
9+
10+
class TestHuggingFaceHubModel(ExtTestCase):
11+
@hide_stdout()
12+
@requires_transformers("4.51")
13+
def test_patch_eager_mask_open_whisper_tiny(self):
14+
mid = "openai/whisper-tiny"
15+
data = get_untrained_model_with_inputs(mid, verbose=1)
16+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
17+
model(**torch_deepcopy(inputs))
18+
with torch_export_patches(patch_transformers=True, verbose=1):
19+
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds))
20+
21+
22+
if __name__ == "__main__":
23+
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
from typing import Callable
33
import torch
44
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
5-
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch
6-
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap
7-
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
8-
patched__vmap_for_bhqkv as _vmap_for_bhqkv2,
9-
)
5+
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_torch, requires_transformers
106

117

128
class TestPatchPatchTorch(ExtTestCase):
9+
@requires_transformers("4.52")
1310
def test_vmap(self):
11+
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap
12+
1413
f = lambda x, y: x * y + 1 # noqa: E731
1514
x = torch.tensor([1.0, 2.0, 3.0])
1615
y = torch.tensor([0.1, 0.2, 0.3])
@@ -32,7 +31,10 @@ def forward(self, x, y):
3231
self.assertEqualArray(Model()(x, y), ep.module()(x, y))
3332

3433
@requires_torch("2.8")
34+
@requires_transformers("4.52")
3535
def test_export_patched_vmap(self):
36+
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap
37+
3638
class Model(torch.nn.Module):
3739
def forward(self, x, y):
3840
f = lambda x, y: x * y + 1 # noqa: E731
@@ -43,14 +45,20 @@ def forward(self, x, y):
4345
ep = torch.export.export(Model(), (x, y))
4446
self.assertEqualArray(Model()(x, y), ep.module()(x, y))
4547

48+
@requires_transformers("4.52")
4649
def test_vmap_outdim(self):
50+
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap
51+
4752
f = lambda x: x**2 # noqa: E731
4853
x = torch.randn(2, 5)
4954
expected = torch.vmap(f, out_dims=1)(x)
5055
got = patched_vmap(f, out_dims=1)(x)
5156
self.assertEqualArray(expected, got)
5257

58+
@requires_transformers("4.52")
5359
def test_vmap_dict(self):
60+
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap
61+
5462
f = lambda d: torch.dot(d["x"], d["y"]) # noqa: E731
5563
x, y = torch.randn(2, 5), torch.randn(5)
5664
input = {"x": x, "y": y}
@@ -60,13 +68,19 @@ def test_vmap_dict(self):
6068
)
6169
# self.assertEqualArray(_expected, got)
6270

71+
@requires_transformers("4.52")
6372
def test_vmap_tuple(self):
73+
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap
74+
6475
x, y = torch.randn(2, 5), torch.randn(5)
6576
expected = torch.vmap(torch.dot, in_dims=(0, None))(x, y)
6677
got = patched_vmap(torch.dot, in_dims=(0, None))(x, y)
6778
self.assertEqualArray(expected, got, atol=1e-5)
6879

80+
@requires_transformers("4.52")
6981
def test_vmap_transformers_scenario_vmap(self):
82+
from onnx_diagnostic.torch_export_patches.patches.patch_torch import patched_vmap
83+
7084
def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
7185
def inner_mask(batch_idx, head_idx, q_idx, kv_idx):
7286
return padding_mask[batch_idx, kv_idx]
@@ -140,7 +154,12 @@ def forward(self, batch_arange, head_arange, cache_position, kv_arange):
140154
self.assertEqualArray(causal_mask, ep.moule(*inputs))
141155

142156
@requires_torch("2.8")
157+
@requires_transformers("4.53")
143158
def test_vmap_transformers_scenario_novmap(self):
159+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
160+
patched__vmap_for_bhqkv as _vmap_for_bhqkv2,
161+
)
162+
144163
def padding_mask_function(padding_mask: torch.Tensor) -> Callable:
145164
def inner_mask(batch_idx, head_idx, q_idx, kv_idx):
146165
return padding_mask[batch_idx, kv_idx]

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -420,7 +420,11 @@ def torch_export_patches(
420420
patch_transformers_list, verbose=verbose
421421
)
422422

423-
if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"):
423+
if (
424+
masking_utils
425+
and patch_transformers_list.patch_masking_utils
426+
and hasattr(masking_utils, "_vmap_for_bhqkv")
427+
):
424428
if verbose:
425429
print(
426430
"[torch_export_patches] patches "
@@ -429,6 +433,27 @@ def torch_export_patches(
429433
f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
430434
masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
431435

436+
if (
437+
masking_utils
438+
and patch_transformers_list.patch_masking_utils
439+
and hasattr(masking_utils, "eager_mask")
440+
):
441+
if verbose:
442+
print(
443+
"[torch_export_patches] patches "
444+
"transformers.masking_utils.eager_mask"
445+
)
446+
f_transformers_eager_mask = masking_utils.eager_mask
447+
masking_utils.eager_mask = patch_transformers_list.patched_eager_mask
448+
if (
449+
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
450+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
451+
== f_transformers_eager_mask
452+
):
453+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
454+
patch_transformers_list.patched_eager_mask
455+
)
456+
432457
if custom_patches:
433458
if verbose:
434459
print("[torch_export_patches] applies custom patches")
@@ -511,7 +536,7 @@ def torch_export_patches(
511536

512537
if custom_patches:
513538
if verbose:
514-
print("[torch_export_patches] unpatch custom patches")
539+
print("[torch_export_patches] unpatches custom patches")
515540
unpatch_module_or_classes(
516541
custom_patches, revert_custom_patches_info, verbose=verbose
517542
)
@@ -526,18 +551,43 @@ def torch_export_patches(
526551
except ImportError:
527552
masking_utils = None
528553
if verbose:
529-
print("[torch_export_patches] unpatch transformers")
554+
print("[torch_export_patches] unpatches transformers")
530555
unpatch_module_or_classes(
531556
patch_transformers_list, revert_patches_info, verbose=verbose
532557
)
533558

534-
if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"):
559+
if (
560+
masking_utils
561+
and patch_transformers_list.patch_masking_utils
562+
and hasattr(masking_utils, "_vmap_for_bhqkv")
563+
):
564+
masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
535565
if verbose:
536566
print(
537-
"[torch_export_patches] unpatch "
567+
"[torch_export_patches] restored "
538568
"transformers.masking_utils._vmap_for_bhqkv"
539569
)
540-
masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
570+
571+
if (
572+
masking_utils
573+
and patch_transformers_list.patch_masking_utils
574+
and hasattr(masking_utils, "eager_mask")
575+
):
576+
f_transformers_eager_mask = masking_utils.eager_mask
577+
masking_utils.eager_mask = f_transformers_eager_mask
578+
if (
579+
"eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
580+
and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
581+
== patch_transformers_list.patched_eager_mask
582+
):
583+
masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
584+
f_transformers_eager_mask
585+
)
586+
if verbose:
587+
print(
588+
"[torch_export_patches] restored "
589+
"transformers.masking_utils.eager_mask"
590+
)
541591

542592
########
543593
# caches

0 commit comments

Comments
 (0)