1212
1313
1414def _register_cache_serialization (verbose : int = 0 ) -> Dict [str , bool ]:
15- # MambaCache
16- unregistered_mamba_cache = True
17- if MambaCache in torch .utils ._pytree .SUPPORTED_NODES :
18- if verbose > 1 :
19- print (f"[_register_cache_serialization] { MambaCache } already registered" )
20- # It is already registered because bypass_export_some_errors was called
21- # within a section already calling bypass_export_some_errors or transformers
22- # has updated its code to do it.
23- # No need to register and unregister then.
24- unregistered_mamba_cache = False
25- else :
26- if verbose :
27- print ("[_register_cache_serialization] register MambaCache" )
28- torch .utils ._pytree .register_pytree_node (
29- MambaCache ,
30- flatten_mamba_cache ,
31- unflatten_mamba_cache ,
32- serialized_type_name = f"{ MambaCache .__module__ } .{ MambaCache .__name__ } " ,
33- flatten_with_keys_fn = flatten_with_keys_mamba_cache ,
34- )
35-
3615 # DynamicCache serialization is different in transformers and does not
3716 # play way with torch.export.export.
3817 # see test test_export_dynamic_cache_cat with NOBYPASS=1
@@ -42,8 +21,8 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
4221 # DynamicCache, _flatten_dynamic_cache_for_fx)
4322 # so we remove it anyway
4423 if (
45- DynamicCache in torch .fx ._pytree .SUPPORTED_NODES
46- and not PATCH_OF_PATCHES
24+ DynamicCache in torch .utils ._pytree .SUPPORTED_NODES
25+ and DynamicCache not in PATCH_OF_PATCHES
4726 # and pv.Version(torch.__version__) < pv.Version("2.7")
4827 and pv .Version (transformers .__version__ ) >= pv .Version ("4.50" )
4928 ):
@@ -52,14 +31,19 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
5231 "[_register_cache_serialization] DynamicCache "
5332 "is unregistered and registered first."
5433 )
55- _unregister (DynamicCache )
34+ _unregister (DynamicCache , verbose = verbose )
5635 torch .utils ._pytree .register_pytree_node (
5736 DynamicCache ,
5837 flatten_dynamic_cache ,
5938 unflatten_dynamic_cache ,
6039 serialized_type_name = f"{ DynamicCache .__module__ } .{ DynamicCache .__name__ } " ,
6140 flatten_with_keys_fn = flatten_with_keys_dynamic_cache ,
6241 )
42+ if verbose :
43+ print (
44+ "[_register_cache_serialization] DynamicCache "
45+ "unregistered and registered done."
46+ )
6347 if pv .Version (torch .__version__ ) < pv .Version ("2.7" ):
6448 torch .fx ._pytree .register_pytree_flatten_spec (
6549 DynamicCache , lambda x , _ : [x .key_cache , x .value_cache ]
@@ -69,20 +53,28 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
6953
7054 # BaseModelOutput serialization is incomplete.
7155 # It does not include dynamic shapes mapping.
72- if BaseModelOutput in torch .fx ._pytree .SUPPORTED_NODES and not PATCH_OF_PATCHES :
56+ if (
57+ BaseModelOutput in torch .utils ._pytree .SUPPORTED_NODES
58+ and BaseModelOutput not in PATCH_OF_PATCHES
59+ ):
7360 if verbose :
7461 print (
7562 "[_register_cache_serialization] BaseModelOutput "
7663 "is unregistered and registered first."
7764 )
78- _unregister (BaseModelOutput )
65+ _unregister (BaseModelOutput , verbose = verbose )
7966 torch .utils ._pytree .register_pytree_node (
8067 BaseModelOutput ,
8168 flatten_base_model_output ,
8269 unflatten_base_model_output ,
8370 serialized_type_name = f"{ BaseModelOutput .__module__ } .{ BaseModelOutput .__name__ } " ,
8471 flatten_with_keys_fn = flatten_with_keys_base_model_output ,
8572 )
73+ if verbose :
74+ print (
75+ "[_register_cache_serialization] BaseModelOutput "
76+ "unregistered and registered done."
77+ )
8678
8779 # To avoid doing it multiple times.
8880 PATCH_OF_PATCHES .add (BaseModelOutput )
@@ -116,49 +108,70 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
116108 # torch.fx._pytree.tree_flatten(cache)
117109 assert len (cache2 .key_cache ) == 1
118110
119- # EncoderDecoderCache
120- unregistered_encode_decode_cache = True
121- if (
122- EncoderDecoderCache is not None
123- and EncoderDecoderCache in torch .utils ._pytree .SUPPORTED_NODES
124- ):
111+ # BaseModelOutput
112+ unregistered_base_model_output = True
113+ if BaseModelOutput is not None and BaseModelOutput in torch .utils ._pytree .SUPPORTED_NODES :
125114 if verbose > 1 :
126- print (f"[_register_cache_serialization] { EncoderDecoderCache } already registered" )
115+ print (f"[_register_cache_serialization] { BaseModelOutput } already registered" )
127116 # It is already registered because bypass_export_some_errors was called
128117 # within a section already calling bypass_export_some_errors or transformers
129118 # has updated its code to do it.
130119 # No need to register and unregister then.
131- unregistered_encode_decode_cache = False
120+ unregistered_base_model_output = False
132121 else :
133122 if verbose :
134- print ("[_register_cache_serialization] register EncoderDecoderCache " )
123+ print ("[_register_cache_serialization] register BaseModelOutput " )
135124 torch .utils ._pytree .register_pytree_node (
136- EncoderDecoderCache ,
125+ BaseModelOutput ,
137126 flatten_encoder_decoder_cache ,
138127 unflatten_encoder_decoder_cache ,
139- serialized_type_name = f"{ EncoderDecoderCache .__module__ } .{ EncoderDecoderCache .__name__ } " ,
140- flatten_with_keys_fn = flatten_with_keys_encoder_decoder_cache ,
128+ serialized_type_name = f"{ BaseModelOutput .__module__ } .{ BaseModelOutput .__name__ } " ,
129+ flatten_with_keys_fn = flatten_with_keys_base_model_output ,
141130 )
142131
143- # BaseModelOutput
144- unregistered_base_model_output = True
145- if BaseModelOutput is not None and BaseModelOutput in torch .utils ._pytree .SUPPORTED_NODES :
132+ # MambaCache
133+ unregistered_mamba_cache = True
134+ if MambaCache in torch .utils ._pytree .SUPPORTED_NODES :
146135 if verbose > 1 :
147- print (f"[_register_cache_serialization] { BaseModelOutput } already registered" )
136+ print (f"[_register_cache_serialization] { MambaCache } already registered" )
148137 # It is already registered because bypass_export_some_errors was called
149138 # within a section already calling bypass_export_some_errors or transformers
150139 # has updated its code to do it.
151140 # No need to register and unregister then.
152- unregistered_base_model_output = False
141+ unregistered_mamba_cache = False
153142 else :
154143 if verbose :
155- print ("[_register_cache_serialization] register BaseModelOutput " )
144+ print ("[_register_cache_serialization] register MambaCache " )
156145 torch .utils ._pytree .register_pytree_node (
157- BaseModelOutput ,
146+ MambaCache ,
147+ flatten_mamba_cache ,
148+ unflatten_mamba_cache ,
149+ serialized_type_name = f"{ MambaCache .__module__ } .{ MambaCache .__name__ } " ,
150+ flatten_with_keys_fn = flatten_with_keys_mamba_cache ,
151+ )
152+
153+ # EncoderDecoderCache
154+ unregistered_encode_decode_cache = True
155+ if (
156+ EncoderDecoderCache is not None
157+ and EncoderDecoderCache in torch .utils ._pytree .SUPPORTED_NODES
158+ ):
159+ if verbose > 1 :
160+ print (f"[_register_cache_serialization] { EncoderDecoderCache } already registered" )
161+ # It is already registered because bypass_export_some_errors was called
162+ # within a section already calling bypass_export_some_errors or transformers
163+ # has updated its code to do it.
164+ # No need to register and unregister then.
165+ unregistered_encode_decode_cache = False
166+ else :
167+ if verbose :
168+ print ("[_register_cache_serialization] register EncoderDecoderCache" )
169+ torch .utils ._pytree .register_pytree_node (
170+ EncoderDecoderCache ,
158171 flatten_encoder_decoder_cache ,
159172 unflatten_encoder_decoder_cache ,
160- serialized_type_name = f"{ BaseModelOutput .__module__ } .{ BaseModelOutput .__name__ } " ,
161- flatten_with_keys_fn = flatten_with_keys_base_model_output ,
173+ serialized_type_name = f"{ EncoderDecoderCache .__module__ } .{ EncoderDecoderCache .__name__ } " ,
174+ flatten_with_keys_fn = flatten_with_keys_encoder_decoder_cache ,
162175 )
163176
164177 return dict (
@@ -170,14 +183,17 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
170183
171184
172185def _unregister (cls : type , verbose : int = 0 ):
173- # torch.fx ._pytree._deregister_pytree_flatten_spec(cls)
186+ # torch.utils ._pytree._deregister_pytree_flatten_spec(cls)
174187 if cls in torch .fx ._pytree .SUPPORTED_NODES :
175188 del torch .fx ._pytree .SUPPORTED_NODES [cls ]
176189 if cls in torch .fx ._pytree .SUPPORTED_NODES_EXACT_MATCH :
177190 del torch .fx ._pytree .SUPPORTED_NODES_EXACT_MATCH [cls ]
178191 if hasattr (torch .utils ._pytree , "_deregister_pytree_node" ):
179192 # torch >= 2.7
180193 torch .utils ._pytree ._deregister_pytree_node (cls )
194+ else :
195+ if cls in torch .utils ._pytree .SUPPORTED_NODES :
196+ del torch .utils ._pytree .SUPPORTED_NODES [cls ]
181197 optree .unregister_pytree_node (cls , namespace = "torch" )
182198 if cls in torch .utils ._pytree .SUPPORTED_NODES :
183199 import packaging .version as pv
@@ -391,7 +407,7 @@ def flatten_with_keys_base_model_output(
391407 Serializes a :class:`transformers.modeling_outputs.BaseModelOutput`
392408 with python objects.
393409 """
394- values , context = flatten_dynamic_cache (bo )
410+ values , context = flatten_base_model_output (bo )
395411 return [(torch .utils ._pytree .MappingKey (k ), v ) for k , v in zip (context , values )], context
396412
397413
0 commit comments