Skip to content

Commit 8cadd99

Browse files
committed
fix patches
1 parent a26916d commit 8cadd99

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

_unittests/ut_torch_export_patches/test_dynamic_class.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import unittest
44
from typing import Any, Dict, List, Tuple
55
import torch
6+
import transformers
67
from onnx_diagnostic.ext_test_case import (
78
ExtTestCase,
89
ignore_warnings,
@@ -16,6 +17,7 @@
1617
)
1718
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1819
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
20+
import onnx_diagnostic.torch_export_patches.patches.patch_transformers as patch_transformers
1921

2022

2123
class TestOnnxExportErrors(ExtTestCase):
@@ -339,7 +341,11 @@ def test_phi2_export_interpreter(self):
339341
str_inputs, string_type(inputs_copied, with_shape=True, with_min_max=True)
340342
)
341343

342-
with torch_export_patches(patch_transformers=True):
344+
with torch_export_patches(patch_transformers=True, verbose=1):
345+
self.assertEqual(
346+
transformers.masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["sdpa"],
347+
patch_transformers.patched_sdpa_mask_recent_torch,
348+
)
343349
ep = torch.export.export(
344350
model,
345351
(),

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
# Introduced in 4.52
4040
from transformers.masking_utils import (
4141
causal_mask_function,
42-
sdpa_mask,
4342
padding_mask_function,
4443
and_masks,
4544
_ignore_causal_mask_sdpa,
@@ -112,7 +111,7 @@ def patched_eager_mask(
112111
"""manual patch for function ``transformers.masking_utils.eager_mask``."""
113112
# The masks for eager attention are simply boolean mask from sdpa, casted to 0 and -inf
114113
_ = kwargs.pop("allow_is_causal_skip", None)
115-
mask = sdpa_mask(
114+
mask = patched_sdpa_mask_recent_torch(
116115
batch_size=batch_size,
117116
cache_position=cache_position,
118117
kv_length=kv_length,

0 commit comments

Comments
 (0)