Skip to content

Commit 5de8afa

Browse files
committed
add get_text_config
1 parent 1fbd28a commit 5de8afa

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,9 @@ def __init__(self):
361361
self.num_hidden_layers = len(key_value_pairs)
362362
self.dtype = dtype
363363

364+
def get_text_config(self):
365+
return self
366+
364367
cache = MambaCache(
365368
_config(),
366369
max_batch_size=key_value_pairs[0][0].shape[0],
@@ -401,6 +404,9 @@ def __init__(self):
401404
self.num_hidden_layers = len(key_value_pairs)
402405
self.sliding_window = key_value_pairs[0][0].shape[2]
403406

407+
def get_text_config(self):
408+
return self
409+
404410
cache = transformers.cache_utils.SlidingWindowCache(
405411
config=_config(),
406412
max_batch_size=key_value_pairs[0][0].shape[0],
@@ -566,6 +572,9 @@ class _config:
566572
sliding_window = _sliding_window
567573
num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
568574

575+
def get_text_config(self):
576+
return self
577+
569578
if layer_types:
570579
_config.layer_types = layer_types # type: ignore[attr-defined]
571580

onnx_diagnostic/helpers/helper.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1479,8 +1479,12 @@ def max_diff(
14791479
# backup function in case pytorch does not know how to serialize.
14801480
if expected.__class__.__name__ == "DynamicCache":
14811481
if got.__class__.__name__ == "DynamicCache":
1482+
from .cache_helper import CacheKeyValue
1483+
14821484
if verbose >= 6:
14831485
print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}")
1486+
expected = CacheKeyValue(expected)
1487+
got = CacheKeyValue(got)
14841488
return max_diff(
14851489
[expected.key_cache, expected.value_cache],
14861490
[got.key_cache, got.value_cache],

0 commit comments

Comments
 (0)