Skip to content

Commit af71ef2

Browse files
authored
Improves handling of StaticCache (#166)
* Improves handling of StaticCache * mypy * static * try * fix isssues * fix issues * disable * fix * fix
1 parent 28fe237 commit af71ef2

File tree

11 files changed

+172
-31
lines changed

11 files changed

+172
-31
lines changed

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.7.2
55
+++++
66

7+
* :pr:`166`: improves handling of StaticCache
78
* :pr:`165`: support for task text-to-image
89
* :pr:`162`: improves graphs rendering for historical data
910

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,13 @@ def test_make_static_cache(self):
175175
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
176176
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
177177
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
178-
]
178+
],
179+
max_cache_len=15,
179180
)
180181
text = self.string_type(cache, with_shape=True)
181182
self.assertEqual(
182-
"StaticCache(key_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7], "
183-
"value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])",
183+
"StaticCache(key_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7], "
184+
"value_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7])",
184185
text,
185186
)
186187
self.assertEqual(0, max_diff(cache, cache)["abs"])
@@ -192,7 +193,8 @@ def test_unflatten_flatten_static_cache(self):
192193
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
193194
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
194195
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
195-
]
196+
],
197+
max_cache_len=6,
196198
)
197199
self.assertEqual(0, max_diff(c2, c2)["abs"])
198200
self.assertIsInstance(c2, transformers.cache_utils.StaticCache)

_unittests/ut_torch_export_patches/test_patch_serialization.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from onnx_diagnostic.helpers.cache_helper import (
1111
make_encoder_decoder_cache,
1212
make_dynamic_cache,
13+
make_static_cache,
1314
make_sliding_window_cache,
1415
flatten_unflatten_for_dynamic_shapes,
1516
)
@@ -180,7 +181,7 @@ def test_base_sliding_window_cache_unflatten_flatten(self):
180181
self.assertEqualAny([cache], cache2)
181182

182183
@ignore_warnings(UserWarning)
183-
@requires_torch("2.8")
184+
@requires_torch("2.7.99")
184185
def test_sliding_window_cache_export(self):
185186
class Model(torch.nn.Module):
186187
def forward(self, cache):
@@ -274,6 +275,69 @@ def forward(self, cache):
274275
with torch_export_patches():
275276
torch.export.export(model, (bo,), dynamic_shapes=(ds,))
276277

278+
@ignore_warnings(UserWarning)
279+
@requires_torch("2.7.99")
280+
def test_static_cache(self):
281+
bo = make_static_cache(
282+
[
283+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
284+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
285+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
286+
],
287+
max_cache_len=15,
288+
)
289+
self.assertEqual(bo.__class__.__name__, "StaticCache")
290+
bo2 = torch_deepcopy([bo])
291+
self.assertIsInstance(bo2, list)
292+
self.assertEqual(
293+
"StaticCache(key_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7], "
294+
"value_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7])",
295+
self.string_type(bo, with_shape=True),
296+
)
297+
298+
with torch_export_patches():
299+
# internal function
300+
bo2 = torch_deepcopy([bo])
301+
self.assertIsInstance(bo2, list)
302+
self.assertEqual(bo2[0].__class__.__name__, "StaticCache")
303+
self.assertEqualAny([bo], bo2)
304+
self.assertEqual(
305+
"StaticCache(key_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7], "
306+
"value_cache=#3[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7])",
307+
self.string_type(bo, with_shape=True),
308+
)
309+
310+
# serialization
311+
flat, _spec = torch.utils._pytree.tree_flatten(bo)
312+
self.assertEqual(
313+
"#6[T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7,T1s4x5x15x7]",
314+
self.string_type(flat, with_shape=True),
315+
)
316+
bo2 = torch.utils._pytree.tree_unflatten(flat, _spec)
317+
self.assertEqual(
318+
self.string_type(bo, with_shape=True, with_min_max=True),
319+
self.string_type(bo2, with_shape=True, with_min_max=True),
320+
)
321+
322+
# flatten_unflatten
323+
flat, _spec = torch.utils._pytree.tree_flatten(bo)
324+
unflat = flatten_unflatten_for_dynamic_shapes(bo, use_dict=True)
325+
self.assertIsInstance(unflat, dict)
326+
self.assertEqual(list(unflat), ["key_cache", "value_cache"])
327+
328+
# export
329+
class Model(torch.nn.Module):
330+
def forward(self, cache):
331+
return cache.key_cache[0]
332+
333+
model = Model()
334+
model(bo)
335+
DYN = torch.export.Dim.DYNAMIC
336+
ds = [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]]
337+
338+
with torch_export_patches(patch_transformers=True, stop_if_static=1):
339+
torch.export.export(model, (bo,), dynamic_shapes=(ds,))
340+
277341

278342
if __name__ == "__main__":
279343
unittest.main(verbosity=2)

_unittests/ut_torch_export_patches/test_patch_torch.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,18 @@ def forward(self, batch_arange, head_arange, cache_position, kv_arange):
204204
ep = torch.export.export(Model(), inputs, dynamic_shapes=ds)
205205
self.assertEqualArray(causal_mask, ep.module()(*inputs))
206206

207+
@requires_torch("2.7")
208+
def test_export_unsqueeze(self):
209+
class Model(torch.nn.Module):
210+
def forward(self, x):
211+
return x.unsqueeze(0).unsqueeze(2).unsqueeze(3)
212+
213+
x = torch.tensor([7.0, 8.0])
214+
Model()(x)
215+
DYN = torch.export.Dim.DYNAMIC
216+
ep = torch.export.export(Model(), (x,), dynamic_shapes=({0: DYN},))
217+
self.assertEqualArray(Model()(x), ep.module()(x))
218+
207219

208220
if __name__ == "__main__":
209221
unittest.main(verbosity=2)

onnx_diagnostic/ext_test_case.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,16 @@ def assertEqualAny(
976976
atol=atol,
977977
rtol=rtol,
978978
)
979+
elif expected.__class__.__name__ == "StaticCache":
980+
self.assertEqual(type(expected), type(value), msg=msg)
981+
self.assertEqual(expected.max_cache_len, value.max_cache_len)
982+
atts = ["key_cache", "value_cache"]
983+
self.assertEqualAny(
984+
{k: expected.__dict__.get(k, None) for k in atts},
985+
{k: value.__dict__.get(k, None) for k in atts},
986+
atol=atol,
987+
rtol=rtol,
988+
)
979989
elif expected.__class__.__name__ == "EncoderDecoderCache":
980990
self.assertEqual(type(expected), type(value), msg=msg)
981991
atts = ["self_attention_cache", "cross_attention_cache"]

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,12 @@ def make_dynamic_cache(
154154

155155
def make_static_cache(
156156
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
157+
max_cache_len: Optional[int] = None,
157158
) -> transformers.cache_utils.DynamicCache:
158159
"""
159160
Creates an instance of :class:`transformers.cache_utils.StaticCache`.
160161
:param key_value_pairs: list of pairs of (key, values)
162+
:param max_cache_len: max_cache_length or something inferred from the vector
161163
:return: :class:`transformers.cache_utils.StaticCache`
162164
163165
Example:
@@ -190,24 +192,32 @@ def __init__(self):
190192
self.num_attention_heads = key_value_pairs[0][0].shape[1]
191193
self.num_hidden_layers = len(key_value_pairs)
192194

195+
assert max_cache_len is not None, (
196+
f"max_cache_len={max_cache_len} cannot be setup "
197+
f"automatically yet from shape {key_value_pairs[0][0].shape}"
198+
)
199+
torch._check(
200+
max_cache_len >= key_value_pairs[0][0].shape[2],
201+
(
202+
f"max_cache_len={max_cache_len} cannot be smaller "
203+
f"shape[2]={key_value_pairs[0][0].shape[2]} in shape "
204+
f"{key_value_pairs[0][0].shape}"
205+
),
206+
)
193207
cache = transformers.cache_utils.StaticCache(
194208
_config(),
195209
max_batch_size=key_value_pairs[0][0].shape[0],
196210
device=key_value_pairs[0][0].device,
197211
dtype=key_value_pairs[0][0].dtype,
198-
max_cache_len=key_value_pairs[0][0].shape[2],
212+
max_cache_len=max_cache_len,
199213
)
200214
for i in range(len(key_value_pairs)):
201-
assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
202-
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
203-
f"got {key_value_pairs[i][0].shape}"
204-
)
205-
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
206-
assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
207-
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
208-
f"got {key_value_pairs[i][1].shape}"
209-
)
210-
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
215+
assert (
216+
key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
217+
), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}"
218+
d = key_value_pairs[i][1].shape[2]
219+
cache.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
220+
cache.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
211221
return cache
212222

213223

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -735,7 +735,8 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
735735
[t.to(to_value) for t in value.key_cache],
736736
[t.to(to_value) for t in value.value_cache],
737737
)
738-
)
738+
),
739+
max_cache_len=value.max_cache_len,
739740
)
740741
if value.__class__.__name__ == "EncoderDecoderCache":
741742
return make_encoder_decoder_cache(
@@ -784,7 +785,10 @@ def torch_deepcopy(value: Any) -> Any:
784785
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
785786
)
786787
if value.__class__.__name__ == "StaticCache":
787-
return make_static_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache))))
788+
return make_static_cache(
789+
torch_deepcopy(list(zip(value.key_cache, value.value_cache))),
790+
max_cache_len=value.max_cache_len,
791+
)
788792
if value.__class__.__name__ == "SlidingWindowCache":
789793
return make_sliding_window_cache(
790794
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))

onnx_diagnostic/tasks/text_generation.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def get_inputs(
109109
sequence_length2 = seq_length_multiple
110110

111111
shapes = {
112-
"input_ids": {0: batch, 1: torch.export.Dim.DYNAMIC},
112+
"input_ids": {0: batch, 1: "sequence_length"},
113113
"attention_mask": {
114114
0: batch,
115115
1: "cache+seq", # cache_length + seq_length
@@ -188,18 +188,25 @@ def get_inputs(
188188
(batch_size, num_key_value_heads, sequence_length2, head_dim)
189189
).to(torch.bool),
190190
cache_position=torch.arange(sequence_length2).to(torch.int64),
191-
past_key_values=make_cache(
191+
past_key_values=make_static_cache(
192192
[
193193
(
194194
torch.randn(
195-
batch_size, num_key_value_heads, sequence_length, head_dim
195+
batch_size,
196+
num_key_value_heads,
197+
sequence_length + sequence_length2,
198+
head_dim,
196199
),
197200
torch.randn(
198-
batch_size, num_key_value_heads, sequence_length, head_dim
201+
batch_size,
202+
num_key_value_heads,
203+
sequence_length + sequence_length2,
204+
head_dim,
199205
),
200206
)
201207
for i in range(num_hidden_layers)
202-
]
208+
],
209+
max_cache_len=max(sequence_length + sequence_length2, head_dim),
203210
),
204211
)
205212
else:
@@ -230,7 +237,7 @@ def get_inputs(
230237
position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
231238
.to(torch.int64)
232239
.expand((batch_size, -1)),
233-
past_key_values=make_cache(
240+
past_key_values=make_cache( # type: ignore[operator]
234241
[
235242
(
236243
torch.randn(

onnx_diagnostic/torch_export_patches/onnx_export_serialization_impl.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,11 @@ def flatten_static_cache(
151151
cache: StaticCache,
152152
) -> Tuple[List[Any], torch.utils._pytree.Context]:
153153
"""Serializes a :class:`transformers.cache_utils.StaticCache` with python objects."""
154+
assert not cache.key_cache or cache.max_cache_len == cache.key_cache[0].shape[2], (
155+
f"Serialization doet not work when "
156+
f"cache.max_cache_len={cache.max_cache_len} != "
157+
f"cache.key_cache[0].shape[2]={cache.key_cache[0].shape[2]}"
158+
)
154159
flat = [("key_cache", cache.key_cache), ("value_cache", cache.value_cache)]
155160
return [f[1] for f in flat], [f[0] for f in flat]
156161

@@ -167,7 +172,9 @@ def unflatten_static_cache(
167172
values: List[Any], context: torch.utils._pytree.Context, output_type=None
168173
) -> StaticCache:
169174
"""Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
170-
return make_static_cache(list(zip(values[0], values[1])))
175+
return make_static_cache(
176+
list(zip(values[0], values[1])), max_cache_len=values[0][0].shape[2]
177+
)
171178

172179

173180
####################

onnx_diagnostic/torch_export_patches/patches/patch_transformers.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,41 @@ def patched__vmap_for_bhqkv(mask_function: Callable, bh_indices: bool = True) ->
2020
]
2121
if bh_indices:
2222
dimensions.extend([(None, 0, None, None), (0, None, None, None)])
23+
# reshape
2324
dimensions = [tuple(1 if d is None else -1 for d in shape) for shape in dimensions]
2425
dimensions = tuple(reversed(dimensions))
2526
indices = tuple(shape.index(-1) for shape in dimensions)
2627

28+
# unsqueeze
29+
udimensions = [tuple(di for di, d in enumerate(shape) if d == 1) for shape in dimensions]
30+
2731
def vector_mask_function(
2832
*args, mask_function=mask_function, dimensions=dimensions, indices=indices
2933
):
30-
assert len(args) == len(
31-
dimensions
32-
), f"Mismatch between args={string_type(args)} and dimensions={dimensions}"
34+
assert len(args) == len(dimensions) == len(udimensions), (
35+
f"Mismatch between args={string_type(args)} and dimensions={dimensions} "
36+
f"and udimensions={udimensions}."
37+
)
38+
assert len(indices) == len(args), (
39+
f"Mismatch between args={string_type(args)} and indices={indices}, "
40+
f"they should have the same length."
41+
)
42+
for a in args:
43+
assert (
44+
a.ndim == 1
45+
), f"Expected a tensor with 1 dimension not {string_type(a, with_shape=True)}"
46+
torch._check(a.shape[0] > 0)
47+
3348
new_args = [a.reshape(shape) for a, shape in zip(args, dimensions)]
49+
# new_args = [
50+
# a.unsqueeze(dims[0]).unsqueeze(dims[1]).unsqueeze(dims[2])
51+
# for a, dims in zip(args, udimensions)
52+
# ]
3453
max_shape = tuple(args[i].shape[0] for i in indices)
54+
# if is_torchdynamo_exporting():
55+
# for a in args:
56+
# # The exporter should export with a dimension > 1 to make sure it is dynamic.
57+
# torch._check(a.shape[0] > 1)
3558
expanded_args = [a.expand(max_shape) for a in new_args]
3659
return mask_function(*expanded_args)
3760

0 commit comments

Comments
 (0)