Skip to content

Commit 16b9ff4

Browse files
committed
eccache
1 parent 19b025c commit 16b9ff4

File tree

6 files changed

+342
-178
lines changed

6 files changed

+342
-178
lines changed
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import unittest
2+
import torch
3+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
4+
from onnx_diagnostic.helpers.cache_helper import make_encoder_decoder_cache, make_dynamic_cache
5+
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
6+
bypass_export_some_errors,
7+
)
8+
9+
10+
class TestPatchSerialization(ExtTestCase):
11+
@ignore_warnings(UserWarning)
12+
def test_flatten_encoder_decoder_cache(self):
13+
cache = make_encoder_decoder_cache(
14+
make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
15+
make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]),
16+
)
17+
with bypass_export_some_errors():
18+
flat, _spec = torch.utils._pytree.tree_flatten(cache)
19+
self.assertEqual(
20+
"#4[T1s4x4x4,T1s4x4x4,T1s5x5x5,T1s5x5x5]",
21+
self.string_type(flat, with_shape=True),
22+
)
23+
cache2 = torch.utils._pytree.tree_unflatten(flat, _spec)
24+
self.assertEqual(
25+
self.string_type(cache, with_shape=True, with_min_max=True),
26+
self.string_type(cache2, with_shape=True, with_min_max=True),
27+
)
28+
29+
@ignore_warnings(UserWarning)
30+
def test_export_encoder_decoder_cache(self):
31+
class Model(torch.nn.Module):
32+
def forward(self, cache):
33+
return cache.self_attention_cache.key_cache[0]
34+
35+
cache1 = make_dynamic_cache(
36+
[(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)]
37+
)
38+
cache2 = make_dynamic_cache(
39+
[(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)]
40+
)
41+
42+
cache = make_encoder_decoder_cache(cache1, cache2)
43+
model = Model()
44+
model(cache)
45+
DYN = torch.export.Dim.DYNAMIC
46+
ds = [
47+
[[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]],
48+
[[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]],
49+
]
50+
51+
with bypass_export_some_errors():
52+
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
53+
54+
@ignore_warnings(UserWarning)
55+
def test_flatten_dynamic_cache(self):
56+
cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
57+
with bypass_export_some_errors():
58+
flat, _spec = torch.utils._pytree.tree_flatten(cache)
59+
self.assertEqual(
60+
"#2[T1s4x4x4,T1s4x4x4]",
61+
self.string_type(flat, with_shape=True),
62+
)
63+
cache2 = torch.utils._pytree.tree_unflatten(flat, _spec)
64+
self.assertEqual(
65+
self.string_type(cache, with_shape=True, with_min_max=True),
66+
self.string_type(cache2, with_shape=True, with_min_max=True),
67+
)
68+
69+
@ignore_warnings(UserWarning)
70+
def test_export_dynamic_cache(self):
71+
class Model(torch.nn.Module):
72+
def forward(self, cache):
73+
return cache.key_cache[0]
74+
75+
cache = make_dynamic_cache(
76+
[(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)]
77+
)
78+
model = Model()
79+
model(cache)
80+
DYN = torch.export.Dim.DYNAMIC
81+
ds = [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]]
82+
83+
with bypass_export_some_errors():
84+
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
85+
86+
87+
if __name__ == "__main__":
88+
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pprint
22
import unittest
3+
import torch
34
import transformers
45
from onnx_diagnostic.ext_test_case import (
56
ExtTestCase,
@@ -14,6 +15,7 @@
1415
)
1516
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config
1617
from onnx_diagnostic.torch_models.hghub.hub_data import load_models_testing
18+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
1719

1820

1921
class TestHuggingFaceHubModel(ExtTestCase):
@@ -109,8 +111,21 @@ def test_get_untrained_model_with_inputs_automatic_speech_recognition(self):
109111
mid = "openai/whisper-tiny"
110112
data = get_untrained_model_with_inputs(mid, verbose=1)
111113
self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)])
112-
model, inputs = data["model"], data["inputs"]
114+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
113115
model(**inputs)
116+
self.assertEqual(
117+
"#1[T1r3]",
118+
self.string_type(torch.utils._pytree.tree_flatten(inputs["encoder_outputs"])[0]),
119+
)
120+
with bypass_export_some_errors(patch_transformers=True):
121+
flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0]
122+
self.assertIsInstance(flat, list)
123+
self.assertIsInstance(flat[0], torch.Tensor)
124+
self.assertEqual(
125+
"#8[T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4]",
126+
self.string_type(flat),
127+
)
128+
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds)
114129

115130
@hide_stdout()
116131
def test_get_untrained_model_with_inputs_imagetext2text_generation(self):

onnx_diagnostic/helpers/cache_helper.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,7 @@ def make_encoder_decoder_cache(
132132
self_attention_cache: transformers.cache_utils.DynamicCache,
133133
cross_attention_cache: transformers.cache_utils.DynamicCache,
134134
) -> transformers.cache_utils.EncoderDecoderCache:
135-
"""
136-
Creates an EncoderDecoderCache.
137-
"""
135+
"""Creates an EncoderDecoderCache."""
138136
return transformers.cache_utils.EncoderDecoderCache(
139137
self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache
140138
)

onnx_diagnostic/torch_export_patches/onnx_export_errors.py

Lines changed: 3 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
import contextlib
2-
import pprint
3-
from typing import Any, Callable, Dict, List, Optional, Set
2+
from typing import Any, Callable, Dict, List, Optional
43
from .onnx_export_serialization import (
5-
flatten_with_keys_dynamic_cache,
6-
flatten_dynamic_cache,
7-
unflatten_dynamic_cache,
8-
flatten_mamba_cache,
9-
flatten_with_keys_mamba_cache,
10-
unflatten_mamba_cache,
4+
_register_cache_serialization,
5+
_unregister_cache_serialization,
116
)
127
from .patches import patch_transformers as patch_transformers_list
138

@@ -84,156 +79,6 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
8479
setattr(original, n, v)
8580

8681

87-
PATCH_OF_PATCHES: Set[Any] = set()
88-
89-
90-
def _register_cache_serialization(verbose: int = 0) -> Dict[str, bool]:
91-
# Cache serialization: to be moved into appropriate packages
92-
import torch
93-
import transformers
94-
import packaging.version as pv
95-
96-
try:
97-
from transformers.cache_utils import DynamicCache
98-
except ImportError:
99-
DynamicCache = None
100-
101-
try:
102-
from transformers.cache_utils import MambaCache
103-
except ImportError:
104-
MambaCache = None
105-
106-
# MambaCache
107-
unregistered_mamba_cache = True
108-
if MambaCache is not None and MambaCache in torch.utils._pytree.SUPPORTED_NODES:
109-
if verbose > 1:
110-
print(f"[_register_cache_serialization] {MambaCache} already registered")
111-
# It is already registered because bypass_export_some_errors was called
112-
# within a section already calling bypass_export_some_errors or transformers
113-
# has updated its code to do it.
114-
# No need to register and unregister then.
115-
unregistered_mamba_cache = False
116-
else:
117-
if verbose:
118-
print("[_register_cache_serialization] register MambaCache")
119-
torch.utils._pytree.register_pytree_node(
120-
MambaCache,
121-
flatten_mamba_cache,
122-
unflatten_mamba_cache,
123-
serialized_type_name=f"{MambaCache.__module__}.{MambaCache.__name__}",
124-
flatten_with_keys_fn=flatten_with_keys_mamba_cache,
125-
)
126-
127-
# DynamicCache serialization is different in transformers and does not
128-
# play way with torch.export.export.
129-
# see test test_export_dynamic_cache_cat with NOBYPASS=1
130-
# :: NOBYBASS=1 python _unittests/ut_torch_export_patches/test_dynamic_class.py -k e_c
131-
# This is caused by this line:
132-
# torch.fx._pytree.register_pytree_flatten_spec(
133-
# DynamicCache, _flatten_dynamic_cache_for_fx)
134-
# so we remove it anyway
135-
if (
136-
DynamicCache in torch.fx._pytree.SUPPORTED_NODES
137-
and not PATCH_OF_PATCHES
138-
# and pv.Version(torch.__version__) < pv.Version("2.7")
139-
and pv.Version(transformers.__version__) >= pv.Version("4.50")
140-
):
141-
if verbose:
142-
print(
143-
"[_register_cache_serialization] DynamicCache "
144-
"is unregistered and registered first."
145-
)
146-
_unregister(DynamicCache)
147-
torch.utils._pytree.register_pytree_node(
148-
DynamicCache,
149-
flatten_dynamic_cache,
150-
unflatten_dynamic_cache,
151-
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
152-
flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
153-
)
154-
if pv.Version(torch.__version__) < pv.Version("2.7"):
155-
torch.fx._pytree.register_pytree_flatten_spec(
156-
DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
157-
)
158-
# To avoid doing it multiple times.
159-
PATCH_OF_PATCHES.add(DynamicCache)
160-
161-
unregistered_dynamic_cache = True
162-
if DynamicCache is not None and DynamicCache in torch.utils._pytree.SUPPORTED_NODES:
163-
if verbose > 1:
164-
print(f"[_register_cache_serialization] {DynamicCache} already registered")
165-
unregistered_dynamic_cache = False
166-
else:
167-
if verbose:
168-
print("[_register_cache_serialization] register DynamicCache")
169-
torch.utils._pytree.register_pytree_node(
170-
DynamicCache,
171-
flatten_dynamic_cache,
172-
unflatten_dynamic_cache,
173-
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
174-
flatten_with_keys_fn=flatten_with_keys_dynamic_cache,
175-
)
176-
if pv.Version(torch.__version__) < pv.Version("2.7"):
177-
torch.fx._pytree.register_pytree_flatten_spec(
178-
DynamicCache, lambda x, _: [x.key_cache, x.value_cache]
179-
)
180-
181-
# check
182-
from ..helpers.cache_helper import make_dynamic_cache
183-
184-
cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
185-
values, spec = torch.utils._pytree.tree_flatten(cache)
186-
cache2 = torch.utils._pytree.tree_unflatten(values, spec)
187-
# torch.fx._pytree.tree_flatten(cache)
188-
assert len(cache2.key_cache) == 1
189-
190-
return dict(DynamicCache=unregistered_dynamic_cache, MambaCache=unregistered_mamba_cache)
191-
192-
193-
def _unregister(cls: type, verbose: int = 0):
194-
import optree
195-
import torch
196-
197-
# torch.fx._pytree._deregister_pytree_flatten_spec(cls)
198-
if cls in torch.fx._pytree.SUPPORTED_NODES:
199-
del torch.fx._pytree.SUPPORTED_NODES[cls]
200-
if cls in torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH:
201-
del torch.fx._pytree.SUPPORTED_NODES_EXACT_MATCH[cls]
202-
if hasattr(torch.utils._pytree, "_deregister_pytree_node"):
203-
# torch >= 2.7
204-
torch.utils._pytree._deregister_pytree_node(cls)
205-
optree.unregister_pytree_node(cls, namespace="torch")
206-
if cls in torch.utils._pytree.SUPPORTED_NODES:
207-
import packaging.version as pv
208-
209-
if pv.Version(torch.__version__) < pv.Version("2.7.0"):
210-
del torch.utils._pytree.SUPPORTED_NODES[cls]
211-
assert cls not in torch.utils._pytree.SUPPORTED_NODES, (
212-
f"{cls} was not successful unregistered "
213-
f"from torch.utils._pytree.SUPPORTED_NODES="
214-
f"{pprint.pformat(list(torch.utils._pytree.SUPPORTED_NODES))}"
215-
)
216-
if verbose:
217-
print(f"[_unregister_cache_serialization] unregistered {cls.__name__}")
218-
219-
220-
def _unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
221-
222-
if undo.get("MambaCache", False):
223-
from transformers.cache_utils import MambaCache
224-
225-
_unregister(MambaCache, verbose)
226-
elif verbose > 1:
227-
print("[_unregister_cache_serialization] skip unregister MambaCache")
228-
229-
if undo.get("DynamicCache", False):
230-
from transformers.cache_utils import DynamicCache
231-
232-
_unregister(DynamicCache, verbose)
233-
elif verbose > 1:
234-
print("[_unregister_cache_serialization] skip unregister DynamicCache")
235-
236-
23782
@contextlib.contextmanager
23883
def register_additional_serialization_functions(
23984
patch_transformers: bool = False, verbose: int = 0

0 commit comments

Comments
 (0)