|
1 | 1 | import contextlib |
2 | | -import pprint |
3 | | -from typing import Any, Callable, Dict, List, Optional, Set |
| 2 | +from typing import Any, Callable, Dict, List, Optional |
4 | 3 | from .onnx_export_serialization import ( |
5 | | - flatten_with_keys_dynamic_cache, |
6 | | - flatten_dynamic_cache, |
7 | | - unflatten_dynamic_cache, |
8 | | - flatten_mamba_cache, |
9 | | - flatten_with_keys_mamba_cache, |
10 | | - unflatten_mamba_cache, |
| 4 | + _register_cache_serialization, |
| 5 | + _unregister_cache_serialization, |
11 | 6 | ) |
12 | 7 | from .patches import patch_transformers as patch_transformers_list |
13 | 8 |
|
@@ -84,156 +79,6 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo |
84 | 79 | setattr(original, n, v) |
85 | 80 |
|
86 | 81 |
|
87 | | -PATCH_OF_PATCHES: Set[Any] = set() |
88 | | - |
89 | | - |
90 | | -def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]: |
91 | | - # Cache serialization: to be moved into appropriate packages |
92 | | - import torch |
93 | | - import transformers |
94 | | - import packaging.version as pv |
95 | | - |
96 | | - try: |
97 | | - from transformers.cache_utils import DynamicCache |
98 | | - except ImportError: |
99 | | - DynamicCache = None |
100 | | - |
101 | | - try: |
102 | | - from transformers.cache_utils import MambaCache |
103 | | - except ImportError: |
104 | | - MambaCache = None |
105 | | - |
106 | | - # MambaCache |
107 | | - unregistered_mamba_cache = True |
108 | | - if MambaCache is not None and MambaCache in torch.utils._pytree.SUPPORTED_NODES: |
109 | | - if verbose > 1: |
110 | | - print(f"[_register_cache_serialization] {MambaCache} already registered") |
111 | | - # It is already registered because bypass_export_some_errors was called |
112 | | - # within a section already calling bypass_export_some_errors or transformers |
113 | | - # has updated its code to do it. |
114 | | - # No need to register and unregister then. |
115 | | - unregistered_mamba_cache = False |
116 | | - else: |
117 | | - if verbose: |
118 | | - print("[_register_cache_serialization] register MambaCache") |
119 | | - torch.utils._pytree.register_pytree_node( |
120 | | - MambaCache, |
121 | | - flatten_mamba_cache, |
122 | | - unflatten_mamba_cache, |
123 | | - serialized_type_name=f"{MambaCache.__module__}.{MambaCache.__name__}", |
124 | | - flatten_with_keys_fn=flatten_with_keys_mamba_cache, |
125 | | - ) |
126 | | - |
127 | | - # DynamicCache serialization is different in transformers and does not |
128 | | - # play way with torch.export.export. |
129 | | - # see test test_export_dynamic_cache_cat with NOBYPASS=1 |
130 | | - # :: NOBYBASS=1 python _unittests/ut_torch_export_patches/test_dynamic_class.py -k e_c |
131 | | - # This is caused by this line: |
132 | | - # torch.fx._pytree.register_pytree_flatten_spec( |
133 | | - # DynamicCache, _flatten_dynamic_cache_for_fx) |
134 | | - # so we remove it anyway |
135 | | - if ( |
136 | | - DynamicCache in torch.fx._pytree.SUPPORTED_NODES |
137 | | - and not PATCH_OF_PATCHES |
138 | | - # and pv.Version(torch.__version__) < pv.Version("2.7") |
139 | | - and pv.Version(transformers.__version__) >= pv.Version("4.50") |
140 | | - ): |
141 | | - if verbose: |
142 | | - print( |
143 | | - "[_register_cache_serialization] DynamicCache " |
144 | | - "is unregistered and registered first." |
145 | | - ) |
146 | | - _unregister(DynamicCache) |
147 | | - torch.utils._pytree.register_pytree_node( |
148 | | - DynamicCache, |
149 | | - flatten_dynamic_cache, |
150 | | - unflatten_dynamic_cache, |
151 | | - serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", |
152 | | - flatten_with_keys_fn=flatten_with_keys_dynamic_cache, |
153 | | - ) |
154 | | - if pv.Version(torch.__version__) < pv.Version("2.7"): |
155 | | - torch.fx._pytree.register_pytree_flatten_spec( |
156 | | - DynamicCache, lambda x, _: [x.key_cache, x.value_cache] |
157 | | - ) |
158 | | - # To avoid doing it multiple times. |
159 | | - PATCH_OF_PATCHES.add(DynamicCache) |
160 | | - |
161 | | - unregistered_dynamic_cache = True |
162 | | - if DynamicCache is not None and DynamicCache in torch.utils._pytree.SUPPORTED_NODES: |
163 | | - if verbose > 1: |
164 | | - print(f"[_register_cache_serialization] {DynamicCache} already registered") |
165 | | - unregistered_dynamic_cache = False |
166 | | - else: |
167 | | - if verbose: |
168 | | - print("[_register_cache_serialization] register DynamicCache") |
169 | | - torch.utils._pytree.register_pytree_node( |
170 | | - DynamicCache, |
171 | | - flatten_dynamic_cache, |
172 | | - unflatten_dynamic_cache, |
173 | | - serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}", |
174 | | - flatten_with_keys_fn=flatten_with_keys_dynamic_cache, |
175 | | - ) |
176 | | - if pv.Version(torch.__version__) < pv.Version("2.7"): |
177 | | - torch.fx._pytree.register_pytree_flatten_spec( |
178 | | - DynamicCache, lambda x, _: [x.key_cache, x.value_cache] |
179 | | - ) |
180 | | - |
181 | | - # check |
182 | | - from ..helpers.cache_helper import make_dynamic_cache |
183 | | - |
184 | | - cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]) |
185 | | - values, spec = torch.utils._pytree.tree_flatten(cache) |
186 | | - cache2 = torch.utils._pytree.tree_unflatten(values, spec) |
187 | | - # torch.fx._pytree.tree_flatten(cache) |
188 | | - assert len(cache2.key_cache) == 1 |
189 | | - |
190 | | - return dict(DynamicCache=unregistered_dynamic_cache, MambaCache=unregistered_mamba_cache) |
191 | | - |
192 | | - |
193 | | -def _unregister(cls: type, verbose: int = 0): |
194 | | - import optree |
195 | | - import torch |
196 | | - |
197 | | - # torch.fx._pytree._deregister_pytree_flatten_spec(cls) |
198 | | - if cls in torch.fx._pytree.SUPPORTED_NODES: |
199 | | - del torch.fx._pytree.SUPPORTED_NODES[cls] |
200 | | - if cls in torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH: |
201 | | - del torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH[cls] |
202 | | - if hasattr(torch.utils._pytree, "_deregister_pytree_node"): |
203 | | - # torch >= 2.7 |
204 | | - torch.utils._pytree._deregister_pytree_node(cls) |
205 | | - optree.unregister_pytree_node(cls, namespace="torch") |
206 | | - if cls in torch.utils._pytree.SUPPORTED_NODES: |
207 | | - import packaging.version as pv |
208 | | - |
209 | | - if pv.Version(torch.__version__) < pv.Version("2.7.0"): |
210 | | - del torch.utils._pytree.SUPPORTED_NODES[cls] |
211 | | - assert cls not in torch.utils._pytree.SUPPORTED_NODES, ( |
212 | | - f"{cls} was not successful unregistered " |
213 | | - f"from torch.utils._pytree.SUPPORTED_NODES=" |
214 | | - f"{pprint.pformat(list(torch.utils._pytree.SUPPORTED_NODES))}" |
215 | | - ) |
216 | | - if verbose: |
217 | | - print(f"[_unregister_cache_serialization] unregistered {cls.__name__}") |
218 | | - |
219 | | - |
220 | | -def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0): |
221 | | - |
222 | | - if undo.get("MambaCache", False): |
223 | | - from transformers.cache_utils import MambaCache |
224 | | - |
225 | | - _unregister(MambaCache, verbose) |
226 | | - elif verbose > 1: |
227 | | - print("[_unregister_cache_serialization] skip unregister MambaCache") |
228 | | - |
229 | | - if undo.get("DynamicCache", False): |
230 | | - from transformers.cache_utils import DynamicCache |
231 | | - |
232 | | - _unregister(DynamicCache, verbose) |
233 | | - elif verbose > 1: |
234 | | - print("[_unregister_cache_serialization] skip unregister DynamicCache") |
235 | | - |
236 | | - |
237 | 82 | @contextlib.contextmanager |
238 | 83 | def register_additional_serialization_functions( |
239 | 84 | patch_transformers: bool = False, verbose: int = 0 |
|
0 commit comments