File tree Expand file tree Collapse file tree 2 files changed +13
-0
lines changed
Expand file tree Collapse file tree 2 files changed +13
-0
lines changed Original file line number Diff line number Diff line change @@ -361,6 +361,9 @@ def __init__(self):
361361 self .num_hidden_layers = len (key_value_pairs )
362362 self .dtype = dtype
363363
364+ def get_text_config (self ):
365+ return self
366+
364367 cache = MambaCache (
365368 _config (),
366369 max_batch_size = key_value_pairs [0 ][0 ].shape [0 ],
@@ -401,6 +404,9 @@ def __init__(self):
401404 self .num_hidden_layers = len (key_value_pairs )
402405 self .sliding_window = key_value_pairs [0 ][0 ].shape [2 ]
403406
407+ def get_text_config (self ):
408+ return self
409+
404410 cache = transformers .cache_utils .SlidingWindowCache (
405411 config = _config (),
406412 max_batch_size = key_value_pairs [0 ][0 ].shape [0 ],
@@ -566,6 +572,9 @@ class _config:
566572 sliding_window = _sliding_window
567573 num_key_value_heads = key_value_pairs [0 ][1 ].shape [1 ] # transformers 4.48.3
568574
575+ def get_text_config (self ):
576+ return self
577+
569578 if layer_types :
570579 _config .layer_types = layer_types # type: ignore[attr-defined]
571580
Original file line number Diff line number Diff line change @@ -1479,8 +1479,12 @@ def max_diff(
14791479 # backup function in case pytorch does not know how to serialize.
14801480 if expected .__class__ .__name__ == "DynamicCache" :
14811481 if got .__class__ .__name__ == "DynamicCache" :
1482+ from .cache_helper import CacheKeyValue
1483+
14821484 if verbose >= 6 :
14831485 print (f"[max_diff] DynamicCache: { string_type (expected )} ? { string_type (got )} " )
1486+ expected = CacheKeyValue (expected )
1487+ got = CacheKeyValue (got )
14841488 return max_diff (
14851489 [expected .key_cache , expected .value_cache ],
14861490 [got .key_cache , got .value_cache ],
You can’t perform that action at this time.
0 commit comments