1212 StaticCache ,
1313)
1414from transformers .modeling_outputs import BaseModelOutput
15+
16+ try :
17+ from diffusers .models .autoencoders .vae import DecoderOutput , EncoderOutput
18+ from diffusers .models .unets .unet_1d import UNet1DOutput
19+ from diffusers .models .unets .unet_2d import UNet2DOutput
20+ from diffusers .models .unets .unet_2d_condition import UNet2DConditionOutput
21+ from diffusers .models .unets .unet_3d_condition import UNet3DConditionOutput
22+ except ImportError as e :
23+ try :
24+ import diffusers
25+ except ImportError :
26+ diffusers = None
27+ DecoderOutput , EncoderOutput = None , None
28+ UNet1DOutput , UNet2DOutput = None , None
29+ UNet2DConditionOutput , UNet3DConditionOutput = None , None
30+ if diffusers :
31+ raise e
32+
1533from ..helpers import string_type
1634from ..helpers .cache_helper import make_static_cache
1735
1836
1937PATCH_OF_PATCHES : Set [Any ] = set ()
38+ WRONG_REGISTRATIONS : Dict [str , str ] = {
39+ DynamicCache : "4.50" ,
40+ BaseModelOutput : None ,
41+ UNet2DConditionOutput : None ,
42+ }
2043
2144
2245def register_class_serialization (
@@ -40,10 +63,12 @@ def register_class_serialization(
4063 :return: registered or not
4164 """
4265 if cls is not None and cls in torch .utils ._pytree .SUPPORTED_NODES :
66+ if verbose and cls is not None :
67+ print (f"[register_class_serialization] already registered { cls .__name__ } " )
4368 return False
4469
4570 if verbose :
46- print (f"[register_cache_serialization] register { cls } " )
71+ print (f"[register_class_serialization] ---------- register { cls . __name__ } " )
4772 torch .utils ._pytree .register_pytree_node (
4873 cls ,
4974 f_flatten ,
@@ -54,8 +79,8 @@ def register_class_serialization(
5479 if pv .Version (torch .__version__ ) < pv .Version ("2.7" ):
5580 if verbose :
5681 print (
57- f"[register_cache_serialization ] "
58- f"register { cls } for torch=={ torch .__version__ } "
82+ f"[register_class_serialization ] "
83+ f"---------- register { cls . __name__ } for torch=={ torch .__version__ } "
5984 )
6085 torch .fx ._pytree .register_pytree_flatten_spec (cls , lambda x , _ : f_flatten (x )[0 ])
6186
@@ -77,6 +102,8 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
77102 Registers many classes with :func:`register_class_serialization`.
78103 Returns information needed to undo the registration.
79104 """
105+ registration_functions = serialization_functions (verbose = verbose )
106+
80107 # DynamicCache serialization is different in transformers and does not
81108 # play way with torch.export.export.
82109 # see test test_export_dynamic_cache_cat with NOBYPASS=1
@@ -85,109 +112,102 @@ def register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
85112 # torch.fx._pytree.register_pytree_flatten_spec(
86113 # DynamicCache, _flatten_dynamic_cache_for_fx)
87114 # so we remove it anyway
88- if (
89- DynamicCache in torch .utils ._pytree .SUPPORTED_NODES
90- and DynamicCache not in PATCH_OF_PATCHES
91- # and pv.Version(torch.__version__) < pv.Version("2.7")
92- and pv .Version (transformers .__version__ ) >= pv .Version ("4.50" )
93- ):
94- if verbose :
95- print (
96- f"[_fix_registration] DynamicCache is unregistered and "
97- f"registered first for transformers=={ transformers .__version__ } "
98- )
99- unregister (DynamicCache , verbose = verbose )
100- register_class_serialization (
101- DynamicCache ,
102- flatten_dynamic_cache ,
103- unflatten_dynamic_cache ,
104- flatten_with_keys_dynamic_cache ,
105- # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
106- verbose = verbose ,
107- )
108- if verbose :
109- print ("[_fix_registration] DynamicCache done." )
110- # To avoid doing it multiple times.
111- PATCH_OF_PATCHES .add (DynamicCache )
112-
113115 # BaseModelOutput serialization is incomplete.
114116 # It does not include dynamic shapes mapping.
115- if (
116- BaseModelOutput in torch .utils ._pytree .SUPPORTED_NODES
117- and BaseModelOutput not in PATCH_OF_PATCHES
118- ):
119- if verbose :
120- print (
121- f"[_fix_registration] BaseModelOutput is unregistered and "
122- f"registered first for transformers=={ transformers .__version__ } "
117+ for cls , version in WRONG_REGISTRATIONS .items ():
118+ if (
119+ cls in torch .utils ._pytree .SUPPORTED_NODES
120+ and cls not in PATCH_OF_PATCHES
121+ # and pv.Version(torch.__version__) < pv.Version("2.7")
122+ and (
123+ version is None or pv .Version (transformers .__version__ ) >= pv .Version (version )
123124 )
124- unregister (BaseModelOutput , verbose = verbose )
125- register_class_serialization (
126- BaseModelOutput ,
127- flatten_base_model_output ,
128- unflatten_base_model_output ,
129- flatten_with_keys_base_model_output ,
130- verbose = verbose ,
131- )
132- if verbose :
133- print ("[_fix_registration] BaseModelOutput done." )
134-
135- # To avoid doing it multiple times.
136- PATCH_OF_PATCHES .add (BaseModelOutput )
137-
138- return serialization_functions (verbose = verbose )
139-
140-
141- def serialization_functions (verbose : int = 0 ) -> Dict [str , Union [Callable , int ]]:
125+ ):
126+ assert cls in registration_functions , (
127+ f"{ cls } has no registration functions mapped to it, "
128+ f"available { sorted (registration_functions )} "
129+ )
130+ if verbose :
131+ print (
132+ f"[_fix_registration] { cls .__name__ } is unregistered and "
133+ f"registered first"
134+ )
135+ unregister_class_serialization (cls , verbose = verbose )
136+ registration_functions [cls ](verbose = verbose )
137+ if verbose :
138+ print (f"[_fix_registration] { cls .__name__ } done." )
139+ # To avoid doing it multiple times.
140+ PATCH_OF_PATCHES .add (cls )
141+
142+ # classes with no registration at all.
143+ done = {}
144+ for k , v in registration_functions .items ():
145+ done [k ] = v (verbose = verbose )
146+ return done
147+
148+
149+ def serialization_functions (verbose : int = 0 ) -> Dict [type , Union [Callable [[], bool ], int ]]:
142150 """Returns the list of serialization functions."""
143- return dict (
144- DynamicCache = register_class_serialization (
151+ transformers_classes = {
152+ DynamicCache : lambda verbose = verbose : register_class_serialization (
145153 DynamicCache ,
146154 flatten_dynamic_cache ,
147155 unflatten_dynamic_cache ,
148156 flatten_with_keys_dynamic_cache ,
149157 # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
150158 verbose = verbose ,
151159 ),
152- MambaCache = register_class_serialization (
160+ MambaCache : lambda verbose = verbose : register_class_serialization (
153161 MambaCache ,
154162 flatten_mamba_cache ,
155163 unflatten_mamba_cache ,
156164 flatten_with_keys_mamba_cache ,
157165 verbose = verbose ,
158166 ),
159- EncoderDecoderCache = register_class_serialization (
167+ EncoderDecoderCache : lambda verbose = verbose : register_class_serialization (
160168 EncoderDecoderCache ,
161169 flatten_encoder_decoder_cache ,
162170 unflatten_encoder_decoder_cache ,
163171 flatten_with_keys_encoder_decoder_cache ,
164172 verbose = verbose ,
165173 ),
166- BaseModelOutput = register_class_serialization (
174+ BaseModelOutput : lambda verbose = verbose : register_class_serialization (
167175 BaseModelOutput ,
168176 flatten_base_model_output ,
169177 unflatten_base_model_output ,
170178 flatten_with_keys_base_model_output ,
171179 verbose = verbose ,
172180 ),
173- SlidingWindowCache = register_class_serialization (
181+ SlidingWindowCache : lambda verbose = verbose : register_class_serialization (
174182 SlidingWindowCache ,
175183 flatten_sliding_window_cache ,
176184 unflatten_sliding_window_cache ,
177185 flatten_with_keys_sliding_window_cache ,
178186 verbose = verbose ,
179187 ),
180- StaticCache = register_class_serialization (
188+ StaticCache : lambda verbose = verbose : register_class_serialization (
181189 StaticCache ,
182190 flatten_static_cache ,
183191 unflatten_static_cache ,
184192 flatten_with_keys_static_cache ,
185193 verbose = verbose ,
186194 ),
187- )
195+ }
196+ if UNet2DConditionOutput :
197+ diffusers_classes = {
198+ UNet2DConditionOutput : lambda verbose = verbose : register_class_serialization (
199+ UNet2DConditionOutput ,
200+ flatten_unet_2d_condition_output ,
201+ unflatten_unet_2d_condition_output ,
202+ flatten_with_keys_unet_2d_condition_output ,
203+ verbose = verbose ,
204+ )
205+ }
206+ transformers_classes .update (diffusers_classes )
207+ return transformers_classes
188208
189209
190- def unregister (cls : type , verbose : int = 0 ):
210+ def unregister_class_serialization (cls : type , verbose : int = 0 ):
191211 """Undo the registration."""
192212 # torch.utils._pytree._deregister_pytree_flatten_spec(cls)
193213 if cls in torch .fx ._pytree .SUPPORTED_NODES :
@@ -217,9 +237,10 @@ def unregister(cls: type, verbose: int = 0):
217237
218238def unregister_cache_serialization (undo : Dict [str , bool ], verbose : int = 0 ):
219239 """Undo all registrations."""
220- for cls in [MambaCache , DynamicCache , EncoderDecoderCache , BaseModelOutput ]:
240+ cls_ensemble = {MambaCache , DynamicCache , EncoderDecoderCache , BaseModelOutput } | set (undo )
241+ for cls in cls_ensemble :
221242 if undo .get (cls .__name__ , False ):
222- unregister (cls , verbose )
243+ unregister_class_serialization (cls , verbose )
223244
224245
225246############
@@ -478,3 +499,41 @@ def unflatten_base_model_output(
478499 from python objects.
479500 """
480501 return BaseModelOutput (** dict (zip (context , values )))
502+
503+
504+ #######################
505+ # UNet2DConditionOutput
506+ #######################
507+
508+
509+ def flatten_unet_2d_condition_output (
510+ obj : UNet2DConditionOutput ,
511+ ) -> Tuple [List [Any ], torch .utils ._pytree .Context ]:
512+ """
513+ Serializes a :class:`diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`
514+ with python objects.
515+ """
516+ return list (obj .values ()), list (obj .keys ())
517+
518+
519+ def flatten_with_keys_unet_2d_condition_output (
520+ obj : UNet2DConditionOutput ,
521+ ) -> Tuple [List [Tuple [torch .utils ._pytree .KeyEntry , Any ]], torch .utils ._pytree .Context ]:
522+ """
523+ Serializes a :class:`diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`
524+ with python objects.
525+ """
526+ values , context = flatten_unet_2d_condition_output (obj )
527+ return [(torch .utils ._pytree .MappingKey (k ), v ) for k , v in zip (context , values )], context
528+
529+
530+ def unflatten_unet_2d_condition_output (
531+ values : List [Any ],
532+ context : torch .utils ._pytree .Context ,
533+ output_type = None ,
534+ ) -> UNet2DConditionOutput :
535+ """
536+ Restores a :class:`diffusers.models.unets.unet_2d_condition.UNet2DConditionOutput`
537+ from python objects.
538+ """
539+ return UNet2DConditionOutput (** dict (zip (context , values )))
0 commit comments