Skip to content

Commit 1f0ca98

Browse files
authored
Support for static cache (#149)
* Support for static cache * fix cache * fix ut * fix static * fix issues * fix * fix
1 parent 3974d97 commit 1f0ca98

27 files changed

+499
-189
lines changed

_doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def linkcode_resolve(domain, info):
136136
("py:class", "transformers.cache_utils.EncoderDecoderCache"),
137137
("py:class", "transformers.cache_utils.MambaCache"),
138138
("py:class", "transformers.cache_utils.SlidingWindowCache"),
139+
("py:class", "transformers.cache_utils.StaticCache"),
139140
("py:class", "transformers.configuration_utils.PretrainedConfig"),
140141
("py:class", "transformers.modeling_outputs.BaseModelOutput"),
141142
("py:class", "transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding"),

_doc/examples/plot_export_tiny_llm.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from onnx_diagnostic.helpers import string_type
3434
from onnx_diagnostic.helpers.torch_helper import steal_forward
3535
from onnx_diagnostic.torch_models.llms import get_tiny_llm
36+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
3637

3738

3839
MODEL_NAME = "arnir0/Tiny-LLM"
@@ -131,7 +132,11 @@ def _forward_(*args, _f=None, **kwargs):
131132

132133
try:
133134
ep = torch.export.export(
134-
untrained_model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
135+
untrained_model,
136+
(),
137+
kwargs=cloned_inputs,
138+
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
139+
strict=False,
135140
)
136141
print("It worked:")
137142
print(ep)
@@ -166,7 +171,11 @@ def _forward_(*args, _f=None, **kwargs):
166171

167172
try:
168173
ep = torch.export.export(
169-
model, (), kwargs=cloned_inputs, dynamic_shapes=dynamic_shapes, strict=False
174+
model,
175+
(),
176+
kwargs=cloned_inputs,
177+
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
178+
strict=False,
170179
)
171180
print("It worked:")
172181
print(ep)

_doc/examples/plot_export_tiny_llm_patched.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
7171
from onnx_diagnostic.helpers import string_type
7272
from onnx_diagnostic.torch_export_patches import torch_export_patches
73+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
7374
from onnx_diagnostic.torch_models.llms import get_tiny_llm
7475

7576

@@ -110,7 +111,7 @@
110111
untrained_model,
111112
(),
112113
kwargs=modificator(cloned_inputs),
113-
dynamic_shapes=dynamic_shapes,
114+
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
114115
strict=False, # mandatory for torch==2.6
115116
)
116117
print("It worked:")
@@ -131,7 +132,7 @@
131132
model,
132133
(),
133134
kwargs=modificator(cloned_inputs),
134-
dynamic_shapes=dynamic_shapes,
135+
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
135136
strict=False, # mandatory for torch==2.6
136137
)
137138
print("It worked:")

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
import torch
33
import transformers
44
from onnx_diagnostic.ext_test_case import ExtTestCase, requires_transformers
5-
from onnx_diagnostic.helpers import string_type
5+
from onnx_diagnostic.helpers import string_type, max_diff
66
from onnx_diagnostic.helpers.cache_helper import (
77
flatten_unflatten_for_dynamic_shapes,
88
make_dynamic_cache,
99
make_encoder_decoder_cache,
1010
make_mamba_cache,
1111
make_sliding_window_cache,
12+
make_static_cache,
1213
)
1314
from onnx_diagnostic.export import CoupleInputsDynamicShapes
1415
from onnx_diagnostic.torch_export_patches.patch_inputs import (
@@ -104,6 +105,7 @@ def test_unflatten_flatten_encoder_decoder_cache(self):
104105
]
105106
),
106107
)
108+
self.assertEqual(0, max_diff(c2, c2)["abs"])
107109
self.assertIsInstance(c2, transformers.cache_utils.EncoderDecoderCache)
108110
flat, _spec = torch.utils._pytree.tree_flatten(c2)
109111
self.assertIsInstance(flat, list)
@@ -149,6 +151,7 @@ def test_make_mamba_cache(self):
149151
"ssm_states=#3[T1s4x4x4,T1s4x4x4,T1s4x4x4])",
150152
text,
151153
)
154+
self.assertEqual(0, max_diff(cache, cache)["abs"])
152155

153156
def test_make_sliding_window_cache(self):
154157
cache = make_sliding_window_cache(
@@ -164,6 +167,45 @@ def test_make_sliding_window_cache(self):
164167
"value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])",
165168
text,
166169
)
170+
self.assertEqual(0, max_diff(cache, cache)["abs"])
171+
172+
def test_make_static_cache(self):
173+
cache = make_static_cache(
174+
[
175+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
176+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
177+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
178+
]
179+
)
180+
text = self.string_type(cache, with_shape=True)
181+
self.assertEqual(
182+
"StaticCache(key_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7], "
183+
"value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])",
184+
text,
185+
)
186+
self.assertEqual(0, max_diff(cache, cache)["abs"])
187+
188+
def test_unflatten_flatten_static_cache(self):
189+
with torch_export_patches(patch_transformers=True):
190+
c2 = make_static_cache(
191+
[
192+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
193+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
194+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
195+
]
196+
)
197+
self.assertEqual(0, max_diff(c2, c2)["abs"])
198+
self.assertIsInstance(c2, transformers.cache_utils.StaticCache)
199+
flat, _spec = torch.utils._pytree.tree_flatten(c2)
200+
self.assertIsInstance(flat, list)
201+
self.assertEqual(len(flat), 6)
202+
unflat = flatten_unflatten_for_dynamic_shapes(c2)
203+
self.assertIsInstance(unflat, list)
204+
self.assertEqual(len(unflat), 2)
205+
self.assertEqual(
206+
"#2[#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7],#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7]]",
207+
self.string_type(unflat, with_shape=True),
208+
)
167209

168210

169211
if __name__ == "__main__":

_unittests/ut_tasks/try_tasks.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,35 @@ def test_text2text_generation(self):
9898
)
9999
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
100100

101+
@never_test()
102+
def test_text2text_generation_static(self):
103+
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k text2t
104+
105+
import torch
106+
from transformers import AutoTokenizer, AutoModelForCausalLM
107+
108+
tokenizer = AutoTokenizer.from_pretrained("arnir0/Tiny-LLM")
109+
model = AutoModelForCausalLM.from_pretrained("arnir0/Tiny-LLM")
110+
111+
text = "def greet(user): print(f'hello <extra_id_0>!')"
112+
input_ids = tokenizer(text, return_tensors="pt").input_ids
113+
mask = (
114+
torch.tensor([1 for i in range(input_ids.shape[1])])
115+
.to(torch.int64)
116+
.reshape((1, -1))
117+
)
118+
119+
# simply generate a single sequence
120+
print()
121+
with steal_forward(model):
122+
generated_ids = model.generate(
123+
input_ids=input_ids,
124+
attention_mask=mask,
125+
max_new_tokens=117,
126+
cache_implementation="static",
127+
)
128+
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
129+
101130
@never_test()
102131
def test_text_generation_phi4_mini(self):
103132
# clear&&NEVERTEST=1 python _unittests/ut_tasks/try_tasks.py -k phi4_mini

_unittests/ut_torch_models/test_tiny_llms.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,70 @@
11
import copy
22
import unittest
33
import torch
4-
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, requires_transformers
4+
from onnx_diagnostic.ext_test_case import (
5+
ExtTestCase,
6+
ignore_warnings,
7+
requires_transformers,
8+
requires_torch,
9+
)
510
from onnx_diagnostic.torch_models.llms import get_tiny_llm
611
from onnx_diagnostic.helpers import string_type
12+
from onnx_diagnostic.torch_export_patches import torch_export_patches
13+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
714

815

916
class TestTinyLlm(ExtTestCase):
10-
def test_get_tiny_llm(self):
17+
def test_tiny_llm_run_dynamic(self):
1118
data = get_tiny_llm()
1219
model, inputs = data["model"], data["inputs"]
1320
self.assertIn("DynamicCache", string_type(inputs))
1421
model(**inputs)
1522

1623
@ignore_warnings(UserWarning)
17-
@requires_transformers("4.53")
18-
def test_export_tiny_llm_1(self):
24+
@requires_torch("2.8")
25+
def test_tiny_llm_export_dynamic(self):
1926
data = get_tiny_llm()
2027
model, inputs = data["model"], data["inputs"]
2128
expected = model(**copy.deepcopy(inputs))
2229
self.assertEqual(
2330
{"attention_mask", "past_key_values", "input_ids", "position_ids"}, set(inputs)
2431
)
25-
ep = torch.export.export(
26-
model, (), kwargs=copy.deepcopy(inputs), dynamic_shapes=data["dynamic_shapes"]
32+
with torch_export_patches(patch_transformers=True):
33+
ep = torch.export.export(
34+
model,
35+
(),
36+
kwargs=copy.deepcopy(inputs),
37+
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
38+
)
39+
got = ep.module()(**inputs)
40+
self.assertEqualArrayAny(expected, got)
41+
42+
@requires_transformers("4.52")
43+
def test_tiny_llm_run_static(self):
44+
data = get_tiny_llm(use_static_cache=True)
45+
model, inputs = data["model"], data["inputs"]
46+
self.assertIn("StaticCache", string_type(inputs))
47+
model(**inputs)
48+
49+
@ignore_warnings(UserWarning)
50+
@requires_transformers("4.52")
51+
@requires_torch("2.8")
52+
def test_tiny_llm_export_static(self):
53+
data = get_tiny_llm(use_static_cache=True)
54+
model, inputs = data["model"], data["inputs"]
55+
expected = model(**copy.deepcopy(inputs))
56+
self.assertEqual(
57+
{"attention_mask", "past_key_values", "input_ids", "cache_position"}, set(inputs)
2758
)
28-
got = ep.module()(**inputs)
29-
self.assertEqualArrayAny(expected, got)
59+
with torch_export_patches(patch_transformers=True, stop_if_static=1):
60+
ep = torch.export.export(
61+
model,
62+
(),
63+
kwargs=copy.deepcopy(inputs),
64+
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
65+
)
66+
got = ep.module()(**inputs)
67+
self.assertEqualArrayAny(expected, got)
3068

3169

3270
if __name__ == "__main__":

_unittests/ut_torch_models/test_tiny_llms_bypassed.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@
22
import unittest
33
import torch
44
from transformers.cache_utils import DynamicCache
5-
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
5+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, hide_stdout
66
from onnx_diagnostic.torch_models.llms import get_tiny_llm
77
from onnx_diagnostic.torch_models.llms import get_phi2
88
from onnx_diagnostic.helpers import string_type
99
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
1010
from onnx_diagnostic.torch_export_patches import torch_export_patches
11+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
1112
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
1213
patched_DynamicCache,
1314
)
1415

1516

1617
class TestTinyLlmBypassed(ExtTestCase):
1718
@ignore_warnings(UserWarning)
19+
@hide_stdout()
1820
def test_export_tiny_llm_2_bypassed(self):
1921
data = get_tiny_llm()
2022
model, inputs = data["model"], data["inputs"]
@@ -50,7 +52,11 @@ def debug():
5052
debug()
5153

5254
ep = torch.export.export(
53-
model, (), kwargs=inputs, dynamic_shapes=data["dynamic_shapes"], strict=False
55+
model,
56+
(),
57+
kwargs=inputs,
58+
dynamic_shapes=use_dyn_not_str(data["dynamic_shapes"]),
59+
strict=False,
5460
)
5561
got = ep.module()(**inputs)
5662
self.assertEqualArrayAny(expected, got)

0 commit comments

Comments
 (0)