Skip to content

Commit cc67cf7

Browse files
committed
update
1 parent ca9dcc6 commit cc67cf7

File tree

9 files changed

+223
-253
lines changed

9 files changed

+223
-253
lines changed

_unittests/ut_tasks/test_tasks_image_text_to_text.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_image_text_to_text_idefics(self):
2020
mid = "HuggingFaceM4/tiny-random-idefics"
2121
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
2222
self.assertEqual(data["task"], "image-text-to-text")
23-
self.assertIn((data["size"], data["n_weights"]), [(12742888, 3185722)])
23+
self.assertIn((data["size"], data["n_weights"]), [(12628776, 3157194)])
2424
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
2525
model(**torch_deepcopy(inputs))
2626
model(**data["inputs2"])
@@ -33,10 +33,11 @@ def test_image_text_to_text_idefics(self):
3333
@requires_transformers("4.53")
3434
@requires_torch("2.7.99")
3535
def test_image_text_to_text_gemma3(self):
36-
mid = "google/gemma-3-4b-it"
36+
# mid = "google/gemma-3-4b-it"
37+
mid = "tiny-random/gemma-3"
3738
data = get_untrained_model_with_inputs(mid, verbose=1, add_second_input=True)
3839
self.assertEqual(data["task"], "image-text-to-text")
39-
self.assertIn((data["size"], data["n_weights"]), [(34401152, 8600288)])
40+
# self.assertIn((data["size"], data["n_weights"]), [(17248576, 4312144)])
4041
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
4142
print("--", self.string_type(data["inputs"], with_shape=True))
4243
model(**torch_deepcopy(inputs))

_unittests/ut_torch_models/test_tiny_llms_bypassed.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,13 @@
11
import copy
22
import unittest
33
import torch
4-
from transformers.cache_utils import DynamicCache
54
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings, hide_stdout
65
from onnx_diagnostic.torch_models.llms import get_tiny_llm
76
from onnx_diagnostic.torch_models.llms import get_phi2
87
from onnx_diagnostic.helpers import string_type
98
from onnx_diagnostic.helpers.torch_helper import torch_deepcopy
109
from onnx_diagnostic.torch_export_patches import torch_export_patches
1110
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
12-
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
13-
patched_DynamicCache,
14-
)
1511

1612

1713
class TestTinyLlmBypassed(ExtTestCase):
@@ -29,9 +25,6 @@ def test_export_tiny_llm_2_bypassed(self):
2925
patch_torch=False, patch_transformers=True, catch_constraints=False, verbose=10
3026
) as modificator:
3127

32-
for k in patched_DynamicCache._PATCHES_:
33-
self.assertEqual(getattr(patched_DynamicCache, k), getattr(DynamicCache, k))
34-
3528
inputs = modificator(copy.deepcopy(inputs))
3629

3730
def debug():

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,18 @@
1010
from transformers.cache_utils import MambaCache
1111

1212

13+
class CacheKeyValue:
14+
def __init__(self, cache: "Cache"): # noqa: F821
15+
if hasattr(cache, "layers"):
16+
self.key_cache = [layer.keys for layer in cache.layers if layer.keys is not None]
17+
self.value_cache = [
18+
layer.values for layer in cache.layers if layer.values is not None
19+
]
20+
else:
21+
self.key_cache = cache.key_cache
22+
self.value_cache = cache.value_cache
23+
24+
1325
def flatten_unflatten_for_dynamic_shapes(
1426
obj: Any,
1527
use_dict: bool = False,
@@ -221,19 +233,20 @@ def __init__(self):
221233
),
222234
)
223235
cache = transformers.cache_utils.StaticCache(
224-
_config(),
236+
config=_config(),
225237
max_batch_size=key_value_pairs[0][0].shape[0],
226238
device=key_value_pairs[0][0].device,
227239
dtype=key_value_pairs[0][0].dtype,
228240
max_cache_len=max_cache_len,
229241
)
242+
ca = CacheKeyValue(cache)
230243
for i in range(len(key_value_pairs)):
231244
assert (
232245
key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
233246
), f"Shape mismatch {key_value_pairs[i][0].shape} != {key_value_pairs[i][1].shape}"
234247
d = key_value_pairs[i][1].shape[2]
235-
cache.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
236-
cache.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
248+
ca.key_cache[i][:, :, :d, :] = key_value_pairs[i][0]
249+
ca.value_cache[i][:, :, :d, :] = key_value_pairs[i][1]
237250
return cache
238251

239252

@@ -300,23 +313,24 @@ def __init__(self):
300313
self.sliding_window = key_value_pairs[0][0].shape[2]
301314

302315
cache = transformers.cache_utils.SlidingWindowCache(
303-
_config(),
316+
config=_config(),
304317
max_batch_size=key_value_pairs[0][0].shape[0],
305318
max_cache_len=key_value_pairs[0][0].shape[2], # same as sliding_window
306319
device=key_value_pairs[0][0].device,
307320
dtype=key_value_pairs[0][0].dtype,
308321
)
322+
ca = CacheKeyValue(cache)
309323
for i in range(len(key_value_pairs)):
310-
assert cache.key_cache[i].shape == key_value_pairs[i][0].shape, (
324+
assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
311325
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
312326
f"got {key_value_pairs[i][0].shape}"
313327
)
314-
cache.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
315-
assert cache.value_cache[i].shape == key_value_pairs[i][1].shape, (
328+
ca.key_cache[i][:, :, :, :] = key_value_pairs[i][0]
329+
assert ca.value_cache[i].shape == key_value_pairs[i][1].shape, (
316330
f"Shape mismatch, expected {cache.value_cache[i].shape}, "
317331
f"got {key_value_pairs[i][1].shape}"
318332
)
319-
cache.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
333+
ca.value_cache[i][:, :, :, :] = key_value_pairs[i][1]
320334
return cache
321335

322336

@@ -373,9 +387,10 @@ class _config:
373387
num_heads = key_value_pairs[0][0].shape[1] if key_value_pairs else None
374388
head_dim = key_value_pairs[0][0].shape[-1] if key_value_pairs else None
375389
num_attention_heads = key_value_pairs[0][1].shape[1] if key_value_pairs else None
390+
num_hidden_layers = len(key_value_pairs)
376391

377392
cache = transformers.cache_utils.HybridCache(
378-
_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
393+
config=_config(), max_cache_len=max_cache_len, max_batch_size=max_batch_size
379394
)
380395
for i, (key, value) in enumerate(key_value_pairs):
381396
cache.update(key, value, i)

onnx_diagnostic/helpers/helper.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -564,16 +564,19 @@ def string_type(
564564
"StaticCache",
565565
"HybridCache",
566566
}:
567+
from .cache_helper import CacheKeyValue
568+
569+
ca = CacheKeyValue(obj)
567570
kc = string_type(
568-
list(obj.key_cache),
571+
ca.key_cache,
569572
with_shape=with_shape,
570573
with_min_max=with_min_max,
571574
with_device=with_device,
572575
limit=limit,
573576
verbose=verbose,
574577
)
575578
vc = string_type(
576-
list(obj.value_cache),
579+
ca.value_cache,
577580
with_shape=with_shape,
578581
with_min_max=with_min_max,
579582
with_device=with_device,
@@ -1471,17 +1474,24 @@ def max_diff(
14711474
# backup function in case pytorch does not know how to serialize.
14721475
if expected.__class__.__name__ == "HybridCache":
14731476
if got.__class__.__name__ == "HybridCache":
1477+
from .cache_helper import CacheKeyValue
1478+
14741479
if verbose >= 6:
14751480
print(f"[max_diff] HybridCache: {string_type(expected)} ? {string_type(got)}")
1481+
cae = CacheKeyValue(expected)
1482+
cag = CacheKeyValue(got)
14761483
return max_diff(
1477-
[expected.key_cache, expected.value_cache],
1478-
[got.key_cache, got.value_cache],
1484+
[cae.key_cache, cae.value_cache],
1485+
[cag.key_cache, cag.value_cache],
14791486
verbose=verbose,
14801487
hist=hist,
14811488
)
14821489
if isinstance(got, tuple) and len(got) == 2:
1490+
from .cache_helper import CacheKeyValue
1491+
1492+
cae = CacheKeyValue(expected)
14831493
return max_diff(
1484-
[expected.key_cache, expected.value_cache],
1494+
[cae.key_cache, cae.value_cache],
14851495
[got[0], got[1]],
14861496
debug_info=_debug(expected.__class__.__name__),
14871497
**_dkws,
@@ -1495,17 +1505,24 @@ def max_diff(
14951505

14961506
if expected.__class__.__name__ == "StaticCache":
14971507
if got.__class__.__name__ == "StaticCache":
1508+
from .cache_helper import CacheKeyValue
1509+
1510+
cae = CacheKeyValue(expected)
1511+
cag = CacheKeyValue(got)
14981512
if verbose >= 6:
14991513
print(f"[max_diff] StaticCache: {string_type(expected)} ? {string_type(got)}")
15001514
return max_diff(
1501-
[expected.key_cache, expected.value_cache],
1502-
[got.key_cache, got.value_cache],
1515+
[cae.key_cache, cae.value_cache],
1516+
[cag.key_cache, cag.value_cache],
15031517
verbose=verbose,
15041518
hist=hist,
15051519
)
15061520
if isinstance(got, tuple) and len(got) == 2:
1521+
from .cache_helper import CacheKeyValue
1522+
1523+
cae = CacheKeyValue(expected)
15071524
return max_diff(
1508-
[expected.key_cache, expected.value_cache],
1525+
[cae.key_cache, cae.value_cache],
15091526
[got[0], got[1]],
15101527
debug_info=_debug(expected.__class__.__name__),
15111528
**_dkws,
@@ -1524,15 +1541,22 @@ def max_diff(
15241541
f"[max_diff] SlidingWindowCache: "
15251542
f"{string_type(expected)} ? {string_type(got)}"
15261543
)
1544+
from .cache_helper import CacheKeyValue
1545+
1546+
cae = CacheKeyValue(expected)
1547+
cag = CacheKeyValue(got)
15271548
return max_diff(
1528-
[expected.key_cache, expected.value_cache],
1529-
[got.key_cache, got.value_cache],
1549+
[cae.key_cache, cae.value_cache],
1550+
[cag.key_cache, cag.value_cache],
15301551
verbose=verbose,
15311552
hist=hist,
15321553
)
15331554
if isinstance(got, tuple) and len(got) == 2:
1555+
from .cache_helper import CacheKeyValue
1556+
1557+
cae = CacheKeyValue(expected)
15341558
return max_diff(
1535-
[expected.key_cache, expected.value_cache],
1559+
[cae.key_cache, cae.value_cache],
15361560
[got[0], got[1]],
15371561
debug_info=_debug(expected.__class__.__name__),
15381562
**_dkws,

onnx_diagnostic/helpers/torch_helper.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -782,19 +782,30 @@ def torch_deepcopy(value: Any) -> Any:
782782
if hasattr(value, "clone"):
783783
return value.clone()
784784
if value.__class__.__name__ == "DynamicCache":
785-
return make_dynamic_cache(
786-
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
787-
)
785+
from .cache_helper import CacheKeyValue
786+
787+
ca = CacheKeyValue(value)
788+
return make_dynamic_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))))
788789
if value.__class__.__name__ == "StaticCache":
790+
from .cache_helper import CacheKeyValue
791+
792+
ca = CacheKeyValue(value)
789793
return make_static_cache(
790-
torch_deepcopy(list(zip(value.key_cache, value.value_cache))),
794+
torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))),
791795
max_cache_len=value.max_cache_len,
792796
)
793797
if value.__class__.__name__ == "HybridCache":
794-
return make_hybrid_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache))))
798+
from .cache_helper import CacheKeyValue
799+
800+
ca = CacheKeyValue(value)
801+
return make_hybrid_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))))
795802
if value.__class__.__name__ == "SlidingWindowCache":
803+
from .cache_helper import CacheKeyValue
804+
805+
ca = CacheKeyValue(value)
806+
return make_hybrid_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))))
796807
return make_sliding_window_cache(
797-
torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
808+
torch_deepcopy(list(zip(ca.key_cache, ca.value_cache)))
798809
)
799810
if value.__class__.__name__ == "EncoderDecoderCache":
800811
return make_encoder_decoder_cache(

0 commit comments

Comments
 (0)