Skip to content

Commit d5e5fbc

Browse files
authored
remove _seen_tokens from the patched code (#185)
* remove _seen_tokens from the patched code * update feature extraction * fix ut * fix issues * add one more patch * bart * fix missing attribute * check * change supported version * tris * 53 * patches * 53 * fix patch
1 parent ebecb67 commit d5e5fbc

File tree

12 files changed

+254
-22
lines changed

12 files changed

+254
-22
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.5
55
+++++
66

7+
* :pr:`185`: remove the use of _seen_tokens in DynamicCache (removed in transformers>4.53),
8+
updates dummpy inputs for feature-extraction
79
* :pr:`184`: implements side-by-side
810

911
0.7.4

_unittests/ut_tasks/test_tasks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
has_transformers,
77
requires_transformers,
88
)
9-
from onnx_diagnostic.helpers.torch_helper import to_any
9+
from onnx_diagnostic.helpers.torch_helper import to_any, torch_deepcopy
1010
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
1111
from onnx_diagnostic.torch_export_patches import torch_export_patches
1212
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
@@ -207,13 +207,14 @@ def test_fill_mask(self):
207207
)
208208

209209
@hide_stdout()
210+
@requires_transformers("4.53.99")
210211
def test_feature_extraction_bart_base(self):
211212
mid = "facebook/bart-base"
212213
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
213214
self.assertEqual(data["task"], "feature-extraction")
214215
self.assertIn((data["size"], data["n_weights"]), [(557681664, 139420416)])
215216
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
216-
model(**inputs)
217+
model(**torch_deepcopy(inputs))
217218
model(**data["inputs2"])
218219
with torch_export_patches(patch_transformers=True, verbose=10):
219220
torch.export.export(

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,23 @@
66
requires_transformers,
77
requires_torch,
88
)
9+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
910
from onnx_diagnostic.torch_models.hghub.model_inputs import get_untrained_model_with_inputs
1011
from onnx_diagnostic.torch_export_patches import torch_export_patches
1112
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1213

1314

1415
class TestTasksImageTextToText(ExtTestCase):
1516
@hide_stdout()
16-
@requires_transformers("4.52")
17+
@requires_transformers("4.53")
1718
@requires_torch("2.7.99")
1819
def test_image_text_to_text(self):
1920
mid = "HuggingFaceM4/tiny-random-idefics"
2021
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
2122
self.assertEqual(data["task"], "image-text-to-text")
2223
self.assertIn((data["size"], data["n_weights"]), [(12742888, 3185722)])
2324
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
24-
model(**inputs)
25+
model(**torch_deepcopy(inputs))
2526
model(**data["inputs2"])
2627
with torch_export_patches(patch_transformers=True, verbose=10):
2728
torch.export.export(

_unittests/ut_tasks/try_tasks.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import unittest
2+
import torch
23
from onnx_diagnostic.ext_test_case import ExtTestCase, never_test
34
from onnx_diagnostic.helpers import string_type
5+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
46
from onnx_diagnostic.helpers.torch_helper import steal_forward
57

68

@@ -378,6 +380,51 @@ def test_feature_extraction(self):
378380
model = BartModel.from_pretrained("facebook/bart-base")
379381
text = "Replace me by any text you'd like."
380382
encoded_input = tokenizer(text, return_tensors="pt")
383+
sequence_length, sequence_length2 = 30, 4
384+
sequence_length = 3
385+
batch_size, encoder_attention_heads, encoder_ffn_dim = 1, 12, 64
386+
batch_size, decoder_attention_heads, decoder_ffn_dim = 1, 12, 64
387+
num_hidden_layers = 6
388+
encoded_input["past_key_values"] = make_encoder_decoder_cache(
389+
make_dynamic_cache(
390+
[
391+
(
392+
torch.randn(
393+
batch_size,
394+
encoder_attention_heads,
395+
sequence_length,
396+
encoder_ffn_dim,
397+
),
398+
torch.randn(
399+
batch_size,
400+
encoder_attention_heads,
401+
sequence_length,
402+
encoder_ffn_dim,
403+
),
404+
)
405+
for i in range(num_hidden_layers)
406+
]
407+
),
408+
make_dynamic_cache(
409+
[
410+
(
411+
torch.randn(
412+
batch_size,
413+
decoder_attention_heads,
414+
sequence_length2,
415+
decoder_ffn_dim,
416+
),
417+
torch.randn(
418+
batch_size,
419+
decoder_attention_heads,
420+
sequence_length2,
421+
decoder_ffn_dim,
422+
),
423+
)
424+
for i in range(num_hidden_layers)
425+
]
426+
),
427+
)
381428
print()
382429
print("-- inputs", string_type(encoded_input, with_shape=True, with_min_max=True))
383430
output = model(**encoded_input)

_unittests/ut_torch_models/test_validate_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
class TestValidateModel(ExtTestCase):
15-
@requires_transformers("4.52")
15+
@requires_transformers("4.53")
1616
@requires_torch("2.7.99")
1717
@requires_experimental()
1818
@hide_stdout()
@@ -33,7 +33,7 @@ def test_validate_microsoft_phi4_reasoning(self):
3333
self.assertLess(summary["disc_onnx_ort_run_abs"], 1e-5)
3434
self.assertIn("onnx_filename", data)
3535

36-
@requires_transformers("4.52")
36+
@requires_transformers("4.53")
3737
@requires_torch("2.7.99")
3838
@requires_experimental()
3939
@hide_stdout()

_unittests/ut_torch_models/test_validate_whole_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def test_validate_model_vit_model(self):
258258
@requires_torch("2.7")
259259
@hide_stdout()
260260
@ignore_warnings(FutureWarning)
261-
@requires_transformers("4.51")
261+
@requires_transformers("4.53")
262262
def test_validate_phi35_mini_instruct(self):
263263
mid = "microsoft/Phi-3.5-mini-instruct"
264264
summary, data = validate_model(
@@ -281,7 +281,7 @@ def test_validate_phi35_mini_instruct(self):
281281
@requires_torch("2.7")
282282
@hide_stdout()
283283
@ignore_warnings(FutureWarning)
284-
@requires_transformers("4.51")
284+
@requires_transformers("4.53")
285285
def test_validate_phi35_4k_mini_instruct(self):
286286
mid = "microsoft/Phi-3-mini-4k-instruct"
287287
summary, data = validate_model(

onnx_diagnostic/tasks/feature_extraction.py

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
from typing import Any, Callable, Dict, Optional, Tuple
22
import torch
33
from ..helpers.config_helper import update_config, check_hasattr
4+
from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
45

56
__TASK__ = "feature-extraction"
67

78

89
def reduce_model_config(config: Any) -> Dict[str, Any]:
910
"""Reduces a model size."""
10-
check_hasattr(config, "num_attention_heads", "num_hidden_layers")
11-
kwargs = dict(
12-
num_hidden_layers=min(config.num_hidden_layers, 2),
13-
num_attention_heads=min(config.num_attention_heads, 4),
14-
)
11+
check_hasattr(config, "num_hidden_layers")
12+
kwargs = dict(num_hidden_layers=min(config.num_hidden_layers, 2))
1513
update_config(config, kwargs)
1614
return kwargs
1715

@@ -22,6 +20,12 @@ def get_inputs(
2220
batch_size: int,
2321
sequence_length: int,
2422
dummy_max_token_id: int,
23+
sequence_length2: int = 3,
24+
decoder_attention_heads: Optional[int] = None,
25+
encoder_attention_heads: Optional[int] = None,
26+
encoder_ffn_dim: Optional[int] = None,
27+
decoder_ffn_dim: Optional[int] = None,
28+
num_hidden_layers: Optional[int] = None,
2529
add_second_input: int = 1,
2630
**kwargs, # unused
2731
):
@@ -50,6 +54,66 @@ def get_inputs(
5054
),
5155
attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
5256
)
57+
if (
58+
encoder_attention_heads
59+
and decoder_attention_heads
60+
and encoder_ffn_dim
61+
and decoder_ffn_dim
62+
and num_hidden_layers
63+
):
64+
inputs["past_key_values"] = make_encoder_decoder_cache(
65+
make_dynamic_cache(
66+
[
67+
(
68+
torch.randn(
69+
batch_size,
70+
encoder_attention_heads,
71+
sequence_length,
72+
encoder_ffn_dim,
73+
),
74+
torch.randn(
75+
batch_size,
76+
encoder_attention_heads,
77+
sequence_length,
78+
encoder_ffn_dim,
79+
),
80+
)
81+
for i in range(num_hidden_layers)
82+
]
83+
),
84+
make_dynamic_cache(
85+
[
86+
(
87+
torch.randn(
88+
batch_size,
89+
decoder_attention_heads,
90+
sequence_length2,
91+
decoder_ffn_dim,
92+
),
93+
torch.randn(
94+
batch_size,
95+
decoder_attention_heads,
96+
sequence_length2,
97+
decoder_ffn_dim,
98+
),
99+
)
100+
for i in range(num_hidden_layers)
101+
]
102+
),
103+
)
104+
cache_length = "cache_length_key"
105+
cache_length2 = "cache_length_val"
106+
shapes["past_key_values"] = [ # type: ignore[assignment]
107+
[
108+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
109+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
110+
],
111+
[
112+
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
113+
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
114+
],
115+
]
116+
53117
res = dict(inputs=inputs, dynamic_shapes=shapes)
54118
if add_second_input:
55119
assert (
@@ -61,6 +125,12 @@ def get_inputs(
61125
batch_size=batch_size + 1,
62126
sequence_length=sequence_length + add_second_input,
63127
dummy_max_token_id=dummy_max_token_id,
128+
sequence_length2=sequence_length2,
129+
decoder_attention_heads=decoder_attention_heads,
130+
encoder_attention_heads=encoder_attention_heads,
131+
encoder_ffn_dim=encoder_ffn_dim,
132+
decoder_ffn_dim=decoder_ffn_dim,
133+
num_hidden_layers=num_hidden_layers,
64134
add_second_input=0,
65135
**kwargs,
66136
)["inputs"]
@@ -80,4 +150,15 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
80150
sequence_length=30,
81151
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
82152
)
153+
for att in [
154+
"decoder_attention_heads",
155+
"encoder_attention_heads",
156+
"encoder_ffn_dim",
157+
"decoder_ffn_dim",
158+
"num_hidden_layers",
159+
]:
160+
if hasattr(config, att):
161+
kwargs[att] = getattr(config, att)
162+
kwargs["decoder_ffn_dim"] = kwargs["encoder_ffn_dim"] = 64
163+
print(kwargs)
83164
return kwargs, get_inputs

onnx_diagnostic/tasks/text2text_generation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def get_inputs(
6969
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
7070
batch = torch.export.Dim("batch", min=1, max=1024)
7171
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
72-
cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
73-
cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096)
72+
cache_length = "cache_length_key"
73+
cache_length2 = "cache_length_val"
7474

7575
shapes = {
7676
"input_ids": {0: batch, 1: seq_length},

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ def get_function(name: str) -> Tuple[type, Callable]:
1616
module_name = ".".join(spl[:-1])
1717
fname = spl[-1]
1818
mod = importlib.import_module(module_name)
19+
if not hasattr(mod, fname):
20+
return None, None
1921
return mod, getattr(mod, fname)
2022

2123

@@ -33,12 +35,16 @@ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
3335
doc = v.__doc__.lstrip()
3436
if doc.startswith("manual patch"):
3537
continue
36-
reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]")
38+
reg = re.compile("[\\[]patch:([a-z_A-Z.]+)[\\]]")
3739
fall = reg.findall(doc)
3840
assert (
3941
len(fall) == 1
4042
), f"Unable to find patching information for {v} in \n{doc}"
4143
fmod, f = get_function(fall[0])
44+
if fmod is None and f is None:
45+
# The function does not exist in this version of transformers.
46+
# No patch is needed.
47+
continue
4248
to_patch.append({"module": fmod, "function": f, "patch": v})
4349

4450
name = mod.__name__

0 commit comments

Comments
 (0)