11import pprint
22from typing import Any , Dict , List , Set , Tuple
3+ import packaging .version as pv
34import optree
45import torch
56import transformers
6- import packaging .version as pv
7+ from transformers .cache_utils import DynamicCache , MambaCache , EncoderDecoderCache
8+ from transformers .modeling_outputs import BaseModelOutput
79
810
911PATCH_OF_PATCHES : Set [Any ] = set ()
1012
1113
1214def _register_cache_serialization (verbose : int = 0 ) -> Dict [str , bool ]:
13- # Cache serialization: to be moved into appropriate packages
14-
15- try :
16- from transformers .cache_utils import DynamicCache
17- except ImportError :
18- DynamicCache = None
19-
20- try :
21- from transformers .cache_utils import MambaCache
22- except ImportError :
23- MambaCache = None
24-
25- try :
26- from transformers .cache_utils import EncoderDecoderCache
27- except ImportError :
28- EncoderDecoderCache = None
29-
3015 # MambaCache
3116 unregistered_mamba_cache = True
32- if MambaCache is not None and MambaCache in torch .utils ._pytree .SUPPORTED_NODES :
17+ if MambaCache in torch .utils ._pytree .SUPPORTED_NODES :
3318 if verbose > 1 :
3419 print (f"[_register_cache_serialization] { MambaCache } already registered" )
3520 # It is already registered because bypass_export_some_errors was called
@@ -82,6 +67,26 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
8267 # To avoid doing it multiple times.
8368 PATCH_OF_PATCHES .add (DynamicCache )
8469
70+ # BaseModelOutput serialization is incomplete.
71+ # It does not include dynamic shapes mapping.
72+ if BaseModelOutput in torch .fx ._pytree .SUPPORTED_NODES and not PATCH_OF_PATCHES :
73+ if verbose :
74+ print (
75+ "[_register_cache_serialization] BaseModelOutput "
76+ "is unregistered and registered first."
77+ )
78+ _unregister (BaseModelOutput )
79+ torch .utils ._pytree .register_pytree_node (
80+ BaseModelOutput ,
81+ flatten_base_model_output ,
82+ unflatten_base_model_output ,
83+ serialized_type_name = f"{ BaseModelOutput .__module__ } .{ BaseModelOutput .__name__ } " ,
84+ flatten_with_keys_fn = flatten_with_keys_base_model_output ,
85+ )
86+
87+ # To avoid doing it multiple times.
88+ PATCH_OF_PATCHES .add (BaseModelOutput )
89+
8590 unregistered_dynamic_cache = True
8691 if DynamicCache is not None and DynamicCache in torch .utils ._pytree .SUPPORTED_NODES :
8792 if verbose > 1 :
@@ -123,7 +128,7 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
123128 # within a section already calling bypass_export_some_errors or transformers
124129 # has updated its code to do it.
125130 # No need to register and unregister then.
126- unregistered_mamba_cache = False
131+ unregistered_encode_decode_cache = False
127132 else :
128133 if verbose :
129134 print ("[_register_cache_serialization] register EncoderDecoderCache" )
@@ -135,10 +140,32 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
135140 flatten_with_keys_fn = flatten_with_keys_encoder_decoder_cache ,
136141 )
137142
143+ # BaseModelOutput
144+ unregistered_base_model_output = True
145+ if BaseModelOutput is not None and BaseModelOutput in torch .utils ._pytree .SUPPORTED_NODES :
146+ if verbose > 1 :
147+ print (f"[_register_cache_serialization] { BaseModelOutput } already registered" )
148+ # It is already registered because bypass_export_some_errors was called
149+ # within a section already calling bypass_export_some_errors or transformers
150+ # has updated its code to do it.
151+ # No need to register and unregister then.
152+ unregistered_base_model_output = False
153+ else :
154+ if verbose :
155+ print ("[_register_cache_serialization] register BaseModelOutput" )
156+ torch .utils ._pytree .register_pytree_node (
157+ BaseModelOutput ,
158+ flatten_encoder_decoder_cache ,
159+ unflatten_encoder_decoder_cache ,
160+ serialized_type_name = f"{ BaseModelOutput .__module__ } .{ BaseModelOutput .__name__ } " ,
161+ flatten_with_keys_fn = flatten_with_keys_base_model_output ,
162+ )
163+
138164 return dict (
139165 DynamicCache = unregistered_dynamic_cache ,
140166 MambaCache = unregistered_mamba_cache ,
141167 EncoderDecoderCache = unregistered_encode_decode_cache ,
168+ BaseModelOutput = unregistered_base_model_output ,
142169 )
143170
144171
@@ -167,20 +194,11 @@ def _unregister(cls: type, verbose: int = 0):
167194
168195
169196def _unregister_cache_serialization (undo : Dict [str , bool ], verbose : int = 0 ):
170- if undo .get ("MambaCache" , False ):
171- _unregister (transformers .cache_utils .MambaCache , verbose )
172- elif verbose > 1 :
173- print ("[_unregister_cache_serialization] skip unregister MambaCache" )
174-
175- if undo .get ("DynamicCache" , False ):
176- _unregister (transformers .cache_utils .DynamicCache , verbose )
177- elif verbose > 1 :
178- print ("[_unregister_cache_serialization] skip unregister DynamicCache" )
179-
180- if undo .get ("EncoderDecoderCache" , False ):
181- _unregister (transformers .cache_utils .EncoderDecoderCache , verbose )
182- elif verbose > 1 :
183- print ("[_unregister_cache_serialization] skip unregister EncoderDecoderCache" )
197+ for cls in [MambaCache , DynamicCache , EncoderDecoderCache , BaseModelOutput ]:
198+ if undo .get (cls .__name__ , False ):
199+ _unregister (cls , verbose )
200+ elif verbose > 1 :
201+ print (f"[_unregister_cache_serialization] skip unregister { cls .__name__ } " )
184202
185203
186204############
@@ -205,7 +223,7 @@ def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
205223# dtype=dtype,
206224# )
207225def flatten_mamba_cache (
208- mamba_cache : transformers . cache_utils . MambaCache ,
226+ mamba_cache : MambaCache ,
209227) -> Tuple [List [Any ], torch .utils ._pytree .Context ]:
210228 """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
211229 flat = [
@@ -224,10 +242,8 @@ def flatten_mamba_cache(
224242
225243
226244def unflatten_mamba_cache (
227- values : List [Any ],
228- context : torch .utils ._pytree .Context ,
229- output_type = None ,
230- ) -> transformers .cache_utils .MambaCache :
245+ values : List [Any ], context : torch .utils ._pytree .Context , output_type = None
246+ ) -> MambaCache :
231247 """Restores a :class:`transformers.cache_utils.MambaCache` from python objects."""
232248 conv_states , ssm_states = values
233249
@@ -258,12 +274,12 @@ def __init__(self):
258274 return cache
259275
260276
261- def flatten_with_keys_mamba_cache (d : Dict [ Any , Any ] ) -> Tuple [
277+ def flatten_with_keys_mamba_cache (cache : MambaCache ) -> Tuple [
262278 List [Tuple [torch .utils ._pytree .KeyEntry , Any ]],
263279 torch .utils ._pytree .Context ,
264280]:
265281 """Serializes a :class:`transformers.cache_utils.MambaCache` with python objects."""
266- values , context = flatten_mamba_cache (d )
282+ values , context = flatten_mamba_cache (cache )
267283 return [(torch .utils ._pytree .MappingKey (k ), v ) for k , v in zip (context , values )], context
268284
269285
@@ -273,7 +289,7 @@ def flatten_with_keys_mamba_cache(d: Dict[Any, Any]) -> Tuple[
273289
274290
275291def flatten_dynamic_cache (
276- dynamic_cache : transformers . cache_utils . DynamicCache ,
292+ dynamic_cache : DynamicCache ,
277293) -> Tuple [List [Any ], torch .utils ._pytree .Context ]:
278294 """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
279295 if hasattr (transformers .cache_utils , "_flatten_dynamic_cache" ):
@@ -287,11 +303,8 @@ def flatten_dynamic_cache(
287303
288304
289305def flatten_with_keys_dynamic_cache (
290- dynamic_cache : transformers .cache_utils .DynamicCache ,
291- ) -> Tuple [
292- List [Tuple [torch .utils ._pytree .KeyEntry , Any ]],
293- torch .utils ._pytree .Context ,
294- ]:
306+ dynamic_cache : DynamicCache ,
307+ ) -> Tuple [List [Tuple [torch .utils ._pytree .KeyEntry , Any ]], torch .utils ._pytree .Context ]:
295308 """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
296309 if hasattr (transformers .cache_utils , "_flatten_with_keys_dynamic_cache" ):
297310 return transformers .cache_utils ._flatten_with_keys_dynamic_cache (dynamic_cache )
@@ -300,10 +313,8 @@ def flatten_with_keys_dynamic_cache(
300313
301314
302315def unflatten_dynamic_cache (
303- values : List [Any ],
304- context : torch .utils ._pytree .Context ,
305- output_type = None ,
306- ) -> transformers .cache_utils .DynamicCache :
316+ values : List [Any ], context : torch .utils ._pytree .Context , output_type = None
317+ ) -> DynamicCache :
307318 """Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
308319 if hasattr (transformers .cache_utils , "_unflatten_dynamic_cache" ):
309320 assert output_type is None , f"output_type={ output_type } not supported"
@@ -322,7 +333,7 @@ def unflatten_dynamic_cache(
322333
323334
324335def flatten_encoder_decoder_cache (
325- ec_cache : transformers . cache_utils . DynamicCache ,
336+ ec_cache : EncoderDecoderCache ,
326337) -> Tuple [List [Any ], torch .utils ._pytree .Context ]:
327338 """
328339 Serializes a :class:`transformers.cache_utils.EncoderDecoderCache`
@@ -335,9 +346,7 @@ def flatten_encoder_decoder_cache(
335346 return torch .utils ._pytree ._dict_flatten (dictionary )
336347
337348
338- def flatten_with_keys_encoder_decoder_cache (
339- ec_cache : transformers .cache_utils .DynamicCache ,
340- ) -> Tuple [
349+ def flatten_with_keys_encoder_decoder_cache (ec_cache : EncoderDecoderCache ) -> Tuple [
341350 List [Tuple [torch .utils ._pytree .KeyEntry , Any ]],
342351 torch .utils ._pytree .Context ,
343352]:
@@ -353,10 +362,46 @@ def flatten_with_keys_encoder_decoder_cache(
353362
354363
355364def unflatten_encoder_decoder_cache (
356- values : List [Any ],
357- context : torch .utils ._pytree .Context ,
358- output_type = None ,
359- ) -> transformers .cache_utils .EncoderDecoderCache :
365+ values : List [Any ], context : torch .utils ._pytree .Context , output_type = None
366+ ) -> EncoderDecoderCache :
360367 """Restores a :class:`transformers.cache_utils.EncoderDecoderCache` from python objects."""
361368 dictionary = torch .utils ._pytree ._dict_unflatten (values , context )
362369 return transformers .cache_utils .EncoderDecoderCache (** dictionary )
370+
371+
372+ #################
373+ # BaseModelOutput
374+ #################
375+
376+
377+ def flatten_base_model_output (
378+ bo : BaseModelOutput ,
379+ ) -> Tuple [List [Any ], torch .utils ._pytree .Context ]:
380+ """
381+ Serializes a :class:`transformers.modeling_outputs.BaseModelOutput`
382+ with python objects.
383+ """
384+ return list (bo .values ()), list (bo .keys ())
385+
386+
387+ def flatten_with_keys_base_model_output (
388+ bo : BaseModelOutput ,
389+ ) -> Tuple [List [Tuple [torch .utils ._pytree .KeyEntry , Any ]], torch .utils ._pytree .Context ]:
390+ """
391+ Serializes a :class:`transformers.modeling_outputs.BaseModelOutput`
392+ with python objects.
393+ """
394+ values , context = flatten_dynamic_cache (bo )
395+ return [(torch .utils ._pytree .MappingKey (k ), v ) for k , v in zip (context , values )], context
396+
397+
398+ def unflatten_base_model_output (
399+ values : List [Any ],
400+ context : torch .utils ._pytree .Context ,
401+ output_type = None ,
402+ ) -> BaseModelOutput :
403+ """
404+ Restores a :class:`transformers.modeling_outputs.BaseModelOutput`
405+ from python objects.
406+ """
407+ return BaseModelOutput (** dict (zip (context , values )))
0 commit comments