Skip to content

Commit 63a1408

Browse files
committed
Support for static cache
1 parent 3974d97 commit 63a1408

22 files changed

+466
-182
lines changed

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
import torch
33
import transformers
44
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
5-
from onnx_diagnostic.helpers import string_type
5+
from onnx_diagnostic.helpers import string_type, max_diff
66
from onnx_diagnostic.helpers.cache_helper import (
77
flatten_unflatten_for_dynamic_shapes,
88
make_dynamic_cache,
99
make_encoder_decoder_cache,
1010
make_mamba_cache,
1111
make_sliding_window_cache,
12+
make_static_cache,
1213
)
1314
from onnx_diagnostic.export import CoupleInputsDynamicShapes
1415
from onnx_diagnostic.torch_export_patches.patch_inputs import (
@@ -104,6 +105,7 @@ def test_unflatten_flatten_encoder_decoder_cache(self):
104105
]
105106
),
106107
)
108+
self.assertEqual(0, max_diff(c2, c2)["abs"])
107109
self.assertIsInstance(c2, transformers.cache_utils.EncoderDecoderCache)
108110
flat, _spec = torch.utils._pytree.tree_flatten(c2)
109111
self.assertIsInstance(flat, list)
@@ -149,6 +151,7 @@ def test_make_mamba_cache(self):
149151
"ssm_states=#3[T1s4x4x4,T1s4x4x4,T1s4x4x4])",
150152
text,
151153
)
154+
self.assertEqual(0, max_diff(cache, cache)["abs"])
152155

153156
def test_make_sliding_window_cache(self):
154157
cache = make_sliding_window_cache(
@@ -164,6 +167,45 @@ def test_make_sliding_window_cache(self):
164167
"value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])",
165168
text,
166169
)
170+
self.assertEqual(0, max_diff(cache, cache)["abs"])
171+
172+
def test_make_static_cache(self):
173+
cache = make_static_cache(
174+
[
175+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
176+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
177+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
178+
]
179+
)
180+
text = self.string_type(cache, with_shape=True)
181+
self.assertEqual(
182+
"StaticCache(key_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7], "
183+
"value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])",
184+
text,
185+
)
186+
self.assertEqual(0, max_diff(cache, cache)["abs"])
187+
188+
def test_unflatten_flatten_static_cache(self):
189+
with torch_export_patches(patch_transformers=True):
190+
c2 = make_static_cache(
191+
[
192+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
193+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
194+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
195+
]
196+
)
197+
self.assertEqual(0, max_diff(c2, c2)["abs"])
198+
self.assertIsInstance(c2, transformers.cache_utils.StaticCache)
199+
flat, _spec = torch.utils._pytree.tree_flatten(c2)
200+
self.assertIsInstance(flat, list)
201+
self.assertEqual(len(flat), 6)
202+
unflat = flatten_unflatten_for_dynamic_shapes(c2)
203+
self.assertIsInstance(unflat, list)
204+
self.assertEqual(len(unflat), 2)
205+
self.assertEqual(
206+
"#2[#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7],#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7]]",
207+
self.string_type(unflat, with_shape=True),
208+
)
167209

168210

169211
if __name__ == "__main__":

_unittests/ut_tasks/try_tasks.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,35 @@ def test_text2text_generation(self):
9898
)
9999
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
100100

101+
@never_test()
102+
def test_text2text_generation_static(self):
103+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k text2t
104+
105+
import torch
106+
from transformers import AutoTokenizer, AutoModelForCausalLM
107+
108+
tokenizer = AutoTokenizer.from_pretrained("arnir0/Tiny-LLM")
109+
model = AutoModelForCausalLM.from_pretrained("arnir0/Tiny-LLM")
110+
111+
text = "def greet(user): print(f'hello <extra_id_0>!')"
112+
input_ids = tokenizer(text, return_tensors="pt").input_ids
113+
mask = (
114+
torch.tensor([1 for i in range(input_ids.shape[1])])
115+
.to(torch.int64)
116+
.reshape((1, -1))
117+
)
118+
119+
# simply generate a single sequence
120+
print()
121+
with steal_forward(model):
122+
generated_ids = model.generate(
123+
input_ids=input_ids,
124+
attention_mask=mask,
125+
max_new_tokens=117,
126+
cache_implementation="static",
127+
)
128+
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
129+
101130
@never_test()
102131
def test_text_generation_phi4_mini(self):
103132
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k phi4_mini

_unittests/ut_torch_models/test_tiny_llms.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,61 @@
11
import copy
22
import unittest
33
import torch
4-
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_transformers
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
55
from onnx_diagnostic.torch_models.llms import get_tiny_llm
66
from onnx_diagnostic.helpers import string_type
7+
from onnx_diagnostic.torch_export_patches import torch_export_patches
8+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
79

810

911
class TestTinyLlm(ExtTestCase):
10-
def test_get_tiny_llm(self):
12+
def test_tiny_llm_run_dynamic(self):
1113
data = get_tiny_llm()
1214
model, inputs = data["model"], data["inputs"]
1315
self.assertIn("DynamicCache", string_type(inputs))
1416
model(**inputs)
1517

1618
@ignore_warnings(UserWarning)
17-
@requires_transformers("4.53")
18-
def test_export_tiny_llm_1(self):
19+
def test_tiny_llm_export_dynamic(self):
1920
data = get_tiny_llm()
2021
model, inputs = data["model"], data["inputs"]
2122
expected = model(**copy.deepcopy(inputs))
2223
self.assertEqual(
2324
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
2425
)
25-
ep = torch.export.export(
26-
model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=data["dynamic_shapes"]
26+
with torch_export_patches(patch_transformers=True):
27+
ep = torch.export.export(
28+
model,
29+
(),
30+
kwargs=copy.deepcopy(inputs),
31+
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
32+
)
33+
got = ep.module()(**inputs)
34+
self.assertEqualArrayAny(expected, got)
35+
36+
def test_tiny_llm_run_static(self):
37+
data = get_tiny_llm(use_static_cache=True)
38+
model, inputs = data["model"], data["inputs"]
39+
self.assertIn("StaticCache", string_type(inputs))
40+
model(**inputs)
41+
42+
@ignore_warnings(UserWarning)
43+
def test_tiny_llm_export_static(self):
44+
data = get_tiny_llm(use_static_cache=True)
45+
model, inputs = data["model"], data["inputs"]
46+
expected = model(**copy.deepcopy(inputs))
47+
self.assertEqual(
48+
{"attention_mask", "past_key_values", "input_ids", "cache_position"}, set(inputs)
2749
)
28-
got = ep.module()(**inputs)
29-
self.assertEqualArrayAny(expected, got)
50+
with torch_export_patches(patch_transformers=True, stop_if_static=1):
51+
ep = torch.export.export(
52+
model,
53+
(),
54+
kwargs=copy.deepcopy(inputs),
55+
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
56+
)
57+
got = ep.module()(**inputs)
58+
self.assertEqualArrayAny(expected, got)
3059

3160

3261
if __name__ == "__main__":

0 commit comments

Comments
 (0)