11import contextlib
22import pprint
3- from typing import Any , Callable , Dict
3+ from typing import Any , Callable , Dict , Set
44from .onnx_export_serialization import (
55 flatten_with_keys_dynamic_cache ,
66 flatten_dynamic_cache ,
@@ -65,7 +65,7 @@ def unpatch_module(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0
6565 setattr (original , n , v )
6666
6767
68- PATCH_OF_PATCHES = set ()
68+ PATCH_OF_PATCHES : Set [ Any ] = set ()
6969
7070
7171def _register_cache_serialization (verbose : int = 0 ) -> Dict [str , bool ]:
@@ -116,7 +116,7 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
116116 if (
117117 DynamicCache in torch .fx ._pytree .SUPPORTED_NODES
118118 and not PATCH_OF_PATCHES
119- and pv .Version (torch .__version__ ) >= pv .Version ("2.7" )
119+ # and pv.Version(torch.__version__) < pv.Version("2.7")
120120 and pv .Version (transformers .__version__ ) >= pv .Version ("4.50" )
121121 ):
122122 if verbose :
@@ -132,6 +132,10 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
132132 serialized_type_name = f"{ DynamicCache .__module__ } .{ DynamicCache .__name__ } " ,
133133 flatten_with_keys_fn = flatten_with_keys_dynamic_cache ,
134134 )
135+ if pv .Version (torch .__version__ ) < pv .Version ("2.7" ):
136+ torch .fx ._pytree .register_pytree_flatten_spec (
137+ DynamicCache , lambda x , _ : [x .key_cache , x .value_cache ]
138+ )
135139 # To avoid doing it multiple times.
136140 PATCH_OF_PATCHES .add (DynamicCache )
137141
0 commit comments