Skip to content

Commit ca9dcc6

Browse files
committed
add support for hybrid cache
1 parent 7e3cc45 commit ca9dcc6

File tree

10 files changed

+449
-33
lines changed

10 files changed

+449
-33
lines changed

CHANGELOGS.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Change Logs
44
0.7.6
55
+++++
66

7+
* :pr:`192`: add support for Gemma-3, add serialization for HybridCache
8+
79
0.7.5
810
+++++
911

_unittests/ut_helpers/test_cache_helper.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@
77
flatten_unflatten_for_dynamic_shapes,
88
make_dynamic_cache,
99
make_encoder_decoder_cache,
10+
make_hybrid_cache,
1011
make_mamba_cache,
1112
make_sliding_window_cache,
1213
make_static_cache,
1314
)
15+
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
1416
from onnx_diagnostic.export import CoupleInputsDynamicShapes
1517
from onnx_diagnostic.torch_export_patches.patch_inputs import (
1618
convert_dynamic_axes_into_dynamic_shapes,
@@ -209,6 +211,45 @@ def test_unflatten_flatten_static_cache(self):
209211
self.string_type(unflat, with_shape=True),
210212
)
211213

214+
def test_make_hybrid_cache(self):
215+
cache = make_hybrid_cache(
216+
[
217+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
218+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
219+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
220+
],
221+
)
222+
text = self.string_type(cache, with_shape=True)
223+
self.assertEqual(
224+
"HybridCache(key_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7], "
225+
"value_cache=#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7])",
226+
text,
227+
)
228+
self.assertEqual(0, max_diff(cache, cache)["abs"])
229+
self.assertEqual(0, max_diff(cache, torch_deepcopy(cache))["abs"])
230+
231+
def test_unflatten_flatten_hybrid_cache(self):
232+
with torch_export_patches(patch_transformers=True):
233+
c2 = make_hybrid_cache(
234+
[
235+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
236+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
237+
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
238+
],
239+
)
240+
self.assertEqual(0, max_diff(c2, c2)["abs"])
241+
self.assertIsInstance(c2, transformers.cache_utils.HybridCache)
242+
flat, _spec = torch.utils._pytree.tree_flatten(c2)
243+
self.assertIsInstance(flat, list)
244+
self.assertEqual(len(flat), 6)
245+
unflat = flatten_unflatten_for_dynamic_shapes(c2)
246+
self.assertIsInstance(unflat, list)
247+
self.assertEqual(len(unflat), 2)
248+
self.assertEqual(
249+
"#2[#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7],#3[T1s4x5x6x7,T1s4x5x6x7,T1s4x5x6x7]]",
250+
self.string_type(unflat, with_shape=True),
251+
)
252+
212253

213254
if __name__ == "__main__":
214255
unittest.main(verbosity=2)

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def test_image_text_to_text_idefics(self):
3333
@requires_transformers("4.53")
3434
@requires_torch("2.7.99")
3535
def test_image_text_to_text_gemma3(self):
36-
mid = "tiny-random/gemma-3"
36+
mid = "google/gemma-3-4b-it"
3737
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
3838
self.assertEqual(data["task"], "image-text-to-text")
3939
self.assertIn((data["size"], data["n_weights"]), [(34401152, 8600288)])

_unittests/ut_tasks/try_tasks.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -289,14 +289,41 @@ def test_imagetext2text_generation_idefics(self):
289289

290290
@never_test()
291291
def test_imagetext2text_generation_gemma3(self):
292+
"""
293+
::
294+
295+
dict(input_ids:T7s1x281,
296+
pixel_values:T16s1x3x896x896,
297+
attention_mask:dict(full_attention:T9s1x1x281x380,sliding_attention:T9s1x1x281x380),
298+
position_ids:T7s1x281,
299+
past_key_values:HybridCache(
300+
key_cache=#34[T1s1x4x380x256,...],
301+
value_cache=#34[T1s1x4x380x256,...]),
302+
token_type_ids:T7s1x281,
303+
cache_position:T7s281,
304+
logits_to_keep:1)
305+
dict(input_ids:T7s1x1,
306+
pixel_values:None,
307+
attention_mask:dict(full_attention:T9s1x1x1x380,sliding_attention:T9s1x1x1x380),
308+
position_ids:T7s1x1,
309+
past_key_values:HybridCache(
310+
key_cache=#34[T1s1x4x380x256,...],
311+
value_cache=#34[T1s1x4x380x256,...]),
312+
token_type_ids:T7s1x1,
313+
cache_position:T7s1,
314+
logits_to_keep:1)
315+
"""
316+
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
292317
import torch
293-
from transformers import Gemma3ForConditionalGeneration, AutoProcessor
294318

295-
mid = "tiny-random/gemma-3"
296-
processor = AutoProcessor.from_pretrained(mid)
319+
# model_id = "tiny-random/gemma-3"
320+
model_id = "google/gemma-3-4b-it"
321+
297322
model = Gemma3ForConditionalGeneration.from_pretrained(
298-
mid, torch_dtype=torch.bfloat16, device_map="auto"
299-
)
323+
model_id, device_map="auto"
324+
).eval()
325+
326+
processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
300327

301328
messages = [
302329
{
@@ -314,19 +341,26 @@ def test_imagetext2text_generation_gemma3(self):
314341
],
315342
},
316343
]
344+
317345
inputs = processor.apply_chat_template(
318346
messages,
319347
add_generation_prompt=True,
320348
tokenize=True,
321349
return_dict=True,
322350
return_tensors="pt",
323351
).to(model.device, dtype=torch.bfloat16)
324-
print()
325-
with steal_forward(model):
326-
generated_ids = model.generate(**inputs, max_new_tokens=10)
327-
decoded = processor.decode(generated_ids, skip_special_tokens=True)
328352

329-
print(decoded[0])
353+
input_len = inputs["input_ids"].shape[-1]
354+
355+
print()
356+
print(f"-- input_len={input_len}")
357+
# steal forward creates a bug...
358+
# with steal_forward(model), torch.inference_mode():
359+
with torch.inference_mode():
360+
generation = model.generate(**inputs, max_new_tokens=100, do_sample=False)
361+
generation = generation[0][input_len:]
362+
decoded = processor.decode(generation, skip_special_tokens=True)
363+
print(decoded)
330364

331365
@never_test()
332366
def test_automatic_speech_recognition(self):

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,3 +318,65 @@ def __init__(self):
318318
)
319319
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
320320
return cache
321+
322+
323+
def make_hybrid_cache(
324+
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
325+
max_cache_len: Optional[int] = None,
326+
max_batch_size: Optional[int] = None,
327+
) -> transformers.cache_utils.HybridCache:
328+
"""
329+
Creates an instance of :class:`transformers.cache_utils.HybridCache`.
330+
This version is valid for ``transformers < 4.50``.
331+
332+
:param key_value_pairs: list of pairs of (key, values)
333+
:return: :class:`transformers.cache_utils.HybridCache`
334+
335+
Example:
336+
337+
.. runpython::
338+
:showcode:
339+
340+
import torch
341+
from onnx_diagnostic.helpers import string_type
342+
from onnx_diagnostic.helpers.cache_helper import make_hybrid_cache
343+
344+
n_layers = 2
345+
bsize, nheads, slen, dim = 2, 4, 3, 7
346+
347+
past_key_values = make_hybrid_cache(
348+
[
349+
(
350+
torch.randn(bsize, nheads, slen, dim),
351+
torch.randn(bsize, nheads, slen, dim),
352+
)
353+
for i in range(n_layers)
354+
]
355+
)
356+
print(string_type(past_key_values, with_shape=True))
357+
"""
358+
if key_value_pairs:
359+
assert (
360+
not max_batch_size and not max_cache_len
361+
), "key_value_pairs is not empty, do not specify max_cache_len and max_batch_size"
362+
max_batch_size = key_value_pairs[0][0].shape[0]
363+
max_cache_len = key_value_pairs[0][0].shape[2]
364+
else:
365+
assert (
366+
max_batch_size and max_cache_len
367+
), "key_value_pairs is empty, max_batch_size and max_cache_len are required"
368+
_ = max_cache_len
369+
370+
class _config:
371+
max_cache_len = _
372+
batch_size = max_batch_size
373+
num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
374+
head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
375+
num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
376+
377+
cache = transformers.cache_utils.HybridCache(
378+
_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
379+
)
380+
for i, (key, value) in enumerate(key_value_pairs):
381+
cache.update(key, value, i)
382+
return cache

onnx_diagnostic/helpers/helper.py

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -565,15 +565,15 @@ def string_type(
565565
"HybridCache",
566566
}:
567567
kc = string_type(
568-
obj.key_cache,
568+
list(obj.key_cache),
569569
with_shape=with_shape,
570570
with_min_max=with_min_max,
571571
with_device=with_device,
572572
limit=limit,
573573
verbose=verbose,
574574
)
575575
vc = string_type(
576-
obj.value_cache,
576+
list(obj.value_cache),
577577
with_shape=with_shape,
578578
with_min_max=with_min_max,
579579
with_device=with_device,
@@ -584,6 +584,27 @@ def string_type(
584584
print(f"[string_type] CACHE2:{type(obj)}")
585585
return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})"
586586

587+
if obj.__class__.__name__ == "StaticLayer":
588+
kc = string_type(
589+
list(obj.keys),
590+
with_shape=with_shape,
591+
with_min_max=with_min_max,
592+
with_device=with_device,
593+
limit=limit,
594+
verbose=verbose,
595+
)
596+
vc = string_type(
597+
list(obj.values),
598+
with_shape=with_shape,
599+
with_min_max=with_min_max,
600+
with_device=with_device,
601+
limit=limit,
602+
verbose=verbose,
603+
)
604+
if verbose:
605+
print(f"[string_type] SL:{type(obj)}")
606+
return f"{obj.__class__.__name__}(keys={kc}, values={vc})"
607+
587608
if obj.__class__.__name__ == "EncoderDecoderCache":
588609
att = string_type(
589610
obj.self_attention_cache,
@@ -668,6 +689,24 @@ def string_type(
668689
f"dtype={obj.dtype}, shape={obj.shape})"
669690
)
670691

692+
if obj.__class__.__name__ == "KeyValuesWrapper":
693+
import transformers
694+
695+
assert isinstance(
696+
obj, transformers.cache_utils.KeyValuesWrapper
697+
), f"Unexpected type {type(obj)}"
698+
if verbose:
699+
print(f"[string_type] KW0:{type(obj)}")
700+
s = string_type(
701+
list(obj),
702+
with_shape=with_shape,
703+
with_min_max=with_min_max,
704+
with_device=with_device,
705+
limit=limit,
706+
verbose=verbose,
707+
)
708+
return f"{obj.__class__.__name__}[{obj.cache_type}]{s}"
709+
671710
if isinstance(obj, torch.nn.Module):
672711
if verbose:
673712
print(f"[string_type] MM:{type(obj)}")
@@ -1429,6 +1468,31 @@ def max_diff(
14291468
f"level={level}"
14301469
)
14311470

1471+
# backup function in case pytorch does not know how to serialize.
1472+
if expected.__class__.__name__ == "HybridCache":
1473+
if got.__class__.__name__ == "HybridCache":
1474+
if verbose >= 6:
1475+
print(f"[max_diff] HybridCache: {string_type(expected)} ? {string_type(got)}")
1476+
return max_diff(
1477+
[expected.key_cache, expected.value_cache],
1478+
[got.key_cache, got.value_cache],
1479+
verbose=verbose,
1480+
hist=hist,
1481+
)
1482+
if isinstance(got, tuple) and len(got) == 2:
1483+
return max_diff(
1484+
[expected.key_cache, expected.value_cache],
1485+
[got[0], got[1]],
1486+
debug_info=_debug(expected.__class__.__name__),
1487+
**_dkws,
1488+
)
1489+
raise AssertionError(
1490+
f"HybridCache not fully implemented with classes "
1491+
f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
1492+
f"and expected={string_type(expected)}, got={string_type(got)},\n"
1493+
f"level={level}"
1494+
)
1495+
14321496
if expected.__class__.__name__ == "StaticCache":
14331497
if got.__class__.__name__ == "StaticCache":
14341498
if verbose >= 6:
@@ -1526,6 +1590,20 @@ def max_diff(
15261590
**_dkws,
15271591
)
15281592

1593+
if expected.__class__.__name__ == "KeyValuesWrapper":
1594+
if verbose >= 6:
1595+
print(f"[max_diff] KeyValuesWrapper: {string_type(expected)} ? {string_type(got)}")
1596+
if got.__class__.__name__ != expected.__class__.__name__:
1597+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1598+
if got.cache_type != expected.cache_type:
1599+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1600+
return max_diff(
1601+
list(expected),
1602+
list(got),
1603+
debug_info=_debug(expected.__class__.__name__),
1604+
**_dkws,
1605+
)
1606+
15291607
raise AssertionError(
15301608
f"Not implemented with implemented with expected="
15311609
f"{string_type(expected)}, got={string_type(got)},\n"

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .cache_helper import (
1515
make_dynamic_cache,
1616
make_encoder_decoder_cache,
17+
make_hybrid_cache,
1718
make_sliding_window_cache,
1819
make_mamba_cache,
1920
make_static_cache,
@@ -789,6 +790,8 @@ def torch_deepcopy(value: Any) -> Any:
789790
torch_deepcopy(list(zip(value.key_cache, value.value_cache))),
790791
max_cache_len=value.max_cache_len,
791792
)
793+
if value.__class__.__name__ == "HybridCache":
794+
return make_hybrid_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache))))
792795
if value.__class__.__name__ == "SlidingWindowCache":
793796
return make_sliding_window_cache(
794797
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))

0 commit comments

Comments
 (0)