Skip to content

Commit 4afcc1b

Browse files
committed
fix seq_length
1 parent d8eee07 commit 4afcc1b

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

_unittests/ut_torch_models/test_tiny_llms.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import unittest
22
import torch
3+
from transformers.cache_utils import DynamicCache
34
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_transformers
45
from onnx_diagnostic.torch_models.llms import get_tiny_llm
56
from onnx_diagnostic.helpers import string_type
67
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
8+
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
9+
patched_DynamicCache,
10+
)
711

812

913
class TestTinyLlm(ExtTestCase):
@@ -35,8 +39,12 @@ def test_export_tiny_llm_2_bypassed(self):
3539
)
3640

3741
with bypass_export_some_errors(
38-
patch_torch=False, patch_transformers=True, catch_constraints=False
42+
patch_torch=False, patch_transformers=True, catch_constraints=False, verbose=10
3943
) as modificator:
44+
45+
for k in patched_DynamicCache._PATCHES_:
46+
self.assertEqual(getattr(patched_DynamicCache, k), getattr(DynamicCache, k))
47+
4048
inputs = modificator(inputs)
4149

4250
def debug():

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ class patched_DynamicCache:
102102
`transformers/#36652 <https://github.com/huggingface/transformers/pull/36652>`_.
103103
"""
104104

105-
_PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits"]
105+
_PATCHES_ = ["reorder_cache", "update", "crop", "from_batch_splits", "get_seq_length"]
106106
_PATCHED_CLASS_ = transformers.cache_utils.DynamicCache
107107

108108
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:

0 commit comments

Comments
 (0)