Skip to content

Commit 0e8b0c7

Browse files
committed
fix patches
1 parent f482c6a commit 0e8b0c7

File tree

2 files changed

+66
-50
lines changed

2 files changed

+66
-50
lines changed

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def test_get_untrained_model_with_inputs_automatic_speech_recognition(self):
117117
"#1[T1r3]",
118118
self.string_type(torch.utils._pytree.tree_flatten(inputs["encoder_outputs"])[0]),
119119
)
120-
with bypass_export_some_errors(patch_transformers=True):
120+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
121121
flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0]
122122
self.assertIsInstance(flat, list)
123123
self.assertIsInstance(flat[0], torch.Tensor)

onnx_diagnostic/torch_export_patches/onnx_export_serialization.py

Lines changed: 65 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,27 +12,6 @@
1212

1313

1414
def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
15-
# MambaCache
16-
unregistered_mamba_cache = True
17-
if MambaCache in torch.utils._pytree.SUPPORTED_NODES:
18-
if verbose > 1:
19-
print(f"[_register_cache_serialization] {MambaCache} already registered")
20-
# It is already registered because bypass_export_some_errors was called
21-
# within a section already calling bypass_export_some_errors or transformers
22-
# has updated its code to do it.
23-
# No need to register and unregister then.
24-
unregistered_mamba_cache = False
25-
else:
26-
if verbose:
27-
print("[_register_cache_serialization] register MambaCache")
28-
torch.utils._pytree.register_pytree_node(
29-
MambaCache,
30-
flatten_mamba_cache,
31-
unflatten_mamba_cache,
32-
serialized_type_name=f"{MambaCache.__module__}.{MambaCache.__name__}",
33-
flatten_with_keys_fn=flatten_with_keys_mamba_cache,
34-
)
35-
3615
# DynamicCache serialization is different in transformers and does not
3716
# play way with torch.export.export.
3817
# see test test_export_dynamic_cache_cat with NOBYPASS=1
@@ -42,8 +21,8 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
4221
# DynamicCache, _flatten_dynamic_cache_for_fx)
4322
# so we remove it anyway
4423
if (
45-
DynamicCache in torch.fx._pytree.SUPPORTED_NODES
46-
and not PATCH_OF_PATCHES
24+
DynamicCache in torch.utils._pytree.SUPPORTED_NODES
25+
and DynamicCache not in PATCH_OF_PATCHES
4726
# and pv.Version(torch.__version__) < pv.Version("2.7")
4827
and pv.Version(transformers.__version__) >= pv.Version("4.50")
4928
):
@@ -52,14 +31,19 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
5231
"[_register_cache_serialization] DynamicCache "
5332
"is unregistered and registered first."
5433
)
55-
_unregister(DynamicCache)
34+
_unregister(DynamicCache, verbose=verbose)
5635
torch.utils._pytree.register_pytree_node(
5736
DynamicCache,
5837
flatten_dynamic_cache,
5938
unflatten_dynamic_cache,
6039
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
6140
flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
6241
)
42+
if verbose:
43+
print(
44+
"[_register_cache_serialization] DynamicCache "
45+
"unregistered and registered done."
46+
)
6347
if pv.Version(torch.__version__) < pv.Version("2.7"):
6448
torch.fx._pytree.register_pytree_flatten_spec(
6549
DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
@@ -69,20 +53,28 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
6953

7054
# BaseModelOutput serialization is incomplete.
7155
# It does not include dynamic shapes mapping.
72-
if BaseModelOutput in torch.fx._pytree.SUPPORTED_NODES and not PATCH_OF_PATCHES:
56+
if (
57+
BaseModelOutput in torch.utils._pytree.SUPPORTED_NODES
58+
and BaseModelOutput not in PATCH_OF_PATCHES
59+
):
7360
if verbose:
7461
print(
7562
"[_register_cache_serialization] BaseModelOutput "
7663
"is unregistered and registered first."
7764
)
78-
_unregister(BaseModelOutput)
65+
_unregister(BaseModelOutput, verbose=verbose)
7966
torch.utils._pytree.register_pytree_node(
8067
BaseModelOutput,
8168
flatten_base_model_output,
8269
unflatten_base_model_output,
8370
serialized_type_name=f"{BaseModelOutput.__module__}.{BaseModelOutput.__name__}",
8471
flatten_with_keys_fn=flatten_with_keys_base_model_output,
8572
)
73+
if verbose:
74+
print(
75+
"[_register_cache_serialization] BaseModelOutput "
76+
"unregistered and registered done."
77+
)
8678

8779
# To avoid doing it multiple times.
8880
PATCH_OF_PATCHES.add(BaseModelOutput)
@@ -116,49 +108,70 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
116108
# torch.fx._pytree.tree_flatten(cache)
117109
assert len(cache2.key_cache) == 1
118110

119-
# EncoderDecoderCache
120-
unregistered_encode_decode_cache = True
121-
if (
122-
EncoderDecoderCache is not None
123-
and EncoderDecoderCache in torch.utils._pytree.SUPPORTED_NODES
124-
):
111+
# BaseModelOutput
112+
unregistered_base_model_output = True
113+
if BaseModelOutput is not None and BaseModelOutput in torch.utils._pytree.SUPPORTED_NODES:
125114
if verbose > 1:
126-
print(f"[_register_cache_serialization] {EncoderDecoderCache} already registered")
115+
print(f"[_register_cache_serialization] {BaseModelOutput} already registered")
127116
# It is already registered because bypass_export_some_errors was called
128117
# within a section already calling bypass_export_some_errors or transformers
129118
# has updated its code to do it.
130119
# No need to register and unregister then.
131-
unregistered_encode_decode_cache = False
120+
unregistered_base_model_output = False
132121
else:
133122
if verbose:
134-
print("[_register_cache_serialization] register EncoderDecoderCache")
123+
print("[_register_cache_serialization] register BaseModelOutput")
135124
torch.utils._pytree.register_pytree_node(
136-
EncoderDecoderCache,
125+
BaseModelOutput,
137126
flatten_encoder_decoder_cache,
138127
unflatten_encoder_decoder_cache,
139-
serialized_type_name=f"{EncoderDecoderCache.__module__}.{EncoderDecoderCache.__name__}",
140-
flatten_with_keys_fn=flatten_with_keys_encoder_decoder_cache,
128+
serialized_type_name=f"{BaseModelOutput.__module__}.{BaseModelOutput.__name__}",
129+
flatten_with_keys_fn=flatten_with_keys_base_model_output,
141130
)
142131

143-
# BaseModelOutput
144-
unregistered_base_model_output = True
145-
if BaseModelOutput is not None and BaseModelOutput in torch.utils._pytree.SUPPORTED_NODES:
132+
# MambaCache
133+
unregistered_mamba_cache = True
134+
if MambaCache in torch.utils._pytree.SUPPORTED_NODES:
146135
if verbose > 1:
147-
print(f"[_register_cache_serialization] {BaseModelOutput} already registered")
136+
print(f"[_register_cache_serialization] {MambaCache} already registered")
148137
# It is already registered because bypass_export_some_errors was called
149138
# within a section already calling bypass_export_some_errors or transformers
150139
# has updated its code to do it.
151140
# No need to register and unregister then.
152-
unregistered_base_model_output = False
141+
unregistered_mamba_cache = False
153142
else:
154143
if verbose:
155-
print("[_register_cache_serialization] register BaseModelOutput")
144+
print("[_register_cache_serialization] register MambaCache")
156145
torch.utils._pytree.register_pytree_node(
157-
BaseModelOutput,
146+
MambaCache,
147+
flatten_mamba_cache,
148+
unflatten_mamba_cache,
149+
serialized_type_name=f"{MambaCache.__module__}.{MambaCache.__name__}",
150+
flatten_with_keys_fn=flatten_with_keys_mamba_cache,
151+
)
152+
153+
# EncoderDecoderCache
154+
unregistered_encode_decode_cache = True
155+
if (
156+
EncoderDecoderCache is not None
157+
and EncoderDecoderCache in torch.utils._pytree.SUPPORTED_NODES
158+
):
159+
if verbose > 1:
160+
print(f"[_register_cache_serialization] {EncoderDecoderCache} already registered")
161+
# It is already registered because bypass_export_some_errors was called
162+
# within a section already calling bypass_export_some_errors or transformers
163+
# has updated its code to do it.
164+
# No need to register and unregister then.
165+
unregistered_encode_decode_cache = False
166+
else:
167+
if verbose:
168+
print("[_register_cache_serialization] register EncoderDecoderCache")
169+
torch.utils._pytree.register_pytree_node(
170+
EncoderDecoderCache,
158171
flatten_encoder_decoder_cache,
159172
unflatten_encoder_decoder_cache,
160-
serialized_type_name=f"{BaseModelOutput.__module__}.{BaseModelOutput.__name__}",
161-
flatten_with_keys_fn=flatten_with_keys_base_model_output,
173+
serialized_type_name=f"{EncoderDecoderCache.__module__}.{EncoderDecoderCache.__name__}",
174+
flatten_with_keys_fn=flatten_with_keys_encoder_decoder_cache,
162175
)
163176

164177
return dict(
@@ -170,14 +183,17 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
170183

171184

172185
def _unregister(cls: type, verbose: int = 0):
173-
# torch.fx._pytree._deregister_pytree_flatten_spec(cls)
186+
# torch.utils._pytree._deregister_pytree_flatten_spec(cls)
174187
if cls in torch.fx._pytree.SUPPORTED_NODES:
175188
del torch.fx._pytree.SUPPORTED_NODES[cls]
176189
if cls in torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH:
177190
del torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH[cls]
178191
if hasattr(torch.utils._pytree, "_deregister_pytree_node"):
179192
# torch >= 2.7
180193
torch.utils._pytree._deregister_pytree_node(cls)
194+
else:
195+
if cls in torch.utils._pytree.SUPPORTED_NODES:
196+
del torch.utils._pytree.SUPPORTED_NODES[cls]
181197
optree.unregister_pytree_node(cls, namespace="torch")
182198
if cls in torch.utils._pytree.SUPPORTED_NODES:
183199
import packaging.version as pv
@@ -391,7 +407,7 @@ def flatten_with_keys_base_model_output(
391407
Serializes a :class:`transformers.modeling_outputs.BaseModelOutput`
392408
with python objects.
393409
"""
394-
values, context = flatten_dynamic_cache(bo)
410+
values, context = flatten_base_model_output(bo)
395411
return [(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)], context
396412

397413

0 commit comments

Comments
 (0)