@@ -65,9 +65,13 @@ def unpatch_module(mod, info: Dict[type, Dict[type, Callable]], verbose: int = 0
6565 setattr (original , n , v )
6666
6767
68+ PATCH_OF_PATCHES = set ()
69+
70+
6871def _register_cache_serialization (verbose : int = 0 ) -> Dict [str , bool ]:
6972 # Cache serialization: to be moved into appropriate packages
7073 import torch
74+ import transformers
7175 import packaging .version as pv
7276
7377 try :
@@ -109,11 +113,17 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
109113 # torch.fx._pytree.register_pytree_flatten_spec(
110114 # DynamicCache, _flatten_dynamic_cache_for_fx)
111115 # so we remove it anyway
112- if DynamicCache in torch .fx ._pytree .SUPPORTED_NODES and pv .Version (
113- torch .__version__
114- ) >= pv .Version ("2.7" ):
116+ if (
117+ DynamicCache in torch .fx ._pytree .SUPPORTED_NODES
118+ and not PATCH_OF_PATCHES
119+ and pv .Version (torch .__version__ ) >= pv .Version ("2.7" )
120+ and pv .Version (transformers .__version__ ) >= pv .Version ("4.50" )
121+ ):
115122 if verbose :
116- print ("[_register_cache_serialization] DynamicCache is unregistered first." )
123+ print (
124+ "[_register_cache_serialization] DynamicCache "
125+ "is unregistered and registered first."
126+ )
117127 _unregister (DynamicCache )
118128 torch .utils ._pytree .register_pytree_node (
119129 DynamicCache ,
@@ -122,6 +132,8 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
122132 serialized_type_name = f"{ DynamicCache .__module__ } .{ DynamicCache .__name__ } " ,
123133 flatten_with_keys_fn = flatten_with_keys_dynamic_cache ,
124134 )
135+ # To avoid doing it multiple times.
136+ PATCH_OF_PATCHES .add (DynamicCache )
125137
126138 unregistered_dynamic_cache = True
127139 if DynamicCache is not None and DynamicCache in torch .utils ._pytree .SUPPORTED_NODES :
@@ -138,9 +150,10 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
138150 serialized_type_name = f"{ DynamicCache .__module__ } .{ DynamicCache .__name__ } " ,
139151 flatten_with_keys_fn = flatten_with_keys_dynamic_cache ,
140152 )
141- torch .fx ._pytree .register_pytree_flatten_spec (
142- DynamicCache , lambda x , _ : [x .key_cache , x .value_cache ]
143- )
153+ if pv .Version (torch .__version__ ) < pv .Version ("2.7" ):
154+ torch .fx ._pytree .register_pytree_flatten_spec (
155+ DynamicCache , lambda x , _ : [x .key_cache , x .value_cache ]
156+ )
144157
145158 # check
146159 from ..helpers .cache_helper import make_dynamic_cache
0 commit comments