Skip to content

Commit 5395ec8

Browse files
authored
EncoderDecoderCache, whisper (#48)
* custom cache * fix ocnfig * add whisper * eccache * doc * fix ser * fix import * fix * fix import * fix ** * fix patches * issue * fix ci * fix issues * fix issues * add custom * fix ext * fix http
1 parent 0e8155c commit 5395ec8

25 files changed

+1351
-385
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
matrix:
1717
os: [ubuntu-latest]
1818
python: ['3.11', '3.12']
19-
transformers: ['4.48.3', '4.51.1', 'main']
19+
transformers: ['4.48.3', '4.51.2', 'main']
2020
torch: ['2.6', 'main']
2121

2222
steps:

CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.4.0
55
+++++
66

7+
* :pr:`48`: add support for EncoderDecoderCache, test with openai/whisper-tiny
78
* :pr:`45`: improve change_dynamic_dimension to fix some dimensions
89

910
0.3.0

_doc/api/helpers/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ onnx_diagnostic.helpers
1313
memory_peak
1414
onnx_helper
1515
ort_session
16+
rt_helper
1617
torch_test_helper
1718

1819
.. autofunction:: onnx_diagnostic.helpers.max_diff

_doc/api/helpers/rt_helper.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
2+
onnx_diagnostic.helpers.rt_helper
3+
=================================
4+
5+
.. automodule:: onnx_diagnostic.helpers.rt_helper
6+
:members:
7+
:no-undoc-members:

_doc/examples/plot_export_tiny_phi2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from onnx_diagnostic import doc
2626
from onnx_diagnostic.helpers import max_diff, string_diff, string_type
2727
from onnx_diagnostic.helpers.cache_helper import is_cache_dynamic_registered
28-
from onnx_diagnostic.helpers.ort_session import make_feeds
28+
from onnx_diagnostic.helpers.rt_helper import make_feeds
2929
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
3030
from onnx_diagnostic.torch_models.hghub import (
3131
get_untrained_model_with_inputs,

_unittests/ut_helpers/test_ort_session_tinyllm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from onnxruntime.capi import _pybind_state as ORTC
88
from onnx_diagnostic.ext_test_case import ExtTestCase, hide_stdout, ignore_warnings
99
from onnx_diagnostic.helpers import max_diff
10+
from onnx_diagnostic.helpers.rt_helper import make_feeds
1011
from onnx_diagnostic.helpers.ort_session import (
1112
InferenceSessionForNumpy,
1213
InferenceSessionForTorch,
13-
make_feeds,
1414
)
1515
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
1616
from onnx_diagnostic.torch_models.llms import get_tiny_llm

_unittests/ut_torch_export_patches/test_onnx_export_errors.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,49 @@ def forward(self, x: torch.Tensor, cache: MambaCache):
121121
dynamic_shapes=({0: DYN}, [[{0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}]]),
122122
)
123123

124+
@ignore_warnings(UserWarning)
125+
def test_exportable_dynamic_shapes_constraints(self):
126+
import torch
127+
128+
class CustomCache:
129+
def __init__(self, shape=None):
130+
self.cache = [torch.zeros((shape)), torch.zeros((shape))] if shape else []
131+
132+
def flatten_cache(cache):
133+
return [cache.cache], ["cache"]
134+
135+
def unflatten_cache(values, context, output_type=None):
136+
cache = CustomCache()
137+
cache.cache = values[0]
138+
return cache
139+
140+
def flatten_with_keys_cache(d):
141+
values, context = flatten_cache(d)
142+
return [
143+
(torch.utils._pytree.MappingKey(k), v) for k, v in zip(context, values)
144+
], context
145+
146+
torch.utils._pytree.register_pytree_node(
147+
CustomCache,
148+
flatten_cache,
149+
unflatten_cache,
150+
serialized_type_name=f"{CustomCache.__module__}.{CustomCache.__name__}",
151+
flatten_with_keys_fn=flatten_with_keys_cache,
152+
)
153+
154+
class Model(torch.nn.Module):
155+
def forward(self, x, cache):
156+
return cache.cache[0][0, :] + x
157+
158+
model = Model()
159+
model.eval()
160+
x, cache = torch.rand((2, 4)), CustomCache((2, 4))
161+
model(x, cache)
162+
DYN = torch.export.Dim.DYNAMIC
163+
torch.export.export(
164+
model, (x, cache), dynamic_shapes=({0: DYN}, [[{0: DYN}, {0: DYN}]])
165+
)
166+
124167

125168
if __name__ == "__main__":
126169
unittest.main(verbosity=2)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
import unittest
2+
import torch
3+
from transformers.modeling_outputs import BaseModelOutput
4+
from onnx_diagnostic.ext_test_case import ExtTestCase, ignore_warnings
5+
from onnx_diagnostic.helpers.cache_helper import make_encoder_decoder_cache, make_dynamic_cache
6+
from onnx_diagnostic.torch_export_patches.onnx_export_errors import (
7+
bypass_export_some_errors,
8+
)
9+
from onnx_diagnostic.helpers.torch_test_helper import torch_deepcopy
10+
11+
12+
class TestPatchSerialization(ExtTestCase):
13+
@ignore_warnings(UserWarning)
14+
def test_encoder_decoder_cache_flatten(self):
15+
cache = make_encoder_decoder_cache(
16+
make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
17+
make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]),
18+
)
19+
with bypass_export_some_errors():
20+
flat, _spec = torch.utils._pytree.tree_flatten(cache)
21+
self.assertEqual(
22+
"#4[T1s4x4x4,T1s4x4x4,T1s5x5x5,T1s5x5x5]",
23+
self.string_type(flat, with_shape=True),
24+
)
25+
cache2 = torch.utils._pytree.tree_unflatten(flat, _spec)
26+
self.assertEqual(
27+
self.string_type(cache, with_shape=True, with_min_max=True),
28+
self.string_type(cache2, with_shape=True, with_min_max=True),
29+
)
30+
31+
@ignore_warnings(UserWarning)
32+
def test_encoder_decoder_cache_deepcopy(self):
33+
cache = make_encoder_decoder_cache(
34+
make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
35+
make_dynamic_cache([(torch.rand((5, 5, 5)), torch.rand((5, 5, 5)))]),
36+
)
37+
with bypass_export_some_errors():
38+
cache2 = torch_deepcopy([cache])
39+
self.assertEqualAny([cache], cache2)
40+
41+
@ignore_warnings(UserWarning)
42+
def test_encoder_decoder_cache_export(self):
43+
class Model(torch.nn.Module):
44+
def forward(self, cache):
45+
return cache.self_attention_cache.key_cache[0]
46+
47+
cache1 = make_dynamic_cache(
48+
[(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)]
49+
)
50+
cache2 = make_dynamic_cache(
51+
[(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)]
52+
)
53+
54+
cache = make_encoder_decoder_cache(cache1, cache2)
55+
model = Model()
56+
model(cache)
57+
DYN = torch.export.Dim.DYNAMIC
58+
ds = [
59+
[[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]],
60+
[[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]],
61+
]
62+
63+
with bypass_export_some_errors(patch_transformers=True):
64+
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
65+
66+
@ignore_warnings(UserWarning)
67+
def test_dynamic_cache_flatten(self):
68+
cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
69+
with bypass_export_some_errors():
70+
flat, _spec = torch.utils._pytree.tree_flatten(cache)
71+
self.assertEqual(
72+
"#2[T1s4x4x4,T1s4x4x4]",
73+
self.string_type(flat, with_shape=True),
74+
)
75+
cache2 = torch.utils._pytree.tree_unflatten(flat, _spec)
76+
self.assertEqual(
77+
self.string_type(cache, with_shape=True, with_min_max=True),
78+
self.string_type(cache2, with_shape=True, with_min_max=True),
79+
)
80+
81+
@ignore_warnings(UserWarning)
82+
def test_dynamic_cache_export(self):
83+
class Model(torch.nn.Module):
84+
def forward(self, cache):
85+
return cache.key_cache[0]
86+
87+
cache = make_dynamic_cache(
88+
[(torch.randn(2, 4, 3, 7), torch.randn(2, 4, 3, 7)) for i in range(3)]
89+
)
90+
model = Model()
91+
model(cache)
92+
DYN = torch.export.Dim.DYNAMIC
93+
ds = [[{0: DYN}, {0: DYN}, {0: DYN}], [{0: DYN}, {0: DYN}, {0: DYN}]]
94+
95+
with bypass_export_some_errors():
96+
torch.export.export(model, (cache,), dynamic_shapes=(ds,))
97+
98+
@ignore_warnings(UserWarning)
99+
def test_dynamic_cache_deepcopy(self):
100+
cache = make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))])
101+
with bypass_export_some_errors():
102+
cache2 = torch_deepcopy([cache])
103+
self.assertEqualAny([cache], cache2)
104+
105+
@ignore_warnings(UserWarning)
106+
def test_base_model_output_deepcopy(self):
107+
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
108+
self.assertEqual(bo.__class__.__name__, "BaseModelOutput")
109+
with bypass_export_some_errors():
110+
bo2 = torch_deepcopy([bo])
111+
self.assertIsInstance(bo2, list)
112+
self.assertEqual(bo2[0].__class__.__name__, "BaseModelOutput")
113+
self.assertEqualAny([bo], bo2)
114+
115+
@ignore_warnings(UserWarning)
116+
def test_base_model_output_string_type(self):
117+
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
118+
with bypass_export_some_errors():
119+
self.assertEqual(
120+
"BaseModelOutput(last_hidden_state:T1s4x4x4)",
121+
self.string_type(bo, with_shape=True),
122+
)
123+
124+
@ignore_warnings(UserWarning)
125+
def test_base_model_output_flatten(self):
126+
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
127+
with bypass_export_some_errors():
128+
flat, _spec = torch.utils._pytree.tree_flatten(bo)
129+
self.assertEqual(
130+
"#1[T1s4x4x4]",
131+
self.string_type(flat, with_shape=True),
132+
)
133+
bo2 = torch.utils._pytree.tree_unflatten(flat, _spec)
134+
self.assertEqual(
135+
self.string_type(bo, with_shape=True, with_min_max=True),
136+
self.string_type(bo2, with_shape=True, with_min_max=True),
137+
)
138+
139+
@ignore_warnings(UserWarning)
140+
def test_base_model_output_export(self):
141+
class Model(torch.nn.Module):
142+
def forward(self, cache):
143+
return cache.last_hidden_state[0]
144+
145+
bo = BaseModelOutput(last_hidden_state=torch.rand((4, 4, 4)))
146+
model = Model()
147+
model(bo)
148+
DYN = torch.export.Dim.DYNAMIC
149+
ds = [{0: DYN}]
150+
151+
with bypass_export_some_errors():
152+
torch.export.export(model, (bo,), dynamic_shapes=(ds,))
153+
154+
155+
if __name__ == "__main__":
156+
unittest.main(verbosity=2)

_unittests/ut_torch_models/test_hghub_model.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pprint
22
import unittest
3+
import torch
34
import transformers
45
from onnx_diagnostic.ext_test_case import (
56
ExtTestCase,
@@ -14,6 +15,7 @@
1415
)
1516
from onnx_diagnostic.torch_models.hghub.hub_api import get_pretrained_config
1617
from onnx_diagnostic.torch_models.hghub.hub_data import load_models_testing
18+
from onnx_diagnostic.torch_export_patches import bypass_export_some_errors
1719

1820

1921
class TestHuggingFaceHubModel(ExtTestCase):
@@ -104,6 +106,72 @@ def test_get_untrained_model_with_inputs_text2text_generation(self):
104106
raise unittest.SkipTest(f"not working for {mid!r}")
105107
model(**inputs)
106108

109+
@hide_stdout()
110+
def test_get_untrained_model_with_inputs_automatic_speech_recognition(self):
111+
mid = "openai/whisper-tiny"
112+
data = get_untrained_model_with_inputs(mid, verbose=1)
113+
self.assertIn((data["size"], data["n_weights"]), [(132115968, 33028992)])
114+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
115+
Dim = torch.export.Dim
116+
self.maxDiff = None
117+
self.assertIn("{0:Dim(batch),1:Dim(seq_length)}", self.string_type(ds))
118+
self.assertEqualAny(
119+
{
120+
"decoder_input_ids": {
121+
0: Dim("batch", min=1, max=1024),
122+
1: Dim("seq_length", min=1, max=4096),
123+
},
124+
"cache_position": {0: Dim("seq_length", min=1, max=4096)},
125+
"encoder_outputs": [{0: Dim("batch", min=1, max=1024)}],
126+
"past_key_values": [
127+
[
128+
[
129+
{0: Dim("batch", min=1, max=1024)},
130+
{0: Dim("batch", min=1, max=1024)},
131+
],
132+
[
133+
{0: Dim("batch", min=1, max=1024)},
134+
{0: Dim("batch", min=1, max=1024)},
135+
],
136+
],
137+
[
138+
[
139+
{0: Dim("batch", min=1, max=1024)},
140+
{0: Dim("batch", min=1, max=1024)},
141+
],
142+
[
143+
{0: Dim("batch", min=1, max=1024)},
144+
{0: Dim("batch", min=1, max=1024)},
145+
],
146+
],
147+
],
148+
},
149+
ds,
150+
)
151+
model(**inputs)
152+
self.assertEqual(
153+
"#1[T1r3]",
154+
self.string_type(torch.utils._pytree.tree_flatten(inputs["encoder_outputs"])[0]),
155+
)
156+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
157+
flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0]
158+
self.assertIsInstance(flat, list)
159+
self.assertIsInstance(flat[0], torch.Tensor)
160+
self.assertEqual(
161+
"#8[T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4]",
162+
self.string_type(flat),
163+
)
164+
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
165+
with bypass_export_some_errors(patch_transformers=True, verbose=10):
166+
flat = torch.utils._pytree.tree_flatten(inputs["past_key_values"])[0]
167+
self.assertIsInstance(flat, list)
168+
self.assertIsInstance(flat[0], torch.Tensor)
169+
self.assertEqual(
170+
"#8[T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4,T1r4]",
171+
self.string_type(flat),
172+
)
173+
torch.export.export(model, (), kwargs=inputs, dynamic_shapes=ds, strict=False)
174+
107175
@hide_stdout()
108176
def test_get_untrained_model_with_inputs_imagetext2text_generation(self):
109177
mid = "HuggingFaceM4/tiny-random-idefics"
@@ -131,6 +199,7 @@ def _diff(c1, c2):
131199
for mid in load_models_testing():
132200
with self.subTest(mid=mid):
133201
if mid in {
202+
"hf-internal-testing/tiny-random-BeitForImageClassification",
134203
"hf-internal-testing/tiny-random-MaskFormerForInstanceSegmentation",
135204
"hf-internal-testing/tiny-random-MoonshineForConditionalGeneration",
136205
"fxmarty/pix2struct-tiny-random",

0 commit comments

Comments
 (0)