Skip to content

Commit e4c9d6f

Browse files
committed
fix issues
1 parent 93b00d9 commit e4c9d6f

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,10 +110,19 @@ def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
110110
# torch.fx._pytree.register_pytree_flatten_spec(
111111
# DynamicCache, _flatten_dynamic_cache_for_fx)
112112
# so we remove it anyway
113-
if DynamicCache in torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH and pv.Version(
113+
if DynamicCache in torch.fx._pytree.SUPPORTED_NODES and pv.Version(
114114
transformers.__version__
115115
) >= pv.Version("2.7"):
116-
del torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH[DynamicCache]
116+
if verbose:
117+
print("[_register_cache_serialization] DynamicCache is unregistered first.")
118+
_unregister(DynamicCache)
119+
torch.utils._pytree.register_pytree_node(
120+
DynamicCache,
121+
flatten_dynamic_cache,
122+
unflatten_dynamic_cache,
123+
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
124+
flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
125+
)
117126

118127
unregistered_dynamic_cache = True
119128
if DynamicCache is not None and DynamicCache in torch.utils._pytree.SUPPORTED_NODES:

0 commit comments

Comments
 (0)