Skip to content

Commit a66a03d

Browse files
committed
another quick fix
1 parent f2f050f commit a66a03d

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

onnx_diagnostic/helpers/helper.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,10 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
931931
return flatten_object(list(x.items()), drop_keys=drop_keys)
932932

933933
if x.__class__.__name__ in {"DynamicCache", "StaticCache"}:
934-
res = flatten_object(x.key_cache) + flatten_object(x.value_cache)
934+
from .cache_helper import CacheKeyValue
935+
936+
kc = CacheKeyValue(x)
937+
res = flatten_object(kc.key_cache) + flatten_object(kc.value_cache)
935938
return tuple(res)
936939
if x.__class__.__name__ == "EncoderDecoderCache":
937940
res = flatten_object(x.self_attention_cache) + flatten_object(x.cross_attention_cache)

0 commit comments

Comments
 (0)