11import pprint
2- from typing import Any , Dict , List , Set , Tuple
2+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple
33import packaging .version as pv
44import optree
55import torch
66import transformers
77from transformers .cache_utils import DynamicCache , MambaCache , EncoderDecoderCache
88from transformers .modeling_outputs import BaseModelOutput
9+ from ..helpers import string_type
910
1011
1112PATCH_OF_PATCHES : Set [Any ] = set ()
1213
1314
15+ def _register_class_serialization (
16+ cls ,
17+ f_flatten : Callable ,
18+ f_unflatten : Callable ,
19+ f_flatten_with_keys : Callable ,
20+ f_check : Optional [Callable ] = None ,
21+ verbose : int = 0 ,
22+ ) -> bool :
23+ if cls is not None and cls in torch .utils ._pytree .SUPPORTED_NODES :
24+ return False
25+
26+ if verbose :
27+ print (f"[_register_cache_serialization] register { cls } " )
28+ torch .utils ._pytree .register_pytree_node (
29+ cls ,
30+ f_flatten ,
31+ f_unflatten ,
32+ serialized_type_name = f"{ cls .__module__ } .{ cls .__name__ } " ,
33+ flatten_with_keys_fn = f_flatten_with_keys ,
34+ )
35+ if pv .Version (torch .__version__ ) < pv .Version ("2.7" ):
36+ if verbose :
37+ print (
38+ f"[_register_cache_serialization] "
39+ f"register { cls } for torch=={ torch .__version__ } "
40+ )
41+ torch .fx ._pytree .register_pytree_flatten_spec (cls , lambda x , _ : f_flatten (x )[0 ])
42+
43+ # check
44+ if f_check :
45+ inst = f_check ()
46+ values , spec = torch .utils ._pytree .tree_flatten (inst )
47+ restored = torch .utils ._pytree .tree_unflatten (values , spec )
48+ assert string_type (inst , with_shape = True ) == string_type (restored , with_shape = True ), (
49+ f"Issue with registration of class { cls } "
50+ f"inst={ string_type (inst , with_shape = True )} , "
51+ f"restored={ string_type (restored , with_shape = True )} "
52+ )
53+ return True
54+
55+
1456def _register_cache_serialization (verbose : int = 0 ) -> Dict [str , bool ]:
1557 # DynamicCache serialization is different in transformers and does not
1658 # play way with torch.export.export.
@@ -28,26 +70,20 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
2870 ):
2971 if verbose :
3072 print (
31- "[_register_cache_serialization ] DynamicCache "
32- "is unregistered and registered first. "
73+ f"[_fix_registration ] DynamicCache is unregistered and "
74+ f" registered first for transformers== { transformers . __version__ } "
3375 )
3476 _unregister (DynamicCache , verbose = verbose )
35- torch . utils . _pytree . register_pytree_node (
77+ _register_class_serialization (
3678 DynamicCache ,
3779 flatten_dynamic_cache ,
3880 unflatten_dynamic_cache ,
39- serialized_type_name = f"{ DynamicCache .__module__ } .{ DynamicCache .__name__ } " ,
40- flatten_with_keys_fn = flatten_with_keys_dynamic_cache ,
81+ flatten_with_keys_dynamic_cache ,
82+ # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
83+ verbose = verbose ,
4184 )
4285 if verbose :
43- print (
44- "[_register_cache_serialization] DynamicCache "
45- "unregistered and registered done."
46- )
47- if pv .Version (torch .__version__ ) < pv .Version ("2.7" ):
48- torch .fx ._pytree .register_pytree_flatten_spec (
49- DynamicCache , lambda x , _ : [x .key_cache , x .value_cache ]
50- )
86+ print ("[_fix_registration] DynamicCache done." )
5187 # To avoid doing it multiple times.
5288 PATCH_OF_PATCHES .add (DynamicCache )
5389
@@ -59,120 +95,52 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
5995 ):
6096 if verbose :
6197 print (
62- "[_register_cache_serialization ] BaseModelOutput "
63- "is unregistered and registered first. "
98+ f"[_fix_registration ] BaseModelOutput is unregistered and "
99+ f" registered first for transformers== { transformers . __version__ } "
64100 )
65101 _unregister (BaseModelOutput , verbose = verbose )
66- torch . utils . _pytree . register_pytree_node (
102+ _register_class_serialization (
67103 BaseModelOutput ,
68104 flatten_base_model_output ,
69105 unflatten_base_model_output ,
70- serialized_type_name = f" { BaseModelOutput . __module__ } . { BaseModelOutput . __name__ } " ,
71- flatten_with_keys_fn = flatten_with_keys_base_model_output ,
106+ flatten_with_keys_base_model_output ,
107+ verbose = verbose ,
72108 )
73109 if verbose :
74- print (
75- "[_register_cache_serialization] BaseModelOutput "
76- "unregistered and registered done."
77- )
110+ print ("[_fix_registration] BaseModelOutput done." )
78111
79112 # To avoid doing it multiple times.
80113 PATCH_OF_PATCHES .add (BaseModelOutput )
81114
82- unregistered_dynamic_cache = True
83- if DynamicCache is not None and DynamicCache in torch .utils ._pytree .SUPPORTED_NODES :
84- if verbose > 1 :
85- print (f"[_register_cache_serialization] { DynamicCache } already registered" )
86- unregistered_dynamic_cache = False
87- else :
88- if verbose :
89- print ("[_register_cache_serialization] register DynamicCache" )
90- torch .utils ._pytree .register_pytree_node (
91- DynamicCache ,
92- flatten_dynamic_cache ,
93- unflatten_dynamic_cache ,
94- serialized_type_name = f"{ DynamicCache .__module__ } .{ DynamicCache .__name__ } " ,
95- flatten_with_keys_fn = flatten_with_keys_dynamic_cache ,
96- )
97- if pv .Version (torch .__version__ ) < pv .Version ("2.7" ):
98- torch .fx ._pytree .register_pytree_flatten_spec (
99- DynamicCache , lambda x , _ : [x .key_cache , x .value_cache ]
100- )
101-
102- # check
103- from ..helpers .cache_helper import make_dynamic_cache
104-
105- cache = make_dynamic_cache ([(torch .rand ((4 , 4 , 4 )), torch .rand ((4 , 4 , 4 )))])
106- values , spec = torch .utils ._pytree .tree_flatten (cache )
107- cache2 = torch .utils ._pytree .tree_unflatten (values , spec )
108- # torch.fx._pytree.tree_flatten(cache)
109- assert len (cache2 .key_cache ) == 1
110-
111- # BaseModelOutput
112- unregistered_base_model_output = True
113- if BaseModelOutput is not None and BaseModelOutput in torch .utils ._pytree .SUPPORTED_NODES :
114- if verbose > 1 :
115- print (f"[_register_cache_serialization] { BaseModelOutput } already registered" )
116- # It is already registered because bypass_export_some_errors was called
117- # within a section already calling bypass_export_some_errors or transformers
118- # has updated its code to do it.
119- # No need to register and unregister then.
120- unregistered_base_model_output = False
121- else :
122- if verbose :
123- print ("[_register_cache_serialization] register BaseModelOutput" )
124- torch .utils ._pytree .register_pytree_node (
125- BaseModelOutput ,
126- flatten_encoder_decoder_cache ,
127- unflatten_encoder_decoder_cache ,
128- serialized_type_name = f"{ BaseModelOutput .__module__ } .{ BaseModelOutput .__name__ } " ,
129- flatten_with_keys_fn = flatten_with_keys_base_model_output ,
130- )
131-
132- # MambaCache
133- unregistered_mamba_cache = True
134- if MambaCache in torch .utils ._pytree .SUPPORTED_NODES :
135- if verbose > 1 :
136- print (f"[_register_cache_serialization] { MambaCache } already registered" )
137- # It is already registered because bypass_export_some_errors was called
138- # within a section already calling bypass_export_some_errors or transformers
139- # has updated its code to do it.
140- # No need to register and unregister then.
141- unregistered_mamba_cache = False
142- else :
143- if verbose :
144- print ("[_register_cache_serialization] register MambaCache" )
145- torch .utils ._pytree .register_pytree_node (
146- MambaCache ,
147- flatten_mamba_cache ,
148- unflatten_mamba_cache ,
149- serialized_type_name = f"{ MambaCache .__module__ } .{ MambaCache .__name__ } " ,
150- flatten_with_keys_fn = flatten_with_keys_mamba_cache ,
151- )
152-
153- # EncoderDecoderCache
154- unregistered_encode_decode_cache = True
155- if (
156- EncoderDecoderCache is not None
157- and EncoderDecoderCache in torch .utils ._pytree .SUPPORTED_NODES
158- ):
159- if verbose > 1 :
160- print (f"[_register_cache_serialization] { EncoderDecoderCache } already registered" )
161- # It is already registered because bypass_export_some_errors was called
162- # within a section already calling bypass_export_some_errors or transformers
163- # has updated its code to do it.
164- # No need to register and unregister then.
165- unregistered_encode_decode_cache = False
166- else :
167- if verbose :
168- print ("[_register_cache_serialization] register EncoderDecoderCache" )
169- torch .utils ._pytree .register_pytree_node (
170- EncoderDecoderCache ,
171- flatten_encoder_decoder_cache ,
172- unflatten_encoder_decoder_cache ,
173- serialized_type_name = f"{ EncoderDecoderCache .__module__ } .{ EncoderDecoderCache .__name__ } " ,
174- flatten_with_keys_fn = flatten_with_keys_encoder_decoder_cache ,
175- )
115+ unregistered_dynamic_cache = _register_class_serialization (
116+ DynamicCache ,
117+ flatten_dynamic_cache ,
118+ unflatten_dynamic_cache ,
119+ flatten_with_keys_dynamic_cache ,
120+ # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
121+ verbose = verbose ,
122+ )
123+ unregistered_base_model_output = _register_class_serialization (
124+ BaseModelOutput ,
125+ flatten_base_model_output ,
126+ unflatten_base_model_output ,
127+ flatten_with_keys_base_model_output ,
128+ verbose = verbose ,
129+ )
130+ unregistered_encode_decode_cache = _register_class_serialization (
131+ EncoderDecoderCache ,
132+ flatten_encoder_decoder_cache ,
133+ unflatten_encoder_decoder_cache ,
134+ flatten_with_keys_encoder_decoder_cache ,
135+ verbose = verbose ,
136+ )
137+ unregistered_mamba_cache = _register_class_serialization (
138+ MambaCache ,
139+ flatten_mamba_cache ,
140+ unflatten_mamba_cache ,
141+ flatten_with_keys_mamba_cache ,
142+ verbose = verbose ,
143+ )
176144
177145 return dict (
178146 DynamicCache = unregistered_dynamic_cache ,
@@ -213,8 +181,6 @@ def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
213181 for cls in [MambaCache , DynamicCache , EncoderDecoderCache , BaseModelOutput ]:
214182 if undo .get (cls .__name__ , False ):
215183 _unregister (cls , verbose )
216- elif verbose > 1 :
217- print (f"[_unregister_cache_serialization] skip unregister { cls .__name__ } " )
218184
219185
220186############
0 commit comments